[Libguestfs] [PATCH libnbd v3] lib: Atomically update h->state when leaving the locked region.

Richard W.M. Jones rjones at redhat.com
Sat Jun 8 18:05:09 UTC 2019


Split h->state into:

 - h->public_state = the state on entry to the locked region
   This is also the atomically, publicly visible state.

 - h->state = the real current state of the handle

When we leave the locked region we update h->public_state with
h->state, so that from outside the lock the handle appears to move
atomically from its previous state to the final state without going
through any intermediate states.

Some calls to ‘get_state’ become calls to ‘get_next_state’ if the need
the real state.  Others which need to see the publicly visible state
are changed to ‘get_public_state’.

All calls to ‘set_state’ become ‘set_next_state’ because that is the
real state that gets updated.

The purpose of this patch is to make it easier to reason about the
state in lockless code.
---
 generator/generator | 87 ++++++++++++++++++++++++++++-----------------
 lib/connect.c       | 10 +++---
 lib/disconnect.c    |  8 ++---
 lib/handle.c        |  2 +-
 lib/internal.h      | 17 +++++++--
 lib/is-state.c      | 28 ++++++++-------
 lib/rw.c            |  4 +--
 7 files changed, 97 insertions(+), 59 deletions(-)

diff --git a/generator/generator b/generator/generator
index 468292f..c1f4c29 100755
--- a/generator/generator
+++ b/generator/generator
@@ -811,9 +811,10 @@ type call = {
   longdesc : string;       (* long description *)
   (* List of permitted states for making this call.  [[]] = Any state. *)
   permitted_states : permitted_state list;
-  (* Most functions must take a lock.  The only known exception is
-   * for a function which {b only} reads from the atomic [h->state]
-   * field and does nothing else with the handle.
+  (* Most functions must take a lock.  The only known exceptions are:
+   * - functions which return a constant (eg. [nbd_supports_uri])
+   * - functions which {b only} read from the atomic
+   *   [get_public_state] and do nothing else with the handle.
    *)
   is_locked : bool;
   (* Most functions can call set_error.  For functions which are
@@ -2449,11 +2450,11 @@ let generate_lib_states_c () =
       pr "  enum state next_state = %s;\n" state_enum;
       pr "\n";
       pr "  r = _enter_%s (h, &next_state, blocked);\n" state_enum;
-      pr "  if (get_state (h) != next_state) {\n";
+      pr "  if (get_next_state (h) != next_state) {\n";
       pr "    debug (h, \"transition: %%s -> %%s\",\n";
       pr "           \"%s\",\n" display_name;
       pr "           nbd_internal_state_short_string (next_state));\n";
-      pr "    set_state (h, next_state);\n";
+      pr "    set_next_state (h, next_state);\n";
       pr "  }\n";
       pr "  return r;\n";
       pr "}\n";
@@ -2468,7 +2469,7 @@ let generate_lib_states_c () =
   pr "  bool blocked;\n";
   pr "\n";
   pr "  /* Validate and handle the external event. */\n";
-  pr "  switch (get_state (h))\n";
+  pr "  switch (get_next_state (h))\n";
   pr "  {\n";
   List.iter (
     fun ({ parsed = { display_name; state_enum; events } } as state) ->
@@ -2480,7 +2481,7 @@ let generate_lib_states_c () =
           fun (e, next_state) ->
             pr "    case %s:\n" (c_string_of_external_event e);
             if state != next_state then (
-              pr "      set_state (h, %s);\n" next_state.parsed.state_enum;
+              pr "      set_next_state (h, %s);\n" next_state.parsed.state_enum;
               pr "      debug (h, \"event %%s: %%s -> %%s\",\n";
               pr "             \"%s\", \"%s\", \"%s\");\n"
                  (string_of_external_event e)
@@ -2496,7 +2497,7 @@ let generate_lib_states_c () =
   pr "  }\n";
   pr "\n";
   pr "  set_error (EINVAL, \"external event %%d is invalid in state %%s\",\n";
-  pr "             ev, nbd_internal_state_short_string (get_state (h)));\n";
+  pr "             ev, nbd_internal_state_short_string (get_next_state (h)));\n";
   pr "  return -1;\n";
   pr "\n";
   pr " ok:\n";
@@ -2504,7 +2505,7 @@ let generate_lib_states_c () =
   pr "    blocked = true;\n";
   pr "\n";
   pr "    /* Run a single step. */\n";
-  pr "    switch (get_state (h))\n";
+  pr "    switch (get_next_state (h))\n";
   pr "    {\n";
   List.iter (
     fun { parsed = { state_enum } } ->
@@ -2530,7 +2531,7 @@ let generate_lib_states_c () =
   pr "{\n";
   pr "  int r = 0;\n";
   pr "\n";
-  pr "  switch (get_state (h))\n";
+  pr "  switch (get_next_state (h))\n";
   pr "  {\n";
   List.iter (
     fun ({ parsed = { state_enum; events } }) ->
@@ -2576,7 +2577,7 @@ let generate_lib_states_c () =
   pr "const char *\n";
   pr "nbd_unlocked_connection_state (struct nbd_handle *h)\n";
   pr "{\n";
-  pr "  switch (get_state (h))\n";
+  pr "  switch (get_next_state (h))\n";
   pr "  {\n";
   List.iter (
     fun ({ comment; parsed = { display_name; state_enum } }) ->
@@ -2842,6 +2843,33 @@ let permitted_state_text permitted_states =
 let generate_lib_api_c () =
   let print_wrapper (name, {args; ret; permitted_states;
                             is_locked; may_set_error}) =
+    if permitted_states <> [] then (
+      pr "static inline bool\n";
+      pr "%s_in_permitted_state (struct nbd_handle *h)\n" name;
+      pr "{\n";
+      pr "  const enum state state = get_public_state (h);\n";
+      pr "\n";
+      let tests =
+        List.map (
+          function
+          | Created -> "nbd_internal_is_state_created (state)"
+          | Connecting -> "nbd_internal_is_state_connecting (state)"
+          | Connected -> "nbd_internal_is_state_ready (state) || nbd_internal_is_state_processing (state)"
+          | Closed -> "nbd_internal_is_state_closed (state)"
+          | Dead -> "nbd_internal_is_state_dead (state)"
+        ) permitted_states in
+      pr "  if (!(%s)) {\n" (String.concat " ||\n        " tests);
+      pr "    set_error (nbd_internal_is_state_created (state) ? ENOTCONN : EINVAL,\n";
+      pr "               \"invalid state: %%s: the handle must be %%s\",\n";
+      pr "               nbd_internal_state_short_string (state),\n";
+      pr "               \"%s\");\n" (permitted_state_text permitted_states);
+      pr "    return false;\n";
+      pr "  }\n";
+      pr "  return true;\n";
+      pr "}\n";
+      pr "\n"
+    );
+
     let ret_c_type, errcode =
       match ret with
       | RBool
@@ -2873,35 +2901,30 @@ let generate_lib_api_c () =
     );
     if permitted_states <> [] then (
       pr "  /* We can check the state outside the handle lock because the\n";
-      pr "   * the state is atomic.\n";
+      pr "   * the state is atomic.  However to avoid TOCTTOU we must also\n";
+      pr "   * check again after we acquire the lock.\n";
       pr "   */\n";
-      pr "  enum state state = get_state (h);\n";
-      let tests =
-        List.map (
-          function
-          | Created -> "nbd_internal_is_state_created (state)"
-          | Connecting -> "nbd_internal_is_state_connecting (state)"
-          | Connected -> "nbd_internal_is_state_ready (state) || nbd_internal_is_state_processing (state)"
-          | Closed -> "nbd_internal_is_state_closed (state)"
-          | Dead -> "nbd_internal_is_state_dead (state)"
-        ) permitted_states in
-      pr "  if (!(%s)) {\n" (String.concat " ||\n        " tests);
-      pr "    set_error (nbd_internal_is_state_created (state) ? ENOTCONN : EINVAL,\n";
-      pr "               \"invalid state: %%s: the handle must be %%s\",\n";
-      pr "               nbd_internal_state_short_string (state),\n";
-      pr "               \"%s\");\n" (permitted_state_text permitted_states);
-      pr "    return %s;\n" errcode;
-      pr "  }\n";
-      pr "\n"
+      pr "  if (!%s_in_permitted_state (h)) return %s;\n" name errcode;
     );
     if is_locked then
       pr "  pthread_mutex_lock (&h->lock);\n";
+    if permitted_states <> [] then (
+      pr "  if (!%s_in_permitted_state (h)) {\n" name;
+      pr "    ret = %s;\n" errcode;
+      pr "    goto out;\n";
+      pr "  }\n"
+    );
     pr "  ret = nbd_unlocked_%s (h" name;
     let argnames = List.flatten (List.map name_of_arg args) in
     List.iter (pr ", %s") argnames;
     pr ");\n";
-    if is_locked then
-      pr "  pthread_mutex_unlock (&h->lock);\n";
+    if permitted_states <> [] then
+      pr " out:\n";
+    if is_locked then (
+      pr "  if (h->public_state != get_next_state (h))\n";
+      pr "    h->public_state = get_next_state (h);\n";
+      pr "  pthread_mutex_unlock (&h->lock);\n"
+    );
     pr "  return ret;\n";
     pr "}\n";
     pr "\n";
diff --git a/lib/connect.c b/lib/connect.c
index 46c434f..96ed1ca 100644
--- a/lib/connect.c
+++ b/lib/connect.c
@@ -38,16 +38,16 @@
 static int
 error_unless_ready (struct nbd_handle *h)
 {
-  if (nbd_internal_is_state_ready (get_state (h)))
+  if (nbd_internal_is_state_ready (get_next_state (h)))
     return 0;
 
   /* Why did it fail? */
-  if (nbd_internal_is_state_closed (get_state (h))) {
+  if (nbd_internal_is_state_closed (get_next_state (h))) {
     set_error (0, "connection is closed");
     return -1;
   }
 
-  if (nbd_internal_is_state_dead (get_state (h)))
+  if (nbd_internal_is_state_dead (get_next_state (h)))
     /* Don't set the error here, keep the error set when
      * the connection died.
      */
@@ -55,14 +55,14 @@ error_unless_ready (struct nbd_handle *h)
 
   /* Should probably never happen. */
   set_error (0, "connection in an unexpected state (%s)",
-             nbd_internal_state_short_string (get_state (h)));
+             nbd_internal_state_short_string (get_next_state (h)));
   return -1;
 }
 
 static int
 wait_until_connected (struct nbd_handle *h)
 {
-  while (nbd_internal_is_state_connecting (get_state (h))) {
+  while (nbd_internal_is_state_connecting (get_next_state (h))) {
     if (nbd_unlocked_poll (h, -1) == -1)
       return -1;
   }
diff --git a/lib/disconnect.c b/lib/disconnect.c
index 423edaf..95e9a37 100644
--- a/lib/disconnect.c
+++ b/lib/disconnect.c
@@ -29,14 +29,14 @@
 int
 nbd_unlocked_shutdown (struct nbd_handle *h)
 {
-  if (nbd_internal_is_state_ready (get_state (h)) ||
-      nbd_internal_is_state_processing (get_state (h))) {
+  if (nbd_internal_is_state_ready (get_next_state (h)) ||
+      nbd_internal_is_state_processing (get_next_state (h))) {
     if (nbd_unlocked_aio_disconnect (h, 0) == -1)
       return -1;
   }
 
-  while (!nbd_internal_is_state_closed (get_state (h)) &&
-         !nbd_internal_is_state_dead (get_state (h))) {
+  while (!nbd_internal_is_state_closed (get_next_state (h)) &&
+         !nbd_internal_is_state_dead (get_next_state (h))) {
     if (nbd_unlocked_poll (h, -1) == -1)
       return -1;
   }
diff --git a/lib/handle.c b/lib/handle.c
index e40b274..300eac7 100644
--- a/lib/handle.c
+++ b/lib/handle.c
@@ -57,7 +57,7 @@ nbd_create (void)
   s = getenv ("LIBNBD_DEBUG");
   h->debug = s && strcmp (s, "1") == 0;
 
-  h->state = STATE_START;
+  h->state = h->public_state = STATE_START;
   h->pid = -1;
 
   h->export_name = strdup ("");
diff --git a/lib/internal.h b/lib/internal.h
index 61ddbde..503bf34 100644
--- a/lib/internal.h
+++ b/lib/internal.h
@@ -80,7 +80,17 @@ struct nbd_handle {
   /* Linked list of close callbacks. */
   struct close_callback *close_callbacks;
 
-  _Atomic enum state state;     /* State machine. */
+  /* State machine.
+   *
+   * The actual current state is ‘state’.  ‘public_state’ is updated
+   * before we release the lock.
+   *
+   * Note don't access these fields directly, use the SET_NEXT_STATE
+   * macro in generator/states* code, or the set_next_state,
+   * get_next_state and get_public_state macros in regular code.
+   */
+  _Atomic enum state public_state;
+  enum state state;
 
   bool structured_replies;      /* If we negotiated NBD_OPT_STRUCTURED_REPLY */
 
@@ -292,8 +302,9 @@ extern const char *nbd_internal_state_short_string (enum state state);
 extern enum state_group nbd_internal_state_group (enum state state);
 extern enum state_group nbd_internal_state_group_parent (enum state_group group);
 
-#define set_state(h,next_state) ((h)->state) = (next_state)
-#define get_state(h) ((h)->state)
+#define set_next_state(h,next_state) ((h)->state) = (next_state)
+#define get_next_state(h) ((h)->state)
+#define get_public_state(h) ((h)->public_state)
 
 /* utils.c */
 extern void nbd_internal_hexdump (const void *data, size_t len, FILE *fp);
diff --git a/lib/is-state.c b/lib/is-state.c
index c941ab4..b2c20df 100644
--- a/lib/is-state.c
+++ b/lib/is-state.c
@@ -98,44 +98,48 @@ nbd_internal_is_state_closed (enum state state)
   return state == STATE_CLOSED;
 }
 
-/* NB: is_locked = false, may_set_error = false. */
+/* The nbd_unlocked_aio_is_* calls are the public APIs
+ * for reading the state of the handle.
+ *
+ * They all have: is_locked = false, may_set_error = false.
+ *
+ * They all read the public state, not the real state.  Therefore you
+ * SHOULD NOT call these functions from elsewhere in the library (use
+ * nbd_internal_is_* instead).
+ */
+
 int
 nbd_unlocked_aio_is_created (struct nbd_handle *h)
 {
-  return nbd_internal_is_state_created (get_state (h));
+  return nbd_internal_is_state_created (get_public_state (h));
 }
 
-/* NB: is_locked = false, may_set_error = false. */
 int
 nbd_unlocked_aio_is_connecting (struct nbd_handle *h)
 {
-  return nbd_internal_is_state_connecting (get_state (h));
+  return nbd_internal_is_state_connecting (get_public_state (h));
 }
 
-/* NB: is_locked = false, may_set_error = false. */
 int
 nbd_unlocked_aio_is_ready (struct nbd_handle *h)
 {
-  return nbd_internal_is_state_ready (get_state (h));
+  return nbd_internal_is_state_ready (get_public_state (h));
 }
 
-/* NB: is_locked = false, may_set_error = false. */
 int
 nbd_unlocked_aio_is_processing (struct nbd_handle *h)
 {
-  return nbd_internal_is_state_processing (get_state (h));
+  return nbd_internal_is_state_processing (get_public_state (h));
 }
 
-/* NB: is_locked = false, may_set_error = false. */
 int
 nbd_unlocked_aio_is_dead (struct nbd_handle *h)
 {
-  return nbd_internal_is_state_dead (get_state (h));
+  return nbd_internal_is_state_dead (get_public_state (h));
 }
 
-/* NB: is_locked = false, may_set_error = false. */
 int
 nbd_unlocked_aio_is_closed (struct nbd_handle *h)
 {
-  return nbd_internal_is_state_closed (get_state (h));
+  return nbd_internal_is_state_closed (get_public_state (h));
 }
diff --git a/lib/rw.c b/lib/rw.c
index b38d95b..ad9c8a0 100644
--- a/lib/rw.c
+++ b/lib/rw.c
@@ -201,7 +201,7 @@ nbd_internal_command_common (struct nbd_handle *h,
    * be handled automatically on a future cycle around to READY.
    */
   if (h->cmds_to_issue != NULL) {
-    assert (nbd_internal_is_state_processing (get_state (h)));
+    assert (nbd_internal_is_state_processing (get_next_state (h)));
     prev_cmd = h->cmds_to_issue;
     while (prev_cmd->next)
       prev_cmd = prev_cmd->next;
@@ -209,7 +209,7 @@ nbd_internal_command_common (struct nbd_handle *h,
   }
   else {
     h->cmds_to_issue = cmd;
-    if (nbd_internal_is_state_ready (get_state (h)) &&
+    if (nbd_internal_is_state_ready (get_next_state (h)) &&
         nbd_internal_run (h, cmd_issue) == -1)
       return -1;
   }
-- 
2.21.0




More information about the Libguestfs mailing list