[Libguestfs] [libnbd PATCH v2 07/10] rust: async: Create an async friendly handle type

Tage Johansson tage.j.lists at posteo.net
Mon Jul 24 08:44:41 UTC 2023


Create another handle type: `AsyncHandle`, which makes use of Rust's
builtin asynchronous functions (see
<https://doc.rust-lang.org/std/keyword.async.html>) and runs on top of
the Tokio runtime (see <https://docs.rs/tokio>). For every asynchronous
command, like `aio_connect()`, a corresponding `async` method is created
on the handle. In this case it would be:
    async fn connect(...) -> Result<(), ...>
When called, it will poll the file descriptor until the command is
complete, and then return with a result. All the synchronous
counterparts (like `nbd_connect()`) are excluded from this handle type
as they are unnecessary and since they might interfear with the polling
made by the Tokio runtime. For more details about how the asynchronous
commands are executed, please see the comments in
rust/src/async_handle.rs.
---
 generator/Rust.ml        | 227 +++++++++++++++++++++++++++++++++++++++
 generator/Rust.mli       |   2 +
 generator/generator.ml   |   1 +
 rust/Cargo.toml          |   1 +
 rust/Makefile.am         |   1 +
 rust/src/async_handle.rs | 222 ++++++++++++++++++++++++++++++++++++++
 rust/src/lib.rs          |   4 +
 7 files changed, 458 insertions(+)
 create mode 100644 rust/src/async_handle.rs

diff --git a/generator/Rust.ml b/generator/Rust.ml
index a4b5257..adc7ffa 100644
--- a/generator/Rust.ml
+++ b/generator/Rust.ml
@@ -588,3 +588,230 @@ let generate_rust_bindings () =
   pr "impl Handle {\n";
   List.iter print_rust_handle_method handle_calls;
   pr "}\n\n"
+
+(*********************************************************)
+(* The rest of the file conserns the asynchronous API.   *)
+(*                                                       *)
+(* See the comments in rust/src/async_handle.rs for more *)
+(* information about how it works.                       *)
+(*********************************************************)
+
+let excluded_handle_calls : NameSet.t =
+  NameSet.of_list
+    [
+      "aio_get_fd";
+      "aio_get_direction";
+      "aio_notify_read";
+      "aio_notify_write";
+      "clear_debug_callback";
+      "get_debug";
+      "poll";
+      "poll2";
+      "set_debug";
+      "set_debug_callback";
+    ]
+
+(* A mapping with names as keys. *)
+module NameMap = Map.Make (String)
+
+(* Strip "aio_" from the beginning of a string. *)
+let strip_aio name : string =
+  if String.starts_with ~prefix:"aio_" name then
+    String.sub name 4 (String.length name - 4)
+  else failwithf "Asynchronous call %s must begin with aio_" name
+
+(* A map with all asynchronous handle calls. The keys are names with "aio_"
+   stripped, the values are a tuple with the actual name (with "aio_"), the
+   [call] and the [async_kind]. *)
+let async_handle_calls : ((string * call) * async_kind) NameMap.t =
+  handle_calls
+  |> List.filter (fun (n, _) -> not (NameSet.mem n excluded_handle_calls))
+  |> List.filter_map (fun (name, call) ->
+         call.async_kind
+         |> Option.map (fun async_kind ->
+                (strip_aio name, ((name, call), async_kind))))
+  |> List.to_seq |> NameMap.of_seq
+
+(* A mapping with all synchronous (not asynchronous) handle calls. Excluded
+   are also all synchronous calls that has an asynchronous counterpart. So if
+   "foo" is the name of a handle call and an asynchronous call "aio_foo"
+   exists, then "foo" will not b in this map. *)
+let sync_handle_calls : call NameMap.t =
+  handle_calls
+  |> List.filter (fun (n, _) -> not (NameSet.mem n excluded_handle_calls))
+  |> List.filter (fun (name, _) ->
+         (not (NameMap.mem name async_handle_calls))
+         && not
+              (String.starts_with ~prefix:"aio_" name
+              && NameMap.mem (strip_aio name) async_handle_calls))
+  |> List.to_seq |> NameMap.of_seq
+
+(* Get the Rust type for an argument in the asynchronous API. Like
+   [rust_arg_type] but no static lifetime on some closures and buffers. *)
+let rust_async_arg_type : arg -> string = function
+  | Closure { cbargs; cbcount; cblifetime } ->
+      let lifetime =
+        match cblifetime with CBCommand -> None | CBHandle -> Some "'static"
+      in
+      "impl " ^ rust_closure_trait ~lifetime cbargs cbcount
+  | BytesPersistIn _ -> "&[u8]"
+  | BytesPersistOut _ -> "&mut [u8]"
+  | x -> rust_arg_type x
+
+(* Get the Rust type for an optional argument in the asynchronous API. Like
+   [rust_optarg_type] but no static lifetime on some closures. *)
+let rust_async_optarg_type : optarg -> string = function
+  | OClosure x -> sprintf "Option<%s>" (rust_async_arg_type (Closure x))
+  | x -> rust_optarg_type x
+
+(* A string of the argument list for a method on the handle, with both
+   mandotory and optional arguments. *)
+let rust_async_handle_call_args { args; optargs } : string =
+  let rust_args_names =
+    List.map rust_arg_name args @ List.map rust_optarg_name optargs
+  and rust_args_types =
+    List.map rust_async_arg_type args
+    @ List.map rust_async_optarg_type optargs
+  in
+  String.concat ", "
+    (List.map2 (sprintf "%s: %s") rust_args_names rust_args_types)
+
+(* Print the Rust function for a not asynchronous handle call. *)
+let print_rust_sync_handle_call (name : string) (call : call) =
+  print_rust_handle_call_comment call;
+  pr "pub fn %s(&self, %s) -> %s\n" name
+    (rust_async_handle_call_args call)
+    (rust_ret_type call);
+  print_ffi_call name "self.data.handle.handle" call;
+  pr "\n"
+
+(* Print the Rust function for an asynchronous handle call with a completion
+   callback. (Note that "callback" might be abbreviated with "cb" in the
+   following code. *)
+let print_rust_async_handle_call_with_completion_cb name (aio_name, call) =
+  (* An array of all optional arguments. Useful because we need to deel with
+     the index of the completion callback. *)
+  let optargs = Array.of_list call.optargs in
+  (* The index of the completion callback in [optargs] *)
+  let completion_cb_index =
+    Array.find_map
+      (fun (i, optarg) ->
+        match optarg with
+        | OClosure { cbname } ->
+            if cbname = "completion" then Some i else None
+        | _ -> None)
+      (Array.mapi (fun x y -> (x, y)) optargs)
+  in
+  let completion_cb_index =
+    match completion_cb_index with
+    | Some x -> x
+    | None ->
+        failwithf
+          "The handle call %s is claimed to have a completion callback among \
+           its optional arguments by the async_kind field, but so does not \
+           seem to be the case."
+          aio_name
+  in
+  let optargs_before_completion_cb =
+    Array.to_list (Array.sub optargs 0 completion_cb_index)
+  and optargs_after_completion_cb =
+    Array.to_list
+      (Array.sub optargs (completion_cb_index + 1)
+         (Array.length optargs - (completion_cb_index + 1)))
+  in
+  (* All optional arguments excluding the completion callback. *)
+  let optargs_without_completion_cb =
+    optargs_before_completion_cb @ optargs_after_completion_cb
+  in
+  print_rust_handle_call_comment call;
+  pr "pub async fn %s(&self, %s) -> Result<(), Arc<Error>> {\n" name
+    (rust_async_handle_call_args
+       { call with optargs = optargs_without_completion_cb });
+  pr "    // A oneshot channel to notify when the call is completed.\n";
+  pr "    let (tx, rx) = oneshot::channel::<Result<(), Arc<Error>>>();\n";
+  (* Completion callback: *)
+  pr "    let %s = Some(|err: &mut i32| {\n"
+    (rust_optarg_name (Array.get optargs completion_cb_index));
+  pr "      let errno = if *err == 0 {\n";
+  pr "        tx.send(Ok(())).ok();\n";
+  pr "        return 1;\n";
+  pr "      } else { *err };\n";
+  pr "      // Spawn a task, which waits for the result of the next\n";
+  pr "      // aio_notify_* call.\n";
+  pr "      let mut res_rx = self.data.result_channel.subscribe();\n";
+  pr "      tokio::spawn(async move {\n";
+  pr "         let err = if let Ok(Err(err)) = res_rx.recv().await {\n";
+  pr "          err\n";
+  pr "        } else {\n";
+  pr "          Arc::new(Error::Recoverable(ErrorKind::from_errno(errno)))\n";
+  pr "        };\n";
+  pr "        tx.send(Err(err)).ok();\n";
+  pr "      });\n";
+  pr "      1\n";
+  pr "    });\n";
+  (* End of completion callback. *)
+  print_ffi_call aio_name "self.data.handle.handle" call;
+  pr "?;\n";
+  pr "    self.data.poll_notifier.notify_one();\n";
+  pr "    rx.await.unwrap()\n";
+  pr "}\n\n"
+
+(* Print a Rust function for an asynchronous handle call which signals
+   completion by changing state. The predicate is a call like
+   "aio_is_connecting" which should get the value (like false) for the call to
+   be complete. *)
+let print_rust_async_handle_call_changing_state name (aio_name, call)
+    (predicate, value) =
+  let value = if value then "true" else "false" in
+  print_rust_handle_call_comment call;
+  pr "pub async fn %s(&self, %s) -> Result<(), Arc<Error>>\n" name
+    (rust_async_handle_call_args call);
+  pr "{\n";
+  pr "    let mut res_rx = self.data.result_channel.subscribe();\n";
+  print_ffi_call aio_name "self.data.handle.handle" call;
+  pr "?;\n";
+  pr "    self.data.poll_notifier.notify_one();\n";
+  pr "    while self.data.handle.%s() != %s {\n" predicate value;
+  pr "      match res_rx.recv().await {\n";
+  pr "        Ok(Ok(())) | Err(RecvError::Lagged(_)) => (),\n";
+  pr "        Ok(err @ Err(_)) => return err,\n";
+  pr "        Err(RecvError::Closed) => unreachable!(),\n";
+  pr "      }\n";
+  pr "    }\n";
+  pr "    Ok(())\n";
+  pr "}\n\n"
+
+(* Print an impl with all handle calls. *)
+let print_rust_async_handle_impls () =
+  pr "impl AsyncHandle {\n";
+  NameMap.iter print_rust_sync_handle_call sync_handle_calls;
+  NameMap.iter
+    (fun name (call, async_kind) ->
+      match async_kind with
+      | WithCompletionCallback ->
+          print_rust_async_handle_call_with_completion_cb name call
+      | ChangesState (predicate, value) ->
+          print_rust_async_handle_call_changing_state name call
+            (predicate, value))
+    async_handle_calls;
+  pr "}\n\n"
+
+let print_rust_async_imports () =
+  pr "use crate::{*, types::*};\n";
+  pr "use os_str_bytes::OsStringBytes as _;\n";
+  pr "use os_socketaddr::OsSocketAddr;\n";
+  pr "use std::ffi::*;\n";
+  pr "use std::mem;\n";
+  pr "use std::net::SocketAddr;\n";
+  pr "use std::os::fd::{AsRawFd, OwnedFd};\n";
+  pr "use std::path::PathBuf;\n";
+  pr "use std::ptr;\n";
+  pr "use std::sync::Arc;\n";
+  pr "use tokio::sync::{oneshot, broadcast::error::RecvError};\n";
+  pr "\n"
+
+let generate_rust_async_bindings () =
+  generate_header CStyle ~copyright:"Tage Johansson";
+  pr "\n";
+  print_rust_async_imports ();
+  print_rust_async_handle_impls ()
diff --git a/generator/Rust.mli b/generator/Rust.mli
index 450e4ca..0960170 100644
--- a/generator/Rust.mli
+++ b/generator/Rust.mli
@@ -18,3 +18,5 @@
 
 (* Print all flag-structs, enums, constants and handle calls in Rust code. *)
 val generate_rust_bindings : unit -> unit
+
+val generate_rust_async_bindings : unit -> unit
diff --git a/generator/generator.ml b/generator/generator.ml
index 67b9502..dc9eb80 100644
--- a/generator/generator.ml
+++ b/generator/generator.ml
@@ -63,3 +63,4 @@ let () =
   output_to "golang/wrappers.h" GoLang.generate_golang_wrappers_h;
 
   output_to ~formatter:(Some Rustfmt) "rust/src/bindings.rs" Rust.generate_rust_bindings;
+  output_to ~formatter:(Some Rustfmt) "rust/src/async_bindings.rs" Rust.generate_rust_async_bindings;
diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index f74c3ac..c3e27b3 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -45,6 +45,7 @@ thiserror = "1.0.40"
 log = { version = "0.4.19", optional = true }
 libc = "0.2.147"
 byte-strings = "0.3.1"
+tokio = { version = "1.29.1", default-features = false, features = ["rt", "sync", "net", "macros"] }
 
 [features]
 default = ["log"]
diff --git a/rust/Makefile.am b/rust/Makefile.am
index cc17bb9..a6fd9b1 100644
--- a/rust/Makefile.am
+++ b/rust/Makefile.am
@@ -18,6 +18,7 @@
 include $(top_srcdir)/subdir-rules.mk
 
 generator_built = \
+	src/async_bindings.rs \
 	src/bindings.rs \
 	$(NULL)
 
diff --git a/rust/src/async_handle.rs b/rust/src/async_handle.rs
new file mode 100644
index 0000000..346c4ef
--- /dev/null
+++ b/rust/src/async_handle.rs
@@ -0,0 +1,222 @@
+// nbd client library in userspace
+// Copyright Tage Johansson
+//
+// This library is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 2 of the License, or (at your option) any later version.
+//
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this library; if not, write to the Free Software
+// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+// This module implements an asynchronous handle working on top of the
+// [Tokio](https://tokio.rs) runtime. When the handle is created,
+// a "polling task" is spawned on the Tokio runtime. The purpose of that
+// "polling task" is to call `aio_notify_*` when appropriate. It shares a
+// reference to the handle as well as some channels with the handle in the
+// [HandleData] struct. The "polling task" is sleeping when no command is in
+// flight, but wakes up as soon as any command is issued.
+//
+// The commands are implemented as
+// [`async fn`s](https://doc.rust-lang.org/std/keyword.async.html)
+// in async_bindings.rs. There are two types of commands: Those with a
+// completion callback and those that waits for the handle to reach some state
+// before completing.
+//
+// The asynchronous function for a command with a completion callback will
+// setup a [oneshot channel](tokio::sync::oneshot). When the completion
+// callback is called, it will send the result of the command on the channel.
+// The asynchronous function itself will wait for this result and return when
+// it is received.
+//
+// The asynchronous function for a command which is finished when the handle
+// enters a certain state will use a while loop to wait for this state to be
+// reached. Each iteration in the loop will block until the next call to
+// `aio_notify_*`, so there is no busy waiting.
+//
+// Both for sending the result to commands with a completion callback, and for
+// waking up the while loop in state changing commands, an
+// [SPMC (single-producer, multiple-comsumer) channel](tokio::sync::broadcast)
+// is used. (See [HandleData.result_channel].) The "polling task" will send
+// every result of calls to `aio_notify_*` on this channel.
+
+use crate::sys;
+use crate::Handle;
+use crate::{Error, Result};
+use crate::{AIO_DIRECTION_BOTH, AIO_DIRECTION_READ, AIO_DIRECTION_WRITE};
+use std::sync::Arc;
+use tokio::io::{unix::AsyncFd, Interest, Ready as IoReady};
+use tokio::sync::{broadcast, Notify};
+use tokio::task;
+
+/// An upper bound of the number of items in the "result_channel".
+///
+/// Shall usually never be more than one unless many commands are run concurrently.
+const RESULT_CHANNEL_CAPACITY: usize = 4;
+
+/// An NBD handle using Rust's `async` functionality on top of the
+/// [Tokio](https://docs.rs/tokio/) runtime.
+#[derive(Debug)]
+pub struct AsyncHandle {
+    /// Data shared both by this struct and the polling task.
+    pub(crate) data: Arc<HandleData>,
+
+    /// A task which soely purpose is to poll the NBD handle.
+    polling_task: tokio::task::AbortHandle,
+}
+
+#[derive(Debug)]
+pub(crate) struct HandleData {
+    /// The underliing handle.
+    pub handle: Handle,
+
+    /// For every call to an `aio_notify_*` method (`aio_notify_read()` or
+    /// `aio_notify_write()`), the result is sent on this channel.
+    pub result_channel: broadcast::Sender<Result<(), Arc<Error>>>,
+
+    /// A notifier used by commands to notify the polling task when a new
+    /// asynchronous command is issued.
+    pub poll_notifier: Notify,
+}
+
+impl AsyncHandle {
+    pub fn new() -> Result<Self> {
+        let handle_data = Arc::new(HandleData {
+            handle: Handle::new()?,
+            poll_notifier: Notify::new(),
+            result_channel: broadcast::channel(RESULT_CHANNEL_CAPACITY).0,
+        });
+
+        let handle_data_2 = handle_data.clone();
+        let polling_task = task::spawn(async move {
+            // The polling task should never finish without an error. If the
+            // handle is dropped, the task is aborted so it'll not return in
+            // that case either.
+            let Err(err) = polling_task(&handle_data_2).await else {
+                unreachable!()
+            };
+            // Send the error as the last thing on the result channel.
+            let err = Arc::new(err);
+            handle_data_2.result_channel.send(Err(err)).ok();
+        })
+        .abort_handle();
+        Ok(Self {
+            data: handle_data,
+            polling_task,
+        })
+    }
+
+    /// Get the underliing C pointer to the handle.
+    pub(crate) fn raw_handle(&self) -> *mut sys::nbd_handle {
+        self.data.handle.raw_handle()
+    }
+}
+
+/// Get the read/write direction that the handle wants on the file descriptor.
+fn get_fd_interest(handle: &Handle) -> Option<Interest> {
+    match handle.aio_get_direction() {
+        0 => None,
+        AIO_DIRECTION_READ => Some(Interest::READABLE),
+        AIO_DIRECTION_WRITE => Some(Interest::WRITABLE),
+        AIO_DIRECTION_BOTH => Some(Interest::READABLE | Interest::WRITABLE),
+        _ => unreachable!(),
+    }
+}
+
+/// A task that will run as long as the handle is alive. It will poll the
+/// file descriptor when new data is availlable.
+async fn polling_task(handle_data: &HandleData) -> Result<()> {
+    let HandleData {
+        handle,
+        result_channel,
+        poll_notifier,
+    } = handle_data;
+    // XXX: Might the file descriptor ever be changed?
+    let fd = handle.aio_get_fd()?;
+    let fd = AsyncFd::new(fd)?;
+
+    // The following loop does approximately the following things:
+    //
+    // 1. Determine what Libnbd wants to do next on the file descriptor,
+    //    (read/write/both/none), and store that in [next_fd_interest].
+    // 2. Wait for either:
+    //   a) That interest to be available on the file descriptor in which case:
+    //     I.   Call the correct `aio_notify_*` method.
+    //     II.  Execute step 1.
+    //     III. Send the result of the call to `aio_notify_*` on
+    //          [result_channel] to notify pending commands that some progress
+    //          has been made.
+    //     IV.  Resume execution from step 2.
+    //   b) A notification was received on [poll_notifier] signaling that a new
+    //      command was registered and that the intrest on the file descriptor
+    //      might has changed. Resume execution from step 1.
+    let mut next_fd_interest = get_fd_interest(handle);
+    loop {
+        let Some(fd_interest) = next_fd_interest else {
+            // The handle does not wait for any data of the file descriptor,
+            // so we wait until some command is issued.
+            poll_notifier.notified().await;
+            next_fd_interest = get_fd_interest(handle);
+            continue;
+        };
+
+        let res = tokio::select! {
+            ready_guard = fd.ready(fd_interest) => {
+                // Some new data is availlable.
+                let mut ready_guard = ready_guard?;
+                let readyness = ready_guard.ready();
+                if readyness.is_readable() && fd_interest.is_readable() {
+                    let res = match handle.aio_notify_read() {
+                        res @ Ok(_)
+                        | res @ Err(Error::Recoverable(_)) => res,
+                        err @ Err(Error::Fatal(_)) => return err,
+                    };
+                    next_fd_interest = get_fd_interest(handle);
+                    // We do only know that the read blocked if the next
+                    // interest contains a read as well.
+                    if next_fd_interest.map_or(false, Interest::is_readable) {
+                        ready_guard.clear_ready_matching(IoReady::READABLE);
+                    } else {
+                        ready_guard.retain_ready();
+                    }
+                    res
+                }
+                else if readyness.is_writable() && fd_interest.is_writable() {
+                    let res = match handle.aio_notify_write() {
+                        res @ Ok(_)
+                        | res @ Err(Error::Recoverable(_)) => res,
+                        err @ Err(Error::Fatal(_)) => return err,
+                    };
+                    next_fd_interest = get_fd_interest(handle);
+                    if next_fd_interest.map_or(false, Interest::is_writable) {
+                        ready_guard.clear_ready_matching(IoReady::WRITABLE);
+                    } else {
+                        ready_guard.retain_ready();
+                    }
+                    res
+                } else {
+                    continue;
+                }
+            }
+            _ = poll_notifier.notified() => {
+                // Someone issued a command so the interest might have changed.
+                next_fd_interest = get_fd_interest(handle);
+                continue;
+            }
+        };
+
+        result_channel.send(res.map_err(Arc::new)).ok();
+    }
+}
+
+impl Drop for AsyncHandle {
+    fn drop(&mut self) {
+        self.polling_task.abort();
+    }
+}
diff --git a/rust/src/lib.rs b/rust/src/lib.rs
index a6f3131..eb3f6cb 100644
--- a/rust/src/lib.rs
+++ b/rust/src/lib.rs
@@ -17,11 +17,15 @@
 
 #![deny(warnings)]
 
+mod async_bindings;
+mod async_handle;
 mod bindings;
 mod error;
 mod handle;
 pub mod types;
 mod utils;
+pub use async_bindings::*;
+pub use async_handle::AsyncHandle;
 pub use bindings::*;
 pub use error::{Error, ErrorKind, FatalErrorKind, Result};
 pub use handle::Handle;
-- 
2.41.0



More information about the Libguestfs mailing list