[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]

[Libguestfs] [PATCH 2/2] common/mlstdutils: Add with_openfile function.



This safe wrapper around Unix.openfile ensures that exceptions
escaping cannot leave unclosed files.

There are only a few places in the code where this wrapper can be used
currently.  There are other occurences of Unix.openfile but they are
not suitable for replacement.
---
 common/mlstdutils/std_utils.ml  |  4 ++++
 common/mlstdutils/std_utils.mli |  6 ++++++
 daemon/devsparts.ml             |  5 ++---
 daemon/inspect_fs_windows.ml    | 18 ++++++++----------
 4 files changed, 20 insertions(+), 13 deletions(-)

diff --git a/common/mlstdutils/std_utils.ml b/common/mlstdutils/std_utils.ml
index ee6bea5af..32944ed27 100644
--- a/common/mlstdutils/std_utils.ml
+++ b/common/mlstdutils/std_utils.ml
@@ -662,6 +662,10 @@ let with_open_out filename f =
   let chan = open_out filename in
   protect ~f:(fun () -> f chan) ~finally:(fun () -> close_out chan)
 
+let with_openfile filename flags perms =
+  let fd = Unix.openfile filename flags perms in
+  protect ~f:(fun () -> f fd) ~finally:(fun () -> close fd)
+
 let read_whole_file path =
   let buf = Buffer.create 16384 in
   with_open_in path (
diff --git a/common/mlstdutils/std_utils.mli b/common/mlstdutils/std_utils.mli
index 7af6c2111..178762819 100644
--- a/common/mlstdutils/std_utils.mli
+++ b/common/mlstdutils/std_utils.mli
@@ -399,6 +399,12 @@ val with_open_out : string -> (out_channel -> 'a) -> 'a
     return or if the function [f] throws an exception, so this is
     both safer and more concise than the regular function. *)
 
+val with_openfile : string -> Unix.open_flag list -> Unix.file_perm -> (Unix.file_desc -> 'a) -> 'a
+(** [with_openfile] calls function [f] with [filename] opened by the
+    {!Unix.openfile} function.  The file is always closed either on
+    normal return or if the function [f] throws an exception, so this
+    is both safer and more concise than the regular function. *)
+
 val read_whole_file : string -> string
 (** Read in the whole file as a string. *)
 
diff --git a/daemon/devsparts.ml b/daemon/devsparts.ml
index 7395de923..0eb7c1282 100644
--- a/daemon/devsparts.ml
+++ b/daemon/devsparts.ml
@@ -49,9 +49,8 @@ let map_block_devices ~return_md f =
     List.filter (
       fun dev ->
         try
-          let fd = openfile ("/dev/" ^ dev) [O_RDONLY; O_CLOEXEC] 0 in
-          close fd;
-          true
+          with_openfile ("/dev/" ^ dev) [O_RDONLY; O_CLOEXEC] 0
+                        (fun _ -> true)
         with _ -> false
     ) devs in
 
diff --git a/daemon/inspect_fs_windows.ml b/daemon/inspect_fs_windows.ml
index 7c42fc5d7..112cc2f92 100644
--- a/daemon/inspect_fs_windows.ml
+++ b/daemon/inspect_fs_windows.ml
@@ -429,16 +429,14 @@ and extract_guid_from_registry_blob blob =
           (data4 &^ 0xffffffffffff_L)
 
 and pread device size offset =
-  let fd = Unix.openfile device [Unix.O_RDONLY; Unix.O_CLOEXEC] 0 in
-  let ret =
-    protect ~f:(
-      fun () ->
-        ignore (Unix.lseek fd offset Unix.SEEK_SET);
-        let ret = Bytes.create size in
-        if Unix.read fd ret 0 size < size then
-          failwithf "pread: %s: short read" device;
-        ret
-    ) ~finally:(fun () -> Unix.close fd) in
+  with_openfile device [Unix.O_RDONLY; Unix.O_CLOEXEC] 0 (
+    fun fd ->
+      ignore (Unix.lseek fd offset Unix.SEEK_SET);
+      let ret = Bytes.create size in
+      if Unix.read fd ret 0 size < size then
+        failwithf "pread: %s: short read" device;
+      ret
+  );
   Bytes.to_string ret
 
 (* Get the hostname. *)
-- 
2.13.2


[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]