From 55c6a3b5cd92e1e0b17c396108ea2209fda0b3ee Mon Sep 17 00:00:00 2001 From: Dennis Kempin Date: Thu, 17 Mar 2022 11:26:01 -0700 Subject: [PATCH] Refactoring: Move common/cros_async to cros_async This runs the script added in https://crrev.com/c/3533607 BUG=b:22320646 TEST=presubmit Change-Id: I2e7efdb35508d45281f046e64c24aa43e27f2000 Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/3533608 Reviewed-by: Daniel Verkamp Tested-by: kokoro --- Cargo.toml | 3 +- base/Cargo.toml | 2 +- cros_async/Cargo.toml | 36 + cros_async/DEPRECATED.md | 4 + cros_async/src/audio_streams_async.rs | 68 + cros_async/src/blocking.rs | 9 + cros_async/src/blocking/block_on.rs | 202 +++ cros_async/src/blocking/pool.rs | 524 ++++++ cros_async/src/complete.rs | 91 + cros_async/src/event.rs | 85 + cros_async/src/executor.rs | 344 ++++ cros_async/src/fd_executor.rs | 634 +++++++ cros_async/src/io_ext.rs | 479 +++++ cros_async/src/lib.rs | 540 ++++++ cros_async/src/mem.rs | 98 ++ cros_async/src/poll_source.rs | 450 +++++ cros_async/src/queue.rs | 66 + cros_async/src/select.rs | 92 + cros_async/src/sync.rs | 12 + cros_async/src/sync/cv.rs | 1179 +++++++++++++ cros_async/src/sync/mu.rs | 2305 +++++++++++++++++++++++++ cros_async/src/sync/spin.rs | 284 +++ cros_async/src/sync/waiter.rs | 288 +++ cros_async/src/timer.rs | 126 ++ cros_async/src/uring_executor.rs | 1151 ++++++++++++ cros_async/src/uring_source.rs | 643 +++++++ cros_async/src/waker.rs | 70 + devices/Cargo.toml | 2 +- disk/Cargo.toml | 2 +- net_util/Cargo.toml | 2 +- vm_memory/Cargo.toml | 2 +- 31 files changed, 9787 insertions(+), 6 deletions(-) create mode 100644 cros_async/Cargo.toml create mode 100644 cros_async/DEPRECATED.md create mode 100644 cros_async/src/audio_streams_async.rs create mode 100644 cros_async/src/blocking.rs create mode 100644 cros_async/src/blocking/block_on.rs create mode 100644 cros_async/src/blocking/pool.rs create mode 100644 cros_async/src/complete.rs create mode 100644 cros_async/src/event.rs create mode 100644 cros_async/src/executor.rs create mode 100644 cros_async/src/fd_executor.rs create mode 100644 cros_async/src/io_ext.rs create mode 100644 cros_async/src/lib.rs create mode 100644 cros_async/src/mem.rs create mode 100644 cros_async/src/poll_source.rs create mode 100644 cros_async/src/queue.rs create mode 100644 cros_async/src/select.rs create mode 100644 cros_async/src/sync.rs create mode 100644 cros_async/src/sync/cv.rs create mode 100644 cros_async/src/sync/mu.rs create mode 100644 cros_async/src/sync/spin.rs create mode 100644 cros_async/src/sync/waiter.rs create mode 100644 cros_async/src/timer.rs create mode 100644 cros_async/src/uring_executor.rs create mode 100644 cros_async/src/uring_source.rs create mode 100644 cros_async/src/waker.rs diff --git a/Cargo.toml b/Cargo.toml index 7b54f9aa3e..62b35dd86c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,7 @@ members = [ "arch", "base", "bit_field", + "cros_async", "crosvm-fuzz", "crosvm_control", "crosvm_plugin", @@ -192,7 +193,7 @@ assertions = { path = "common/assertions" } audio_streams = { path = "common/audio_streams" } base = { path = "base" } sys_util_core = { path = "common/sys_util_core" } -cros_async = { path = "common/cros_async" } +cros_async = { path = "cros_async" } cros_fuzz = { path = "common/cros-fuzz" } # ignored by ebuild data_model = { path = "common/data_model" } libcras = { path = "libcras_stub" } # ignored by ebuild diff --git a/base/Cargo.toml b/base/Cargo.toml index 08d0d5342a..8deacd0d6a 100644 --- a/base/Cargo.toml +++ b/base/Cargo.toml @@ -9,7 +9,7 @@ chromeos = ["sys_util/chromeos"] [dependencies] audio_streams = { path = "../common/audio_streams" } -cros_async = { path = "../common/cros_async" } +cros_async = { path = "../cros_async" } data_model = { path = "../common/data_model" } libc = "*" remain = "0.2" diff --git a/cros_async/Cargo.toml b/cros_async/Cargo.toml new file mode 100644 index 0000000000..7f3e924a7a --- /dev/null +++ b/cros_async/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "cros_async" +version = "0.1.0" +authors = ["The Chromium OS Authors"] +edition = "2021" + +[dependencies] +async-trait = "0.1.36" +async-task = "4" +data_model = { path = "../common/data_model" } # provided by ebuild +intrusive-collections = "0.9" +io_uring = { path = "../common/io_uring" } # provided by ebuild +libc = "*" +once_cell = "1.7.2" +paste = "1.0" +pin-utils = "0.1.0-alpha.4" +remain = "0.2" +slab = "0.4" +sync = { path = "../common/sync" } # provided by ebuild +sys_util = { path = "../common/sys_util" } # provided by ebuild +thiserror = "1.0.20" +audio_streams = { path = "../common/audio_streams" } # provided by ebuild +anyhow = "1.0" + +[dependencies.futures] +version = "*" +default-features = false +features = ["alloc"] + +[dev-dependencies] +futures = { version = "*", features = ["executor"] } +futures-executor = { version = "0.3", features = ["thread-pool"] } +futures-util = "0.3" +tempfile = "3" + + diff --git a/cros_async/DEPRECATED.md b/cros_async/DEPRECATED.md new file mode 100644 index 0000000000..77e18642d1 --- /dev/null +++ b/cros_async/DEPRECATED.md @@ -0,0 +1,4 @@ +Use crosvm/cros_async instead. + +Code in this directory is not used by crosvm, it is only used in ChromeOS and will move to a +separate ChromeOS repository soon. diff --git a/cros_async/src/audio_streams_async.rs b/cros_async/src/audio_streams_async.rs new file mode 100644 index 0000000000..7350066585 --- /dev/null +++ b/cros_async/src/audio_streams_async.rs @@ -0,0 +1,68 @@ +// Copyright 2022 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! Implements the interface required by `audio_streams` using the cros_async Executor. +//! +//! It implements the `AudioStreamsExecutor` trait for `Executor`, so it can be passed into +//! the audio_streams API. +#[cfg(unix)] +use std::os::unix::net::UnixStream; + +use std::{io::Result, time::Duration}; + +use super::{AsyncWrapper, IntoAsync, IoSourceExt, TimerAsync}; +use async_trait::async_trait; +use audio_streams::async_api::{ + AsyncStream, AudioStreamsExecutor, ReadAsync, ReadWriteAsync, WriteAsync, +}; + +/// A wrapper around IoSourceExt that is compatible with the audio_streams traits. +pub struct IoSourceWrapper { + source: Box + Send>, +} + +#[async_trait(?Send)] +impl ReadAsync for IoSourceWrapper { + async fn read_to_vec<'a>( + &'a self, + file_offset: Option, + vec: Vec, + ) -> Result<(usize, Vec)> { + self.source + .read_to_vec(file_offset, vec) + .await + .map_err(Into::into) + } +} + +#[async_trait(?Send)] +impl WriteAsync for IoSourceWrapper { + async fn write_from_vec<'a>( + &'a self, + file_offset: Option, + vec: Vec, + ) -> Result<(usize, Vec)> { + self.source + .write_from_vec(file_offset, vec) + .await + .map_err(Into::into) + } +} + +#[async_trait(?Send)] +impl ReadWriteAsync for IoSourceWrapper {} + +#[async_trait(?Send)] +impl AudioStreamsExecutor for super::Executor { + #[cfg(unix)] + fn async_unix_stream(&self, stream: UnixStream) -> Result { + return Ok(Box::new(IoSourceWrapper { + source: self.async_from(AsyncWrapper::new(stream))?, + })); + } + + async fn delay(&self, dur: Duration) -> Result<()> { + TimerAsync::sleep(self, dur).await.map_err(Into::into) + } +} diff --git a/cros_async/src/blocking.rs b/cros_async/src/blocking.rs new file mode 100644 index 0000000000..f6430a78b7 --- /dev/null +++ b/cros_async/src/blocking.rs @@ -0,0 +1,9 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +mod block_on; +mod pool; + +pub use block_on::*; +pub use pool::*; diff --git a/cros_async/src/blocking/block_on.rs b/cros_async/src/blocking/block_on.rs new file mode 100644 index 0000000000..2cb0b1e24e --- /dev/null +++ b/cros_async/src/blocking/block_on.rs @@ -0,0 +1,202 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::{ + future::Future, + ptr, + sync::{ + atomic::{AtomicI32, Ordering}, + Arc, + }, + task::{Context, Poll}, +}; + +use futures::{ + pin_mut, + task::{waker_ref, ArcWake}, +}; + +// Randomly generated values to indicate the state of the current thread. +const WAITING: i32 = 0x25de_74d1; +const WOKEN: i32 = 0x72d3_2c9f; + +const FUTEX_WAIT_PRIVATE: libc::c_int = libc::FUTEX_WAIT | libc::FUTEX_PRIVATE_FLAG; +const FUTEX_WAKE_PRIVATE: libc::c_int = libc::FUTEX_WAKE | libc::FUTEX_PRIVATE_FLAG; + +thread_local!(static PER_THREAD_WAKER: Arc = Arc::new(Waker(AtomicI32::new(WAITING)))); + +#[repr(transparent)] +struct Waker(AtomicI32); + +extern "C" { + #[cfg_attr(target_os = "android", link_name = "__errno")] + #[cfg_attr(target_os = "linux", link_name = "__errno_location")] + fn errno_location() -> *mut libc::c_int; +} + +impl ArcWake for Waker { + fn wake_by_ref(arc_self: &Arc) { + let state = arc_self.0.swap(WOKEN, Ordering::Release); + if state == WAITING { + // The thread hasn't already been woken up so wake it up now. Safe because this doesn't + // modify any memory and we check the return value. + let res = unsafe { + libc::syscall( + libc::SYS_futex, + &arc_self.0, + FUTEX_WAKE_PRIVATE, + libc::INT_MAX, // val + ptr::null() as *const libc::timespec, // timeout + ptr::null() as *const libc::c_int, // uaddr2 + 0_i32, // val3 + ) + }; + if res < 0 { + panic!("unexpected error from FUTEX_WAKE_PRIVATE: {}", unsafe { + *errno_location() + }); + } + } + } +} + +/// Run a future to completion on the current thread. +/// +/// This method will block the current thread until `f` completes. Useful when you need to call an +/// async fn from a non-async context. +pub fn block_on(f: F) -> F::Output { + pin_mut!(f); + + PER_THREAD_WAKER.with(|thread_waker| { + let waker = waker_ref(thread_waker); + let mut cx = Context::from_waker(&waker); + + loop { + if let Poll::Ready(t) = f.as_mut().poll(&mut cx) { + return t; + } + + let state = thread_waker.0.swap(WAITING, Ordering::Acquire); + if state == WAITING { + // If we weren't already woken up then wait until we are. Safe because this doesn't + // modify any memory and we check the return value. + let res = unsafe { + libc::syscall( + libc::SYS_futex, + &thread_waker.0, + FUTEX_WAIT_PRIVATE, + state, + ptr::null() as *const libc::timespec, // timeout + ptr::null() as *const libc::c_int, // uaddr2 + 0_i32, // val3 + ) + }; + + if res < 0 { + // Safe because libc guarantees that this is a valid pointer. + match unsafe { *errno_location() } { + libc::EAGAIN | libc::EINTR => {} + e => panic!("unexpected error from FUTEX_WAIT_PRIVATE: {}", e), + } + } + + // Clear the state to prevent unnecessary extra loop iterations and also to allow + // nested usage of `block_on`. + thread_waker.0.store(WAITING, Ordering::Release); + } + } + }) +} + +#[cfg(test)] +mod test { + use super::*; + + use std::{ + future::Future, + pin::Pin, + sync::{ + mpsc::{channel, Sender}, + Arc, + }, + task::{Context, Poll, Waker}, + thread, + time::Duration, + }; + + use super::super::super::sync::SpinLock; + + struct TimerState { + fired: bool, + waker: Option, + } + struct Timer { + state: Arc>, + } + + impl Future for Timer { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let mut state = self.state.lock(); + if state.fired { + return Poll::Ready(()); + } + + state.waker = Some(cx.waker().clone()); + Poll::Pending + } + } + + fn start_timer(dur: Duration, notify: Option>) -> Timer { + let state = Arc::new(SpinLock::new(TimerState { + fired: false, + waker: None, + })); + + let thread_state = Arc::clone(&state); + thread::spawn(move || { + thread::sleep(dur); + let mut ts = thread_state.lock(); + ts.fired = true; + if let Some(waker) = ts.waker.take() { + waker.wake(); + } + drop(ts); + + if let Some(tx) = notify { + tx.send(()).expect("Failed to send completion notification"); + } + }); + + Timer { state } + } + + #[test] + fn it_works() { + block_on(start_timer(Duration::from_millis(100), None)); + } + + #[test] + fn nested() { + async fn inner() { + block_on(start_timer(Duration::from_millis(100), None)); + } + + block_on(inner()); + } + + #[test] + fn ready_before_poll() { + let (tx, rx) = channel(); + + let timer = start_timer(Duration::from_millis(50), Some(tx)); + + rx.recv() + .expect("Failed to receive completion notification"); + + // We know the timer has already fired so the poll should complete immediately. + block_on(timer); + } +} diff --git a/cros_async/src/blocking/pool.rs b/cros_async/src/blocking/pool.rs new file mode 100644 index 0000000000..4bf45020ee --- /dev/null +++ b/cros_async/src/blocking/pool.rs @@ -0,0 +1,524 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::{ + collections::VecDeque, + mem, + sync::{ + mpsc::{channel, Receiver, Sender}, + Arc, + }, + thread::{ + JoinHandle, {self}, + }, + time::{Duration, Instant}, +}; + +use async_task::{Runnable, Task}; +use slab::Slab; +use sync::{Condvar, Mutex}; +use sys_util::{error, warn}; + +const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10); + +struct State { + tasks: VecDeque, + num_threads: usize, + num_idle: usize, + num_notified: usize, + worker_threads: Slab>, + exited_threads: Option>, + exit: Sender, + shutting_down: bool, +} + +fn run_blocking_thread(idx: usize, inner: Arc, exit: Sender) { + let mut state = inner.state.lock(); + while !state.shutting_down { + if let Some(runnable) = state.tasks.pop_front() { + drop(state); + runnable.run(); + state = inner.state.lock(); + continue; + } + + // No more tasks so wait for more work. + state.num_idle += 1; + + let (guard, result) = inner + .condvar + .wait_timeout_while(state, inner.keepalive, |s| { + !s.shutting_down && s.num_notified == 0 + }); + state = guard; + + // If `state.num_notified > 0` then this was a real wakeup. + if state.num_notified > 0 { + state.num_notified -= 1; + continue; + } + + // Only decrement the idle count if we timed out. Otherwise, it was decremented when new + // work was added to `state.tasks`. + if result.timed_out() { + state.num_idle = state + .num_idle + .checked_sub(1) + .expect("`num_idle` underflow on timeout"); + break; + } + } + + state.num_threads -= 1; + + // If we're shutting down then the BlockingPool will take care of joining all the threads. + // Otherwise, we need to join the last worker thread that exited here. + let last_exited_thread = if let Some(exited_threads) = state.exited_threads.as_mut() { + exited_threads + .try_recv() + .map(|idx| state.worker_threads.remove(idx)) + .ok() + } else { + None + }; + + // Drop the lock before trying to join the last exited thread. + drop(state); + + if let Some(handle) = last_exited_thread { + let _ = handle.join(); + } + + if let Err(e) = exit.send(idx) { + error!("Failed to send thread exit event on channel: {}", e); + } +} + +struct Inner { + state: Mutex, + condvar: Condvar, + max_threads: usize, + keepalive: Duration, +} + +impl Inner { + fn schedule(self: &Arc, runnable: Runnable) { + let mut state = self.state.lock(); + + // If we're shutting down then nothing is going to run this task. + if state.shutting_down { + return; + } + + state.tasks.push_back(runnable); + + if state.num_idle == 0 { + // There are no idle threads. Spawn a new one if possible. + if state.num_threads < self.max_threads { + state.num_threads += 1; + let exit = state.exit.clone(); + let entry = state.worker_threads.vacant_entry(); + let idx = entry.key(); + let inner = self.clone(); + entry.insert( + thread::Builder::new() + .name(format!("blockingPool{}", idx)) + .spawn(move || run_blocking_thread(idx, inner, exit)) + .unwrap(), + ); + } + } else { + // We have idle threads, wake one up. + state.num_idle -= 1; + state.num_notified += 1; + self.condvar.notify_one(); + } + } +} + +#[derive(Debug, thiserror::Error)] +#[error("{0} BlockingPool threads did not exit in time and will be detached")] +pub struct ShutdownTimedOut(usize); + +/// A thread pool for running work that may block. +/// +/// It is generally discouraged to do any blocking work inside an async function. However, this is +/// sometimes unavoidable when dealing with interfaces that don't provide async variants. In this +/// case callers may use the `BlockingPool` to run the blocking work on a different thread and +/// `await` for its result to finish, which will prevent blocking the main thread of the +/// application. +/// +/// Since the blocking work is sent to another thread, users should be careful when using the +/// `BlockingPool` for latency-sensitive operations. Additionally, the `BlockingPool` is intended to +/// be used for work that will eventually complete on its own. Users who want to spawn a thread +/// should just use `thread::spawn` directly. +/// +/// There is no way to cancel work once it has been picked up by one of the worker threads in the +/// `BlockingPool`. Dropping or shutting down the pool will block up to a timeout (default 10 +/// seconds) to wait for any active blocking work to finish. Any threads running tasks that have not +/// completed by that time will be detached. +/// +/// # Examples +/// +/// Spawn a task to run in the `BlockingPool` and await on its result. +/// +/// ```edition2018 +/// use cros_async::BlockingPool; +/// +/// # async fn do_it() { +/// let pool = BlockingPool::default(); +/// +/// let res = pool.spawn(move || { +/// // Do some CPU-intensive or blocking work here. +/// +/// 42 +/// }).await; +/// +/// assert_eq!(res, 42); +/// # } +/// # cros_async::block_on(do_it()); +/// ``` +pub struct BlockingPool { + inner: Arc, +} + +impl BlockingPool { + /// Create a new `BlockingPool`. + /// + /// The `BlockingPool` will never spawn more than `max_threads` threads to do work, regardless + /// of the number of tasks that are added to it. This value should be set relatively low (for + /// example, the number of CPUs on the machine) if the pool is intended to run CPU intensive + /// work or it should be set relatively high (128 or more) if the pool is intended to be used + /// for various IO operations that cannot be completed asynchronously. The default value is 256. + /// + /// Worker threads are spawned on demand when new work is added to the pool and will + /// automatically exit after being idle for some time so there is no overhead for setting + /// `max_threads` to a large value when there is little to no work assigned to the + /// `BlockingPool`. `keepalive` determines the idle duration after which the worker thread will + /// exit. The default value is 10 seconds. + pub fn new(max_threads: usize, keepalive: Duration) -> BlockingPool { + let (exit, exited_threads) = channel(); + BlockingPool { + inner: Arc::new(Inner { + state: Mutex::new(State { + tasks: VecDeque::new(), + num_threads: 0, + num_idle: 0, + num_notified: 0, + worker_threads: Slab::new(), + exited_threads: Some(exited_threads), + exit, + shutting_down: false, + }), + condvar: Condvar::new(), + max_threads, + keepalive, + }), + } + } + + /// Like new but with pre-allocating capacity for up to `max_threads`. + pub fn with_capacity(max_threads: usize, keepalive: Duration) -> BlockingPool { + let (exit, exited_threads) = channel(); + BlockingPool { + inner: Arc::new(Inner { + state: Mutex::new(State { + tasks: VecDeque::new(), + num_threads: 0, + num_idle: 0, + num_notified: 0, + worker_threads: Slab::with_capacity(max_threads), + exited_threads: Some(exited_threads), + exit, + shutting_down: false, + }), + condvar: Condvar::new(), + max_threads, + keepalive, + }), + } + } + + /// Spawn a task to run in the `BlockingPool`. + /// + /// Callers may `await` the returned `Task` to be notified when the work is completed. + /// + /// # Panics + /// + /// `await`ing a `Task` after dropping the `BlockingPool` or calling `BlockingPool::shutdown` + /// will panic if the work was not completed before the pool was shut down. + pub fn spawn(&self, f: F) -> Task + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let raw = Arc::downgrade(&self.inner); + let schedule = move |runnable| { + if let Some(i) = raw.upgrade() { + i.schedule(runnable); + } + }; + + let (runnable, task) = async_task::spawn(async move { f() }, schedule); + runnable.schedule(); + + task + } + + /// Shut down the `BlockingPool`. + /// + /// If `deadline` is provided then this will block until either all worker threads exit or the + /// deadline is exceeded. If `deadline` is not given then this will block indefinitely until all + /// worker threads exit. Any work that was added to the `BlockingPool` but not yet picked up by + /// a worker thread will not complete and `await`ing on the `Task` for that work will panic. + pub fn shutdown(&self, deadline: Option) -> Result<(), ShutdownTimedOut> { + let mut state = self.inner.state.lock(); + + if state.shutting_down { + // We've already shut down this BlockingPool. + return Ok(()); + } + + state.shutting_down = true; + let exited_threads = state.exited_threads.take().expect("exited_threads missing"); + let unfinished_tasks = std::mem::take(&mut state.tasks); + let mut worker_threads = mem::replace(&mut state.worker_threads, Slab::new()); + drop(state); + + self.inner.condvar.notify_all(); + + // Cancel any unfinished work after releasing the lock. + drop(unfinished_tasks); + + // Now wait for all worker threads to exit. + if let Some(deadline) = deadline { + let mut now = Instant::now(); + while now < deadline && !worker_threads.is_empty() { + if let Ok(idx) = exited_threads.recv_timeout(deadline - now) { + let _ = worker_threads.remove(idx).join(); + } + now = Instant::now(); + } + + // Any threads that have not yet joined will just be detached. + if !worker_threads.is_empty() { + return Err(ShutdownTimedOut(worker_threads.len())); + } + + Ok(()) + } else { + // Block indefinitely until all worker threads exit. + for handle in worker_threads.drain() { + let _ = handle.join(); + } + + Ok(()) + } + } +} + +impl Default for BlockingPool { + fn default() -> BlockingPool { + BlockingPool::new(256, Duration::from_secs(10)) + } +} + +impl Drop for BlockingPool { + fn drop(&mut self) { + if let Err(e) = self.shutdown(Some(Instant::now() + DEFAULT_SHUTDOWN_TIMEOUT)) { + warn!("{}", e); + } + } +} + +#[cfg(test)] +mod test { + use std::{ + sync::{Arc, Barrier}, + thread, + time::{Duration, Instant}, + }; + + use futures::{stream::FuturesUnordered, StreamExt}; + use sync::{Condvar, Mutex}; + + use super::super::super::{block_on, BlockingPool}; + + #[test] + fn blocking_sleep() { + let pool = BlockingPool::default(); + + let res = block_on(pool.spawn(|| 42)); + assert_eq!(res, 42); + } + + #[test] + fn fast_tasks_with_short_keepalive() { + let pool = BlockingPool::new(256, Duration::from_millis(1)); + + let streams = FuturesUnordered::new(); + for _ in 0..2 { + for _ in 0..256 { + let task = pool.spawn(|| ()); + streams.push(task); + } + + thread::sleep(Duration::from_millis(1)); + } + + block_on(streams.collect::>()); + + // The test passes if there are no panics, which would happen if one of the worker threads + // triggered an underflow on `pool.inner.state.num_idle`. + } + + #[test] + fn more_tasks_than_threads() { + let pool = BlockingPool::new(4, Duration::from_secs(10)); + + let stream = (0..19) + .map(|_| pool.spawn(|| thread::sleep(Duration::from_millis(5)))) + .collect::>(); + + let results = block_on(stream.collect::>()); + assert_eq!(results.len(), 19); + } + + #[test] + fn shutdown() { + let pool = BlockingPool::default(); + + let stream = (0..19) + .map(|_| pool.spawn(|| thread::sleep(Duration::from_millis(5)))) + .collect::>(); + + let results = block_on(stream.collect::>()); + assert_eq!(results.len(), 19); + + pool.shutdown(Some(Instant::now() + Duration::from_secs(10))) + .unwrap(); + let state = pool.inner.state.lock(); + assert_eq!(state.num_threads, 0); + } + + #[test] + fn keepalive_timeout() { + // Set the keepalive to a very low value so that threads will exit soon after they run out + // of work. + let pool = BlockingPool::new(7, Duration::from_millis(1)); + + let stream = (0..19) + .map(|_| pool.spawn(|| thread::sleep(Duration::from_millis(5)))) + .collect::>(); + + let results = block_on(stream.collect::>()); + assert_eq!(results.len(), 19); + + // Wait for all threads to exit. + let deadline = Instant::now() + Duration::from_secs(10); + while Instant::now() < deadline { + thread::sleep(Duration::from_millis(100)); + let state = pool.inner.state.lock(); + if state.num_threads == 0 { + break; + } + } + + { + let state = pool.inner.state.lock(); + assert_eq!(state.num_threads, 0); + assert_eq!(state.num_idle, 0); + } + } + + #[test] + #[should_panic] + fn shutdown_with_pending_work() { + let pool = BlockingPool::new(1, Duration::from_secs(10)); + + let mu = Arc::new(Mutex::new(false)); + let cv = Arc::new(Condvar::new()); + + // First spawn a thread that blocks the pool. + let task_mu = mu.clone(); + let task_cv = cv.clone(); + pool.spawn(move || { + let mut ready = task_mu.lock(); + while !*ready { + ready = task_cv.wait(ready); + } + }) + .detach(); + + // This task will never finish because we will shut down the pool first. + let unfinished = pool.spawn(|| 5); + + // Spawn a thread to unblock the work we started earlier once it sees that the pool is + // shutting down. + let inner = pool.inner.clone(); + thread::spawn(move || { + let mut state = inner.state.lock(); + while !state.shutting_down { + state = inner.condvar.wait(state); + } + + *mu.lock() = true; + cv.notify_all(); + }); + pool.shutdown(None).unwrap(); + + // This should panic. + assert_eq!(block_on(unfinished), 5); + } + + #[test] + fn unfinished_worker_thread() { + let pool = BlockingPool::default(); + + let ready = Arc::new(Mutex::new(false)); + let cv = Arc::new(Condvar::new()); + let barrier = Arc::new(Barrier::new(2)); + + let thread_ready = ready.clone(); + let thread_barrier = barrier.clone(); + let thread_cv = cv.clone(); + + let task = pool.spawn(move || { + thread_barrier.wait(); + let mut ready = thread_ready.lock(); + while !*ready { + ready = thread_cv.wait(ready); + } + }); + + // Wait to shut down the pool until after the worker thread has started. + barrier.wait(); + pool.shutdown(Some(Instant::now() + Duration::from_millis(5))) + .unwrap_err(); + + let num_threads = pool.inner.state.lock().num_threads; + assert_eq!(num_threads, 1); + + // Now wake up the blocked task so we don't leak the thread. + *ready.lock() = true; + cv.notify_all(); + + block_on(task); + + let deadline = Instant::now() + Duration::from_secs(10); + while Instant::now() < deadline { + thread::sleep(Duration::from_millis(100)); + let state = pool.inner.state.lock(); + if state.num_threads == 0 { + break; + } + } + + { + let state = pool.inner.state.lock(); + assert_eq!(state.num_threads, 0); + assert_eq!(state.num_idle, 0); + } + } +} diff --git a/cros_async/src/complete.rs b/cros_async/src/complete.rs new file mode 100644 index 0000000000..9fb9273f41 --- /dev/null +++ b/cros_async/src/complete.rs @@ -0,0 +1,91 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Need non-snake case so the macro can re-use type names for variables. +#![allow(non_snake_case)] + +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::future::{maybe_done, MaybeDone}; +use pin_utils::unsafe_pinned; + +// Macro-generate future combinators to allow for running different numbers of top-level futures in +// this FutureList. Generates the implementation of `FutureList` for the completion types. For an +// explicit example this is modeled after, see `UnitFutures`. +macro_rules! generate { + ($( + $(#[$doc:meta])* + ($Complete:ident, <$($Fut:ident),*>), + )*) => ($( + #[must_use = "Combinations of futures don't do anything unless run in an executor."] + pub(crate) struct $Complete<$($Fut: Future),*> { + $($Fut: MaybeDone<$Fut>,)* + } + + impl<$($Fut),*> $Complete<$($Fut),*> + where $( + $Fut: Future, + )* + { + // Safety: + // * No Drop impl + // * No Unpin impl + // * Not #[repr(packed)] + $( + unsafe_pinned!($Fut: MaybeDone<$Fut>); + )* + + pub(crate) fn new($($Fut: $Fut),*) -> $Complete<$($Fut),*> { + $( + let $Fut = maybe_done($Fut); + )* + $Complete { + $($Fut),* + } + } + } + + impl<$($Fut),*> Future for $Complete<$($Fut),*> + where $( + $Fut: Future, + )* + { + type Output = ($($Fut::Output),*); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let mut complete = true; + $( + complete &= self.as_mut().$Fut().poll(cx).is_ready(); + )* + + if complete { + $( + let $Fut = self.as_mut().$Fut().take_output().unwrap(); + )* + Poll::Ready(($($Fut), *)) + } else { + Poll::Pending + } + } + } + )*) +} + +generate! { + /// _Future for the [`complete2`] function. + (Complete2, <_Fut1, _Fut2>), + + /// _Future for the [`complete3`] function. + (Complete3, <_Fut1, _Fut2, _Fut3>), + + /// _Future for the [`complete4`] function. + (Complete4, <_Fut1, _Fut2, _Fut3, _Fut4>), + + /// _Future for the [`complete5`] function. + (Complete5, <_Fut1, _Fut2, _Fut3, _Fut4, _Fut5>), +} diff --git a/cros_async/src/event.rs b/cros_async/src/event.rs new file mode 100644 index 0000000000..318ad4ed0f --- /dev/null +++ b/cros_async/src/event.rs @@ -0,0 +1,85 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use sys_util::EventFd; + +use super::{AsyncResult, Executor, IntoAsync, IoSourceExt}; + +/// An async version of `sys_util::EventFd`. +pub struct EventAsync { + io_source: Box>, +} + +impl EventAsync { + pub fn new(event: EventFd, ex: &Executor) -> AsyncResult { + ex.async_from(event) + .map(|io_source| EventAsync { io_source }) + } + + #[cfg(test)] + pub(crate) fn new_poll(event: EventFd, ex: &super::FdExecutor) -> AsyncResult { + super::executor::async_poll_from(event, ex).map(|io_source| EventAsync { io_source }) + } + + #[cfg(test)] + pub(crate) fn new_uring(event: EventFd, ex: &super::URingExecutor) -> AsyncResult { + super::executor::async_uring_from(event, ex).map(|io_source| EventAsync { io_source }) + } + + /// Gets the next value from the eventfd. + #[allow(dead_code)] + pub async fn next_val(&self) -> AsyncResult { + self.io_source.read_u64().await + } +} + +impl IntoAsync for EventFd {} + +#[cfg(test)] +mod tests { + use super::*; + + use super::super::{uring_executor::use_uring, Executor, FdExecutor, URingExecutor}; + + #[test] + fn next_val_reads_value() { + async fn go(event: EventFd, ex: &Executor) -> u64 { + let event_async = EventAsync::new(event, ex).unwrap(); + event_async.next_val().await.unwrap() + } + + let eventfd = EventFd::new().unwrap(); + eventfd.write(0xaa).unwrap(); + let ex = Executor::new().unwrap(); + let val = ex.run_until(go(eventfd, &ex)).unwrap(); + assert_eq!(val, 0xaa); + } + + #[test] + fn next_val_reads_value_poll_and_ring() { + if !use_uring() { + return; + } + + async fn go(event_async: EventAsync) -> u64 { + event_async.next_val().await.unwrap() + } + + let eventfd = EventFd::new().unwrap(); + eventfd.write(0xaa).unwrap(); + let uring_ex = URingExecutor::new().unwrap(); + let val = uring_ex + .run_until(go(EventAsync::new_uring(eventfd, &uring_ex).unwrap())) + .unwrap(); + assert_eq!(val, 0xaa); + + let eventfd = EventFd::new().unwrap(); + eventfd.write(0xaa).unwrap(); + let poll_ex = FdExecutor::new().unwrap(); + let val = poll_ex + .run_until(go(EventAsync::new_poll(eventfd, &poll_ex).unwrap())) + .unwrap(); + assert_eq!(val, 0xaa); + } +} diff --git a/cros_async/src/executor.rs b/cros_async/src/executor.rs new file mode 100644 index 0000000000..7f6cfce930 --- /dev/null +++ b/cros_async/src/executor.rs @@ -0,0 +1,344 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::future::Future; + +use async_task::Task; + +use super::{ + poll_source::Error as PollError, uring_executor::use_uring, AsyncResult, FdExecutor, IntoAsync, + IoSourceExt, PollSource, URingExecutor, UringSource, +}; + +pub(crate) fn async_uring_from<'a, F: IntoAsync + Send + 'a>( + f: F, + ex: &URingExecutor, +) -> AsyncResult + Send + 'a>> { + Ok(UringSource::new(f, ex).map(|u| Box::new(u) as Box + Send>)?) +} + +/// Creates a concrete `IoSourceExt` using the fd_executor. +pub(crate) fn async_poll_from<'a, F: IntoAsync + Send + 'a>( + f: F, + ex: &FdExecutor, +) -> AsyncResult + Send + 'a>> { + Ok(PollSource::new(f, ex).map(|u| Box::new(u) as Box + Send>)?) +} + +/// An executor for scheduling tasks that poll futures to completion. +/// +/// All asynchronous operations must run within an executor, which is capable of spawning futures as +/// tasks. This executor also provides a mechanism for performing asynchronous I/O operations. +/// +/// The returned type is a cheap, clonable handle to the underlying executor. Cloning it will only +/// create a new reference, not a new executor. +/// +/// # Examples +/// +/// Concurrently wait for multiple files to become readable/writable and then read/write the data. +/// +/// ``` +/// use std::cmp::min; +/// use std::error::Error; +/// use std::fs::{File, OpenOptions}; +/// +/// use cros_async::{AsyncResult, Executor, IoSourceExt, complete3}; +/// const CHUNK_SIZE: usize = 32; +/// +/// // Write all bytes from `data` to `f`. +/// async fn write_file(f: &dyn IoSourceExt, mut data: Vec) -> AsyncResult<()> { +/// while data.len() > 0 { +/// let (count, mut buf) = f.write_from_vec(None, data).await?; +/// +/// data = buf.split_off(count); +/// } +/// +/// Ok(()) +/// } +/// +/// // Transfer `len` bytes of data from `from` to `to`. +/// async fn transfer_data( +/// from: Box>, +/// to: Box>, +/// len: usize, +/// ) -> AsyncResult { +/// let mut rem = len; +/// +/// while rem > 0 { +/// let buf = vec![0u8; min(rem, CHUNK_SIZE)]; +/// let (count, mut data) = from.read_to_vec(None, buf).await?; +/// +/// if count == 0 { +/// // End of file. Return the number of bytes transferred. +/// return Ok(len - rem); +/// } +/// +/// data.truncate(count); +/// write_file(&*to, data).await?; +/// +/// rem = rem.saturating_sub(count); +/// } +/// +/// Ok(len) +/// } +/// +/// # fn do_it() -> Result<(), Box> { +/// let ex = Executor::new()?; +/// +/// let (rx, tx) = sys_util::pipe(true)?; +/// let zero = File::open("/dev/zero")?; +/// let zero_bytes = CHUNK_SIZE * 7; +/// let zero_to_pipe = transfer_data( +/// ex.async_from(zero)?, +/// ex.async_from(tx.try_clone()?)?, +/// zero_bytes, +/// ); +/// +/// let rand = File::open("/dev/urandom")?; +/// let rand_bytes = CHUNK_SIZE * 19; +/// let rand_to_pipe = transfer_data(ex.async_from(rand)?, ex.async_from(tx)?, rand_bytes); +/// +/// let null = OpenOptions::new().write(true).open("/dev/null")?; +/// let null_bytes = zero_bytes + rand_bytes; +/// let pipe_to_null = transfer_data(ex.async_from(rx)?, ex.async_from(null)?, null_bytes); +/// +/// ex.run_until(complete3( +/// async { assert_eq!(pipe_to_null.await.unwrap(), null_bytes) }, +/// async { assert_eq!(zero_to_pipe.await.unwrap(), zero_bytes) }, +/// async { assert_eq!(rand_to_pipe.await.unwrap(), rand_bytes) }, +/// ))?; +/// +/// # Ok(()) +/// # } +/// +/// # do_it().unwrap(); +/// ``` + +#[derive(Clone)] +pub enum Executor { + Uring(URingExecutor), + Fd(FdExecutor), +} + +impl Executor { + /// Create a new `Executor`. + pub fn new() -> AsyncResult { + if use_uring() { + Ok(URingExecutor::new().map(Executor::Uring)?) + } else { + Ok(FdExecutor::new() + .map(Executor::Fd) + .map_err(PollError::Executor)?) + } + } + + /// Create a new `Box>` associated with `self`. Callers may then use the + /// returned `IoSourceExt` to directly start async operations without needing a separate + /// reference to the executor. + pub fn async_from<'a, F: IntoAsync + Send + 'a>( + &self, + f: F, + ) -> AsyncResult + Send + 'a>> { + match self { + Executor::Uring(ex) => async_uring_from(f, ex), + Executor::Fd(ex) => async_poll_from(f, ex), + } + } + + /// Spawn a new future for this executor to run to completion. Callers may use the returned + /// `Task` to await on the result of `f`. Dropping the returned `Task` will cancel `f`, + /// preventing it from being polled again. To drop a `Task` without canceling the future + /// associated with it use `Task::detach`. To cancel a task gracefully and wait until it is + /// fully destroyed, use `Task::cancel`. + /// + /// # Examples + /// + /// ``` + /// # use cros_async::AsyncResult; + /// # fn example_spawn() -> AsyncResult<()> { + /// # use std::thread; + /// + /// # use cros_async::Executor; + /// use futures::executor::block_on; + /// + /// # let ex = Executor::new()?; + /// + /// # // Spawn a thread that runs the executor. + /// # let ex2 = ex.clone(); + /// # thread::spawn(move || ex2.run()); + /// + /// let task = ex.spawn(async { 7 + 13 }); + /// + /// let result = block_on(task); + /// assert_eq!(result, 20); + /// # Ok(()) + /// # } + /// + /// # example_spawn().unwrap(); + /// ``` + pub fn spawn(&self, f: F) -> Task + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + match self { + Executor::Uring(ex) => ex.spawn(f), + Executor::Fd(ex) => ex.spawn(f), + } + } + + /// Spawn a thread-local task for this executor to drive to completion. Like `spawn` but without + /// requiring `Send` on `F` or `F::Output`. This method should only be called from the same + /// thread where `run()` or `run_until()` is called. + /// + /// # Panics + /// + /// `Executor::run` and `Executor::run_util` will panic if they try to poll a future that was + /// added by calling `spawn_local` from a different thread. + /// + /// # Examples + /// + /// ``` + /// # use cros_async::AsyncResult; + /// # fn example_spawn_local() -> AsyncResult<()> { + /// # use cros_async::Executor; + /// + /// # let ex = Executor::new()?; + /// + /// let task = ex.spawn_local(async { 7 + 13 }); + /// + /// let result = ex.run_until(task)?; + /// assert_eq!(result, 20); + /// # Ok(()) + /// # } + /// + /// # example_spawn_local().unwrap(); + /// ``` + pub fn spawn_local(&self, f: F) -> Task + where + F: Future + 'static, + F::Output: 'static, + { + match self { + Executor::Uring(ex) => ex.spawn_local(f), + Executor::Fd(ex) => ex.spawn_local(f), + } + } + + /// Run the provided closure on a dedicated thread where blocking is allowed. + /// + /// Callers may `await` on the returned `Task` to wait for the result of `f`. Dropping or + /// canceling the returned `Task` may not cancel the operation if it was already started on a + /// worker thread. + /// + /// # Panics + /// + /// `await`ing the `Task` after the `Executor` is dropped will panic if the work was not already + /// completed. + /// + /// # Examples + /// + /// ```edition2018 + /// # use cros_async::Executor; + /// + /// # async fn do_it(ex: &Executor) { + /// let res = ex.spawn_blocking(move || { + /// // Do some CPU-intensive or blocking work here. + /// + /// 42 + /// }).await; + /// + /// assert_eq!(res, 42); + /// # } + /// + /// # let ex = Executor::new().unwrap(); + /// # ex.run_until(do_it(&ex)).unwrap(); + /// ``` + pub fn spawn_blocking(&self, f: F) -> Task + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + match self { + Executor::Uring(ex) => ex.spawn_blocking(f), + Executor::Fd(ex) => ex.spawn_blocking(f), + } + } + + /// Run the executor indefinitely, driving all spawned futures to completion. This method will + /// block the current thread and only return in the case of an error. + /// + /// # Panics + /// + /// Once this method has been called on a thread, it may only be called on that thread from that + /// point on. Attempting to call it from another thread will panic. + /// + /// # Examples + /// + /// ``` + /// # use cros_async::AsyncResult; + /// # fn example_run() -> AsyncResult<()> { + /// use std::thread; + /// + /// use cros_async::Executor; + /// use futures::executor::block_on; + /// + /// let ex = Executor::new()?; + /// + /// // Spawn a thread that runs the executor. + /// let ex2 = ex.clone(); + /// thread::spawn(move || ex2.run()); + /// + /// let task = ex.spawn(async { 7 + 13 }); + /// + /// let result = block_on(task); + /// assert_eq!(result, 20); + /// # Ok(()) + /// # } + /// + /// # example_run().unwrap(); + /// ``` + pub fn run(&self) -> AsyncResult<()> { + match self { + Executor::Uring(ex) => ex.run()?, + Executor::Fd(ex) => ex.run().map_err(PollError::Executor)?, + } + + Ok(()) + } + + /// Drive all futures spawned in this executor until `f` completes. This method will block the + /// current thread only until `f` is complete and there may still be unfinished futures in the + /// executor. + /// + /// # Panics + /// + /// Once this method has been called on a thread, from then onwards it may only be called on + /// that thread. Attempting to call it from another thread will panic. + /// + /// # Examples + /// + /// ``` + /// # use cros_async::AsyncResult; + /// # fn example_run_until() -> AsyncResult<()> { + /// use cros_async::Executor; + /// + /// let ex = Executor::new()?; + /// + /// let task = ex.spawn_local(async { 7 + 13 }); + /// + /// let result = ex.run_until(task)?; + /// assert_eq!(result, 20); + /// # Ok(()) + /// # } + /// + /// # example_run_until().unwrap(); + /// ``` + pub fn run_until(&self, f: F) -> AsyncResult { + match self { + Executor::Uring(ex) => Ok(ex.run_until(f)?), + Executor::Fd(ex) => Ok(ex.run_until(f).map_err(PollError::Executor)?), + } + } +} diff --git a/cros_async/src/fd_executor.rs b/cros_async/src/fd_executor.rs new file mode 100644 index 0000000000..9439749f27 --- /dev/null +++ b/cros_async/src/fd_executor.rs @@ -0,0 +1,634 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! The executor runs all given futures to completion. Futures register wakers associated with file +//! descriptors. The wakers will be called when the FD becomes readable or writable depending on +//! the situation. +//! +//! `FdExecutor` is meant to be used with the `futures-rs` crate that provides combinators and +//! utility functions to combine futures. + +use std::{ + fs::File, + future::Future, + io, mem, + os::unix::io::{AsRawFd, FromRawFd, RawFd}, + pin::Pin, + sync::{ + atomic::{AtomicI32, Ordering}, + Arc, Weak, + }, + task::{Context, Poll, Waker}, +}; + +use async_task::Task; +use futures::task::noop_waker; +use pin_utils::pin_mut; +use remain::sorted; +use slab::Slab; +use sync::Mutex; +use sys_util::{add_fd_flags, warn, EpollContext, EpollEvents, EventFd, WatchingEvents}; +use thiserror::Error as ThisError; + +use super::{ + queue::RunnableQueue, + waker::{new_waker, WakerToken, WeakWake}, + BlockingPool, +}; + +#[sorted] +#[derive(Debug, ThisError)] +pub enum Error { + /// Failed to clone the EventFd for waking the executor. + #[error("Failed to clone the EventFd for waking the executor: {0}")] + CloneEventFd(sys_util::Error), + /// Failed to create the EventFd for waking the executor. + #[error("Failed to create the EventFd for waking the executor: {0}")] + CreateEventFd(sys_util::Error), + /// Creating a context to wait on FDs failed. + #[error("An error creating the fd waiting context: {0}")] + CreatingContext(sys_util::Error), + /// Failed to copy the FD for the polling context. + #[error("Failed to copy the FD for the polling context: {0}")] + DuplicatingFd(sys_util::Error), + /// The Executor is gone. + #[error("The FDExecutor is gone")] + ExecutorGone, + /// PollContext failure. + #[error("PollContext failure: {0}")] + PollContextError(sys_util::Error), + /// An error occurred when setting the FD non-blocking. + #[error("An error occurred setting the FD non-blocking: {0}.")] + SettingNonBlocking(sys_util::Error), + /// Failed to submit the waker to the polling context. + #[error("An error adding to the Aio context: {0}")] + SubmittingWaker(sys_util::Error), + /// A Waker was canceled, but the operation isn't running. + #[error("Unknown waker")] + UnknownWaker, +} +pub type Result = std::result::Result; + +impl From for io::Error { + fn from(e: Error) -> Self { + use Error::*; + match e { + CloneEventFd(e) => e.into(), + CreateEventFd(e) => e.into(), + DuplicatingFd(e) => e.into(), + ExecutorGone => io::Error::new(io::ErrorKind::Other, e), + CreatingContext(e) => e.into(), + PollContextError(e) => e.into(), + SettingNonBlocking(e) => e.into(), + SubmittingWaker(e) => e.into(), + UnknownWaker => io::Error::new(io::ErrorKind::Other, e), + } + } +} + +// A poll operation that has been submitted and is potentially being waited on. +struct OpData { + file: File, + waker: Option, +} + +// The current status of a submitted operation. +enum OpStatus { + Pending(OpData), + Completed, +} + +// An IO source previously registered with an FdExecutor. Used to initiate asynchronous IO with the +// associated executor. +pub struct RegisteredSource { + source: F, + ex: Weak, +} + +impl RegisteredSource { + // Start an asynchronous operation to wait for this source to become readable. The returned + // future will not be ready until the source is readable. + pub fn wait_readable(&self) -> Result { + let ex = self.ex.upgrade().ok_or(Error::ExecutorGone)?; + + let token = + ex.add_operation(self.source.as_raw_fd(), WatchingEvents::empty().set_read())?; + + Ok(PendingOperation { + token: Some(token), + ex: self.ex.clone(), + }) + } + + // Start an asynchronous operation to wait for this source to become writable. The returned + // future will not be ready until the source is writable. + pub fn wait_writable(&self) -> Result { + let ex = self.ex.upgrade().ok_or(Error::ExecutorGone)?; + + let token = + ex.add_operation(self.source.as_raw_fd(), WatchingEvents::empty().set_write())?; + + Ok(PendingOperation { + token: Some(token), + ex: self.ex.clone(), + }) + } +} + +impl RegisteredSource { + // Consume this RegisteredSource and return the inner IO source. + pub fn into_source(self) -> F { + self.source + } +} + +impl AsRef for RegisteredSource { + fn as_ref(&self) -> &F { + &self.source + } +} + +impl AsMut for RegisteredSource { + fn as_mut(&mut self) -> &mut F { + &mut self.source + } +} + +/// A token returned from `add_operation` that can be used to cancel the waker before it completes. +/// Used to manage getting the result from the underlying executor for a completed operation. +/// Dropping a `PendingOperation` will get the result from the executor. +pub struct PendingOperation { + token: Option, + ex: Weak, +} + +impl Future for PendingOperation { + type Output = Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let token = self + .token + .as_ref() + .expect("PendingOperation polled after returning Poll::Ready"); + if let Some(ex) = self.ex.upgrade() { + if ex.is_ready(token, cx) { + self.token = None; + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } else { + Poll::Ready(Err(Error::ExecutorGone)) + } + } +} + +impl Drop for PendingOperation { + fn drop(&mut self) { + if let Some(token) = self.token.take() { + if let Some(ex) = self.ex.upgrade() { + let _ = ex.cancel_operation(token); + } + } + } +} + +// This function exists to guarantee that non-epoll futures will not starve until an epoll future is +// ready to be polled again. The mechanism is very similar to the self-pipe trick used by C programs +// to reliably mix select / poll with signal handling. This is how it works: +// +// * RawExecutor::new creates an eventfd, dupes it, and spawns this async function with the duped fd. +// * The first time notify_task is polled it tries to read from the eventfd and if that fails, waits +// for the fd to become readable. +// * Meanwhile the RawExecutor keeps the original fd for the eventfd. +// * Whenever RawExecutor::wake is called it will write to the eventfd if it determines that the +// executor thread is currently blocked inside an io_epoll_enter call. This can happen when a +// non-epoll future becomes ready to poll. +// * The write to the eventfd causes the fd to become readable, which then allows the epoll() call +// to return with at least one readable fd. +// * The executor then polls the non-epoll future that became ready, any epoll futures that +// completed, and the notify_task function, which then queues up another read on the eventfd and +// the process can repeat. +async fn notify_task(notify: EventFd, raw: Weak) { + add_fd_flags(notify.as_raw_fd(), libc::O_NONBLOCK) + .expect("Failed to set notify EventFd as non-blocking"); + + loop { + match notify.read() { + Ok(_) => {} + Err(e) if e.errno() == libc::EWOULDBLOCK => {} + Err(e) => panic!("Unexpected error while reading notify EventFd: {}", e), + } + + if let Some(ex) = raw.upgrade() { + let token = ex + .add_operation(notify.as_raw_fd(), WatchingEvents::empty().set_read()) + .expect("Failed to add notify EventFd to PollCtx"); + + // We don't want to hold an active reference to the executor in the .await below. + mem::drop(ex); + + let op = PendingOperation { + token: Some(token), + ex: raw.clone(), + }; + + match op.await { + Ok(()) => {} + Err(Error::ExecutorGone) => break, + Err(e) => panic!("Unexpected error while waiting for notify EventFd: {}", e), + } + } else { + // The executor is gone so we should also exit. + break; + } + } +} + +// Indicates that the executor is either within or about to make a PollContext::wait() call. When a +// waker sees this value, it will write to the notify EventFd, which will cause the +// PollContext::wait() call to return. +const WAITING: i32 = 0x1d5b_c019u32 as i32; + +// Indicates that the executor is processing any futures that are ready to run. +const PROCESSING: i32 = 0xd474_77bcu32 as i32; + +// Indicates that one or more futures may be ready to make progress. +const WOKEN: i32 = 0x3e4d_3276u32 as i32; + +struct RawExecutor { + queue: RunnableQueue, + poll_ctx: EpollContext, + ops: Mutex>, + blocking_pool: BlockingPool, + state: AtomicI32, + notify: EventFd, +} + +impl RawExecutor { + fn new(notify: EventFd) -> Result { + Ok(RawExecutor { + queue: RunnableQueue::new(), + poll_ctx: EpollContext::new().map_err(Error::CreatingContext)?, + ops: Mutex::new(Slab::with_capacity(64)), + blocking_pool: Default::default(), + state: AtomicI32::new(PROCESSING), + notify, + }) + } + + fn add_operation(&self, fd: RawFd, events: WatchingEvents) -> Result { + let duped_fd = unsafe { + // Safe because duplicating an FD doesn't affect memory safety, and the dup'd FD + // will only be added to the poll loop. + File::from_raw_fd(dup_fd(fd)?) + }; + let mut ops = self.ops.lock(); + let entry = ops.vacant_entry(); + let next_token = entry.key(); + self.poll_ctx + .add_fd_with_events(&duped_fd, events, next_token) + .map_err(Error::SubmittingWaker)?; + entry.insert(OpStatus::Pending(OpData { + file: duped_fd, + waker: None, + })); + Ok(WakerToken(next_token)) + } + + fn wake(&self) { + let oldstate = self.state.swap(WOKEN, Ordering::Release); + if oldstate == WAITING { + if let Err(e) = self.notify.write(1) { + warn!("Failed to notify executor that a future is ready: {}", e); + } + } + } + + fn spawn(self: &Arc, f: F) -> Task + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let raw = Arc::downgrade(self); + let schedule = move |runnable| { + if let Some(r) = raw.upgrade() { + r.queue.push_back(runnable); + r.wake(); + } + }; + let (runnable, task) = async_task::spawn(f, schedule); + runnable.schedule(); + task + } + + fn spawn_local(self: &Arc, f: F) -> Task + where + F: Future + 'static, + F::Output: 'static, + { + let raw = Arc::downgrade(self); + let schedule = move |runnable| { + if let Some(r) = raw.upgrade() { + r.queue.push_back(runnable); + r.wake(); + } + }; + let (runnable, task) = async_task::spawn_local(f, schedule); + runnable.schedule(); + task + } + + fn spawn_blocking(self: &Arc, f: F) -> Task + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + self.blocking_pool.spawn(f) + } + + fn run(&self, cx: &mut Context, done: F) -> Result { + let events = EpollEvents::new(); + pin_mut!(done); + + loop { + self.state.store(PROCESSING, Ordering::Release); + for runnable in self.queue.iter() { + runnable.run(); + } + + if let Poll::Ready(val) = done.as_mut().poll(cx) { + return Ok(val); + } + + let oldstate = self.state.compare_exchange( + PROCESSING, + WAITING, + Ordering::Acquire, + Ordering::Acquire, + ); + if let Err(oldstate) = oldstate { + debug_assert_eq!(oldstate, WOKEN); + // One or more futures have become runnable. + continue; + } + + let events = self + .poll_ctx + .wait(&events) + .map_err(Error::PollContextError)?; + + // Set the state back to PROCESSING to prevent any tasks woken up by the loop below from + // writing to the eventfd. + self.state.store(PROCESSING, Ordering::Release); + for e in events.iter() { + let token = e.token(); + let mut ops = self.ops.lock(); + + // The op could have been canceled and removed by another thread so ignore it if it + // doesn't exist. + if let Some(op) = ops.get_mut(token) { + let (file, waker) = match mem::replace(op, OpStatus::Completed) { + OpStatus::Pending(OpData { file, waker }) => (file, waker), + OpStatus::Completed => panic!("poll operation completed more than once"), + }; + + mem::drop(ops); + + self.poll_ctx + .delete(&file) + .map_err(Error::PollContextError)?; + + if let Some(waker) = waker { + waker.wake(); + } + } + } + } + } + + fn is_ready(&self, token: &WakerToken, cx: &mut Context) -> bool { + let mut ops = self.ops.lock(); + + let op = ops + .get_mut(token.0) + .expect("`is_ready` called on unknown operation"); + match op { + OpStatus::Pending(data) => { + data.waker = Some(cx.waker().clone()); + false + } + OpStatus::Completed => { + ops.remove(token.0); + true + } + } + } + + // Remove the waker for the given token if it hasn't fired yet. + fn cancel_operation(&self, token: WakerToken) -> Result<()> { + match self.ops.lock().remove(token.0) { + OpStatus::Pending(data) => self + .poll_ctx + .delete(&data.file) + .map_err(Error::PollContextError), + OpStatus::Completed => Ok(()), + } + } +} + +impl WeakWake for RawExecutor { + fn wake_by_ref(weak_self: &Weak) { + if let Some(arc_self) = weak_self.upgrade() { + RawExecutor::wake(&arc_self); + } + } +} + +impl Drop for RawExecutor { + fn drop(&mut self) { + // Wake up the notify_task. We set the state to WAITING here so that wake() will write to + // the eventfd. + self.state.store(WAITING, Ordering::Release); + self.wake(); + + // Wake up any futures still waiting on poll operations as they are just going to get an + // ExecutorGone error now. + for op in self.ops.get_mut().drain() { + match op { + OpStatus::Pending(mut data) => { + if let Some(waker) = data.waker.take() { + waker.wake(); + } + + if let Err(e) = self.poll_ctx.delete(&data.file) { + warn!("Failed to remove file from EpollCtx: {}", e); + } + } + OpStatus::Completed => {} + } + } + + // Now run the executor one more time to drive any remaining futures to completion. + let waker = noop_waker(); + let mut cx = Context::from_waker(&waker); + if let Err(e) = self.run(&mut cx, async {}) { + warn!("Failed to drive FdExecutor to completion: {}", e); + } + } +} + +#[derive(Clone)] +pub struct FdExecutor { + raw: Arc, +} + +impl FdExecutor { + pub fn new() -> Result { + let notify = EventFd::new().map_err(Error::CreateEventFd)?; + let raw = notify + .try_clone() + .map_err(Error::CloneEventFd) + .and_then(RawExecutor::new) + .map(Arc::new)?; + + raw.spawn(notify_task(notify, Arc::downgrade(&raw))) + .detach(); + + Ok(FdExecutor { raw }) + } + + pub fn spawn(&self, f: F) -> Task + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.raw.spawn(f) + } + + pub fn spawn_local(&self, f: F) -> Task + where + F: Future + 'static, + F::Output: 'static, + { + self.raw.spawn_local(f) + } + + pub fn spawn_blocking(&self, f: F) -> Task + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + self.raw.spawn_blocking(f) + } + + pub fn run(&self) -> Result<()> { + let waker = new_waker(Arc::downgrade(&self.raw)); + let mut cx = Context::from_waker(&waker); + + self.raw.run(&mut cx, super::empty::<()>()) + } + + pub fn run_until(&self, f: F) -> Result { + let waker = new_waker(Arc::downgrade(&self.raw)); + let mut ctx = Context::from_waker(&waker); + + self.raw.run(&mut ctx, f) + } + + pub(crate) fn register_source(&self, f: F) -> Result> { + add_fd_flags(f.as_raw_fd(), libc::O_NONBLOCK).map_err(Error::SettingNonBlocking)?; + Ok(RegisteredSource { + source: f, + ex: Arc::downgrade(&self.raw), + }) + } +} + +// Used to `dup` the FDs passed to the executor so there is a guarantee they aren't closed while +// waiting in TLS to be added to the main polling context. +unsafe fn dup_fd(fd: RawFd) -> Result { + let ret = libc::fcntl(fd, libc::F_DUPFD_CLOEXEC, 0); + if ret < 0 { + Err(Error::DuplicatingFd(sys_util::Error::last())) + } else { + Ok(ret) + } +} + +#[cfg(test)] +mod test { + use std::{ + cell::RefCell, + io::{Read, Write}, + rc::Rc, + }; + + use futures::future::Either; + + use super::*; + + #[test] + fn test_it() { + async fn do_test(ex: &FdExecutor) { + let (r, _w) = sys_util::pipe(true).unwrap(); + let done = Box::pin(async { 5usize }); + let source = ex.register_source(r).unwrap(); + let pending = source.wait_readable().unwrap(); + match futures::future::select(pending, done).await { + Either::Right((5, pending)) => std::mem::drop(pending), + _ => panic!("unexpected select result"), + } + } + + let ex = FdExecutor::new().unwrap(); + ex.run_until(do_test(&ex)).unwrap(); + + // Example of starting the framework and running a future: + async fn my_async(x: Rc>) { + x.replace(4); + } + + let x = Rc::new(RefCell::new(0)); + super::super::run_one_poll(my_async(x.clone())).unwrap(); + assert_eq!(*x.borrow(), 4); + } + + #[test] + fn drop_before_completion() { + const VALUE: u64 = 0x66ae_cb65_12fb_d260; + + async fn write_value(mut tx: File) { + let buf = VALUE.to_ne_bytes(); + tx.write_all(&buf[..]).expect("Failed to write to pipe"); + } + + async fn check_op(op: PendingOperation) { + let err = op.await.expect_err("Task completed successfully"); + match err { + Error::ExecutorGone => {} + e => panic!("Unexpected error from task: {}", e), + } + } + + let (mut rx, tx) = sys_util::pipe(true).expect("Pipe failed"); + + let ex = FdExecutor::new().unwrap(); + + let source = ex.register_source(tx.try_clone().unwrap()).unwrap(); + let op = source.wait_writable().unwrap(); + + ex.spawn_local(write_value(tx)).detach(); + ex.spawn_local(check_op(op)).detach(); + + // Now drop the executor. It should still run until the write to the pipe is complete. + mem::drop(ex); + + let mut buf = 0u64.to_ne_bytes(); + rx.read_exact(&mut buf[..]) + .expect("Failed to read from pipe"); + + assert_eq!(u64::from_ne_bytes(buf), VALUE); + } +} diff --git a/cros_async/src/io_ext.rs b/cros_async/src/io_ext.rs new file mode 100644 index 0000000000..010b716e3f --- /dev/null +++ b/cros_async/src/io_ext.rs @@ -0,0 +1,479 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! # `IoSourceExt` +//! +//! User functions to asynchronously access files. +//! Using `IoSource` directly is inconvenient and requires dealing with state +//! machines for the backing uring, future libraries, etc. `IoSourceExt` instead +//! provides users with a future that can be `await`ed from async context. +//! +//! Each member of `IoSourceExt` returns a future for the supported operation. One or more +//! operation can be pending at a time. +//! +//! Operations can only access memory in a `Vec` or an implementor of `BackingMemory`. See the +//! `URingExecutor` documentation for an explaination of why. + +use std::{ + fs::File, + io, + ops::{Deref, DerefMut}, + os::unix::io::{AsRawFd, RawFd}, + sync::Arc, +}; + +use async_trait::async_trait; +use remain::sorted; +use sys_util::net::UnixSeqpacket; +use thiserror::Error as ThisError; + +use super::{BackingMemory, MemRegion}; + +#[sorted] +#[derive(ThisError, Debug)] +pub enum Error { + /// An error with a polled(FD) source. + #[error("An error with a poll source: {0}")] + Poll(#[from] super::poll_source::Error), + /// An error with a uring source. + #[error("An error with a uring source: {0}")] + Uring(#[from] super::uring_executor::Error), +} +pub type Result = std::result::Result; + +impl From for io::Error { + fn from(e: Error) -> Self { + use Error::*; + match e { + Poll(e) => e.into(), + Uring(e) => e.into(), + } + } +} + +/// Ergonomic methods for async reads. +#[async_trait(?Send)] +pub trait ReadAsync { + /// Reads from the iosource at `file_offset` and fill the given `vec`. + async fn read_to_vec<'a>( + &'a self, + file_offset: Option, + vec: Vec, + ) -> Result<(usize, Vec)>; + + /// Reads to the given `mem` at the given offsets from the file starting at `file_offset`. + async fn read_to_mem<'a>( + &'a self, + file_offset: Option, + mem: Arc, + mem_offsets: &'a [MemRegion], + ) -> Result; + + /// Wait for the FD of `self` to be readable. + async fn wait_readable(&self) -> Result<()>; + + /// Reads a single u64 from the current offset. + async fn read_u64(&self) -> Result; +} + +/// Ergonomic methods for async writes. +#[async_trait(?Send)] +pub trait WriteAsync { + /// Writes from the given `vec` to the file starting at `file_offset`. + async fn write_from_vec<'a>( + &'a self, + file_offset: Option, + vec: Vec, + ) -> Result<(usize, Vec)>; + + /// Writes from the given `mem` from the given offsets to the file starting at `file_offset`. + async fn write_from_mem<'a>( + &'a self, + file_offset: Option, + mem: Arc, + mem_offsets: &'a [MemRegion], + ) -> Result; + + /// See `fallocate(2)`. Note this op is synchronous when using the Polled backend. + async fn fallocate(&self, file_offset: u64, len: u64, mode: u32) -> Result<()>; + + /// Sync all completed write operations to the backing storage. + async fn fsync(&self) -> Result<()>; +} + +/// Subtrait for general async IO. +#[async_trait(?Send)] +pub trait IoSourceExt: ReadAsync + WriteAsync { + /// Yields the underlying IO source. + fn into_source(self: Box) -> F; + + /// Provides a mutable ref to the underlying IO source. + fn as_source_mut(&mut self) -> &mut F; + + /// Provides a ref to the underlying IO source. + fn as_source(&self) -> &F; +} + +/// Marker trait signifying that the implementor is suitable for use with +/// cros_async. Examples of this include File, and sys_util::net::UnixSeqpacket. +/// +/// (Note: it'd be really nice to implement a TryFrom for any implementors, and +/// remove our factory functions. Unfortunately +/// makes that too painful.) +pub trait IntoAsync: AsRawFd {} + +impl IntoAsync for File {} +impl IntoAsync for UnixSeqpacket {} +impl IntoAsync for &UnixSeqpacket {} + +/// Simple wrapper struct to implement IntoAsync on foreign types. +pub struct AsyncWrapper(T); + +impl AsyncWrapper { + /// Create a new `AsyncWrapper` that wraps `val`. + pub fn new(val: T) -> Self { + AsyncWrapper(val) + } + + /// Consumes the `AsyncWrapper`, returning the inner struct. + pub fn into_inner(self) -> T { + self.0 + } +} + +impl Deref for AsyncWrapper { + type Target = T; + + fn deref(&self) -> &T { + &self.0 + } +} + +impl DerefMut for AsyncWrapper { + fn deref_mut(&mut self) -> &mut T { + &mut self.0 + } +} + +impl AsRawFd for AsyncWrapper { + fn as_raw_fd(&self) -> RawFd { + self.0.as_raw_fd() + } +} + +impl IntoAsync for AsyncWrapper {} + +#[cfg(test)] +mod tests { + use std::{ + fs::{File, OpenOptions}, + future::Future, + os::unix::io::AsRawFd, + pin::Pin, + sync::Arc, + task::{Context, Poll, Waker}, + thread, + }; + + use sync::Mutex; + + use super::{ + super::{ + executor::{async_poll_from, async_uring_from}, + mem::VecIoWrapper, + uring_executor::use_uring, + Executor, FdExecutor, MemRegion, PollSource, URingExecutor, UringSource, + }, + *, + }; + + struct State { + should_quit: bool, + waker: Option, + } + + impl State { + fn wake(&mut self) { + self.should_quit = true; + let waker = self.waker.take(); + + if let Some(waker) = waker { + waker.wake(); + } + } + } + + struct Quit { + state: Arc>, + } + + impl Future for Quit { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> { + let mut state = self.state.lock(); + if state.should_quit { + return Poll::Ready(()); + } + + state.waker = Some(cx.waker().clone()); + Poll::Pending + } + } + + #[test] + fn await_uring_from_poll() { + if !use_uring() { + return; + } + // Start a uring operation and then await the result from an FdExecutor. + async fn go(source: UringSource) { + let v = vec![0xa4u8; 16]; + let (len, vec) = source.read_to_vec(None, v).await.unwrap(); + assert_eq!(len, 16); + assert!(vec.iter().all(|&b| b == 0)); + } + + let state = Arc::new(Mutex::new(State { + should_quit: false, + waker: None, + })); + + let uring_ex = URingExecutor::new().unwrap(); + let f = File::open("/dev/zero").unwrap(); + let source = UringSource::new(f, &uring_ex).unwrap(); + + let quit = Quit { + state: state.clone(), + }; + let handle = thread::spawn(move || uring_ex.run_until(quit)); + + let poll_ex = FdExecutor::new().unwrap(); + poll_ex.run_until(go(source)).unwrap(); + + state.lock().wake(); + handle.join().unwrap().unwrap(); + } + + #[test] + fn await_poll_from_uring() { + if !use_uring() { + return; + } + // Start a poll operation and then await the result from a URingExecutor. + async fn go(source: PollSource) { + let v = vec![0x2cu8; 16]; + let (len, vec) = source.read_to_vec(None, v).await.unwrap(); + assert_eq!(len, 16); + assert!(vec.iter().all(|&b| b == 0)); + } + + let state = Arc::new(Mutex::new(State { + should_quit: false, + waker: None, + })); + + let poll_ex = FdExecutor::new().unwrap(); + let f = File::open("/dev/zero").unwrap(); + let source = PollSource::new(f, &poll_ex).unwrap(); + + let quit = Quit { + state: state.clone(), + }; + let handle = thread::spawn(move || poll_ex.run_until(quit)); + + let uring_ex = URingExecutor::new().unwrap(); + uring_ex.run_until(go(source)).unwrap(); + + state.lock().wake(); + handle.join().unwrap().unwrap(); + } + + #[test] + fn readvec() { + if !use_uring() { + return; + } + async fn go(async_source: Box>) { + let v = vec![0x55u8; 32]; + let v_ptr = v.as_ptr(); + let ret = async_source.read_to_vec(None, v).await.unwrap(); + assert_eq!(ret.0, 32); + let ret_v = ret.1; + assert_eq!(v_ptr, ret_v.as_ptr()); + assert!(ret_v.iter().all(|&b| b == 0)); + } + + let f = File::open("/dev/zero").unwrap(); + let uring_ex = URingExecutor::new().unwrap(); + let uring_source = async_uring_from(f, &uring_ex).unwrap(); + uring_ex.run_until(go(uring_source)).unwrap(); + + let f = File::open("/dev/zero").unwrap(); + let poll_ex = FdExecutor::new().unwrap(); + let poll_source = async_poll_from(f, &poll_ex).unwrap(); + poll_ex.run_until(go(poll_source)).unwrap(); + } + + #[test] + fn writevec() { + if !use_uring() { + return; + } + async fn go(async_source: Box>) { + let v = vec![0x55u8; 32]; + let v_ptr = v.as_ptr(); + let ret = async_source.write_from_vec(None, v).await.unwrap(); + assert_eq!(ret.0, 32); + let ret_v = ret.1; + assert_eq!(v_ptr, ret_v.as_ptr()); + } + + let f = OpenOptions::new().write(true).open("/dev/null").unwrap(); + let ex = URingExecutor::new().unwrap(); + let uring_source = async_uring_from(f, &ex).unwrap(); + ex.run_until(go(uring_source)).unwrap(); + + let f = OpenOptions::new().write(true).open("/dev/null").unwrap(); + let poll_ex = FdExecutor::new().unwrap(); + let poll_source = async_poll_from(f, &poll_ex).unwrap(); + poll_ex.run_until(go(poll_source)).unwrap(); + } + + #[test] + fn readmem() { + if !use_uring() { + return; + } + async fn go(async_source: Box>) { + let mem = Arc::new(VecIoWrapper::from(vec![0x55u8; 8192])); + let ret = async_source + .read_to_mem( + None, + Arc::::clone(&mem), + &[ + MemRegion { offset: 0, len: 32 }, + MemRegion { + offset: 200, + len: 56, + }, + ], + ) + .await + .unwrap(); + assert_eq!(ret, 32 + 56); + let vec: Vec = match Arc::try_unwrap(mem) { + Ok(v) => v.into(), + Err(_) => panic!("Too many vec refs"), + }; + assert!(vec.iter().take(32).all(|&b| b == 0)); + assert!(vec.iter().skip(32).take(168).all(|&b| b == 0x55)); + assert!(vec.iter().skip(200).take(56).all(|&b| b == 0)); + assert!(vec.iter().skip(256).all(|&b| b == 0x55)); + } + + let f = File::open("/dev/zero").unwrap(); + let ex = URingExecutor::new().unwrap(); + let uring_source = async_uring_from(f, &ex).unwrap(); + ex.run_until(go(uring_source)).unwrap(); + + let f = File::open("/dev/zero").unwrap(); + let poll_ex = FdExecutor::new().unwrap(); + let poll_source = async_poll_from(f, &poll_ex).unwrap(); + poll_ex.run_until(go(poll_source)).unwrap(); + } + + #[test] + fn writemem() { + if !use_uring() { + return; + } + async fn go(async_source: Box>) { + let mem = Arc::new(VecIoWrapper::from(vec![0x55u8; 8192])); + let ret = async_source + .write_from_mem( + None, + Arc::::clone(&mem), + &[MemRegion { offset: 0, len: 32 }], + ) + .await + .unwrap(); + assert_eq!(ret, 32); + } + + let f = OpenOptions::new().write(true).open("/dev/null").unwrap(); + let ex = URingExecutor::new().unwrap(); + let uring_source = async_uring_from(f, &ex).unwrap(); + ex.run_until(go(uring_source)).unwrap(); + + let f = OpenOptions::new().write(true).open("/dev/null").unwrap(); + let poll_ex = FdExecutor::new().unwrap(); + let poll_source = async_poll_from(f, &poll_ex).unwrap(); + poll_ex.run_until(go(poll_source)).unwrap(); + } + + #[test] + fn read_u64s() { + if !use_uring() { + return; + } + async fn go(async_source: File, ex: URingExecutor) -> u64 { + let source = async_uring_from(async_source, &ex).unwrap(); + source.read_u64().await.unwrap() + } + + let f = File::open("/dev/zero").unwrap(); + let ex = URingExecutor::new().unwrap(); + let val = ex.run_until(go(f, ex.clone())).unwrap(); + assert_eq!(val, 0); + } + + #[test] + fn read_eventfds() { + if !use_uring() { + return; + } + use sys_util::EventFd; + + async fn go(source: Box>) -> u64 { + source.read_u64().await.unwrap() + } + + let eventfd = EventFd::new().unwrap(); + eventfd.write(0x55).unwrap(); + let ex = URingExecutor::new().unwrap(); + let uring_source = async_uring_from(eventfd, &ex).unwrap(); + let val = ex.run_until(go(uring_source)).unwrap(); + assert_eq!(val, 0x55); + + let eventfd = EventFd::new().unwrap(); + eventfd.write(0xaa).unwrap(); + let poll_ex = FdExecutor::new().unwrap(); + let poll_source = async_poll_from(eventfd, &poll_ex).unwrap(); + let val = poll_ex.run_until(go(poll_source)).unwrap(); + assert_eq!(val, 0xaa); + } + + #[test] + fn fsync() { + if !use_uring() { + return; + } + async fn go(source: Box>) { + let v = vec![0x55u8; 32]; + let v_ptr = v.as_ptr(); + let ret = source.write_from_vec(None, v).await.unwrap(); + assert_eq!(ret.0, 32); + let ret_v = ret.1; + assert_eq!(v_ptr, ret_v.as_ptr()); + source.fsync().await.unwrap(); + } + + let f = tempfile::tempfile().unwrap(); + let ex = Executor::new().unwrap(); + let source = ex.async_from(f).unwrap(); + + ex.run_until(go(source)).unwrap(); + } +} diff --git a/cros_async/src/lib.rs b/cros_async/src/lib.rs new file mode 100644 index 0000000000..b353a2e035 --- /dev/null +++ b/cros_async/src/lib.rs @@ -0,0 +1,540 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! An Executor and future combinators based on operations that block on file descriptors. +//! +//! This crate is meant to be used with the `futures-rs` crate that provides further combinators +//! and utility functions to combine and manage futures. All futures will run until they block on a +//! file descriptor becoming readable or writable. Facilities are provided to register future +//! wakers based on such events. +//! +//! # Running top-level futures. +//! +//! Use helper functions based the desired behavior of your application. +//! +//! ## Running one future. +//! +//! If there is only one top-level future to run, use the [`run_one`](fn.run_one.html) function. +//! +//! ## Completing one of several futures. +//! +//! If there are several top level tasks that should run until any one completes, use the "select" +//! family of executor constructors. These return an [`Executor`](trait.Executor.html) whose `run` +//! function will return when the first future completes. The uncompleted futures will also be +//! returned so they can be run further or otherwise cleaned up. These functions are inspired by +//! the `select_all` function from futures-rs, but built to be run inside an FD based executor and +//! to poll only when necessary. See the docs for [`select2`](fn.select2.html), +//! [`select3`](fn.select3.html), [`select4`](fn.select4.html), and [`select5`](fn.select5.html). +//! +//! ## Completing all of several futures. +//! +//! If there are several top level tasks that all need to be completed, use the "complete" family +//! of executor constructors. These return an [`Executor`](trait.Executor.html) whose `run` +//! function will return only once all the futures passed to it have completed. These functions are +//! inspired by the `join_all` function from futures-rs, but built to be run inside an FD based +//! executor and to poll only when necessary. See the docs for [`complete2`](fn.complete2.html), +//! [`complete3`](fn.complete3.html), [`complete4`](fn.complete4.html), and +//! [`complete5`](fn.complete5.html). +//! +//! # Implementing new FD-based futures. +//! +//! For URing implementations should provide an implementation of the `IoSource` trait. +//! For the FD executor, new futures can use the existing ability to poll a source to build async +//! functionality on top of. +//! +//! # Implementations +//! +//! Currently there are two paths for using the asynchronous IO. One uses a PollContext and drivers +//! futures based on the FDs signaling they are ready for the opteration. This method will exist so +//! long as kernels < 5.4 are supported. +//! The other method submits operations to io_uring and is signaled when they complete. This is more +//! efficient, but only supported on kernel 5.4+. +//! If `IoSourceExt::new` is used to interface with async IO, then the correct backend will be chosen +//! automatically. +//! +//! # Examples +//! +//! See the docs for `IoSourceExt` if support for kernels <5.4 is required. Focus on `UringSource` if +//! all systems have support for io_uring. + +pub mod audio_streams_async; +mod blocking; +mod complete; +mod event; +mod executor; +mod fd_executor; +mod io_ext; +pub mod mem; +mod poll_source; +mod queue; +mod select; +pub mod sync; +mod timer; +mod uring_executor; +mod uring_source; +mod waker; + +pub use blocking::{block_on, BlockingPool}; +pub use event::EventAsync; +pub use executor::Executor; +pub use fd_executor::FdExecutor; +pub use io_ext::{ + AsyncWrapper, Error as AsyncError, IntoAsync, IoSourceExt, ReadAsync, Result as AsyncResult, + WriteAsync, +}; +pub use mem::{BackingMemory, MemRegion}; +pub use poll_source::PollSource; +pub use select::SelectResult; +pub use sys_util; +pub use timer::TimerAsync; +pub use uring_executor::URingExecutor; +pub use uring_source::UringSource; + +use std::{ + future::Future, + io, + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, +}; + +use remain::sorted; +use thiserror::Error as ThisError; + +#[sorted] +#[derive(ThisError, Debug)] +pub enum Error { + /// Error from the FD executor. + #[error("Failure in the FD executor: {0}")] + FdExecutor(fd_executor::Error), + /// Error from TimerFd. + #[error("Failure in TimerAsync: {0}")] + TimerAsync(AsyncError), + /// Error from TimerFd. + #[error("Failure in TimerFd: {0}")] + TimerFd(sys_util::Error), + /// Error from the uring executor. + #[error("Failure in the uring executor: {0}")] + URingExecutor(uring_executor::Error), +} +pub type Result = std::result::Result; + +impl From for io::Error { + fn from(e: Error) -> Self { + use Error::*; + match e { + FdExecutor(e) => e.into(), + URingExecutor(e) => e.into(), + TimerFd(e) => e.into(), + TimerAsync(e) => e.into(), + } + } +} + +// A Future that never completes. +pub struct Empty { + phantom: PhantomData, +} + +impl Future for Empty { + type Output = T; + + fn poll(self: Pin<&mut Self>, _: &mut Context) -> Poll { + Poll::Pending + } +} + +pub fn empty() -> Empty { + Empty { + phantom: PhantomData, + } +} + +/// Creates an Executor that runs one future to completion. +/// +/// # Example +/// +/// ``` +/// use cros_async::run_one; +/// +/// let fut = async { 55 }; +/// assert_eq!(55, run_one(fut).unwrap()); +/// ``` +pub fn run_one(fut: F) -> Result { + if uring_executor::use_uring() { + run_one_uring(fut) + } else { + run_one_poll(fut) + } +} + +/// Creates a URingExecutor that runs one future to completion. +/// +/// # Example +/// +/// ``` +/// use cros_async::run_one_uring; +/// +/// let fut = async { 55 }; +/// assert_eq!(55, run_one_uring(fut).unwrap()); +/// ``` +pub fn run_one_uring(fut: F) -> Result { + URingExecutor::new() + .and_then(|ex| ex.run_until(fut)) + .map_err(Error::URingExecutor) +} + +/// Creates a FdExecutor that runs one future to completion. +/// +/// # Example +/// +/// ``` +/// use cros_async::run_one_poll; +/// +/// let fut = async { 55 }; +/// assert_eq!(55, run_one_poll(fut).unwrap()); +/// ``` +pub fn run_one_poll(fut: F) -> Result { + FdExecutor::new() + .and_then(|ex| ex.run_until(fut)) + .map_err(Error::FdExecutor) +} + +// Select helpers to run until any future completes. + +/// Creates a combinator that runs the two given futures until one completes, returning a tuple +/// containing the result of the finished future and the still pending future. +/// +/// # Example +/// +/// ``` +/// use cros_async::{SelectResult, select2, run_one}; +/// use futures::future::pending; +/// use futures::pin_mut; +/// +/// let first = async {5}; +/// let second = async {let () = pending().await;}; +/// pin_mut!(first); +/// pin_mut!(second); +/// match run_one(select2(first, second)) { +/// Ok((SelectResult::Finished(5), SelectResult::Pending(_second))) => (), +/// _ => panic!("Select didn't return the first future"), +/// }; +/// ``` +pub async fn select2( + f1: F1, + f2: F2, +) -> (SelectResult, SelectResult) { + select::Select2::new(f1, f2).await +} + +/// Creates a combinator that runs the three given futures until one or more completes, returning a +/// tuple containing the result of the finished future(s) and the still pending future(s). +/// +/// # Example +/// +/// ``` +/// use cros_async::{SelectResult, select3, run_one}; +/// use futures::future::pending; +/// use futures::pin_mut; +/// +/// let first = async {4}; +/// let second = async {let () = pending().await;}; +/// let third = async {5}; +/// pin_mut!(first); +/// pin_mut!(second); +/// pin_mut!(third); +/// match run_one(select3(first, second, third)) { +/// Ok((SelectResult::Finished(4), +/// SelectResult::Pending(_second), +/// SelectResult::Finished(5))) => (), +/// _ => panic!("Select didn't return the futures"), +/// }; +/// ``` +pub async fn select3( + f1: F1, + f2: F2, + f3: F3, +) -> (SelectResult, SelectResult, SelectResult) { + select::Select3::new(f1, f2, f3).await +} + +/// Creates a combinator that runs the four given futures until one or more completes, returning a +/// tuple containing the result of the finished future(s) and the still pending future(s). +/// +/// # Example +/// +/// ``` +/// use cros_async::{SelectResult, select4, run_one}; +/// use futures::future::pending; +/// use futures::pin_mut; +/// +/// let first = async {4}; +/// let second = async {let () = pending().await;}; +/// let third = async {5}; +/// let fourth = async {let () = pending().await;}; +/// pin_mut!(first); +/// pin_mut!(second); +/// pin_mut!(third); +/// pin_mut!(fourth); +/// match run_one(select4(first, second, third, fourth)) { +/// Ok((SelectResult::Finished(4), SelectResult::Pending(_second), +/// SelectResult::Finished(5), SelectResult::Pending(_fourth))) => (), +/// _ => panic!("Select didn't return the futures"), +/// }; +/// ``` +pub async fn select4< + F1: Future + Unpin, + F2: Future + Unpin, + F3: Future + Unpin, + F4: Future + Unpin, +>( + f1: F1, + f2: F2, + f3: F3, + f4: F4, +) -> ( + SelectResult, + SelectResult, + SelectResult, + SelectResult, +) { + select::Select4::new(f1, f2, f3, f4).await +} + +/// Creates a combinator that runs the five given futures until one or more completes, returning a +/// tuple containing the result of the finished future(s) and the still pending future(s). +/// +/// # Example +/// +/// ``` +/// use cros_async::{SelectResult, select5, run_one}; +/// use futures::future::pending; +/// use futures::pin_mut; +/// +/// let first = async {4}; +/// let second = async {let () = pending().await;}; +/// let third = async {5}; +/// let fourth = async {let () = pending().await;}; +/// let fifth = async {6}; +/// pin_mut!(first); +/// pin_mut!(second); +/// pin_mut!(third); +/// pin_mut!(fourth); +/// pin_mut!(fifth); +/// match run_one(select5(first, second, third, fourth, fifth)) { +/// Ok((SelectResult::Finished(4), SelectResult::Pending(_second), +/// SelectResult::Finished(5), SelectResult::Pending(_fourth), +/// SelectResult::Finished(6))) => (), +/// _ => panic!("Select didn't return the futures"), +/// }; +/// ``` +pub async fn select5< + F1: Future + Unpin, + F2: Future + Unpin, + F3: Future + Unpin, + F4: Future + Unpin, + F5: Future + Unpin, +>( + f1: F1, + f2: F2, + f3: F3, + f4: F4, + f5: F5, +) -> ( + SelectResult, + SelectResult, + SelectResult, + SelectResult, + SelectResult, +) { + select::Select5::new(f1, f2, f3, f4, f5).await +} + +/// Creates a combinator that runs the six given futures until one or more completes, returning a +/// tuple containing the result of the finished future(s) and the still pending future(s). +/// +/// # Example +/// +/// ``` +/// use cros_async::{SelectResult, select6, run_one}; +/// use futures::future::pending; +/// use futures::pin_mut; +/// +/// let first = async {1}; +/// let second = async {let () = pending().await;}; +/// let third = async {3}; +/// let fourth = async {let () = pending().await;}; +/// let fifth = async {5}; +/// let sixth = async {6}; +/// pin_mut!(first); +/// pin_mut!(second); +/// pin_mut!(third); +/// pin_mut!(fourth); +/// pin_mut!(fifth); +/// pin_mut!(sixth); +/// match run_one(select6(first, second, third, fourth, fifth, sixth)) { +/// Ok((SelectResult::Finished(1), SelectResult::Pending(_second), +/// SelectResult::Finished(3), SelectResult::Pending(_fourth), +/// SelectResult::Finished(5), SelectResult::Finished(6))) => (), +/// _ => panic!("Select didn't return the futures"), +/// }; +/// ``` +pub async fn select6< + F1: Future + Unpin, + F2: Future + Unpin, + F3: Future + Unpin, + F4: Future + Unpin, + F5: Future + Unpin, + F6: Future + Unpin, +>( + f1: F1, + f2: F2, + f3: F3, + f4: F4, + f5: F5, + f6: F6, +) -> ( + SelectResult, + SelectResult, + SelectResult, + SelectResult, + SelectResult, + SelectResult, +) { + select::Select6::new(f1, f2, f3, f4, f5, f6).await +} + +pub async fn select7< + F1: Future + Unpin, + F2: Future + Unpin, + F3: Future + Unpin, + F4: Future + Unpin, + F5: Future + Unpin, + F6: Future + Unpin, + F7: Future + Unpin, +>( + f1: F1, + f2: F2, + f3: F3, + f4: F4, + f5: F5, + f6: F6, + f7: F7, +) -> ( + SelectResult, + SelectResult, + SelectResult, + SelectResult, + SelectResult, + SelectResult, + SelectResult, +) { + select::Select7::new(f1, f2, f3, f4, f5, f6, f7).await +} +// Combination helpers to run until all futures are complete. + +/// Creates a combinator that runs the two given futures to completion, returning a tuple of the +/// outputs each yields. +/// +/// # Example +/// +/// ``` +/// use cros_async::{complete2, run_one}; +/// +/// let first = async {5}; +/// let second = async {6}; +/// assert_eq!(run_one(complete2(first, second)).unwrap_or((0,0)), (5,6)); +/// ``` +pub async fn complete2(f1: F1, f2: F2) -> (F1::Output, F2::Output) +where + F1: Future, + F2: Future, +{ + complete::Complete2::new(f1, f2).await +} + +/// Creates a combinator that runs the three given futures to completion, returning a tuple of the +/// outputs each yields. +/// +/// # Example +/// +/// ``` +/// use cros_async::{complete3, run_one}; +/// +/// let first = async {5}; +/// let second = async {6}; +/// let third = async {7}; +/// assert_eq!(run_one(complete3(first, second, third)).unwrap_or((0,0,0)), (5,6,7)); +/// ``` +pub async fn complete3(f1: F1, f2: F2, f3: F3) -> (F1::Output, F2::Output, F3::Output) +where + F1: Future, + F2: Future, + F3: Future, +{ + complete::Complete3::new(f1, f2, f3).await +} + +/// Creates a combinator that runs the four given futures to completion, returning a tuple of the +/// outputs each yields. +/// +/// # Example +/// +/// ``` +/// use cros_async::{complete4, run_one}; +/// +/// let first = async {5}; +/// let second = async {6}; +/// let third = async {7}; +/// let fourth = async {8}; +/// assert_eq!(run_one(complete4(first, second, third, fourth)).unwrap_or((0,0,0,0)), (5,6,7,8)); +/// ``` +pub async fn complete4( + f1: F1, + f2: F2, + f3: F3, + f4: F4, +) -> (F1::Output, F2::Output, F3::Output, F4::Output) +where + F1: Future, + F2: Future, + F3: Future, + F4: Future, +{ + complete::Complete4::new(f1, f2, f3, f4).await +} + +/// Creates a combinator that runs the five given futures to completion, returning a tuple of the +/// outputs each yields. +/// +/// # Example +/// +/// ``` +/// use cros_async::{complete5, run_one}; +/// +/// let first = async {5}; +/// let second = async {6}; +/// let third = async {7}; +/// let fourth = async {8}; +/// let fifth = async {9}; +/// assert_eq!(run_one(complete5(first, second, third, fourth, fifth)).unwrap_or((0,0,0,0,0)), +/// (5,6,7,8,9)); +/// ``` +pub async fn complete5( + f1: F1, + f2: F2, + f3: F3, + f4: F4, + f5: F5, +) -> (F1::Output, F2::Output, F3::Output, F4::Output, F5::Output) +where + F1: Future, + F2: Future, + F3: Future, + F4: Future, + F5: Future, +{ + complete::Complete5::new(f1, f2, f3, f4, f5).await +} diff --git a/cros_async/src/mem.rs b/cros_async/src/mem.rs new file mode 100644 index 0000000000..691e629fec --- /dev/null +++ b/cros_async/src/mem.rs @@ -0,0 +1,98 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use data_model::VolatileSlice; +use remain::sorted; +use thiserror::Error as ThisError; + +#[sorted] +#[derive(ThisError, Debug)] +pub enum Error { + /// Invalid offset or length given for an iovec in backing memory. + #[error("Invalid offset/len for getting a slice from {0} with len {1}.")] + InvalidOffset(u64, usize), +} +pub type Result = std::result::Result; + +/// Used to index subslices of backing memory. Like an iovec, but relative to the start of the +/// memory region instead of an absolute pointer. +/// The backing memory referenced by the region can be an array, an mmapped file, or guest memory. +/// The offset is a u64 to allow having file or guest offsets >4GB when run on a 32bit host. +#[derive(Copy, Clone, Debug)] +pub struct MemRegion { + pub offset: u64, + pub len: usize, +} + +/// Trait for memory that can yeild both iovecs in to the backing memory. +/// Must be OK to modify the backing memory without owning a mut able reference. For example, +/// this is safe for GuestMemory and VolatileSlices in crosvm as those types guarantee they are +/// dealt with as volatile. +pub unsafe trait BackingMemory { + /// Returns VolatileSlice pointing to the backing memory. This is most commonly unsafe. + /// To implement this safely the implementor must guarantee that the backing memory can be + /// modified out of band without affecting safety guarantees. + fn get_volatile_slice(&self, mem_range: MemRegion) -> Result; +} + +/// Wrapper to be used for passing a Vec in as backing memory for asynchronous operations. The +/// wrapper owns a Vec according to the borrow checker. It is loaning this vec out to the kernel(or +/// other modifiers) through the `BackingMemory` trait. This allows multiple modifiers of the array +/// in the `Vec` while this struct is alive. The data in the Vec is loaned to the kernel not the +/// data structure itself, the length, capacity, and pointer to memory cannot be modified. +/// To ensure that those operations can be done safely, no access is allowed to the `Vec`'s memory +/// starting at the time that `VecIoWrapper` is constructed until the time it is turned back in to a +/// `Vec` using `to_inner`. The returned `Vec` is guaranteed to be valid as any combination of bits +/// in a `Vec` of `u8` is valid. +pub(crate) struct VecIoWrapper { + inner: Box<[u8]>, +} + +impl From> for VecIoWrapper { + fn from(vec: Vec) -> Self { + VecIoWrapper { inner: vec.into() } + } +} + +impl From for Vec { + fn from(v: VecIoWrapper) -> Vec { + v.inner.into() + } +} + +impl VecIoWrapper { + /// Get the length of the Vec that is wrapped. + pub fn len(&self) -> usize { + self.inner.len() + } + + // Check that the offsets are all valid in the backing vec. + fn check_addrs(&self, mem_range: &MemRegion) -> Result<()> { + let end = mem_range + .offset + .checked_add(mem_range.len as u64) + .ok_or(Error::InvalidOffset(mem_range.offset, mem_range.len))?; + if end > self.inner.len() as u64 { + return Err(Error::InvalidOffset(mem_range.offset, mem_range.len)); + } + Ok(()) + } +} + +// Safe to implement BackingMemory as the vec is only accessible inside the wrapper and these iovecs +// are the only thing allowed to modify it. Nothing else can get a reference to the vec until all +// iovecs are dropped because they borrow Self. Nothing can borrow the owned inner vec until self +// is consumed by `into`, which can't happen if there are outstanding mut borrows. +unsafe impl BackingMemory for VecIoWrapper { + fn get_volatile_slice(&self, mem_range: MemRegion) -> Result> { + self.check_addrs(&mem_range)?; + // Safe because the mem_range range is valid in the backing memory as checked above. + unsafe { + Ok(VolatileSlice::from_raw_parts( + self.inner.as_ptr().add(mem_range.offset as usize) as *mut _, + mem_range.len, + )) + } + } +} diff --git a/cros_async/src/poll_source.rs b/cros_async/src/poll_source.rs new file mode 100644 index 0000000000..4a904ba790 --- /dev/null +++ b/cros_async/src/poll_source.rs @@ -0,0 +1,450 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! A wrapped IO source that uses FdExecutor to drive asynchronous completion. Used from +//! `IoSourceExt::new` when uring isn't available in the kernel. + +use std::{ + io, + ops::{Deref, DerefMut}, + os::unix::io::AsRawFd, + sync::Arc, +}; + +use async_trait::async_trait; +use data_model::VolatileSlice; +use remain::sorted; +use thiserror::Error as ThisError; + +use super::{ + fd_executor::{ + FdExecutor, RegisteredSource, {self}, + }, + mem::{BackingMemory, MemRegion}, + AsyncError, AsyncResult, IoSourceExt, ReadAsync, WriteAsync, +}; + +#[sorted] +#[derive(ThisError, Debug)] +pub enum Error { + /// An error occurred attempting to register a waker with the executor. + #[error("An error occurred attempting to register a waker with the executor: {0}.")] + AddingWaker(fd_executor::Error), + /// An executor error occurred. + #[error("An executor error occurred: {0}")] + Executor(fd_executor::Error), + /// An error occurred when executing fallocate synchronously. + #[error("An error occurred when executing fallocate synchronously: {0}")] + Fallocate(sys_util::Error), + /// An error occurred when executing fsync synchronously. + #[error("An error occurred when executing fsync synchronously: {0}")] + Fsync(sys_util::Error), + /// An error occurred when reading the FD. + #[error("An error occurred when reading the FD: {0}.")] + Read(sys_util::Error), + /// Can't seek file. + #[error("An error occurred when seeking the FD: {0}.")] + Seeking(sys_util::Error), + /// An error occurred when writing the FD. + #[error("An error occurred when writing the FD: {0}.")] + Write(sys_util::Error), +} +pub type Result = std::result::Result; + +impl From for io::Error { + fn from(e: Error) -> Self { + use Error::*; + match e { + AddingWaker(e) => e.into(), + Executor(e) => e.into(), + Fallocate(e) => e.into(), + Fsync(e) => e.into(), + Read(e) => e.into(), + Seeking(e) => e.into(), + Write(e) => e.into(), + } + } +} + +/// Async wrapper for an IO source that uses the FD executor to drive async operations. +/// Used by `IoSourceExt::new` when uring isn't available. +pub struct PollSource(RegisteredSource); + +impl PollSource { + /// Create a new `PollSource` from the given IO source. + pub fn new(f: F, ex: &FdExecutor) -> Result { + ex.register_source(f) + .map(PollSource) + .map_err(Error::Executor) + } + + /// Return the inner source. + pub fn into_source(self) -> F { + self.0.into_source() + } +} + +impl Deref for PollSource { + type Target = F; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +impl DerefMut for PollSource { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.as_mut() + } +} + +#[async_trait(?Send)] +impl ReadAsync for PollSource { + /// Reads from the iosource at `file_offset` and fill the given `vec`. + async fn read_to_vec<'a>( + &'a self, + file_offset: Option, + mut vec: Vec, + ) -> AsyncResult<(usize, Vec)> { + loop { + // Safe because this will only modify `vec` and we check the return value. + let res = if let Some(offset) = file_offset { + unsafe { + libc::pread64( + self.as_raw_fd(), + vec.as_mut_ptr() as *mut libc::c_void, + vec.len(), + offset as libc::off64_t, + ) + } + } else { + unsafe { + libc::read( + self.as_raw_fd(), + vec.as_mut_ptr() as *mut libc::c_void, + vec.len(), + ) + } + }; + + if res >= 0 { + return Ok((res as usize, vec)); + } + + match sys_util::Error::last() { + e if e.errno() == libc::EWOULDBLOCK => { + let op = self.0.wait_readable().map_err(Error::AddingWaker)?; + op.await.map_err(Error::Executor)?; + } + e => return Err(Error::Read(e).into()), + } + } + } + + /// Reads to the given `mem` at the given offsets from the file starting at `file_offset`. + async fn read_to_mem<'a>( + &'a self, + file_offset: Option, + mem: Arc, + mem_offsets: &'a [MemRegion], + ) -> AsyncResult { + let mut iovecs = mem_offsets + .iter() + .filter_map(|&mem_vec| mem.get_volatile_slice(mem_vec).ok()) + .collect::>(); + + loop { + // Safe because we trust the kernel not to write path the length given and the length is + // guaranteed to be valid from the pointer by io_slice_mut. + let res = if let Some(offset) = file_offset { + unsafe { + libc::preadv64( + self.as_raw_fd(), + iovecs.as_mut_ptr() as *mut _, + iovecs.len() as i32, + offset as libc::off64_t, + ) + } + } else { + unsafe { + libc::readv( + self.as_raw_fd(), + iovecs.as_mut_ptr() as *mut _, + iovecs.len() as i32, + ) + } + }; + + if res >= 0 { + return Ok(res as usize); + } + + match sys_util::Error::last() { + e if e.errno() == libc::EWOULDBLOCK => { + let op = self.0.wait_readable().map_err(Error::AddingWaker)?; + op.await.map_err(Error::Executor)?; + } + e => return Err(Error::Read(e).into()), + } + } + } + + /// Wait for the FD of `self` to be readable. + async fn wait_readable(&self) -> AsyncResult<()> { + let op = self.0.wait_readable().map_err(Error::AddingWaker)?; + op.await.map_err(Error::Executor)?; + Ok(()) + } + + async fn read_u64(&self) -> AsyncResult { + let mut buf = 0u64.to_ne_bytes(); + loop { + // Safe because this will only modify `buf` and we check the return value. + let res = unsafe { + libc::read( + self.as_raw_fd(), + buf.as_mut_ptr() as *mut libc::c_void, + buf.len(), + ) + }; + + if res >= 0 { + return Ok(u64::from_ne_bytes(buf)); + } + + match sys_util::Error::last() { + e if e.errno() == libc::EWOULDBLOCK => { + let op = self.0.wait_readable().map_err(Error::AddingWaker)?; + op.await.map_err(Error::Executor)?; + } + e => return Err(Error::Read(e).into()), + } + } + } +} + +#[async_trait(?Send)] +impl WriteAsync for PollSource { + /// Writes from the given `vec` to the file starting at `file_offset`. + async fn write_from_vec<'a>( + &'a self, + file_offset: Option, + vec: Vec, + ) -> AsyncResult<(usize, Vec)> { + loop { + // Safe because this will not modify any memory and we check the return value. + let res = if let Some(offset) = file_offset { + unsafe { + libc::pwrite64( + self.as_raw_fd(), + vec.as_ptr() as *const libc::c_void, + vec.len(), + offset as libc::off64_t, + ) + } + } else { + unsafe { + libc::write( + self.as_raw_fd(), + vec.as_ptr() as *const libc::c_void, + vec.len(), + ) + } + }; + + if res >= 0 { + return Ok((res as usize, vec)); + } + + match sys_util::Error::last() { + e if e.errno() == libc::EWOULDBLOCK => { + let op = self.0.wait_writable().map_err(Error::AddingWaker)?; + op.await.map_err(Error::Executor)?; + } + e => return Err(Error::Write(e).into()), + } + } + } + + /// Writes from the given `mem` from the given offsets to the file starting at `file_offset`. + async fn write_from_mem<'a>( + &'a self, + file_offset: Option, + mem: Arc, + mem_offsets: &'a [MemRegion], + ) -> AsyncResult { + let iovecs = mem_offsets + .iter() + .map(|&mem_vec| mem.get_volatile_slice(mem_vec)) + .filter_map(|r| r.ok()) + .collect::>(); + + loop { + // Safe because we trust the kernel not to write path the length given and the length is + // guaranteed to be valid from the pointer by io_slice_mut. + let res = if let Some(offset) = file_offset { + unsafe { + libc::pwritev64( + self.as_raw_fd(), + iovecs.as_ptr() as *mut _, + iovecs.len() as i32, + offset as libc::off64_t, + ) + } + } else { + unsafe { + libc::writev( + self.as_raw_fd(), + iovecs.as_ptr() as *mut _, + iovecs.len() as i32, + ) + } + }; + + if res >= 0 { + return Ok(res as usize); + } + + match sys_util::Error::last() { + e if e.errno() == libc::EWOULDBLOCK => { + let op = self.0.wait_writable().map_err(Error::AddingWaker)?; + op.await.map_err(Error::Executor)?; + } + e => return Err(Error::Write(e).into()), + } + } + } + + /// See `fallocate(2)` for details. + async fn fallocate(&self, file_offset: u64, len: u64, mode: u32) -> AsyncResult<()> { + let ret = unsafe { + libc::fallocate64( + self.as_raw_fd(), + mode as libc::c_int, + file_offset as libc::off64_t, + len as libc::off64_t, + ) + }; + if ret == 0 { + Ok(()) + } else { + Err(AsyncError::Poll(Error::Fallocate(sys_util::Error::last()))) + } + } + + /// Sync all completed write operations to the backing storage. + async fn fsync(&self) -> AsyncResult<()> { + let ret = unsafe { libc::fsync(self.as_raw_fd()) }; + if ret == 0 { + Ok(()) + } else { + Err(AsyncError::Poll(Error::Fsync(sys_util::Error::last()))) + } + } +} + +#[async_trait(?Send)] +impl IoSourceExt for PollSource { + /// Yields the underlying IO source. + fn into_source(self: Box) -> F { + self.0.into_source() + } + + /// Provides a mutable ref to the underlying IO source. + fn as_source_mut(&mut self) -> &mut F { + self + } + + /// Provides a ref to the underlying IO source. + fn as_source(&self) -> &F { + self + } +} + +#[cfg(test)] +mod tests { + use std::{ + fs::{File, OpenOptions}, + path::PathBuf, + }; + + use super::*; + + #[test] + fn readvec() { + async fn go(ex: &FdExecutor) { + let f = File::open("/dev/zero").unwrap(); + let async_source = PollSource::new(f, ex).unwrap(); + let v = vec![0x55u8; 32]; + let v_ptr = v.as_ptr(); + let ret = async_source.read_to_vec(None, v).await.unwrap(); + assert_eq!(ret.0, 32); + let ret_v = ret.1; + assert_eq!(v_ptr, ret_v.as_ptr()); + assert!(ret_v.iter().all(|&b| b == 0)); + } + + let ex = FdExecutor::new().unwrap(); + ex.run_until(go(&ex)).unwrap(); + } + + #[test] + fn writevec() { + async fn go(ex: &FdExecutor) { + let f = OpenOptions::new().write(true).open("/dev/null").unwrap(); + let async_source = PollSource::new(f, ex).unwrap(); + let v = vec![0x55u8; 32]; + let v_ptr = v.as_ptr(); + let ret = async_source.write_from_vec(None, v).await.unwrap(); + assert_eq!(ret.0, 32); + let ret_v = ret.1; + assert_eq!(v_ptr, ret_v.as_ptr()); + } + + let ex = FdExecutor::new().unwrap(); + ex.run_until(go(&ex)).unwrap(); + } + + #[test] + fn fallocate() { + async fn go(ex: &FdExecutor) { + let dir = tempfile::TempDir::new().unwrap(); + let mut file_path = PathBuf::from(dir.path()); + file_path.push("test"); + + let f = OpenOptions::new() + .create(true) + .write(true) + .open(&file_path) + .unwrap(); + let source = PollSource::new(f, ex).unwrap(); + source.fallocate(0, 4096, 0).await.unwrap(); + + let meta_data = std::fs::metadata(&file_path).unwrap(); + assert_eq!(meta_data.len(), 4096); + } + + let ex = FdExecutor::new().unwrap(); + ex.run_until(go(&ex)).unwrap(); + } + + #[test] + fn memory_leak() { + // This test needs to run under ASAN to detect memory leaks. + + async fn owns_poll_source(source: PollSource) { + let _ = source.wait_readable().await; + } + + let (rx, _tx) = sys_util::pipe(true).unwrap(); + let ex = FdExecutor::new().unwrap(); + let source = PollSource::new(rx, &ex).unwrap(); + ex.spawn_local(owns_poll_source(source)).detach(); + + // Drop `ex` without running. This would cause a memory leak if PollSource owned a strong + // reference to the executor because it owns a reference to the future that owns PollSource + // (via its Runnable). The strong reference prevents the drop impl from running, which would + // otherwise poll the future and have it return with an error. + } +} diff --git a/cros_async/src/queue.rs b/cros_async/src/queue.rs new file mode 100644 index 0000000000..f95de06290 --- /dev/null +++ b/cros_async/src/queue.rs @@ -0,0 +1,66 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::collections::VecDeque; + +use async_task::Runnable; +use sync::Mutex; + +/// A queue of `Runnables`. Intended to be used by executors to keep track of futures that have been +/// scheduled to run. +pub struct RunnableQueue { + runnables: Mutex>, +} + +impl RunnableQueue { + /// Create a new, empty `RunnableQueue`. + pub fn new() -> RunnableQueue { + RunnableQueue { + runnables: Mutex::new(VecDeque::new()), + } + } + + /// Schedule `runnable` to run in the future by adding it to this `RunnableQueue`. + pub fn push_back(&self, runnable: Runnable) { + self.runnables.lock().push_back(runnable); + } + + /// Remove and return the first `Runnable` in this `RunnableQueue` or `None` if it is empty. + pub fn pop_front(&self) -> Option { + self.runnables.lock().pop_front() + } + + /// Create an iterator over this `RunnableQueue` that repeatedly calls `pop_front()` until it is + /// empty. + pub fn iter(&self) -> RunnableQueueIter { + self.into_iter() + } +} + +impl Default for RunnableQueue { + fn default() -> Self { + Self::new() + } +} + +impl<'q> IntoIterator for &'q RunnableQueue { + type Item = Runnable; + type IntoIter = RunnableQueueIter<'q>; + + fn into_iter(self) -> Self::IntoIter { + RunnableQueueIter { queue: self } + } +} + +/// An iterator over a `RunnableQueue`. +pub struct RunnableQueueIter<'q> { + queue: &'q RunnableQueue, +} + +impl<'q> Iterator for RunnableQueueIter<'q> { + type Item = Runnable; + fn next(&mut self) -> Option { + self.queue.pop_front() + } +} diff --git a/cros_async/src/select.rs b/cros_async/src/select.rs new file mode 100644 index 0000000000..acf83e7683 --- /dev/null +++ b/cros_async/src/select.rs @@ -0,0 +1,92 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Need non-snake case so the macro can re-use type names for variables. +#![allow(non_snake_case)] + +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::future::{maybe_done, FutureExt, MaybeDone}; + +pub enum SelectResult { + Pending(F), + Finished(F::Output), +} + +// Macro-generate future combinators to allow for running different numbers of top-level futures in +// this FutureList. Generates the implementation of `FutureList` for the select types. For an +// explicit example this is modeled after, see `UnitFutures`. +macro_rules! generate { + ($( + $(#[$doc:meta])* + ($Select:ident, <$($Fut:ident),*>), + )*) => ($( + + paste::item! { + pub(crate) struct $Select<$($Fut: Future + Unpin),*> { + $($Fut: MaybeDone<$Fut>,)* + } + } + + impl<$($Fut: Future + Unpin),*> $Select<$($Fut),*> { + paste::item! { + pub(crate) fn new($($Fut: $Fut),*) -> $Select<$($Fut),*> { + $Select { + $($Fut: maybe_done($Fut),)* + } + } + } + } + + impl<$($Fut: Future + Unpin),*> Future for $Select<$($Fut),*> { + type Output = ($(SelectResult<$Fut>),*); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let mut complete = false; + $( + let $Fut = Pin::new(&mut self.$Fut); + // The future impls `Unpin`, use `poll_unpin` to avoid wrapping it in + // `Pin` to call `poll`. + complete |= self.$Fut.poll_unpin(cx).is_ready(); + )* + + if complete { + Poll::Ready(($( + match std::mem::replace(&mut self.$Fut, MaybeDone::Gone) { + MaybeDone::Future(f) => SelectResult::Pending(f), + MaybeDone::Done(o) => SelectResult::Finished(o), + MaybeDone::Gone => unreachable!(), + } + ), *)) + } else { + Poll::Pending + } + } + } + )*) +} + +generate! { + /// _Future for the [`select2`] function. + (Select2, <_Fut1, _Fut2>), + + /// _Future for the [`select3`] function. + (Select3, <_Fut1, _Fut2, _Fut3>), + + /// _Future for the [`select4`] function. + (Select4, <_Fut1, _Fut2, _Fut3, _Fut4>), + + /// _Future for the [`select5`] function. + (Select5, <_Fut1, _Fut2, _Fut3, _Fut4, _Fut5>), + + /// _Future for the [`select6`] function. + (Select6, <_Fut1, _Fut2, _Fut3, _Fut4, _Fut5, _Fut6>), + + /// _Future for the [`select7`] function. + (Select7, <_Fut1, _Fut2, _Fut3, _Fut4, _Fut5, _Fut6, _Fut7>), +} diff --git a/cros_async/src/sync.rs b/cros_async/src/sync.rs new file mode 100644 index 0000000000..e2c1a1452f --- /dev/null +++ b/cros_async/src/sync.rs @@ -0,0 +1,12 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +mod cv; +mod mu; +mod spin; +mod waiter; + +pub use cv::Condvar; +pub use mu::Mutex; +pub use spin::SpinLock; diff --git a/cros_async/src/sync/cv.rs b/cros_async/src/sync/cv.rs new file mode 100644 index 0000000000..46b96dd188 --- /dev/null +++ b/cros_async/src/sync/cv.rs @@ -0,0 +1,1179 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::{ + cell::UnsafeCell, + hint, mem, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +use super::super::sync::{ + mu::{MutexGuard, MutexReadGuard, RawMutex}, + waiter::{Kind as WaiterKind, Waiter, WaiterAdapter, WaiterList, WaitingFor}, +}; + +const SPINLOCK: usize = 1 << 0; +const HAS_WAITERS: usize = 1 << 1; + +/// A primitive to wait for an event to occur without consuming CPU time. +/// +/// Condition variables are used in combination with a `Mutex` when a thread wants to wait for some +/// condition to become true. The condition must always be verified while holding the `Mutex` lock. +/// It is an error to use a `Condvar` with more than one `Mutex` while there are threads waiting on +/// the `Condvar`. +/// +/// # Examples +/// +/// ```edition2018 +/// use std::sync::Arc; +/// use std::thread; +/// use std::sync::mpsc::channel; +/// +/// use cros_async::{ +/// block_on, +/// sync::{Condvar, Mutex}, +/// }; +/// +/// const N: usize = 13; +/// +/// // Spawn a few threads to increment a shared variable (non-atomically), and +/// // let all threads waiting on the Condvar know once the increments are done. +/// let data = Arc::new(Mutex::new(0)); +/// let cv = Arc::new(Condvar::new()); +/// +/// for _ in 0..N { +/// let (data, cv) = (data.clone(), cv.clone()); +/// thread::spawn(move || { +/// let mut data = block_on(data.lock()); +/// *data += 1; +/// if *data == N { +/// cv.notify_all(); +/// } +/// }); +/// } +/// +/// let mut val = block_on(data.lock()); +/// while *val != N { +/// val = block_on(cv.wait(val)); +/// } +/// ``` +#[repr(align(128))] +pub struct Condvar { + state: AtomicUsize, + waiters: UnsafeCell, + mu: UnsafeCell, +} + +impl Condvar { + /// Creates a new condition variable ready to be waited on and notified. + pub fn new() -> Condvar { + Condvar { + state: AtomicUsize::new(0), + waiters: UnsafeCell::new(WaiterList::new(WaiterAdapter::new())), + mu: UnsafeCell::new(0), + } + } + + /// Block the current thread until this `Condvar` is notified by another thread. + /// + /// This method will atomically unlock the `Mutex` held by `guard` and then block the current + /// thread. Any call to `notify_one` or `notify_all` after the `Mutex` is unlocked may wake up + /// the thread. + /// + /// To allow for more efficient scheduling, this call may return even when the programmer + /// doesn't expect the thread to be woken. Therefore, calls to `wait()` should be used inside a + /// loop that checks the predicate before continuing. + /// + /// Callers that are not in an async context may wish to use the `block_on` method to block the + /// thread until the `Condvar` is notified. + /// + /// # Panics + /// + /// This method will panic if used with more than one `Mutex` at the same time. + /// + /// # Examples + /// + /// ``` + /// # use std::sync::Arc; + /// # use std::thread; + /// + /// # use cros_async::{ + /// # block_on, + /// # sync::{Condvar, Mutex}, + /// # }; + /// + /// # let mu = Arc::new(Mutex::new(false)); + /// # let cv = Arc::new(Condvar::new()); + /// # let (mu2, cv2) = (mu.clone(), cv.clone()); + /// + /// # let t = thread::spawn(move || { + /// # *block_on(mu2.lock()) = true; + /// # cv2.notify_all(); + /// # }); + /// + /// let mut ready = block_on(mu.lock()); + /// while !*ready { + /// ready = block_on(cv.wait(ready)); + /// } + /// + /// # t.join().expect("failed to join thread"); + /// ``` + // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code + // that doesn't compile. + #[allow(clippy::needless_lifetimes)] + pub async fn wait<'g, T>(&self, guard: MutexGuard<'g, T>) -> MutexGuard<'g, T> { + let waiter = Arc::new(Waiter::new( + WaiterKind::Exclusive, + cancel_waiter, + self as *const Condvar as usize, + WaitingFor::Condvar, + )); + + self.add_waiter(waiter.clone(), guard.as_raw_mutex()); + + // Get a reference to the mutex and then drop the lock. + let mu = guard.into_inner(); + + // Wait to be woken up. + waiter.wait().await; + + // Now re-acquire the lock. + mu.lock_from_cv().await + } + + /// Like `wait()` but takes and returns a `MutexReadGuard` instead. + // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code + // that doesn't compile. + #[allow(clippy::needless_lifetimes)] + pub async fn wait_read<'g, T>(&self, guard: MutexReadGuard<'g, T>) -> MutexReadGuard<'g, T> { + let waiter = Arc::new(Waiter::new( + WaiterKind::Shared, + cancel_waiter, + self as *const Condvar as usize, + WaitingFor::Condvar, + )); + + self.add_waiter(waiter.clone(), guard.as_raw_mutex()); + + // Get a reference to the mutex and then drop the lock. + let mu = guard.into_inner(); + + // Wait to be woken up. + waiter.wait().await; + + // Now re-acquire the lock. + mu.read_lock_from_cv().await + } + + fn add_waiter(&self, waiter: Arc, raw_mutex: &RawMutex) { + // Acquire the spin lock. + let mut oldstate = self.state.load(Ordering::Relaxed); + while (oldstate & SPINLOCK) != 0 + || self + .state + .compare_exchange_weak( + oldstate, + oldstate | SPINLOCK | HAS_WAITERS, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + hint::spin_loop(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock guarantees exclusive access and the reference does not escape + // this function. + let mu = unsafe { &mut *self.mu.get() }; + let muptr = raw_mutex as *const RawMutex as usize; + + match *mu { + 0 => *mu = muptr, + p if p == muptr => {} + _ => panic!("Attempting to use Condvar with more than one Mutex at the same time"), + } + + // Safe because the spin lock guarantees exclusive access. + unsafe { (*self.waiters.get()).push_back(waiter) }; + + // Release the spin lock. Use a direct store here because no other thread can modify + // `self.state` while we hold the spin lock. Keep the `HAS_WAITERS` bit that we set earlier + // because we just added a waiter. + self.state.store(HAS_WAITERS, Ordering::Release); + } + + /// Notify at most one thread currently waiting on the `Condvar`. + /// + /// If there is a thread currently waiting on the `Condvar` it will be woken up from its call to + /// `wait`. + /// + /// Unlike more traditional condition variable interfaces, this method requires a reference to + /// the `Mutex` associated with this `Condvar`. This is because it is inherently racy to call + /// `notify_one` or `notify_all` without first acquiring the `Mutex` lock. Additionally, taking + /// a reference to the `Mutex` here allows us to make some optimizations that can improve + /// performance by reducing unnecessary wakeups. + pub fn notify_one(&self) { + let mut oldstate = self.state.load(Ordering::Relaxed); + if (oldstate & HAS_WAITERS) == 0 { + // No waiters. + return; + } + + while (oldstate & SPINLOCK) != 0 + || self + .state + .compare_exchange_weak( + oldstate, + oldstate | SPINLOCK, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + hint::spin_loop(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock guarantees exclusive access and the reference does not escape + // this function. + let waiters = unsafe { &mut *self.waiters.get() }; + let wake_list = get_wake_list(waiters); + + let newstate = if waiters.is_empty() { + // Also clear the mutex associated with this Condvar since there are no longer any + // waiters. Safe because the spin lock guarantees exclusive access. + unsafe { *self.mu.get() = 0 }; + + // We are releasing the spin lock and there are no more waiters so we can clear all bits + // in `self.state`. + 0 + } else { + // There are still waiters so we need to keep the HAS_WAITERS bit in the state. + HAS_WAITERS + }; + + // Release the spin lock. + self.state.store(newstate, Ordering::Release); + + // Now wake any waiters in the wake list. + for w in wake_list { + w.wake(); + } + } + + /// Notify all threads currently waiting on the `Condvar`. + /// + /// All threads currently waiting on the `Condvar` will be woken up from their call to `wait`. + /// + /// Unlike more traditional condition variable interfaces, this method requires a reference to + /// the `Mutex` associated with this `Condvar`. This is because it is inherently racy to call + /// `notify_one` or `notify_all` without first acquiring the `Mutex` lock. Additionally, taking + /// a reference to the `Mutex` here allows us to make some optimizations that can improve + /// performance by reducing unnecessary wakeups. + pub fn notify_all(&self) { + let mut oldstate = self.state.load(Ordering::Relaxed); + if (oldstate & HAS_WAITERS) == 0 { + // No waiters. + return; + } + + while (oldstate & SPINLOCK) != 0 + || self + .state + .compare_exchange_weak( + oldstate, + oldstate | SPINLOCK, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + hint::spin_loop(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock guarantees exclusive access to `self.waiters`. + let wake_list = unsafe { (*self.waiters.get()).take() }; + + // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe + // because we the spin lock guarantees exclusive access. + unsafe { *self.mu.get() = 0 }; + + // Mark any waiters left as no longer waiting for the Condvar. + for w in &wake_list { + w.set_waiting_for(WaitingFor::None); + } + + // Release the spin lock. We can clear all bits in the state since we took all the waiters. + self.state.store(0, Ordering::Release); + + // Now wake any waiters in the wake list. + for w in wake_list { + w.wake(); + } + } + + fn cancel_waiter(&self, waiter: &Waiter, wake_next: bool) { + let mut oldstate = self.state.load(Ordering::Relaxed); + while oldstate & SPINLOCK != 0 + || self + .state + .compare_exchange_weak( + oldstate, + oldstate | SPINLOCK, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + hint::spin_loop(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock provides exclusive access and the reference does not escape + // this function. + let waiters = unsafe { &mut *self.waiters.get() }; + + let waiting_for = waiter.is_waiting_for(); + // Don't drop the old waiter now as we're still holding the spin lock. + let old_waiter = if waiter.is_linked() && waiting_for == WaitingFor::Condvar { + // Safe because we know that the waiter is still linked and is waiting for the Condvar, + // which guarantees that it is still in `self.waiters`. + let mut cursor = unsafe { waiters.cursor_mut_from_ptr(waiter as *const Waiter) }; + cursor.remove() + } else { + None + }; + + let wake_list = if wake_next || waiting_for == WaitingFor::None { + // Either the waiter was already woken or it's been removed from the condvar's waiter + // list and is going to be woken. Either way, we need to wake up another thread. + get_wake_list(waiters) + } else { + WaiterList::new(WaiterAdapter::new()) + }; + + let set_on_release = if waiters.is_empty() { + // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe + // because we the spin lock guarantees exclusive access. + unsafe { *self.mu.get() = 0 }; + + 0 + } else { + HAS_WAITERS + }; + + self.state.store(set_on_release, Ordering::Release); + + // Now wake any waiters still left in the wake list. + for w in wake_list { + w.wake(); + } + + mem::drop(old_waiter); + } +} + +unsafe impl Send for Condvar {} +unsafe impl Sync for Condvar {} + +impl Default for Condvar { + fn default() -> Self { + Self::new() + } +} + +// Scan `waiters` and return all waiters that should be woken up. +// +// If the first waiter is trying to acquire a shared lock, then all waiters in the list that are +// waiting for a shared lock are also woken up. In addition one writer is woken up, if possible. +// +// If the first waiter is trying to acquire an exclusive lock, then only that waiter is returned and +// the rest of the list is not scanned. +fn get_wake_list(waiters: &mut WaiterList) -> WaiterList { + let mut to_wake = WaiterList::new(WaiterAdapter::new()); + let mut cursor = waiters.front_mut(); + + let mut waking_readers = false; + let mut all_readers = true; + while let Some(w) = cursor.get() { + match w.kind() { + WaiterKind::Exclusive if !waking_readers => { + // This is the first waiter and it's a writer. No need to check the other waiters. + // Also mark the waiter as having been removed from the Condvar's waiter list. + let waiter = cursor.remove().unwrap(); + waiter.set_waiting_for(WaitingFor::None); + to_wake.push_back(waiter); + break; + } + + WaiterKind::Shared => { + // This is a reader and the first waiter in the list was not a writer so wake up all + // the readers in the wait list. + let waiter = cursor.remove().unwrap(); + waiter.set_waiting_for(WaitingFor::None); + to_wake.push_back(waiter); + waking_readers = true; + } + + WaiterKind::Exclusive => { + debug_assert!(waking_readers); + if all_readers { + // We are waking readers but we need to ensure that at least one writer is woken + // up. Since we haven't yet woken up a writer, wake up this one. + let waiter = cursor.remove().unwrap(); + waiter.set_waiting_for(WaitingFor::None); + to_wake.push_back(waiter); + all_readers = false; + } else { + // We are waking readers and have already woken one writer. Skip this one. + cursor.move_next(); + } + } + } + } + + to_wake +} + +fn cancel_waiter(cv: usize, waiter: &Waiter, wake_next: bool) { + let condvar = cv as *const Condvar; + + // Safe because the thread that owns the waiter being canceled must also own a reference to the + // Condvar, which guarantees that this pointer is valid. + unsafe { (*condvar).cancel_waiter(waiter, wake_next) } +} + +#[cfg(test)] +mod test { + use super::*; + + use std::{ + future::Future, + mem, ptr, + rc::Rc, + sync::{ + mpsc::{channel, Sender}, + Arc, + }, + task::{Context, Poll}, + thread::{ + JoinHandle, {self}, + }, + time::Duration, + }; + + use futures::{ + channel::oneshot, + select, + task::{waker_ref, ArcWake}, + FutureExt, + }; + use futures_executor::{LocalPool, LocalSpawner, ThreadPool}; + use futures_util::task::LocalSpawnExt; + + use super::super::super::{block_on, sync::Mutex}; + + // Dummy waker used when we want to manually drive futures. + struct TestWaker; + impl ArcWake for TestWaker { + fn wake_by_ref(_arc_self: &Arc) {} + } + + #[test] + fn smoke() { + let cv = Condvar::new(); + cv.notify_one(); + cv.notify_all(); + } + + #[test] + fn notify_one() { + let mu = Arc::new(Mutex::new(())); + let cv = Arc::new(Condvar::new()); + + let mu2 = mu.clone(); + let cv2 = cv.clone(); + + let guard = block_on(mu.lock()); + thread::spawn(move || { + let _g = block_on(mu2.lock()); + cv2.notify_one(); + }); + + let guard = block_on(cv.wait(guard)); + mem::drop(guard); + } + + #[test] + fn multi_mutex() { + const NUM_THREADS: usize = 5; + + let mu = Arc::new(Mutex::new(false)); + let cv = Arc::new(Condvar::new()); + + let mut threads = Vec::with_capacity(NUM_THREADS); + for _ in 0..NUM_THREADS { + let mu = mu.clone(); + let cv = cv.clone(); + + threads.push(thread::spawn(move || { + let mut ready = block_on(mu.lock()); + while !*ready { + ready = block_on(cv.wait(ready)); + } + })); + } + + let mut g = block_on(mu.lock()); + *g = true; + mem::drop(g); + cv.notify_all(); + + threads + .into_iter() + .try_for_each(JoinHandle::join) + .expect("Failed to join threads"); + + // Now use the Condvar with a different mutex. + let alt_mu = Arc::new(Mutex::new(None)); + let alt_mu2 = alt_mu.clone(); + let cv2 = cv.clone(); + let handle = thread::spawn(move || { + let mut g = block_on(alt_mu2.lock()); + while g.is_none() { + g = block_on(cv2.wait(g)); + } + }); + + let mut alt_g = block_on(alt_mu.lock()); + *alt_g = Some(()); + mem::drop(alt_g); + cv.notify_all(); + + handle + .join() + .expect("Failed to join thread alternate mutex"); + } + + #[test] + fn notify_one_single_thread_async() { + async fn notify(mu: Rc>, cv: Rc) { + let _g = mu.lock().await; + cv.notify_one(); + } + + async fn wait(mu: Rc>, cv: Rc, spawner: LocalSpawner) { + let mu2 = Rc::clone(&mu); + let cv2 = Rc::clone(&cv); + + let g = mu.lock().await; + // Has to be spawned _after_ acquiring the lock to prevent a race + // where the notify happens before the waiter has acquired the lock. + spawner + .spawn_local(notify(mu2, cv2)) + .expect("Failed to spawn `notify` task"); + let _g = cv.wait(g).await; + } + + let mut ex = LocalPool::new(); + let spawner = ex.spawner(); + + let mu = Rc::new(Mutex::new(())); + let cv = Rc::new(Condvar::new()); + + spawner + .spawn_local(wait(mu, cv, spawner.clone())) + .expect("Failed to spawn `wait` task"); + + ex.run(); + } + + #[test] + fn notify_one_multi_thread_async() { + async fn notify(mu: Arc>, cv: Arc) { + let _g = mu.lock().await; + cv.notify_one(); + } + + async fn wait(mu: Arc>, cv: Arc, tx: Sender<()>, pool: ThreadPool) { + let mu2 = Arc::clone(&mu); + let cv2 = Arc::clone(&cv); + + let g = mu.lock().await; + // Has to be spawned _after_ acquiring the lock to prevent a race + // where the notify happens before the waiter has acquired the lock. + pool.spawn_ok(notify(mu2, cv2)); + let _g = cv.wait(g).await; + + tx.send(()).expect("Failed to send completion notification"); + } + + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(())); + let cv = Arc::new(Condvar::new()); + + let (tx, rx) = channel(); + ex.spawn_ok(wait(mu, cv, tx, ex.clone())); + + rx.recv_timeout(Duration::from_secs(5)) + .expect("Failed to receive completion notification"); + } + + #[test] + fn notify_one_with_cancel() { + const TASKS: usize = 17; + const OBSERVERS: usize = 7; + const ITERATIONS: usize = 103; + + async fn observe(mu: &Arc>, cv: &Arc) { + let mut count = mu.read_lock().await; + while *count == 0 { + count = cv.wait_read(count).await; + } + let _ = unsafe { ptr::read_volatile(&*count as *const usize) }; + } + + async fn decrement(mu: &Arc>, cv: &Arc) { + let mut count = mu.lock().await; + while *count == 0 { + count = cv.wait(count).await; + } + *count -= 1; + } + + async fn increment(mu: Arc>, cv: Arc, done: Sender<()>) { + for _ in 0..TASKS * OBSERVERS * ITERATIONS { + *mu.lock().await += 1; + cv.notify_one(); + } + + done.send(()).expect("Failed to send completion message"); + } + + async fn observe_either( + mu: Arc>, + cv: Arc, + alt_mu: Arc>, + alt_cv: Arc, + done: Sender<()>, + ) { + for _ in 0..ITERATIONS { + select! { + () = observe(&mu, &cv).fuse() => {}, + () = observe(&alt_mu, &alt_cv).fuse() => {}, + } + } + + done.send(()).expect("Failed to send completion message"); + } + + async fn decrement_either( + mu: Arc>, + cv: Arc, + alt_mu: Arc>, + alt_cv: Arc, + done: Sender<()>, + ) { + for _ in 0..ITERATIONS { + select! { + () = decrement(&mu, &cv).fuse() => {}, + () = decrement(&alt_mu, &alt_cv).fuse() => {}, + } + } + + done.send(()).expect("Failed to send completion message"); + } + + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(0usize)); + let alt_mu = Arc::new(Mutex::new(0usize)); + + let cv = Arc::new(Condvar::new()); + let alt_cv = Arc::new(Condvar::new()); + + let (tx, rx) = channel(); + for _ in 0..TASKS { + ex.spawn_ok(decrement_either( + Arc::clone(&mu), + Arc::clone(&cv), + Arc::clone(&alt_mu), + Arc::clone(&alt_cv), + tx.clone(), + )); + } + + for _ in 0..OBSERVERS { + ex.spawn_ok(observe_either( + Arc::clone(&mu), + Arc::clone(&cv), + Arc::clone(&alt_mu), + Arc::clone(&alt_cv), + tx.clone(), + )); + } + + ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone())); + ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx)); + + for _ in 0..TASKS + OBSERVERS + 2 { + if let Err(e) = rx.recv_timeout(Duration::from_secs(20)) { + panic!("Error while waiting for threads to complete: {}", e); + } + } + + assert_eq!( + *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()), + (TASKS * OBSERVERS * ITERATIONS * 2) - (TASKS * ITERATIONS) + ); + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn notify_all_with_cancel() { + const TASKS: usize = 17; + const ITERATIONS: usize = 103; + + async fn decrement(mu: &Arc>, cv: &Arc) { + let mut count = mu.lock().await; + while *count == 0 { + count = cv.wait(count).await; + } + *count -= 1; + } + + async fn increment(mu: Arc>, cv: Arc, done: Sender<()>) { + for _ in 0..TASKS * ITERATIONS { + *mu.lock().await += 1; + cv.notify_all(); + } + + done.send(()).expect("Failed to send completion message"); + } + + async fn decrement_either( + mu: Arc>, + cv: Arc, + alt_mu: Arc>, + alt_cv: Arc, + done: Sender<()>, + ) { + for _ in 0..ITERATIONS { + select! { + () = decrement(&mu, &cv).fuse() => {}, + () = decrement(&alt_mu, &alt_cv).fuse() => {}, + } + } + + done.send(()).expect("Failed to send completion message"); + } + + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(0usize)); + let alt_mu = Arc::new(Mutex::new(0usize)); + + let cv = Arc::new(Condvar::new()); + let alt_cv = Arc::new(Condvar::new()); + + let (tx, rx) = channel(); + for _ in 0..TASKS { + ex.spawn_ok(decrement_either( + Arc::clone(&mu), + Arc::clone(&cv), + Arc::clone(&alt_mu), + Arc::clone(&alt_cv), + tx.clone(), + )); + } + + ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone())); + ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx)); + + for _ in 0..TASKS + 2 { + if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) { + panic!("Error while waiting for threads to complete: {}", e); + } + } + + assert_eq!( + *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()), + TASKS * ITERATIONS + ); + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0); + } + #[test] + fn notify_all() { + const THREADS: usize = 13; + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + let (tx, rx) = channel(); + + let mut threads = Vec::with_capacity(THREADS); + for _ in 0..THREADS { + let mu2 = mu.clone(); + let cv2 = cv.clone(); + let tx2 = tx.clone(); + + threads.push(thread::spawn(move || { + let mut count = block_on(mu2.lock()); + *count += 1; + if *count == THREADS { + tx2.send(()).unwrap(); + } + + while *count != 0 { + count = block_on(cv2.wait(count)); + } + })); + } + + mem::drop(tx); + + // Wait till all threads have started. + rx.recv_timeout(Duration::from_secs(5)).unwrap(); + + let mut count = block_on(mu.lock()); + *count = 0; + mem::drop(count); + cv.notify_all(); + + for t in threads { + t.join().unwrap(); + } + } + + #[test] + fn notify_all_single_thread_async() { + const TASKS: usize = 13; + + async fn reset(mu: Rc>, cv: Rc) { + let mut count = mu.lock().await; + *count = 0; + cv.notify_all(); + } + + async fn watcher(mu: Rc>, cv: Rc, spawner: LocalSpawner) { + let mut count = mu.lock().await; + *count += 1; + if *count == TASKS { + spawner + .spawn_local(reset(mu.clone(), cv.clone())) + .expect("Failed to spawn reset task"); + } + + while *count != 0 { + count = cv.wait(count).await; + } + } + + let mut ex = LocalPool::new(); + let spawner = ex.spawner(); + + let mu = Rc::new(Mutex::new(0)); + let cv = Rc::new(Condvar::new()); + + for _ in 0..TASKS { + spawner + .spawn_local(watcher(mu.clone(), cv.clone(), spawner.clone())) + .expect("Failed to spawn watcher task"); + } + + ex.run(); + } + + #[test] + fn notify_all_multi_thread_async() { + const TASKS: usize = 13; + + async fn reset(mu: Arc>, cv: Arc) { + let mut count = mu.lock().await; + *count = 0; + cv.notify_all(); + } + + async fn watcher( + mu: Arc>, + cv: Arc, + pool: ThreadPool, + tx: Sender<()>, + ) { + let mut count = mu.lock().await; + *count += 1; + if *count == TASKS { + pool.spawn_ok(reset(mu.clone(), cv.clone())); + } + + while *count != 0 { + count = cv.wait(count).await; + } + + tx.send(()).expect("Failed to send completion notification"); + } + + let pool = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let (tx, rx) = channel(); + for _ in 0..TASKS { + pool.spawn_ok(watcher(mu.clone(), cv.clone(), pool.clone(), tx.clone())); + } + + for _ in 0..TASKS { + rx.recv_timeout(Duration::from_secs(5)) + .expect("Failed to receive completion notification"); + } + } + + #[test] + fn wake_all_readers() { + async fn read(mu: Arc>, cv: Arc) { + let mut ready = mu.read_lock().await; + while !*ready { + ready = cv.wait_read(ready).await; + } + } + + let mu = Arc::new(Mutex::new(false)); + let cv = Arc::new(Condvar::new()); + let mut readers = [ + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + ]; + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + // First have all the readers wait on the Condvar. + for r in &mut readers { + if let Poll::Ready(()) = r.as_mut().poll(&mut cx) { + panic!("reader unexpectedly ready"); + } + } + + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + // Now make the condition true and notify the condvar. Even though we will call notify_one, + // all the readers should be woken up. + *block_on(mu.lock()) = true; + cv.notify_one(); + + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + + // All readers should now be able to complete. + for r in &mut readers { + if r.as_mut().poll(&mut cx).is_pending() { + panic!("reader unable to complete"); + } + } + } + + #[test] + fn cancel_before_notify() { + async fn dec(mu: Arc>, cv: Arc) { + let mut count = mu.lock().await; + + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut fut1 = Box::pin(dec(mu.clone(), cv.clone())); + let mut fut2 = Box::pin(dec(mu.clone(), cv.clone())); + + if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + *block_on(mu.lock()) = 2; + // Drop fut1 before notifying the cv. + mem::drop(fut1); + cv.notify_one(); + + // fut2 should now be ready to complete. + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + + if fut2.as_mut().poll(&mut cx).is_pending() { + panic!("future unable to complete"); + } + + assert_eq!(*block_on(mu.lock()), 1); + } + + #[test] + fn cancel_after_notify_one() { + async fn dec(mu: Arc>, cv: Arc) { + let mut count = mu.lock().await; + + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut fut1 = Box::pin(dec(mu.clone(), cv.clone())); + let mut fut2 = Box::pin(dec(mu.clone(), cv.clone())); + + if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + *block_on(mu.lock()) = 2; + cv.notify_one(); + + // fut1 should now be ready to complete. Drop it before polling. This should wake up fut2. + mem::drop(fut1); + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + + if fut2.as_mut().poll(&mut cx).is_pending() { + panic!("future unable to complete"); + } + + assert_eq!(*block_on(mu.lock()), 1); + } + + #[test] + fn cancel_after_notify_all() { + async fn dec(mu: Arc>, cv: Arc) { + let mut count = mu.lock().await; + + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut fut1 = Box::pin(dec(mu.clone(), cv.clone())); + let mut fut2 = Box::pin(dec(mu.clone(), cv.clone())); + + if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + let mut count = block_on(mu.lock()); + *count = 2; + + // Notify the cv while holding the lock. This should wake up both waiters. + cv.notify_all(); + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + + mem::drop(count); + + mem::drop(fut1); + + if fut2.as_mut().poll(&mut cx).is_pending() { + panic!("future unable to complete"); + } + + assert_eq!(*block_on(mu.lock()), 1); + } + + #[test] + fn timed_wait() { + async fn wait_deadline( + mu: Arc>, + cv: Arc, + timeout: oneshot::Receiver<()>, + ) { + let mut count = mu.lock().await; + + if *count == 0 { + let mut rx = timeout.fuse(); + + while *count == 0 { + select! { + res = rx => { + if let Err(e) = res { + panic!("Error while receiving timeout notification: {}", e); + } + + return; + }, + c = cv.wait(count).fuse() => count = c, + } + } + } + + *count += 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let (tx, rx) = oneshot::channel(); + let mut wait = Box::pin(wait_deadline(mu.clone(), cv.clone(), rx)); + + if let Poll::Ready(()) = wait.as_mut().poll(&mut cx) { + panic!("wait_deadline unexpectedly ready"); + } + + assert_eq!(cv.state.load(Ordering::Relaxed), HAS_WAITERS); + + // Signal the channel, which should cancel the wait. + tx.send(()).expect("Failed to send wakeup"); + + // Wait for the timer to run out. + if wait.as_mut().poll(&mut cx).is_pending() { + panic!("wait_deadline unable to complete in time"); + } + + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + assert_eq!(*block_on(mu.lock()), 0); + } +} diff --git a/cros_async/src/sync/mu.rs b/cros_async/src/sync/mu.rs new file mode 100644 index 0000000000..e1f408dfa2 --- /dev/null +++ b/cros_async/src/sync/mu.rs @@ -0,0 +1,2305 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::{ + cell::UnsafeCell, + hint, mem, + ops::{Deref, DerefMut}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + thread::yield_now, +}; + +use super::super::sync::waiter::{ + Kind as WaiterKind, Waiter, WaiterAdapter, WaiterList, WaitingFor, +}; + +// Set when the mutex is exclusively locked. +const LOCKED: usize = 1 << 0; +// Set when there are one or more threads waiting to acquire the lock. +const HAS_WAITERS: usize = 1 << 1; +// Set when a thread has been woken up from the wait queue. Cleared when that thread either acquires +// the lock or adds itself back into the wait queue. Used to prevent unnecessary wake ups when a +// thread has been removed from the wait queue but has not gotten CPU time yet. +const DESIGNATED_WAKER: usize = 1 << 2; +// Used to provide exclusive access to the `waiters` field in `Mutex`. Should only be held while +// modifying the waiter list. +const SPINLOCK: usize = 1 << 3; +// Set when a thread that wants an exclusive lock adds itself to the wait queue. New threads +// attempting to acquire a shared lock will be preventing from getting it when this bit is set. +// However, this bit is ignored once a thread has gone through the wait queue at least once. +const WRITER_WAITING: usize = 1 << 4; +// Set when a thread has gone through the wait queue many times but has failed to acquire the lock +// every time it is woken up. When this bit is set, all other threads are prevented from acquiring +// the lock until the thread that set the `LONG_WAIT` bit has acquired the lock. +const LONG_WAIT: usize = 1 << 5; +// The bit that is added to the mutex state in order to acquire a shared lock. Since more than one +// thread can acquire a shared lock, we cannot use a single bit. Instead we use all the remaining +// bits in the state to track the number of threads that have acquired a shared lock. +const READ_LOCK: usize = 1 << 8; +// Mask used for checking if any threads currently hold a shared lock. +const READ_MASK: usize = !0xff; + +// The number of times the thread should just spin and attempt to re-acquire the lock. +const SPIN_THRESHOLD: usize = 7; + +// The number of times the thread needs to go through the wait queue before it sets the `LONG_WAIT` +// bit and forces all other threads to wait for it to acquire the lock. This value is set relatively +// high so that we don't lose the benefit of having running threads unless it is absolutely +// necessary. +const LONG_WAIT_THRESHOLD: usize = 19; + +// Common methods between shared and exclusive locks. +trait Kind { + // The bits that must be zero for the thread to acquire this kind of lock. If any of these bits + // are not zero then the thread will first spin and retry a few times before adding itself to + // the wait queue. + fn zero_to_acquire() -> usize; + + // The bit that must be added in order to acquire this kind of lock. This should either be + // `LOCKED` or `READ_LOCK`. + fn add_to_acquire() -> usize; + + // The bits that should be set when a thread adds itself to the wait queue while waiting to + // acquire this kind of lock. + fn set_when_waiting() -> usize; + + // The bits that should be cleared when a thread acquires this kind of lock. + fn clear_on_acquire() -> usize; + + // The waiter that a thread should use when waiting to acquire this kind of lock. + fn new_waiter(raw: &RawMutex) -> Arc; +} + +// A lock type for shared read-only access to the data. More than one thread may hold this kind of +// lock simultaneously. +struct Shared; + +impl Kind for Shared { + fn zero_to_acquire() -> usize { + LOCKED | WRITER_WAITING | LONG_WAIT + } + + fn add_to_acquire() -> usize { + READ_LOCK + } + + fn set_when_waiting() -> usize { + 0 + } + + fn clear_on_acquire() -> usize { + 0 + } + + fn new_waiter(raw: &RawMutex) -> Arc { + Arc::new(Waiter::new( + WaiterKind::Shared, + cancel_waiter, + raw as *const RawMutex as usize, + WaitingFor::Mutex, + )) + } +} + +// A lock type for mutually exclusive read-write access to the data. Only one thread can hold this +// kind of lock at a time. +struct Exclusive; + +impl Kind for Exclusive { + fn zero_to_acquire() -> usize { + LOCKED | READ_MASK | LONG_WAIT + } + + fn add_to_acquire() -> usize { + LOCKED + } + + fn set_when_waiting() -> usize { + WRITER_WAITING + } + + fn clear_on_acquire() -> usize { + WRITER_WAITING + } + + fn new_waiter(raw: &RawMutex) -> Arc { + Arc::new(Waiter::new( + WaiterKind::Exclusive, + cancel_waiter, + raw as *const RawMutex as usize, + WaitingFor::Mutex, + )) + } +} + +// Scan `waiters` and return the ones that should be woken up. Also returns any bits that should be +// set in the mutex state when the current thread releases the spin lock protecting the waiter list. +// +// If the first waiter is trying to acquire a shared lock, then all waiters in the list that are +// waiting for a shared lock are also woken up. If any waiters waiting for an exclusive lock are +// found when iterating through the list, then the returned `usize` contains the `WRITER_WAITING` +// bit, which should be set when the thread releases the spin lock. +// +// If the first waiter is trying to acquire an exclusive lock, then only that waiter is returned and +// no bits are set in the returned `usize`. +fn get_wake_list(waiters: &mut WaiterList) -> (WaiterList, usize) { + let mut to_wake = WaiterList::new(WaiterAdapter::new()); + let mut set_on_release = 0; + let mut cursor = waiters.front_mut(); + + let mut waking_readers = false; + while let Some(w) = cursor.get() { + match w.kind() { + WaiterKind::Exclusive if !waking_readers => { + // This is the first waiter and it's a writer. No need to check the other waiters. + let waiter = cursor.remove().unwrap(); + waiter.set_waiting_for(WaitingFor::None); + to_wake.push_back(waiter); + break; + } + + WaiterKind::Shared => { + // This is a reader and the first waiter in the list was not a writer so wake up all + // the readers in the wait list. + let waiter = cursor.remove().unwrap(); + waiter.set_waiting_for(WaitingFor::None); + to_wake.push_back(waiter); + waking_readers = true; + } + + WaiterKind::Exclusive => { + // We found a writer while looking for more readers to wake up. Set the + // `WRITER_WAITING` bit to prevent any new readers from acquiring the lock. All + // readers currently in the wait list will ignore this bit since they already waited + // once. + set_on_release |= WRITER_WAITING; + cursor.move_next(); + } + } + } + + (to_wake, set_on_release) +} + +#[inline] +fn cpu_relax(iterations: usize) { + for _ in 0..iterations { + hint::spin_loop(); + } +} + +pub(crate) struct RawMutex { + state: AtomicUsize, + waiters: UnsafeCell, +} + +impl RawMutex { + pub fn new() -> RawMutex { + RawMutex { + state: AtomicUsize::new(0), + waiters: UnsafeCell::new(WaiterList::new(WaiterAdapter::new())), + } + } + + #[inline] + pub async fn lock(&self) { + match self + .state + .compare_exchange_weak(0, LOCKED, Ordering::Acquire, Ordering::Relaxed) + { + Ok(_) => {} + Err(oldstate) => { + // If any bits that should be zero are not zero or if we fail to acquire the lock + // with a single compare_exchange then go through the slow path. + if (oldstate & Exclusive::zero_to_acquire()) != 0 + || self + .state + .compare_exchange_weak( + oldstate, + (oldstate + Exclusive::add_to_acquire()) + & !Exclusive::clear_on_acquire(), + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + self.lock_slow::(0, 0).await; + } + } + } + } + + #[inline] + pub async fn read_lock(&self) { + match self + .state + .compare_exchange_weak(0, READ_LOCK, Ordering::Acquire, Ordering::Relaxed) + { + Ok(_) => {} + Err(oldstate) => { + if (oldstate & Shared::zero_to_acquire()) != 0 + || self + .state + .compare_exchange_weak( + oldstate, + (oldstate + Shared::add_to_acquire()) & !Shared::clear_on_acquire(), + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + self.lock_slow::(0, 0).await; + } + } + } + } + + // Slow path for acquiring the lock. `clear` should contain any bits that need to be cleared + // when the lock is acquired. Any bits set in `zero_mask` are cleared from the bits returned by + // `K::zero_to_acquire()`. + #[cold] + async fn lock_slow(&self, mut clear: usize, zero_mask: usize) { + let mut zero_to_acquire = K::zero_to_acquire() & !zero_mask; + + let mut spin_count = 0; + let mut wait_count = 0; + let mut waiter = None; + loop { + let oldstate = self.state.load(Ordering::Relaxed); + // If all the bits in `zero_to_acquire` are actually zero then try to acquire the lock + // directly. + if (oldstate & zero_to_acquire) == 0 { + if self + .state + .compare_exchange_weak( + oldstate, + (oldstate + K::add_to_acquire()) & !(clear | K::clear_on_acquire()), + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_ok() + { + return; + } + } else if (oldstate & SPINLOCK) == 0 { + // The mutex is locked and the spin lock is available. Try to add this thread + // to the waiter queue. + let w = waiter.get_or_insert_with(|| K::new_waiter(self)); + w.reset(WaitingFor::Mutex); + + if self + .state + .compare_exchange_weak( + oldstate, + (oldstate | SPINLOCK | HAS_WAITERS | K::set_when_waiting()) & !clear, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_ok() + { + let mut set_on_release = 0; + + // Safe because we have acquired the spin lock and it provides exclusive + // access to the waiter queue. + if wait_count < LONG_WAIT_THRESHOLD { + // Add the waiter to the back of the queue. + unsafe { (*self.waiters.get()).push_back(w.clone()) }; + } else { + // This waiter has gone through the queue too many times. Put it in the + // front of the queue and block all other threads from acquiring the lock + // until this one has acquired it at least once. + unsafe { (*self.waiters.get()).push_front(w.clone()) }; + + // Set the LONG_WAIT bit to prevent all other threads from acquiring the + // lock. + set_on_release |= LONG_WAIT; + + // Make sure we clear the LONG_WAIT bit when we do finally get the lock. + clear |= LONG_WAIT; + + // Since we set the LONG_WAIT bit we shouldn't allow that bit to prevent us + // from acquiring the lock. + zero_to_acquire &= !LONG_WAIT; + } + + // Release the spin lock. + let mut state = oldstate; + loop { + match self.state.compare_exchange_weak( + state, + (state | set_on_release) & !SPINLOCK, + Ordering::Release, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(w) => state = w, + } + } + + // Now wait until we are woken. + w.wait().await; + + // The `DESIGNATED_WAKER` bit gets set when this thread is woken up by the + // thread that originally held the lock. While this bit is set, no other waiters + // will be woken up so it's important to clear it the next time we try to + // acquire the main lock or the spin lock. + clear |= DESIGNATED_WAKER; + + // Now that the thread has waited once, we no longer care if there is a writer + // waiting. Only the limits of mutual exclusion can prevent us from acquiring + // the lock. + zero_to_acquire &= !WRITER_WAITING; + + // Reset the spin count since we just went through the wait queue. + spin_count = 0; + + // Increment the wait count since we went through the wait queue. + wait_count += 1; + + // Skip the `cpu_relax` below. + continue; + } + } + + // Both the lock and the spin lock are held by one or more other threads. First, we'll + // spin a few times in case we can acquire the lock or the spin lock. If that fails then + // we yield because we might be preventing the threads that do hold the 2 locks from + // getting cpu time. + if spin_count < SPIN_THRESHOLD { + cpu_relax(1 << spin_count); + spin_count += 1; + } else { + yield_now(); + } + } + } + + #[inline] + pub fn unlock(&self) { + // Fast path, if possible. We can directly clear the locked bit since we have exclusive + // access to the mutex. + let oldstate = self.state.fetch_sub(LOCKED, Ordering::Release); + + // Panic if we just tried to unlock a mutex that wasn't held by this thread. This shouldn't + // really be possible since `unlock` is not a public method. + debug_assert_eq!( + oldstate & READ_MASK, + 0, + "`unlock` called on mutex held in read-mode" + ); + debug_assert_ne!( + oldstate & LOCKED, + 0, + "`unlock` called on mutex not held in write-mode" + ); + + if (oldstate & HAS_WAITERS) != 0 && (oldstate & DESIGNATED_WAKER) == 0 { + // The oldstate has waiters but no designated waker has been chosen yet. + self.unlock_slow(); + } + } + + #[inline] + pub fn read_unlock(&self) { + // Fast path, if possible. We can directly subtract the READ_LOCK bit since we had + // previously added it. + let oldstate = self.state.fetch_sub(READ_LOCK, Ordering::Release); + + debug_assert_eq!( + oldstate & LOCKED, + 0, + "`read_unlock` called on mutex held in write-mode" + ); + debug_assert_ne!( + oldstate & READ_MASK, + 0, + "`read_unlock` called on mutex not held in read-mode" + ); + + if (oldstate & HAS_WAITERS) != 0 + && (oldstate & DESIGNATED_WAKER) == 0 + && (oldstate & READ_MASK) == READ_LOCK + { + // There are waiters, no designated waker has been chosen yet, and the last reader is + // unlocking so we have to take the slow path. + self.unlock_slow(); + } + } + + #[cold] + fn unlock_slow(&self) { + let mut spin_count = 0; + + loop { + let oldstate = self.state.load(Ordering::Relaxed); + if (oldstate & HAS_WAITERS) == 0 || (oldstate & DESIGNATED_WAKER) != 0 { + // No more waiters or a designated waker has been chosen. Nothing left for us to do. + return; + } else if (oldstate & SPINLOCK) == 0 { + // The spin lock is not held by another thread. Try to acquire it. Also set the + // `DESIGNATED_WAKER` bit since we are likely going to wake up one or more threads. + if self + .state + .compare_exchange_weak( + oldstate, + oldstate | SPINLOCK | DESIGNATED_WAKER, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_ok() + { + // Acquired the spinlock. Try to wake a waiter. We may also end up wanting to + // clear the HAS_WAITER and DESIGNATED_WAKER bits so start collecting the bits + // to be cleared. + let mut clear = SPINLOCK; + + // Safe because the spinlock guarantees exclusive access to the waiter list and + // the reference does not escape this function. + let waiters = unsafe { &mut *self.waiters.get() }; + let (wake_list, set_on_release) = get_wake_list(waiters); + + // If the waiter list is now empty, clear the HAS_WAITERS bit. + if waiters.is_empty() { + clear |= HAS_WAITERS; + } + + if wake_list.is_empty() { + // Since we are not going to wake any waiters clear the DESIGNATED_WAKER bit + // that we set when we acquired the spin lock. + clear |= DESIGNATED_WAKER; + } + + // Release the spin lock and clear any other bits as necessary. Also, set any + // bits returned by `get_wake_list`. For now, this is just the `WRITER_WAITING` + // bit, which needs to be set when we are waking up a bunch of readers and there + // are still writers in the wait queue. This will prevent any readers that + // aren't in `wake_list` from acquiring the read lock. + let mut state = oldstate; + loop { + match self.state.compare_exchange_weak( + state, + (state | set_on_release) & !clear, + Ordering::Release, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(w) => state = w, + } + } + + // Now wake the waiters, if any. + for w in wake_list { + w.wake(); + } + + // We're done. + return; + } + } + + // Spin and try again. It's ok to block here as we have already released the lock. + if spin_count < SPIN_THRESHOLD { + cpu_relax(1 << spin_count); + spin_count += 1; + } else { + yield_now(); + } + } + } + + fn cancel_waiter(&self, waiter: &Waiter, wake_next: bool) { + let mut oldstate = self.state.load(Ordering::Relaxed); + while oldstate & SPINLOCK != 0 + || self + .state + .compare_exchange_weak( + oldstate, + oldstate | SPINLOCK, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + hint::spin_loop(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock provides exclusive access and the reference does not escape + // this function. + let waiters = unsafe { &mut *self.waiters.get() }; + + let mut clear = SPINLOCK; + + // If we are about to remove the first waiter in the wait list, then clear the LONG_WAIT + // bit. Also clear the bit if we are going to be waking some other waiters. In this case the + // waiter that set the bit may have already been removed from the waiter list (and could be + // the one that is currently being dropped). If it is still in the waiter list then clearing + // this bit may starve it for one more iteration through the lock_slow() loop, whereas not + // clearing this bit could cause a deadlock if the waiter that set it is the one that is + // being dropped. + if wake_next + || waiters + .front() + .get() + .map(|front| std::ptr::eq(front, waiter)) + .unwrap_or(false) + { + clear |= LONG_WAIT; + } + + let waiting_for = waiter.is_waiting_for(); + + // Don't drop the old waiter while holding the spin lock. + let old_waiter = if waiter.is_linked() && waiting_for == WaitingFor::Mutex { + // We know that the waiter is still linked and is waiting for the mutex, which + // guarantees that it is still linked into `self.waiters`. + let mut cursor = unsafe { waiters.cursor_mut_from_ptr(waiter as *const Waiter) }; + cursor.remove() + } else { + None + }; + + let (wake_list, set_on_release) = if wake_next || waiting_for == WaitingFor::None { + // Either the waiter was already woken or it's been removed from the mutex's waiter + // list and is going to be woken. Either way, we need to wake up another thread. + get_wake_list(waiters) + } else { + (WaiterList::new(WaiterAdapter::new()), 0) + }; + + if waiters.is_empty() { + clear |= HAS_WAITERS; + } + + if wake_list.is_empty() { + // We're not waking any other threads so clear the DESIGNATED_WAKER bit. In the worst + // case this leads to an additional thread being woken up but we risk a deadlock if we + // don't clear it. + clear |= DESIGNATED_WAKER; + } + + if let WaiterKind::Exclusive = waiter.kind() { + // The waiter being dropped is a writer so clear the writer waiting bit for now. If we + // found more writers in the list while fetching waiters to wake up then this bit will + // be set again via `set_on_release`. + clear |= WRITER_WAITING; + } + + while self + .state + .compare_exchange_weak( + oldstate, + (oldstate & !clear) | set_on_release, + Ordering::Release, + Ordering::Relaxed, + ) + .is_err() + { + hint::spin_loop(); + oldstate = self.state.load(Ordering::Relaxed); + } + + for w in wake_list { + w.wake(); + } + + mem::drop(old_waiter); + } +} + +unsafe impl Send for RawMutex {} +unsafe impl Sync for RawMutex {} + +fn cancel_waiter(raw: usize, waiter: &Waiter, wake_next: bool) { + let raw_mutex = raw as *const RawMutex; + + // Safe because the thread that owns the waiter that is being canceled must + // also own a reference to the mutex, which ensures that this pointer is + // valid. + unsafe { (*raw_mutex).cancel_waiter(waiter, wake_next) } +} + +/// A high-level primitive that provides safe, mutable access to a shared resource. +/// +/// Unlike more traditional mutexes, `Mutex` can safely provide both shared, immutable access (via +/// `read_lock()`) as well as exclusive, mutable access (via `lock()`) to an underlying resource +/// with no loss of performance. +/// +/// # Poisoning +/// +/// `Mutex` does not support lock poisoning so if a thread panics while holding the lock, the +/// poisoned data will be accessible by other threads in your program. If you need to guarantee that +/// other threads cannot access poisoned data then you may wish to wrap this `Mutex` inside another +/// type that provides the poisoning feature. See the implementation of `std::sync::Mutex` for an +/// example of this. +/// +/// +/// # Fairness +/// +/// This `Mutex` implementation does not guarantee that threads will acquire the lock in the same +/// order that they call `lock()` or `read_lock()`. However it will attempt to prevent long-term +/// starvation: if a thread repeatedly fails to acquire the lock beyond a threshold then all other +/// threads will fail to acquire the lock until the starved thread has acquired it. +/// +/// Similarly, this `Mutex` will attempt to balance reader and writer threads: once there is a +/// writer thread waiting to acquire the lock no new reader threads will be allowed to acquire it. +/// However, any reader threads that were already waiting will still be allowed to acquire it. +/// +/// # Examples +/// +/// ```edition2018 +/// use std::sync::Arc; +/// use std::thread; +/// use std::sync::mpsc::channel; +/// +/// use cros_async::{block_on, sync::Mutex}; +/// +/// const N: usize = 10; +/// +/// // Spawn a few threads to increment a shared variable (non-atomically), and +/// // let the main thread know once all increments are done. +/// // +/// // Here we're using an Arc to share memory among threads, and the data inside +/// // the Arc is protected with a mutex. +/// let data = Arc::new(Mutex::new(0)); +/// +/// let (tx, rx) = channel(); +/// for _ in 0..N { +/// let (data, tx) = (Arc::clone(&data), tx.clone()); +/// thread::spawn(move || { +/// // The shared state can only be accessed once the lock is held. +/// // Our non-atomic increment is safe because we're the only thread +/// // which can access the shared state when the lock is held. +/// let mut data = block_on(data.lock()); +/// *data += 1; +/// if *data == N { +/// tx.send(()).unwrap(); +/// } +/// // the lock is unlocked here when `data` goes out of scope. +/// }); +/// } +/// +/// rx.recv().unwrap(); +/// ``` +#[repr(align(128))] +pub struct Mutex { + raw: RawMutex, + value: UnsafeCell, +} + +impl Mutex { + /// Create a new, unlocked `Mutex` ready for use. + pub fn new(v: T) -> Mutex { + Mutex { + raw: RawMutex::new(), + value: UnsafeCell::new(v), + } + } + + /// Consume the `Mutex` and return the contained value. This method does not perform any locking + /// as the compiler will guarantee that there are no other references to `self` and the caller + /// owns the `Mutex`. + pub fn into_inner(self) -> T { + // Don't need to acquire the lock because the compiler guarantees that there are + // no references to `self`. + self.value.into_inner() + } +} + +impl Mutex { + /// Acquires exclusive, mutable access to the resource protected by the `Mutex`, blocking the + /// current thread until it is able to do so. Upon returning, the current thread will be the + /// only thread with access to the resource. The `Mutex` will be released when the returned + /// `MutexGuard` is dropped. + /// + /// Calling `lock()` while holding a `MutexGuard` or a `MutexReadGuard` will cause a deadlock. + /// + /// Callers that are not in an async context may wish to use the `block_on` method to block the + /// thread until the `Mutex` is acquired. + #[inline] + pub async fn lock(&self) -> MutexGuard<'_, T> { + self.raw.lock().await; + + // Safe because we have exclusive access to `self.value`. + MutexGuard { + mu: self, + value: unsafe { &mut *self.value.get() }, + } + } + + /// Acquires shared, immutable access to the resource protected by the `Mutex`, blocking the + /// current thread until it is able to do so. Upon returning there may be other threads that + /// also have immutable access to the resource but there will not be any threads that have + /// mutable access to the resource. When the returned `MutexReadGuard` is dropped the thread + /// releases its access to the resource. + /// + /// Calling `read_lock()` while holding a `MutexReadGuard` may deadlock. Calling `read_lock()` + /// while holding a `MutexGuard` will deadlock. + /// + /// Callers that are not in an async context may wish to use the `block_on` method to block the + /// thread until the `Mutex` is acquired. + #[inline] + pub async fn read_lock(&self) -> MutexReadGuard<'_, T> { + self.raw.read_lock().await; + + // Safe because we have shared read-only access to `self.value`. + MutexReadGuard { + mu: self, + value: unsafe { &*self.value.get() }, + } + } + + // Called from `Condvar::wait` when the thread wants to reacquire the lock. + #[inline] + pub(crate) async fn lock_from_cv(&self) -> MutexGuard<'_, T> { + self.raw.lock_slow::(DESIGNATED_WAKER, 0).await; + + // Safe because we have exclusive access to `self.value`. + MutexGuard { + mu: self, + value: unsafe { &mut *self.value.get() }, + } + } + + // Like `lock_from_cv` but for acquiring a shared lock. + #[inline] + pub(crate) async fn read_lock_from_cv(&self) -> MutexReadGuard<'_, T> { + // Threads that have waited in the Condvar's waiter list don't have to care if there is a + // writer waiting since they have already waited once. + self.raw + .lock_slow::(DESIGNATED_WAKER, WRITER_WAITING) + .await; + + // Safe because we have exclusive access to `self.value`. + MutexReadGuard { + mu: self, + value: unsafe { &*self.value.get() }, + } + } + + #[inline] + fn unlock(&self) { + self.raw.unlock(); + } + + #[inline] + fn read_unlock(&self) { + self.raw.read_unlock(); + } + + pub fn get_mut(&mut self) -> &mut T { + // Safe because the compiler statically guarantees that are no other references to `self`. + // This is also why we don't need to acquire the lock first. + unsafe { &mut *self.value.get() } + } +} + +unsafe impl Send for Mutex {} +unsafe impl Sync for Mutex {} + +impl Default for Mutex { + fn default() -> Self { + Self::new(Default::default()) + } +} + +impl From for Mutex { + fn from(source: T) -> Self { + Self::new(source) + } +} + +/// An RAII implementation of a "scoped exclusive lock" for a `Mutex`. When this structure is +/// dropped, the lock will be released. The resource protected by the `Mutex` can be accessed via +/// the `Deref` and `DerefMut` implementations of this structure. +pub struct MutexGuard<'a, T: ?Sized + 'a> { + mu: &'a Mutex, + value: &'a mut T, +} + +impl<'a, T: ?Sized> MutexGuard<'a, T> { + pub(crate) fn into_inner(self) -> &'a Mutex { + self.mu + } + + pub(crate) fn as_raw_mutex(&self) -> &RawMutex { + &self.mu.raw + } +} + +impl<'a, T: ?Sized> Deref for MutexGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.value + } +} + +impl<'a, T: ?Sized> DerefMut for MutexGuard<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.value + } +} + +impl<'a, T: ?Sized> Drop for MutexGuard<'a, T> { + fn drop(&mut self) { + self.mu.unlock() + } +} + +/// An RAII implementation of a "scoped shared lock" for a `Mutex`. When this structure is dropped, +/// the lock will be released. The resource protected by the `Mutex` can be accessed via the `Deref` +/// implementation of this structure. +pub struct MutexReadGuard<'a, T: ?Sized + 'a> { + mu: &'a Mutex, + value: &'a T, +} + +impl<'a, T: ?Sized> MutexReadGuard<'a, T> { + pub(crate) fn into_inner(self) -> &'a Mutex { + self.mu + } + + pub(crate) fn as_raw_mutex(&self) -> &RawMutex { + &self.mu.raw + } +} + +impl<'a, T: ?Sized> Deref for MutexReadGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.value + } +} + +impl<'a, T: ?Sized> Drop for MutexReadGuard<'a, T> { + fn drop(&mut self) { + self.mu.read_unlock() + } +} + +#[cfg(test)] +mod test { + use super::*; + + use std::{ + future::Future, + mem, + pin::Pin, + rc::Rc, + sync::{ + atomic::{AtomicUsize, Ordering}, + mpsc::{channel, Sender}, + Arc, + }, + task::{Context, Poll, Waker}, + thread, + time::Duration, + }; + + use futures::{ + channel::oneshot, + pending, select, + task::{waker_ref, ArcWake}, + FutureExt, + }; + use futures_executor::{LocalPool, ThreadPool}; + use futures_util::task::LocalSpawnExt; + + use super::super::super::{ + block_on, + sync::{Condvar, SpinLock}, + }; + + #[derive(Debug, Eq, PartialEq)] + struct NonCopy(u32); + + // Dummy waker used when we want to manually drive futures. + struct TestWaker; + impl ArcWake for TestWaker { + fn wake_by_ref(_arc_self: &Arc) {} + } + + #[test] + fn it_works() { + let mu = Mutex::new(NonCopy(13)); + + assert_eq!(*block_on(mu.lock()), NonCopy(13)); + } + + #[test] + fn smoke() { + let mu = Mutex::new(NonCopy(7)); + + mem::drop(block_on(mu.lock())); + mem::drop(block_on(mu.lock())); + } + + #[test] + fn rw_smoke() { + let mu = Mutex::new(NonCopy(7)); + + mem::drop(block_on(mu.lock())); + mem::drop(block_on(mu.read_lock())); + mem::drop((block_on(mu.read_lock()), block_on(mu.read_lock()))); + mem::drop(block_on(mu.lock())); + } + + #[test] + fn async_smoke() { + async fn lock(mu: Rc>) { + mu.lock().await; + } + + async fn read_lock(mu: Rc>) { + mu.read_lock().await; + } + + async fn double_read_lock(mu: Rc>) { + let first = mu.read_lock().await; + mu.read_lock().await; + + // Make sure first lives past the second read lock. + first.as_raw_mutex(); + } + + let mu = Rc::new(Mutex::new(NonCopy(7))); + + let mut ex = LocalPool::new(); + let spawner = ex.spawner(); + + spawner + .spawn_local(lock(Rc::clone(&mu))) + .expect("Failed to spawn future"); + spawner + .spawn_local(read_lock(Rc::clone(&mu))) + .expect("Failed to spawn future"); + spawner + .spawn_local(double_read_lock(Rc::clone(&mu))) + .expect("Failed to spawn future"); + spawner + .spawn_local(lock(Rc::clone(&mu))) + .expect("Failed to spawn future"); + + ex.run(); + } + + #[test] + fn send() { + let mu = Mutex::new(NonCopy(19)); + + thread::spawn(move || { + let value = block_on(mu.lock()); + assert_eq!(*value, NonCopy(19)); + }) + .join() + .unwrap(); + } + + #[test] + fn arc_nested() { + // Tests nested mutexes and access to underlying data. + let mu = Mutex::new(1); + let arc = Arc::new(Mutex::new(mu)); + thread::spawn(move || { + let nested = block_on(arc.lock()); + let lock2 = block_on(nested.lock()); + assert_eq!(*lock2, 1); + }) + .join() + .unwrap(); + } + + #[test] + fn arc_access_in_unwind() { + let arc = Arc::new(Mutex::new(1)); + let arc2 = arc.clone(); + thread::spawn(move || { + struct Unwinder { + i: Arc>, + } + impl Drop for Unwinder { + fn drop(&mut self) { + *block_on(self.i.lock()) += 1; + } + } + let _u = Unwinder { i: arc2 }; + panic!(); + }) + .join() + .expect_err("thread did not panic"); + let lock = block_on(arc.lock()); + assert_eq!(*lock, 2); + } + + #[test] + fn unsized_value() { + let mutex: &Mutex<[i32]> = &Mutex::new([1, 2, 3]); + { + let b = &mut *block_on(mutex.lock()); + b[0] = 4; + b[2] = 5; + } + let expected: &[i32] = &[4, 2, 5]; + assert_eq!(&*block_on(mutex.lock()), expected); + } + #[test] + fn high_contention() { + const THREADS: usize = 17; + const ITERATIONS: usize = 103; + + let mut threads = Vec::with_capacity(THREADS); + + let mu = Arc::new(Mutex::new(0usize)); + for _ in 0..THREADS { + let mu2 = mu.clone(); + threads.push(thread::spawn(move || { + for _ in 0..ITERATIONS { + *block_on(mu2.lock()) += 1; + } + })); + } + + for t in threads.into_iter() { + t.join().unwrap(); + } + + assert_eq!(*block_on(mu.read_lock()), THREADS * ITERATIONS); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn high_contention_with_cancel() { + const TASKS: usize = 17; + const ITERATIONS: usize = 103; + + async fn increment(mu: Arc>, alt_mu: Arc>, tx: Sender<()>) { + for _ in 0..ITERATIONS { + select! { + mut count = mu.lock().fuse() => *count += 1, + mut count = alt_mu.lock().fuse() => *count += 1, + } + } + tx.send(()).expect("Failed to send completion signal"); + } + + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(0usize)); + let alt_mu = Arc::new(Mutex::new(0usize)); + + let (tx, rx) = channel(); + for _ in 0..TASKS { + ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&alt_mu), tx.clone())); + } + + for _ in 0..TASKS { + if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) { + panic!("Error while waiting for threads to complete: {}", e); + } + } + + assert_eq!( + *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()), + TASKS * ITERATIONS + ); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + assert_eq!(alt_mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn single_thread_async() { + const TASKS: usize = 17; + const ITERATIONS: usize = 103; + + // Async closures are unstable. + async fn increment(mu: Rc>) { + for _ in 0..ITERATIONS { + *mu.lock().await += 1; + } + } + + let mut ex = LocalPool::new(); + let spawner = ex.spawner(); + + let mu = Rc::new(Mutex::new(0usize)); + for _ in 0..TASKS { + spawner + .spawn_local(increment(Rc::clone(&mu))) + .expect("Failed to spawn task"); + } + + ex.run(); + + assert_eq!(*block_on(mu.read_lock()), TASKS * ITERATIONS); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn multi_thread_async() { + const TASKS: usize = 17; + const ITERATIONS: usize = 103; + + // Async closures are unstable. + async fn increment(mu: Arc>, tx: Sender<()>) { + for _ in 0..ITERATIONS { + *mu.lock().await += 1; + } + tx.send(()).expect("Failed to send completion signal"); + } + + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(0usize)); + let (tx, rx) = channel(); + for _ in 0..TASKS { + ex.spawn_ok(increment(Arc::clone(&mu), tx.clone())); + } + + for _ in 0..TASKS { + rx.recv_timeout(Duration::from_secs(5)) + .expect("Failed to receive completion signal"); + } + assert_eq!(*block_on(mu.read_lock()), TASKS * ITERATIONS); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn get_mut() { + let mut mu = Mutex::new(NonCopy(13)); + *mu.get_mut() = NonCopy(17); + + assert_eq!(mu.into_inner(), NonCopy(17)); + } + + #[test] + fn into_inner() { + let mu = Mutex::new(NonCopy(29)); + assert_eq!(mu.into_inner(), NonCopy(29)); + } + + #[test] + fn into_inner_drop() { + struct NeedsDrop(Arc); + impl Drop for NeedsDrop { + fn drop(&mut self) { + self.0.fetch_add(1, Ordering::AcqRel); + } + } + + let value = Arc::new(AtomicUsize::new(0)); + let needs_drop = Mutex::new(NeedsDrop(value.clone())); + assert_eq!(value.load(Ordering::Acquire), 0); + + { + let inner = needs_drop.into_inner(); + assert_eq!(inner.0.load(Ordering::Acquire), 0); + } + + assert_eq!(value.load(Ordering::Acquire), 1); + } + + #[test] + fn rw_arc() { + const THREADS: isize = 7; + const ITERATIONS: isize = 13; + + let mu = Arc::new(Mutex::new(0isize)); + let mu2 = mu.clone(); + + let (tx, rx) = channel(); + thread::spawn(move || { + let mut guard = block_on(mu2.lock()); + for _ in 0..ITERATIONS { + let tmp = *guard; + *guard = -1; + thread::yield_now(); + *guard = tmp + 1; + } + tx.send(()).unwrap(); + }); + + let mut readers = Vec::with_capacity(10); + for _ in 0..THREADS { + let mu3 = mu.clone(); + let handle = thread::spawn(move || { + let guard = block_on(mu3.read_lock()); + assert!(*guard >= 0); + }); + + readers.push(handle); + } + + // Wait for the readers to finish their checks. + for r in readers { + r.join().expect("One or more readers saw a negative value"); + } + + // Wait for the writer to finish. + rx.recv_timeout(Duration::from_secs(5)).unwrap(); + assert_eq!(*block_on(mu.read_lock()), ITERATIONS); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn rw_single_thread_async() { + // A Future that returns `Poll::pending` the first time it is polled and `Poll::Ready` every + // time after that. + struct TestFuture { + polled: bool, + waker: Arc>>, + } + + impl Future for TestFuture { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + if self.polled { + Poll::Ready(()) + } else { + self.polled = true; + *self.waker.lock() = Some(cx.waker().clone()); + Poll::Pending + } + } + } + + fn wake_future(waker: Arc>>) { + loop { + if let Some(w) = waker.lock().take() { + w.wake(); + return; + } + + // This sleep cannot be moved into an else branch because we would end up holding + // the lock while sleeping due to rust's drop ordering rules. + thread::sleep(Duration::from_millis(10)); + } + } + + async fn writer(mu: Rc>) { + let mut guard = mu.lock().await; + for _ in 0..ITERATIONS { + let tmp = *guard; + *guard = -1; + let waker = Arc::new(SpinLock::new(None)); + let waker2 = Arc::clone(&waker); + thread::spawn(move || wake_future(waker2)); + let fut = TestFuture { + polled: false, + waker, + }; + fut.await; + *guard = tmp + 1; + } + } + + async fn reader(mu: Rc>) { + let guard = mu.read_lock().await; + assert!(*guard >= 0); + } + + const TASKS: isize = 7; + const ITERATIONS: isize = 13; + + let mu = Rc::new(Mutex::new(0isize)); + let mut ex = LocalPool::new(); + let spawner = ex.spawner(); + + spawner + .spawn_local(writer(Rc::clone(&mu))) + .expect("Failed to spawn writer"); + + for _ in 0..TASKS { + spawner + .spawn_local(reader(Rc::clone(&mu))) + .expect("Failed to spawn reader"); + } + + ex.run(); + + assert_eq!(*block_on(mu.read_lock()), ITERATIONS); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn rw_multi_thread_async() { + async fn writer(mu: Arc>, tx: Sender<()>) { + let mut guard = mu.lock().await; + for _ in 0..ITERATIONS { + let tmp = *guard; + *guard = -1; + thread::yield_now(); + *guard = tmp + 1; + } + + mem::drop(guard); + tx.send(()).unwrap(); + } + + async fn reader(mu: Arc>, tx: Sender<()>) { + let guard = mu.read_lock().await; + assert!(*guard >= 0); + + mem::drop(guard); + tx.send(()).expect("Failed to send completion message"); + } + + const TASKS: isize = 7; + const ITERATIONS: isize = 13; + + let mu = Arc::new(Mutex::new(0isize)); + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let (txw, rxw) = channel(); + ex.spawn_ok(writer(Arc::clone(&mu), txw)); + + let (txr, rxr) = channel(); + for _ in 0..TASKS { + ex.spawn_ok(reader(Arc::clone(&mu), txr.clone())); + } + + // Wait for the readers to finish their checks. + for _ in 0..TASKS { + rxr.recv_timeout(Duration::from_secs(5)) + .expect("Failed to receive completion message from reader"); + } + + // Wait for the writer to finish. + rxw.recv_timeout(Duration::from_secs(5)) + .expect("Failed to receive completion message from writer"); + + assert_eq!(*block_on(mu.read_lock()), ITERATIONS); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn wake_all_readers() { + async fn read(mu: Arc>) { + let g = mu.read_lock().await; + pending!(); + mem::drop(g); + } + + async fn write(mu: Arc>) { + mu.lock().await; + } + + let mu = Arc::new(Mutex::new(())); + let mut futures: [Pin>>; 5] = [ + Box::pin(read(mu.clone())), + Box::pin(read(mu.clone())), + Box::pin(read(mu.clone())), + Box::pin(write(mu.clone())), + Box::pin(read(mu.clone())), + ]; + const NUM_READERS: usize = 4; + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + // Acquire the lock so that the futures cannot get it. + let g = block_on(mu.lock()); + + for r in &mut futures { + if let Poll::Ready(()) = r.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, + HAS_WAITERS + ); + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, + WRITER_WAITING + ); + + // Drop the lock. This should allow all readers to make progress. Since they already waited + // once they should ignore the WRITER_WAITING bit that is currently set. + mem::drop(g); + for r in &mut futures { + if let Poll::Ready(()) = r.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + + // Check that all readers were able to acquire the lock. + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & READ_MASK, + READ_LOCK * NUM_READERS + ); + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, + WRITER_WAITING + ); + + let mut needs_poll = None; + + // All the readers can now finish but the writer needs to be polled again. + for (i, r) in futures.iter_mut().enumerate() { + match r.as_mut().poll(&mut cx) { + Poll::Ready(()) => {} + Poll::Pending => { + if needs_poll.is_some() { + panic!("More than one future unable to complete"); + } + needs_poll = Some(i); + } + } + } + + if futures[needs_poll.expect("Writer unexpectedly able to complete")] + .as_mut() + .poll(&mut cx) + .is_pending() + { + panic!("Writer unable to complete"); + } + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn long_wait() { + async fn tight_loop(mu: Arc>) { + loop { + let ready = mu.lock().await; + if *ready { + break; + } + pending!(); + } + } + + async fn mark_ready(mu: Arc>) { + *mu.lock().await = true; + } + + let mu = Arc::new(Mutex::new(false)); + let mut tl = Box::pin(tight_loop(mu.clone())); + let mut mark = Box::pin(mark_ready(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + for _ in 0..=LONG_WAIT_THRESHOLD { + if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) { + panic!("tight_loop unexpectedly ready"); + } + + if let Poll::Ready(()) = mark.as_mut().poll(&mut cx) { + panic!("mark_ready unexpectedly ready"); + } + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed), + LOCKED | HAS_WAITERS | WRITER_WAITING | LONG_WAIT + ); + + // This time the tight loop will fail to acquire the lock. + if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) { + panic!("tight_loop unexpectedly ready"); + } + + // Which will finally allow the mark_ready function to make progress. + if mark.as_mut().poll(&mut cx).is_pending() { + panic!("mark_ready not able to make progress"); + } + + // Now the tight loop will finish. + if tl.as_mut().poll(&mut cx).is_pending() { + panic!("tight_loop not able to finish"); + } + + assert!(*block_on(mu.lock())); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn cancel_long_wait_before_wake() { + async fn tight_loop(mu: Arc>) { + loop { + let ready = mu.lock().await; + if *ready { + break; + } + pending!(); + } + } + + async fn mark_ready(mu: Arc>) { + *mu.lock().await = true; + } + + let mu = Arc::new(Mutex::new(false)); + let mut tl = Box::pin(tight_loop(mu.clone())); + let mut mark = Box::pin(mark_ready(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + for _ in 0..=LONG_WAIT_THRESHOLD { + if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) { + panic!("tight_loop unexpectedly ready"); + } + + if let Poll::Ready(()) = mark.as_mut().poll(&mut cx) { + panic!("mark_ready unexpectedly ready"); + } + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed), + LOCKED | HAS_WAITERS | WRITER_WAITING | LONG_WAIT + ); + + // Now drop the mark_ready future, which should clear the LONG_WAIT bit. + mem::drop(mark); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), LOCKED); + + mem::drop(tl); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn cancel_long_wait_after_wake() { + async fn tight_loop(mu: Arc>) { + loop { + let ready = mu.lock().await; + if *ready { + break; + } + pending!(); + } + } + + async fn mark_ready(mu: Arc>) { + *mu.lock().await = true; + } + + let mu = Arc::new(Mutex::new(false)); + let mut tl = Box::pin(tight_loop(mu.clone())); + let mut mark = Box::pin(mark_ready(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + for _ in 0..=LONG_WAIT_THRESHOLD { + if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) { + panic!("tight_loop unexpectedly ready"); + } + + if let Poll::Ready(()) = mark.as_mut().poll(&mut cx) { + panic!("mark_ready unexpectedly ready"); + } + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed), + LOCKED | HAS_WAITERS | WRITER_WAITING | LONG_WAIT + ); + + // This time the tight loop will fail to acquire the lock. + if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) { + panic!("tight_loop unexpectedly ready"); + } + + // Now drop the mark_ready future, which should clear the LONG_WAIT bit. + mem::drop(mark); + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & LONG_WAIT, 0); + + // Since the lock is not held, we should be able to spawn a future to set the ready flag. + block_on(mark_ready(mu.clone())); + + // Now the tight loop will finish. + if tl.as_mut().poll(&mut cx).is_pending() { + panic!("tight_loop not able to finish"); + } + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn designated_waker() { + async fn inc(mu: Arc>) { + *mu.lock().await += 1; + } + + let mu = Arc::new(Mutex::new(0)); + + let mut futures = [ + Box::pin(inc(mu.clone())), + Box::pin(inc(mu.clone())), + Box::pin(inc(mu.clone())), + ]; + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let count = block_on(mu.lock()); + + // Poll 2 futures. Since neither will be able to acquire the lock, they should get added to + // the waiter list. + if let Poll::Ready(()) = futures[0].as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + if let Poll::Ready(()) = futures[1].as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed), + LOCKED | HAS_WAITERS | WRITER_WAITING, + ); + + // Now drop the lock. This should set the DESIGNATED_WAKER bit and wake up the first future + // in the wait list. + mem::drop(count); + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed), + DESIGNATED_WAKER | HAS_WAITERS | WRITER_WAITING, + ); + + // Now poll the third future. It should be able to acquire the lock immediately. + if futures[2].as_mut().poll(&mut cx).is_pending() { + panic!("future unable to complete"); + } + assert_eq!(*block_on(mu.lock()), 1); + + // There should still be a waiter in the wait list and the DESIGNATED_WAKER bit should still + // be set. + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & DESIGNATED_WAKER, + DESIGNATED_WAKER + ); + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, + HAS_WAITERS + ); + + // Now let the future that was woken up run. + if futures[0].as_mut().poll(&mut cx).is_pending() { + panic!("future unable to complete"); + } + assert_eq!(*block_on(mu.lock()), 2); + + if futures[1].as_mut().poll(&mut cx).is_pending() { + panic!("future unable to complete"); + } + assert_eq!(*block_on(mu.lock()), 3); + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn cancel_designated_waker() { + async fn inc(mu: Arc>) { + *mu.lock().await += 1; + } + + let mu = Arc::new(Mutex::new(0)); + + let mut fut = Box::pin(inc(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let count = block_on(mu.lock()); + + if let Poll::Ready(()) = fut.as_mut().poll(&mut cx) { + panic!("Future unexpectedly ready when lock is held"); + } + + // Drop the lock. This will wake up the future. + mem::drop(count); + + // Now drop the future without polling. This should clear all the state in the mutex. + mem::drop(fut); + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn cancel_before_wake() { + async fn inc(mu: Arc>) { + *mu.lock().await += 1; + } + + let mu = Arc::new(Mutex::new(0)); + + let mut fut1 = Box::pin(inc(mu.clone())); + + let mut fut2 = Box::pin(inc(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + // First acquire the lock. + let count = block_on(mu.lock()); + + // Now poll the futures. Since the lock is acquired they will both get queued in the waiter + // list. + match fut1.as_mut().poll(&mut cx) { + Poll::Pending => {} + Poll::Ready(()) => panic!("Future is unexpectedly ready"), + } + + match fut2.as_mut().poll(&mut cx) { + Poll::Pending => {} + Poll::Ready(()) => panic!("Future is unexpectedly ready"), + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, + WRITER_WAITING + ); + + // Drop fut1. This should remove it from the waiter list but shouldn't wake fut2. + mem::drop(fut1); + + // There should be no designated waker. + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & DESIGNATED_WAKER, 0); + + // Since the waiter was a writer, we should clear the WRITER_WAITING bit. + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, 0); + + match fut2.as_mut().poll(&mut cx) { + Poll::Pending => {} + Poll::Ready(()) => panic!("Future is unexpectedly ready"), + } + + // Now drop the lock. This should mark fut2 as ready to make progress. + mem::drop(count); + + match fut2.as_mut().poll(&mut cx) { + Poll::Pending => panic!("Future is not ready to make progress"), + Poll::Ready(()) => {} + } + + // Verify that we only incremented the count once. + assert_eq!(*block_on(mu.lock()), 1); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn cancel_after_wake() { + async fn inc(mu: Arc>) { + *mu.lock().await += 1; + } + + let mu = Arc::new(Mutex::new(0)); + + let mut fut1 = Box::pin(inc(mu.clone())); + + let mut fut2 = Box::pin(inc(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + // First acquire the lock. + let count = block_on(mu.lock()); + + // Now poll the futures. Since the lock is acquired they will both get queued in the waiter + // list. + match fut1.as_mut().poll(&mut cx) { + Poll::Pending => {} + Poll::Ready(()) => panic!("Future is unexpectedly ready"), + } + + match fut2.as_mut().poll(&mut cx) { + Poll::Pending => {} + Poll::Ready(()) => panic!("Future is unexpectedly ready"), + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, + WRITER_WAITING + ); + + // Drop the lock. This should mark fut1 as ready to make progress. + mem::drop(count); + + // Now drop fut1. This should make fut2 ready to make progress. + mem::drop(fut1); + + // Since there was still another waiter in the list we shouldn't have cleared the + // DESIGNATED_WAKER bit. + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & DESIGNATED_WAKER, + DESIGNATED_WAKER + ); + + // Since the waiter was a writer, we should clear the WRITER_WAITING bit. + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, 0); + + match fut2.as_mut().poll(&mut cx) { + Poll::Pending => panic!("Future is not ready to make progress"), + Poll::Ready(()) => {} + } + + // Verify that we only incremented the count once. + assert_eq!(*block_on(mu.lock()), 1); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn timeout() { + async fn timed_lock(timer: oneshot::Receiver<()>, mu: Arc>) { + select! { + res = timer.fuse() => { + match res { + Ok(()) => {}, + Err(e) => panic!("Timer unexpectedly canceled: {}", e), + } + } + _ = mu.lock().fuse() => panic!("Successfuly acquired lock"), + } + } + + let mu = Arc::new(Mutex::new(())); + let (tx, rx) = oneshot::channel(); + + let mut timeout = Box::pin(timed_lock(rx, mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + // Acquire the lock. + let g = block_on(mu.lock()); + + // Poll the future. + if let Poll::Ready(()) = timeout.as_mut().poll(&mut cx) { + panic!("timed_lock unexpectedly ready"); + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, + HAS_WAITERS + ); + + // Signal the channel, which should cancel the lock. + tx.send(()).expect("Failed to send wakeup"); + + // Now the future should have completed without acquiring the lock. + if timeout.as_mut().poll(&mut cx).is_pending() { + panic!("timed_lock not ready after timeout"); + } + + // The mutex state should not show any waiters. + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, 0); + + mem::drop(g); + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn writer_waiting() { + async fn read_zero(mu: Arc>) { + let val = mu.read_lock().await; + pending!(); + + assert_eq!(*val, 0); + } + + async fn inc(mu: Arc>) { + *mu.lock().await += 1; + } + + async fn read_one(mu: Arc>) { + let val = mu.read_lock().await; + + assert_eq!(*val, 1); + } + + let mu = Arc::new(Mutex::new(0)); + + let mut r1 = Box::pin(read_zero(mu.clone())); + let mut r2 = Box::pin(read_zero(mu.clone())); + + let mut w = Box::pin(inc(mu.clone())); + let mut r3 = Box::pin(read_one(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + if let Poll::Ready(()) = r1.as_mut().poll(&mut cx) { + panic!("read_zero unexpectedly ready"); + } + if let Poll::Ready(()) = r2.as_mut().poll(&mut cx) { + panic!("read_zero unexpectedly ready"); + } + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & READ_MASK, + 2 * READ_LOCK + ); + + if let Poll::Ready(()) = w.as_mut().poll(&mut cx) { + panic!("inc unexpectedly ready"); + } + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, + WRITER_WAITING + ); + + // The WRITER_WAITING bit should prevent the next reader from acquiring the lock. + if let Poll::Ready(()) = r3.as_mut().poll(&mut cx) { + panic!("read_one unexpectedly ready"); + } + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & READ_MASK, + 2 * READ_LOCK + ); + + if r1.as_mut().poll(&mut cx).is_pending() { + panic!("read_zero unable to complete"); + } + if r2.as_mut().poll(&mut cx).is_pending() { + panic!("read_zero unable to complete"); + } + if w.as_mut().poll(&mut cx).is_pending() { + panic!("inc unable to complete"); + } + if r3.as_mut().poll(&mut cx).is_pending() { + panic!("read_one unable to complete"); + } + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn notify_one() { + async fn read(mu: Arc>, cv: Arc) { + let mut count = mu.read_lock().await; + while *count == 0 { + count = cv.wait_read(count).await; + } + } + + async fn write(mu: Arc>, cv: Arc) { + let mut count = mu.lock().await; + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut readers = [ + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + ]; + let mut writer = Box::pin(write(mu.clone(), cv.clone())); + + for r in &mut readers { + if let Poll::Ready(()) = r.as_mut().poll(&mut cx) { + panic!("reader unexpectedly ready"); + } + } + if let Poll::Ready(()) = writer.as_mut().poll(&mut cx) { + panic!("writer unexpectedly ready"); + } + + let mut count = block_on(mu.lock()); + *count = 1; + + // This should wake all readers + one writer. + cv.notify_one(); + + // Poll the readers and the writer so they add themselves to the mutex's waiter list. + for r in &mut readers { + if r.as_mut().poll(&mut cx).is_ready() { + panic!("reader unexpectedly ready"); + } + } + + if writer.as_mut().poll(&mut cx).is_ready() { + panic!("writer unexpectedly ready"); + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, + HAS_WAITERS + ); + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, + WRITER_WAITING + ); + + mem::drop(count); + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & (HAS_WAITERS | WRITER_WAITING), + HAS_WAITERS | WRITER_WAITING + ); + + for r in &mut readers { + if r.as_mut().poll(&mut cx).is_pending() { + panic!("reader unable to complete"); + } + } + + if writer.as_mut().poll(&mut cx).is_pending() { + panic!("writer unable to complete"); + } + + assert_eq!(*block_on(mu.read_lock()), 0); + } + + #[test] + fn notify_when_unlocked() { + async fn dec(mu: Arc>, cv: Arc) { + let mut count = mu.lock().await; + + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut futures = [ + Box::pin(dec(mu.clone(), cv.clone())), + Box::pin(dec(mu.clone(), cv.clone())), + Box::pin(dec(mu.clone(), cv.clone())), + Box::pin(dec(mu.clone(), cv.clone())), + ]; + + for f in &mut futures { + if let Poll::Ready(()) = f.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + + *block_on(mu.lock()) = futures.len(); + cv.notify_all(); + + // Since we haven't polled `futures` yet, the mutex should not have any waiters. + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, 0); + + for f in &mut futures { + if f.as_mut().poll(&mut cx).is_pending() { + panic!("future unexpectedly ready"); + } + } + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn notify_reader_writer() { + async fn read(mu: Arc>, cv: Arc) { + let mut count = mu.read_lock().await; + while *count == 0 { + count = cv.wait_read(count).await; + } + + // Yield once while holding the read lock, which should prevent the writer from waking + // up. + pending!(); + } + + async fn write(mu: Arc>, cv: Arc) { + let mut count = mu.lock().await; + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + async fn lock(mu: Arc>) { + mem::drop(mu.lock().await); + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut futures: [Pin>>; 5] = [ + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(write(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + ]; + const NUM_READERS: usize = 4; + + let mut l = Box::pin(lock(mu.clone())); + + for f in &mut futures { + if let Poll::Ready(()) = f.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + + let mut count = block_on(mu.lock()); + *count = 1; + + // Now poll the lock function. Since the lock is held by us, it will get queued on the + // waiter list. + if let Poll::Ready(()) = l.as_mut().poll(&mut cx) { + panic!("lock() unexpectedly ready"); + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & (HAS_WAITERS | WRITER_WAITING), + HAS_WAITERS | WRITER_WAITING + ); + + // Wake up waiters while holding the lock. + cv.notify_all(); + + // Drop the lock. This should wake up the lock function. + mem::drop(count); + + if l.as_mut().poll(&mut cx).is_pending() { + panic!("lock() unable to complete"); + } + + // Since we haven't polled `futures` yet, the mutex state should now be empty. + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + + // Poll everything again. The readers should be able to make progress (but not complete) but + // the writer should be blocked. + for f in &mut futures { + if let Poll::Ready(()) = f.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & READ_MASK, + READ_LOCK * NUM_READERS + ); + + // All the readers can now finish but the writer needs to be polled again. + let mut needs_poll = None; + for (i, r) in futures.iter_mut().enumerate() { + match r.as_mut().poll(&mut cx) { + Poll::Ready(()) => {} + Poll::Pending => { + if needs_poll.is_some() { + panic!("More than one future unable to complete"); + } + needs_poll = Some(i); + } + } + } + + if futures[needs_poll.expect("Writer unexpectedly able to complete")] + .as_mut() + .poll(&mut cx) + .is_pending() + { + panic!("Writer unable to complete"); + } + + assert_eq!(*block_on(mu.lock()), 0); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn notify_readers_with_read_lock() { + async fn read(mu: Arc>, cv: Arc) { + let mut count = mu.read_lock().await; + while *count == 0 { + count = cv.wait_read(count).await; + } + + // Yield once while holding the read lock. + pending!(); + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut futures = [ + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + ]; + + for f in &mut futures { + if let Poll::Ready(()) = f.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + + // Increment the count and then grab a read lock. + *block_on(mu.lock()) = 1; + + let g = block_on(mu.read_lock()); + + // Notify the condvar while holding the read lock. This should wake up all the waiters. + cv.notify_all(); + + // Since the lock is held in shared mode, all the readers should immediately be able to + // acquire the read lock. + for f in &mut futures { + if let Poll::Ready(()) = f.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, 0); + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & READ_MASK, + READ_LOCK * (futures.len() + 1) + ); + + mem::drop(g); + + for f in &mut futures { + if f.as_mut().poll(&mut cx).is_pending() { + panic!("future unable to complete"); + } + } + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } +} diff --git a/cros_async/src/sync/spin.rs b/cros_async/src/sync/spin.rs new file mode 100644 index 0000000000..7faec6dc20 --- /dev/null +++ b/cros_async/src/sync/spin.rs @@ -0,0 +1,284 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::{ + cell::UnsafeCell, + hint, + ops::{Deref, DerefMut}, + sync::atomic::{AtomicBool, Ordering}, +}; + +const UNLOCKED: bool = false; +const LOCKED: bool = true; + +/// A primitive that provides safe, mutable access to a shared resource. +/// +/// Unlike `Mutex`, a `SpinLock` will not voluntarily yield its CPU time until the resource is +/// available and will instead keep spinning until the resource is acquired. For the vast majority +/// of cases, `Mutex` is a better choice than `SpinLock`. If a `SpinLock` must be used then users +/// should try to do as little work as possible while holding the `SpinLock` and avoid any sort of +/// blocking at all costs as it can severely penalize performance. +/// +/// # Poisoning +/// +/// This `SpinLock` does not implement lock poisoning so it is possible for threads to access +/// poisoned data if a thread panics while holding the lock. If lock poisoning is needed, it can be +/// implemented by wrapping the `SpinLock` in a new type that implements poisoning. See the +/// implementation of `std::sync::Mutex` for an example of how to do this. +#[repr(align(128))] +pub struct SpinLock { + lock: AtomicBool, + value: UnsafeCell, +} + +impl SpinLock { + /// Creates a new, unlocked `SpinLock` that's ready for use. + pub fn new(value: T) -> SpinLock { + SpinLock { + lock: AtomicBool::new(UNLOCKED), + value: UnsafeCell::new(value), + } + } + + /// Consumes the `SpinLock` and returns the value guarded by it. This method doesn't perform any + /// locking as the compiler guarantees that there are no references to `self`. + pub fn into_inner(self) -> T { + // No need to take the lock because the compiler can statically guarantee + // that there are no references to the SpinLock. + self.value.into_inner() + } +} + +impl SpinLock { + /// Acquires exclusive, mutable access to the resource protected by the `SpinLock`, blocking the + /// current thread until it is able to do so. Upon returning, the current thread will be the + /// only thread with access to the resource. The `SpinLock` will be released when the returned + /// `SpinLockGuard` is dropped. Attempting to call `lock` while already holding the `SpinLock` + /// will cause a deadlock. + pub fn lock(&self) -> SpinLockGuard { + loop { + let state = self.lock.load(Ordering::Relaxed); + if state == UNLOCKED + && self + .lock + .compare_exchange_weak(UNLOCKED, LOCKED, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + { + break; + } + hint::spin_loop(); + } + + SpinLockGuard { + lock: self, + value: unsafe { &mut *self.value.get() }, + } + } + + fn unlock(&self) { + // Don't need to compare and swap because we exclusively hold the lock. + self.lock.store(UNLOCKED, Ordering::Release); + } + + /// Returns a mutable reference to the contained value. This method doesn't perform any locking + /// as the compiler will statically guarantee that there are no other references to `self`. + pub fn get_mut(&mut self) -> &mut T { + // Safe because the compiler can statically guarantee that there are no other references to + // `self`. This is also why we don't need to acquire the lock. + unsafe { &mut *self.value.get() } + } +} + +unsafe impl Send for SpinLock {} +unsafe impl Sync for SpinLock {} + +impl Default for SpinLock { + fn default() -> Self { + Self::new(Default::default()) + } +} + +impl From for SpinLock { + fn from(source: T) -> Self { + Self::new(source) + } +} + +/// An RAII implementation of a "scoped lock" for a `SpinLock`. When this structure is dropped, the +/// lock will be released. The resource protected by the `SpinLock` can be accessed via the `Deref` +/// and `DerefMut` implementations of this structure. +pub struct SpinLockGuard<'a, T: 'a + ?Sized> { + lock: &'a SpinLock, + value: &'a mut T, +} + +impl<'a, T: ?Sized> Deref for SpinLockGuard<'a, T> { + type Target = T; + fn deref(&self) -> &T { + self.value + } +} + +impl<'a, T: ?Sized> DerefMut for SpinLockGuard<'a, T> { + fn deref_mut(&mut self) -> &mut T { + self.value + } +} + +impl<'a, T: ?Sized> Drop for SpinLockGuard<'a, T> { + fn drop(&mut self) { + self.lock.unlock(); + } +} + +#[cfg(test)] +mod test { + use super::*; + + use std::{ + mem, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + thread, + }; + + #[derive(PartialEq, Eq, Debug)] + struct NonCopy(u32); + + #[test] + fn it_works() { + let sl = SpinLock::new(NonCopy(13)); + + assert_eq!(*sl.lock(), NonCopy(13)); + } + + #[test] + fn smoke() { + let sl = SpinLock::new(NonCopy(7)); + + mem::drop(sl.lock()); + mem::drop(sl.lock()); + } + + #[test] + fn send() { + let sl = SpinLock::new(NonCopy(19)); + + thread::spawn(move || { + let value = sl.lock(); + assert_eq!(*value, NonCopy(19)); + }) + .join() + .unwrap(); + } + + #[test] + fn high_contention() { + const THREADS: usize = 23; + const ITERATIONS: usize = 101; + + let mut threads = Vec::with_capacity(THREADS); + + let sl = Arc::new(SpinLock::new(0usize)); + for _ in 0..THREADS { + let sl2 = sl.clone(); + threads.push(thread::spawn(move || { + for _ in 0..ITERATIONS { + *sl2.lock() += 1; + } + })); + } + + for t in threads.into_iter() { + t.join().unwrap(); + } + + assert_eq!(*sl.lock(), THREADS * ITERATIONS); + } + + #[test] + fn get_mut() { + let mut sl = SpinLock::new(NonCopy(13)); + *sl.get_mut() = NonCopy(17); + + assert_eq!(sl.into_inner(), NonCopy(17)); + } + + #[test] + fn into_inner() { + let sl = SpinLock::new(NonCopy(29)); + assert_eq!(sl.into_inner(), NonCopy(29)); + } + + #[test] + fn into_inner_drop() { + struct NeedsDrop(Arc); + impl Drop for NeedsDrop { + fn drop(&mut self) { + self.0.fetch_add(1, Ordering::AcqRel); + } + } + + let value = Arc::new(AtomicUsize::new(0)); + let needs_drop = SpinLock::new(NeedsDrop(value.clone())); + assert_eq!(value.load(Ordering::Acquire), 0); + + { + let inner = needs_drop.into_inner(); + assert_eq!(inner.0.load(Ordering::Acquire), 0); + } + + assert_eq!(value.load(Ordering::Acquire), 1); + } + + #[test] + fn arc_nested() { + // Tests nested sltexes and access to underlying data. + let sl = SpinLock::new(1); + let arc = Arc::new(SpinLock::new(sl)); + thread::spawn(move || { + let nested = arc.lock(); + let lock2 = nested.lock(); + assert_eq!(*lock2, 1); + }) + .join() + .unwrap(); + } + + #[test] + fn arc_access_in_unwind() { + let arc = Arc::new(SpinLock::new(1)); + let arc2 = arc.clone(); + thread::spawn(move || { + struct Unwinder { + i: Arc>, + } + impl Drop for Unwinder { + fn drop(&mut self) { + *self.i.lock() += 1; + } + } + let _u = Unwinder { i: arc2 }; + panic!(); + }) + .join() + .expect_err("thread did not panic"); + let lock = arc.lock(); + assert_eq!(*lock, 2); + } + + #[test] + fn unsized_value() { + let sltex: &SpinLock<[i32]> = &SpinLock::new([1, 2, 3]); + { + let b = &mut *sltex.lock(); + b[0] = 4; + b[2] = 5; + } + let expected: &[i32] = &[4, 2, 5]; + assert_eq!(&*sltex.lock(), expected); + } +} diff --git a/cros_async/src/sync/waiter.rs b/cros_async/src/sync/waiter.rs new file mode 100644 index 0000000000..e161c623d0 --- /dev/null +++ b/cros_async/src/sync/waiter.rs @@ -0,0 +1,288 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::{ + cell::UnsafeCell, + future::Future, + mem, + pin::Pin, + ptr::NonNull, + sync::{ + atomic::{AtomicBool, AtomicU8, Ordering}, + Arc, + }, + task::{Context, Poll, Waker}, +}; + +use intrusive_collections::{ + intrusive_adapter, + linked_list::{LinkedList, LinkedListOps}, + DefaultLinkOps, LinkOps, +}; + +use super::super::sync::SpinLock; + +// An atomic version of a LinkedListLink. See https://github.com/Amanieu/intrusive-rs/issues/47 for +// more details. +#[repr(align(128))] +pub struct AtomicLink { + prev: UnsafeCell>>, + next: UnsafeCell>>, + linked: AtomicBool, +} + +impl AtomicLink { + fn new() -> AtomicLink { + AtomicLink { + linked: AtomicBool::new(false), + prev: UnsafeCell::new(None), + next: UnsafeCell::new(None), + } + } + + fn is_linked(&self) -> bool { + self.linked.load(Ordering::Relaxed) + } +} + +impl DefaultLinkOps for AtomicLink { + type Ops = AtomicLinkOps; + + const NEW: Self::Ops = AtomicLinkOps; +} + +// Safe because the only way to mutate `AtomicLink` is via the `LinkedListOps` trait whose methods +// are all unsafe and require that the caller has first called `acquire_link` (and had it return +// true) to use them safely. +unsafe impl Send for AtomicLink {} +unsafe impl Sync for AtomicLink {} + +#[derive(Copy, Clone, Default)] +pub struct AtomicLinkOps; + +unsafe impl LinkOps for AtomicLinkOps { + type LinkPtr = NonNull; + + unsafe fn acquire_link(&mut self, ptr: Self::LinkPtr) -> bool { + !ptr.as_ref().linked.swap(true, Ordering::Acquire) + } + + unsafe fn release_link(&mut self, ptr: Self::LinkPtr) { + ptr.as_ref().linked.store(false, Ordering::Release) + } +} + +unsafe impl LinkedListOps for AtomicLinkOps { + unsafe fn next(&self, ptr: Self::LinkPtr) -> Option { + *ptr.as_ref().next.get() + } + + unsafe fn prev(&self, ptr: Self::LinkPtr) -> Option { + *ptr.as_ref().prev.get() + } + + unsafe fn set_next(&mut self, ptr: Self::LinkPtr, next: Option) { + *ptr.as_ref().next.get() = next; + } + + unsafe fn set_prev(&mut self, ptr: Self::LinkPtr, prev: Option) { + *ptr.as_ref().prev.get() = prev; + } +} + +#[derive(Clone, Copy)] +pub enum Kind { + Shared, + Exclusive, +} + +enum State { + Init, + Waiting(Waker), + Woken, + Finished, + Processing, +} + +// Indicates the queue to which the waiter belongs. It is the responsibility of the Mutex and +// Condvar implementations to update this value when adding/removing a Waiter from their respective +// waiter lists. +#[repr(u8)] +#[derive(Debug, Eq, PartialEq)] +pub enum WaitingFor { + // The waiter is either not linked into a waiter list or it is linked into a temporary list. + None = 0, + // The waiter is linked into the Mutex's waiter list. + Mutex = 1, + // The waiter is linked into the Condvar's waiter list. + Condvar = 2, +} + +// Represents a thread currently blocked on a Condvar or on acquiring a Mutex. +pub struct Waiter { + link: AtomicLink, + state: SpinLock, + cancel: fn(usize, &Waiter, bool), + cancel_data: usize, + kind: Kind, + waiting_for: AtomicU8, +} + +impl Waiter { + // Create a new, initialized Waiter. + // + // `kind` should indicate whether this waiter represent a thread that is waiting for a shared + // lock or an exclusive lock. + // + // `cancel` is the function that is called when a `WaitFuture` (returned by the `wait()` + // function) is dropped before it can complete. `cancel_data` is used as the first parameter of + // the `cancel` function. The second parameter is the `Waiter` that was canceled and the third + // parameter indicates whether the `WaitFuture` was dropped after it was woken (but before it + // was polled to completion). A value of `false` for the third parameter may already be stale + // by the time the cancel function runs and so does not guarantee that the waiter was not woken. + // In this case, implementations should still check if the Waiter was woken. However, a value of + // `true` guarantees that the waiter was already woken up so no additional checks are necessary. + // In this case, the cancel implementation should wake up the next waiter in its wait list, if + // any. + // + // `waiting_for` indicates the waiter list to which this `Waiter` will be added. See the + // documentation of the `WaitingFor` enum for the meaning of the different values. + pub fn new( + kind: Kind, + cancel: fn(usize, &Waiter, bool), + cancel_data: usize, + waiting_for: WaitingFor, + ) -> Waiter { + Waiter { + link: AtomicLink::new(), + state: SpinLock::new(State::Init), + cancel, + cancel_data, + kind, + waiting_for: AtomicU8::new(waiting_for as u8), + } + } + + // The kind of lock that this `Waiter` is waiting to acquire. + pub fn kind(&self) -> Kind { + self.kind + } + + // Returns true if this `Waiter` is currently linked into a waiter list. + pub fn is_linked(&self) -> bool { + self.link.is_linked() + } + + // Indicates the waiter list to which this `Waiter` belongs. + pub fn is_waiting_for(&self) -> WaitingFor { + match self.waiting_for.load(Ordering::Acquire) { + 0 => WaitingFor::None, + 1 => WaitingFor::Mutex, + 2 => WaitingFor::Condvar, + v => panic!("Unknown value for `WaitingFor`: {}", v), + } + } + + // Change the waiter list to which this `Waiter` belongs. This will panic if called when the + // `Waiter` is still linked into a waiter list. + pub fn set_waiting_for(&self, waiting_for: WaitingFor) { + self.waiting_for.store(waiting_for as u8, Ordering::Release); + } + + // Reset the Waiter back to its initial state. Panics if this `Waiter` is still linked into a + // waiter list. + pub fn reset(&self, waiting_for: WaitingFor) { + debug_assert!(!self.is_linked(), "Cannot reset `Waiter` while linked"); + self.set_waiting_for(waiting_for); + + let mut state = self.state.lock(); + if let State::Waiting(waker) = mem::replace(&mut *state, State::Init) { + mem::drop(state); + mem::drop(waker); + } + } + + // Wait until woken up by another thread. + pub fn wait(&self) -> WaitFuture<'_> { + WaitFuture { waiter: self } + } + + // Wake up the thread associated with this `Waiter`. Panics if `waiting_for()` does not return + // `WaitingFor::None` or if `is_linked()` returns true. + pub fn wake(&self) { + debug_assert!(!self.is_linked(), "Cannot wake `Waiter` while linked"); + debug_assert_eq!(self.is_waiting_for(), WaitingFor::None); + + let mut state = self.state.lock(); + + if let State::Waiting(waker) = mem::replace(&mut *state, State::Woken) { + mem::drop(state); + waker.wake(); + } + } +} + +pub struct WaitFuture<'w> { + waiter: &'w Waiter, +} + +impl<'w> Future for WaitFuture<'w> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let mut state = self.waiter.state.lock(); + + match mem::replace(&mut *state, State::Processing) { + State::Init => { + *state = State::Waiting(cx.waker().clone()); + + Poll::Pending + } + State::Waiting(old_waker) => { + *state = State::Waiting(cx.waker().clone()); + mem::drop(state); + mem::drop(old_waker); + + Poll::Pending + } + State::Woken => { + *state = State::Finished; + Poll::Ready(()) + } + State::Finished => { + panic!("Future polled after returning Poll::Ready"); + } + State::Processing => { + panic!("Unexpected waker state"); + } + } + } +} + +impl<'w> Drop for WaitFuture<'w> { + fn drop(&mut self) { + let state = self.waiter.state.lock(); + + match *state { + State::Finished => {} + State::Processing => panic!("Unexpected waker state"), + State::Woken => { + mem::drop(state); + + // We were woken but not polled. Wake up the next waiter. + (self.waiter.cancel)(self.waiter.cancel_data, self.waiter, true); + } + _ => { + mem::drop(state); + + // Not woken. No need to wake up any waiters. + (self.waiter.cancel)(self.waiter.cancel_data, self.waiter, false); + } + } + } +} + +intrusive_adapter!(pub WaiterAdapter = Arc: Waiter { link: AtomicLink }); + +pub type WaiterList = LinkedList; diff --git a/cros_async/src/timer.rs b/cros_async/src/timer.rs new file mode 100644 index 0000000000..f9624e0e9c --- /dev/null +++ b/cros_async/src/timer.rs @@ -0,0 +1,126 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::time::Duration; + +use sys_util::{Result as SysResult, TimerFd}; + +use super::{AsyncResult, Error, Executor, IntoAsync, IoSourceExt}; + +#[cfg(test)] +use super::{FdExecutor, URingExecutor}; + +/// An async version of sys_util::TimerFd. +pub struct TimerAsync { + io_source: Box>, +} + +impl TimerAsync { + pub fn new(timer: TimerFd, ex: &Executor) -> AsyncResult { + ex.async_from(timer) + .map(|io_source| TimerAsync { io_source }) + } + + #[cfg(test)] + pub(crate) fn new_poll(timer: TimerFd, ex: &FdExecutor) -> AsyncResult { + super::executor::async_poll_from(timer, ex).map(|io_source| TimerAsync { io_source }) + } + + #[cfg(test)] + pub(crate) fn new_uring(timer: TimerFd, ex: &URingExecutor) -> AsyncResult { + super::executor::async_uring_from(timer, ex).map(|io_source| TimerAsync { io_source }) + } + + /// Gets the next value from the timer. + pub async fn next_val(&self) -> AsyncResult { + self.io_source.read_u64().await + } + + /// Async sleep for the given duration + pub async fn sleep(ex: &Executor, dur: Duration) -> std::result::Result<(), Error> { + let tfd = TimerFd::new().map_err(Error::TimerFd)?; + tfd.reset(dur, None).map_err(Error::TimerFd)?; + let t = TimerAsync::new(tfd, ex).map_err(Error::TimerAsync)?; + t.next_val().await.map_err(Error::TimerAsync)?; + Ok(()) + } + + /// Sets the timer to expire after `dur`. If `interval` is not `None` it represents + /// the period for repeated expirations after the initial expiration. Otherwise + /// the timer will expire just once. Cancels any existing duration and repeating interval. + pub fn reset(&mut self, dur: Duration, interval: Option) -> SysResult<()> { + self.io_source.as_source_mut().reset(dur, interval) + } +} + +impl IntoAsync for TimerFd {} + +#[cfg(test)] +mod tests { + use super::{super::uring_executor::use_uring, *}; + use std::time::{Duration, Instant}; + + #[test] + fn one_shot() { + if !use_uring() { + return; + } + + async fn this_test(ex: &URingExecutor) { + let tfd = TimerFd::new().expect("failed to create timerfd"); + assert_eq!(tfd.is_armed().unwrap(), false); + + let dur = Duration::from_millis(200); + let now = Instant::now(); + tfd.reset(dur, None).expect("failed to arm timer"); + + assert_eq!(tfd.is_armed().unwrap(), true); + + let t = TimerAsync::new_uring(tfd, ex).unwrap(); + let count = t.next_val().await.expect("unable to wait for timer"); + + assert_eq!(count, 1); + assert!(now.elapsed() >= dur); + } + + let ex = URingExecutor::new().unwrap(); + ex.run_until(this_test(&ex)).unwrap(); + } + + #[test] + fn one_shot_fd() { + async fn this_test(ex: &FdExecutor) { + let tfd = TimerFd::new().expect("failed to create timerfd"); + assert_eq!(tfd.is_armed().unwrap(), false); + + let dur = Duration::from_millis(200); + let now = Instant::now(); + tfd.reset(dur, None).expect("failed to arm timer"); + + assert_eq!(tfd.is_armed().unwrap(), true); + + let t = TimerAsync::new_poll(tfd, ex).unwrap(); + let count = t.next_val().await.expect("unable to wait for timer"); + + assert_eq!(count, 1); + assert!(now.elapsed() >= dur); + } + + let ex = FdExecutor::new().unwrap(); + ex.run_until(this_test(&ex)).unwrap(); + } + + #[test] + fn timer() { + async fn this_test(ex: &Executor) { + let dur = Duration::from_millis(200); + let now = Instant::now(); + TimerAsync::sleep(ex, dur).await.expect("unable to sleep"); + assert!(now.elapsed() >= dur); + } + + let ex = Executor::new().expect("creating an executor failed"); + ex.run_until(this_test(&ex)).unwrap(); + } +} diff --git a/cros_async/src/uring_executor.rs b/cros_async/src/uring_executor.rs new file mode 100644 index 0000000000..63c0bb7453 --- /dev/null +++ b/cros_async/src/uring_executor.rs @@ -0,0 +1,1151 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! `URingExecutor` +//! +//! The executor runs all given futures to completion. Futures register wakers associated with +//! io_uring operations. A waker is called when the set of uring ops the waker is waiting on +//! completes. +//! +//! `URingExecutor` is meant to be used with the `futures-rs` crate that provides combinators and +//! utility functions to combine futures. In general, the interface provided by `URingExecutor` +//! shouldn't be used directly. Instead, use them by interacting with implementors of `IoSource`, +//! and the high-level future functions. +//! +//! +//! ## Read/Write buffer management. +//! +//! There are two key issues managing asynchronous IO buffers in rust. +//! 1) The kernel has a mutable reference to the memory until the completion is returned. Rust must +//! not have any references to it during that time. +//! 2) The memory must remain valid as long as the kernel has a reference to it. +//! +//! ### The kernel's mutable borrow of the buffer +//! +//! Because the buffers used for read and write must be passed to the kernel for an unknown +//! duration, the functions must maintain ownership of the memory. The core of this problem is that +//! the lifetime of the future isn't tied to the scope in which the kernel can modify the buffer the +//! future has a reference to. The buffer can be modified at any point from submission until the +//! operation completes. The operation can't be synchronously canceled when the future is dropped, +//! and Drop can't be used for safety guarantees. To ensure this never happens, only memory that +//! implements `BackingMemory` is accepted. For implementors of `BackingMemory` the mut borrow +//! isn't an issue because those are already Ok with external modifications to the memory (Like a +//! `VolatileSlice`). +//! +//! ### Buffer lifetime +//! +//! What if the kernel's reference to the buffer outlives the buffer itself? This could happen if a +//! read operation was submitted, then the memory is dropped. To solve this, the executor takes an +//! Arc to the backing memory. Vecs being read to are also wrapped in an Arc before being passed to +//! the executor. The executor holds the Arc and ensures all operations are complete before dropping +//! it, that guarantees the memory is valid for the duration. +//! +//! The buffers _have_ to be on the heap. Because we don't have a way to cancel a future if it is +//! dropped(can't rely on drop running), there is no way to ensure the kernel's buffer remains valid +//! until the operation completes unless the executor holds an Arc to the memory on the heap. +//! +//! ## Using `Vec` for reads/writes. +//! +//! There is a convenience wrapper `VecIoWrapper` provided for fully owned vectors. This type +//! ensures that only the kernel is allowed to access the `Vec` and wraps the the `Vec` in an Arc to +//! ensure it lives long enough. + +use std::{ + convert::TryInto, + ffi::CStr, + fs::File, + future::Future, + io, + mem::{ + MaybeUninit, {self}, + }, + os::unix::io::{AsRawFd, FromRawFd, RawFd}, + pin::Pin, + sync::{ + atomic::{AtomicI32, Ordering}, + Arc, Weak, + }, + task::{Context, Poll, Waker}, + thread::{ + ThreadId, {self}, + }, +}; + +use async_task::Task; +use futures::task::noop_waker; +use io_uring::URingContext; +use once_cell::sync::Lazy; +use pin_utils::pin_mut; +use remain::sorted; +use slab::Slab; +use sync::Mutex; +use sys_util::{warn, WatchingEvents}; +use thiserror::Error as ThisError; + +use super::{ + mem::{BackingMemory, MemRegion}, + queue::RunnableQueue, + waker::{new_waker, WakerToken, WeakWake}, + BlockingPool, +}; + +#[sorted] +#[derive(Debug, ThisError)] +pub enum Error { + /// Creating a context to wait on FDs failed. + #[error("Error creating the fd waiting context: {0}")] + CreatingContext(io_uring::Error), + /// Failed to copy the FD for the polling context. + #[error("Failed to copy the FD for the polling context: {0}")] + DuplicatingFd(sys_util::Error), + /// The Executor is gone. + #[error("The URingExecutor is gone")] + ExecutorGone, + /// Invalid offset or length given for an iovec in backing memory. + #[error("Invalid offset/len for getting an iovec")] + InvalidOffset, + /// Invalid FD source specified. + #[error("Invalid source, FD not registered for use")] + InvalidSource, + /// Error doing the IO. + #[error("Error during IO: {0}")] + Io(io::Error), + /// Failed to remove the waker remove the polling context. + #[error("Error removing from the URing context: {0}")] + RemovingWaker(io_uring::Error), + /// Failed to submit the operation to the polling context. + #[error("Error adding to the URing context: {0}")] + SubmittingOp(io_uring::Error), + /// URingContext failure. + #[error("URingContext failure: {0}")] + URingContextError(io_uring::Error), + /// Failed to submit or wait for io_uring events. + #[error("URing::enter: {0}")] + URingEnter(io_uring::Error), +} +pub type Result = std::result::Result; + +impl From for io::Error { + fn from(e: Error) -> Self { + use Error::*; + match e { + DuplicatingFd(e) => e.into(), + ExecutorGone => io::Error::new(io::ErrorKind::Other, ExecutorGone), + InvalidOffset => io::Error::new(io::ErrorKind::InvalidInput, InvalidOffset), + InvalidSource => io::Error::new(io::ErrorKind::InvalidData, InvalidSource), + Io(e) => e, + CreatingContext(e) => e.into(), + RemovingWaker(e) => e.into(), + SubmittingOp(e) => e.into(), + URingContextError(e) => e.into(), + URingEnter(e) => e.into(), + } + } +} + +static USE_URING: Lazy = Lazy::new(|| { + let mut utsname = MaybeUninit::zeroed(); + + // Safe because this will only modify `utsname` and we check the return value. + let res = unsafe { libc::uname(utsname.as_mut_ptr()) }; + if res < 0 { + return false; + } + + // Safe because the kernel has initialized `utsname`. + let utsname = unsafe { utsname.assume_init() }; + + // Safe because the pointer is valid and the kernel guarantees that this is a valid C string. + let release = unsafe { CStr::from_ptr(utsname.release.as_ptr()) }; + + let mut components = match release.to_str().map(|r| r.split('.').map(str::parse)) { + Ok(c) => c, + Err(_) => return false, + }; + + // Kernels older than 5.10 either didn't support io_uring or had bugs in the implementation. + match (components.next(), components.next()) { + (Some(Ok(major)), Some(Ok(minor))) if (major, minor) >= (5, 10) => { + // The kernel version is new enough so check if we can actually make a uring context. + URingContext::new(8).is_ok() + } + _ => false, + } +}); + +// Checks if the uring executor is available. +// Caches the result so that the check is only run once. +// Useful for falling back to the FD executor on pre-uring kernels. +pub(crate) fn use_uring() -> bool { + *USE_URING +} + +pub struct RegisteredSource { + tag: usize, + ex: Weak, +} + +impl RegisteredSource { + pub fn start_read_to_mem( + &self, + file_offset: u64, + mem: Arc, + addrs: &[MemRegion], + ) -> Result { + let ex = self.ex.upgrade().ok_or(Error::ExecutorGone)?; + let token = ex.submit_read_to_vectored(self, mem, file_offset, addrs)?; + + Ok(PendingOperation { + waker_token: Some(token), + ex: self.ex.clone(), + submitted: false, + }) + } + + pub fn start_write_from_mem( + &self, + file_offset: u64, + mem: Arc, + addrs: &[MemRegion], + ) -> Result { + let ex = self.ex.upgrade().ok_or(Error::ExecutorGone)?; + let token = ex.submit_write_from_vectored(self, mem, file_offset, addrs)?; + + Ok(PendingOperation { + waker_token: Some(token), + ex: self.ex.clone(), + submitted: false, + }) + } + + pub fn start_fallocate(&self, offset: u64, len: u64, mode: u32) -> Result { + let ex = self.ex.upgrade().ok_or(Error::ExecutorGone)?; + let token = ex.submit_fallocate(self, offset, len, mode)?; + + Ok(PendingOperation { + waker_token: Some(token), + ex: self.ex.clone(), + submitted: false, + }) + } + + pub fn start_fsync(&self) -> Result { + let ex = self.ex.upgrade().ok_or(Error::ExecutorGone)?; + let token = ex.submit_fsync(self)?; + + Ok(PendingOperation { + waker_token: Some(token), + ex: self.ex.clone(), + submitted: false, + }) + } + + pub fn poll_fd_readable(&self) -> Result { + let events = WatchingEvents::empty().set_read(); + + let ex = self.ex.upgrade().ok_or(Error::ExecutorGone)?; + let token = ex.submit_poll(self, &events)?; + + Ok(PendingOperation { + waker_token: Some(token), + ex: self.ex.clone(), + submitted: false, + }) + } +} + +impl Drop for RegisteredSource { + fn drop(&mut self) { + if let Some(ex) = self.ex.upgrade() { + let _ = ex.deregister_source(self); + } + } +} + +// Indicates that the executor is either within or about to make an io_uring_enter syscall. When a +// waker sees this value, it will add and submit a NOP to the uring, which will wake up the thread +// blocked on the io_uring_enter syscall. +const WAITING: i32 = 0xb80d_21b5u32 as i32; + +// Indicates that the executor is processing any futures that are ready to run. +const PROCESSING: i32 = 0xdb31_83a3u32 as i32; + +// Indicates that one or more futures may be ready to make progress. +const WOKEN: i32 = 0x0fc7_8f7eu32 as i32; + +// Number of entries in the ring. +const NUM_ENTRIES: usize = 256; + +// An operation that has been submitted to the uring and is potentially being waited on. +struct OpData { + _file: Arc, + _mem: Option>, + waker: Option, + canceled: bool, +} + +// The current status of an operation that's been submitted to the uring. +enum OpStatus { + Nop, + Pending(OpData), + Completed(Option<::std::io::Result>), +} + +struct Ring { + ops: Slab, + registered_sources: Slab>, +} + +struct RawExecutor { + // The URingContext needs to be first so that it is dropped first, closing the uring fd, and + // releasing the resources borrowed by the kernel before we free them. + ctx: URingContext, + queue: RunnableQueue, + ring: Mutex, + blocking_pool: BlockingPool, + thread_id: Mutex>, + state: AtomicI32, +} + +impl RawExecutor { + fn new() -> Result { + Ok(RawExecutor { + ctx: URingContext::new(NUM_ENTRIES).map_err(Error::CreatingContext)?, + queue: RunnableQueue::new(), + ring: Mutex::new(Ring { + ops: Slab::with_capacity(NUM_ENTRIES), + registered_sources: Slab::with_capacity(NUM_ENTRIES), + }), + blocking_pool: Default::default(), + thread_id: Mutex::new(None), + state: AtomicI32::new(PROCESSING), + }) + } + + fn wake(&self) { + let oldstate = self.state.swap(WOKEN, Ordering::Release); + if oldstate == WAITING { + let mut ring = self.ring.lock(); + let entry = ring.ops.vacant_entry(); + let next_op_token = entry.key(); + if let Err(e) = self.ctx.add_nop(usize_to_u64(next_op_token)) { + warn!("Failed to add NOP for waking up executor: {}", e); + } + entry.insert(OpStatus::Nop); + mem::drop(ring); + + match self.ctx.submit() { + Ok(()) => {} + // If the kernel's submit ring is full then we know we won't block when calling + // io_uring_enter, which is all we really care about. + Err(io_uring::Error::RingEnter(libc::EBUSY)) => {} + Err(e) => warn!("Failed to submit NOP for waking up executor: {}", e), + } + } + } + + fn spawn(self: &Arc, f: F) -> Task + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let raw = Arc::downgrade(self); + let schedule = move |runnable| { + if let Some(r) = raw.upgrade() { + r.queue.push_back(runnable); + r.wake(); + } + }; + let (runnable, task) = async_task::spawn(f, schedule); + runnable.schedule(); + task + } + + fn spawn_local(self: &Arc, f: F) -> Task + where + F: Future + 'static, + F::Output: 'static, + { + let raw = Arc::downgrade(self); + let schedule = move |runnable| { + if let Some(r) = raw.upgrade() { + r.queue.push_back(runnable); + r.wake(); + } + }; + let (runnable, task) = async_task::spawn_local(f, schedule); + runnable.schedule(); + task + } + + fn spawn_blocking(self: &Arc, f: F) -> Task + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + self.blocking_pool.spawn(f) + } + + fn runs_tasks_on_current_thread(&self) -> bool { + let executor_thread = self.thread_id.lock(); + executor_thread + .map(|id| id == thread::current().id()) + .unwrap_or(false) + } + + fn run(&self, cx: &mut Context, done: F) -> Result { + let current_thread = thread::current().id(); + let mut thread_id = self.thread_id.lock(); + assert_eq!( + *thread_id.get_or_insert(current_thread), + current_thread, + "`URingExecutor::run` cannot be called from more than one thread" + ); + mem::drop(thread_id); + + pin_mut!(done); + loop { + self.state.store(PROCESSING, Ordering::Release); + for runnable in self.queue.iter() { + runnable.run(); + } + + if let Poll::Ready(val) = done.as_mut().poll(cx) { + return Ok(val); + } + + let oldstate = self.state.compare_exchange( + PROCESSING, + WAITING, + Ordering::Acquire, + Ordering::Acquire, + ); + if let Err(oldstate) = oldstate { + debug_assert_eq!(oldstate, WOKEN); + // One or more futures have become runnable. + continue; + } + + let events = self.ctx.wait().map_err(Error::URingEnter)?; + + // Set the state back to PROCESSING to prevent any tasks woken up by the loop below from + // writing to the eventfd. + self.state.store(PROCESSING, Ordering::Release); + + let mut ring = self.ring.lock(); + for (raw_token, result) in events { + // While the `expect()` might fail on arbitrary `u64`s, the `raw_token` was + // something that we originally gave to the kernel and that was created from a + // `usize` so we should always be able to convert it back into a `usize`. + let token = raw_token + .try_into() + .expect("`u64` doesn't fit inside a `usize`"); + + let op = ring + .ops + .get_mut(token) + .expect("Received completion token for unexpected operation"); + match mem::replace(op, OpStatus::Completed(Some(result))) { + // No one is waiting on a Nop. + OpStatus::Nop => mem::drop(ring.ops.remove(token)), + OpStatus::Pending(data) => { + if data.canceled { + // No one is waiting for this operation and the uring is done with + // it so it's safe to remove. + ring.ops.remove(token); + } + if let Some(waker) = data.waker { + waker.wake(); + } + } + OpStatus::Completed(_) => panic!("uring operation completed more than once"), + } + } + } + } + + fn get_result(&self, token: &WakerToken, cx: &mut Context) -> Option> { + let mut ring = self.ring.lock(); + + let op = ring + .ops + .get_mut(token.0) + .expect("`get_result` called on unknown operation"); + match op { + OpStatus::Nop => panic!("`get_result` called on nop"), + OpStatus::Pending(data) => { + if data.canceled { + panic!("`get_result` called on canceled operation"); + } + data.waker = Some(cx.waker().clone()); + None + } + OpStatus::Completed(res) => { + let out = res.take(); + ring.ops.remove(token.0); + Some(out.expect("Missing result in completed operation")) + } + } + } + + // Remove the waker for the given token if it hasn't fired yet. + fn cancel_operation(&self, token: WakerToken) { + let mut ring = self.ring.lock(); + if let Some(op) = ring.ops.get_mut(token.0) { + match op { + OpStatus::Nop => panic!("`cancel_operation` called on nop"), + OpStatus::Pending(data) => { + if data.canceled { + panic!("uring operation canceled more than once"); + } + + // Clear the waker as it is no longer needed. + data.waker = None; + data.canceled = true; + + // Keep the rest of the op data as the uring might still be accessing either + // the source of the backing memory so it needs to live until the kernel + // completes the operation. TODO: cancel the operation in the uring. + } + OpStatus::Completed(_) => { + ring.ops.remove(token.0); + } + } + } + } + + fn register_source(&self, f: Arc) -> usize { + self.ring.lock().registered_sources.insert(f) + } + + fn deregister_source(&self, source: &RegisteredSource) { + // There isn't any need to pull pending ops out, the all have Arc's to the file and mem they + // need.let them complete. deregister with pending ops is not a common path no need to + // optimize that case yet. + self.ring.lock().registered_sources.remove(source.tag); + } + + fn submit_poll( + &self, + source: &RegisteredSource, + events: &sys_util::WatchingEvents, + ) -> Result { + let mut ring = self.ring.lock(); + let src = ring + .registered_sources + .get(source.tag) + .map(Arc::clone) + .ok_or(Error::InvalidSource)?; + let entry = ring.ops.vacant_entry(); + let next_op_token = entry.key(); + self.ctx + .add_poll_fd(src.as_raw_fd(), events, usize_to_u64(next_op_token)) + .map_err(Error::SubmittingOp)?; + + entry.insert(OpStatus::Pending(OpData { + _file: src, + _mem: None, + waker: None, + canceled: false, + })); + + Ok(WakerToken(next_op_token)) + } + + fn submit_fallocate( + &self, + source: &RegisteredSource, + offset: u64, + len: u64, + mode: u32, + ) -> Result { + let mut ring = self.ring.lock(); + let src = ring + .registered_sources + .get(source.tag) + .map(Arc::clone) + .ok_or(Error::InvalidSource)?; + let entry = ring.ops.vacant_entry(); + let next_op_token = entry.key(); + self.ctx + .add_fallocate( + src.as_raw_fd(), + offset, + len, + mode, + usize_to_u64(next_op_token), + ) + .map_err(Error::SubmittingOp)?; + + entry.insert(OpStatus::Pending(OpData { + _file: src, + _mem: None, + waker: None, + canceled: false, + })); + + Ok(WakerToken(next_op_token)) + } + + fn submit_fsync(&self, source: &RegisteredSource) -> Result { + let mut ring = self.ring.lock(); + let src = ring + .registered_sources + .get(source.tag) + .map(Arc::clone) + .ok_or(Error::InvalidSource)?; + let entry = ring.ops.vacant_entry(); + let next_op_token = entry.key(); + self.ctx + .add_fsync(src.as_raw_fd(), usize_to_u64(next_op_token)) + .map_err(Error::SubmittingOp)?; + + entry.insert(OpStatus::Pending(OpData { + _file: src, + _mem: None, + waker: None, + canceled: false, + })); + + Ok(WakerToken(next_op_token)) + } + + fn submit_read_to_vectored( + &self, + source: &RegisteredSource, + mem: Arc, + offset: u64, + addrs: &[MemRegion], + ) -> Result { + if addrs + .iter() + .any(|&mem_range| mem.get_volatile_slice(mem_range).is_err()) + { + return Err(Error::InvalidOffset); + } + + let mut ring = self.ring.lock(); + let src = ring + .registered_sources + .get(source.tag) + .map(Arc::clone) + .ok_or(Error::InvalidSource)?; + + // We can't insert the OpData into the slab yet because `iovecs` borrows `mem` below. + let entry = ring.ops.vacant_entry(); + let next_op_token = entry.key(); + + // The addresses have already been validated, so unwrapping them will succeed. + // validate their addresses before submitting. + let iovecs = addrs.iter().map(|&mem_range| { + *mem.get_volatile_slice(mem_range) + .unwrap() + .as_iobuf() + .as_ref() + }); + + unsafe { + // Safe because all the addresses are within the Memory that an Arc is kept for the + // duration to ensure the memory is valid while the kernel accesses it. + // Tested by `dont_drop_backing_mem_read` unit test. + self.ctx + .add_readv_iter(iovecs, src.as_raw_fd(), offset, usize_to_u64(next_op_token)) + .map_err(Error::SubmittingOp)?; + } + + entry.insert(OpStatus::Pending(OpData { + _file: src, + _mem: Some(mem), + waker: None, + canceled: false, + })); + + Ok(WakerToken(next_op_token)) + } + + fn submit_write_from_vectored( + &self, + source: &RegisteredSource, + mem: Arc, + offset: u64, + addrs: &[MemRegion], + ) -> Result { + if addrs + .iter() + .any(|&mem_range| mem.get_volatile_slice(mem_range).is_err()) + { + return Err(Error::InvalidOffset); + } + + let mut ring = self.ring.lock(); + let src = ring + .registered_sources + .get(source.tag) + .map(Arc::clone) + .ok_or(Error::InvalidSource)?; + + // We can't insert the OpData into the slab yet because `iovecs` borrows `mem` below. + let entry = ring.ops.vacant_entry(); + let next_op_token = entry.key(); + + // The addresses have already been validated, so unwrapping them will succeed. + // validate their addresses before submitting. + let iovecs = addrs.iter().map(|&mem_range| { + *mem.get_volatile_slice(mem_range) + .unwrap() + .as_iobuf() + .as_ref() + }); + + unsafe { + // Safe because all the addresses are within the Memory that an Arc is kept for the + // duration to ensure the memory is valid while the kernel accesses it. + // Tested by `dont_drop_backing_mem_write` unit test. + self.ctx + .add_writev_iter(iovecs, src.as_raw_fd(), offset, usize_to_u64(next_op_token)) + .map_err(Error::SubmittingOp)?; + } + + entry.insert(OpStatus::Pending(OpData { + _file: src, + _mem: Some(mem), + waker: None, + canceled: false, + })); + + Ok(WakerToken(next_op_token)) + } +} + +impl WeakWake for RawExecutor { + fn wake_by_ref(weak_self: &Weak) { + if let Some(arc_self) = weak_self.upgrade() { + RawExecutor::wake(&arc_self); + } + } +} + +impl Drop for RawExecutor { + fn drop(&mut self) { + // Wake up any futures still waiting on uring operations. + let ring = self.ring.get_mut(); + for (_, op) in ring.ops.iter_mut() { + match op { + OpStatus::Nop => {} + OpStatus::Pending(data) => { + // If the operation wasn't already canceled then wake up the future waiting on + // it. When polled that future will get an ExecutorGone error anyway so there's + // no point in waiting until the operation completes to wake it up. + if !data.canceled { + if let Some(waker) = data.waker.take() { + waker.wake(); + } + } + + data.canceled = true; + } + OpStatus::Completed(_) => {} + } + } + + // Since the RawExecutor is wrapped in an Arc it may end up being dropped from a different + // thread than the one that called `run` or `run_until`. Since we know there are no other + // references, just clear the thread id so that we don't panic. + *self.thread_id.lock() = None; + + // Now run the executor loop once more to poll any futures we just woke up. + let waker = noop_waker(); + let mut cx = Context::from_waker(&waker); + let res = self.run(&mut cx, async {}); + + if let Err(e) = res { + warn!("Failed to drive uring to completion: {}", e); + } + } +} + +/// An executor that uses io_uring for its asynchronous I/O operations. See the documentation of +/// `Executor` for more details about the methods. +#[derive(Clone)] +pub struct URingExecutor { + raw: Arc, +} + +impl URingExecutor { + pub fn new() -> Result { + let raw = RawExecutor::new().map(Arc::new)?; + + Ok(URingExecutor { raw }) + } + + pub fn spawn(&self, f: F) -> Task + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.raw.spawn(f) + } + + pub fn spawn_local(&self, f: F) -> Task + where + F: Future + 'static, + F::Output: 'static, + { + self.raw.spawn_local(f) + } + + pub fn spawn_blocking(&self, f: F) -> Task + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + self.raw.spawn_blocking(f) + } + + pub fn run(&self) -> Result<()> { + let waker = new_waker(Arc::downgrade(&self.raw)); + let mut cx = Context::from_waker(&waker); + + self.raw.run(&mut cx, super::empty::<()>()) + } + + pub fn run_until(&self, f: F) -> Result { + let waker = new_waker(Arc::downgrade(&self.raw)); + let mut ctx = Context::from_waker(&waker); + self.raw.run(&mut ctx, f) + } + + /// Register a file and memory pair for buffered asynchronous operation. + pub(crate) fn register_source(&self, fd: &F) -> Result { + let duped_fd = unsafe { + // Safe because duplicating an FD doesn't affect memory safety, and the dup'd FD + // will only be added to the poll loop. + File::from_raw_fd(dup_fd(fd.as_raw_fd())?) + }; + + Ok(RegisteredSource { + tag: self.raw.register_source(Arc::new(duped_fd)), + ex: Arc::downgrade(&self.raw), + }) + } +} + +// Used to dup the FDs passed to the executor so there is a guarantee they aren't closed while +// waiting in TLS to be added to the main polling context. +unsafe fn dup_fd(fd: RawFd) -> Result { + let ret = libc::fcntl(fd, libc::F_DUPFD_CLOEXEC, 0); + if ret < 0 { + Err(Error::DuplicatingFd(sys_util::Error::last())) + } else { + Ok(ret) + } +} + +// Converts a `usize` into a `u64` and panics if the conversion fails. +#[inline] +fn usize_to_u64(val: usize) -> u64 { + val.try_into().expect("`usize` doesn't fit inside a `u64`") +} + +pub struct PendingOperation { + waker_token: Option, + ex: Weak, + submitted: bool, +} + +impl Future for PendingOperation { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let token = self + .waker_token + .as_ref() + .expect("PendingOperation polled after returning Poll::Ready"); + if let Some(ex) = self.ex.upgrade() { + if let Some(result) = ex.get_result(token, cx) { + self.waker_token = None; + Poll::Ready(result.map_err(Error::Io)) + } else { + // If we haven't submitted the operation yet, and the executor runs on a different + // thread then submit it now. Otherwise the executor will submit it automatically + // the next time it calls UringContext::wait. + if !self.submitted && !ex.runs_tasks_on_current_thread() { + match ex.ctx.submit() { + Ok(()) => self.submitted = true, + // If the kernel ring is full then wait until some ops are removed from the + // completion queue. This op should get submitted the next time the executor + // calls UringContext::wait. + Err(io_uring::Error::RingEnter(libc::EBUSY)) => {} + Err(e) => return Poll::Ready(Err(Error::URingEnter(e))), + } + } + Poll::Pending + } + } else { + Poll::Ready(Err(Error::ExecutorGone)) + } + } +} + +impl Drop for PendingOperation { + fn drop(&mut self) { + if let Some(waker_token) = self.waker_token.take() { + if let Some(ex) = self.ex.upgrade() { + ex.cancel_operation(waker_token); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::{ + future::Future, + io::{Read, Write}, + mem, + pin::Pin, + task::{Context, Poll}, + }; + + use futures::executor::block_on; + + use super::{ + super::mem::{BackingMemory, MemRegion, VecIoWrapper}, + *, + }; + + // A future that returns ready when the uring queue is empty. + struct UringQueueEmpty<'a> { + ex: &'a URingExecutor, + } + + impl<'a> Future for UringQueueEmpty<'a> { + type Output = (); + + fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { + if self.ex.raw.ring.lock().ops.is_empty() { + Poll::Ready(()) + } else { + Poll::Pending + } + } + } + + #[test] + fn dont_drop_backing_mem_read() { + if !use_uring() { + return; + } + + // Create a backing memory wrapped in an Arc and check that the drop isn't called while the + // op is pending. + let bm = + Arc::new(VecIoWrapper::from(vec![0u8; 4096])) as Arc; + + // Use pipes to create a future that will block forever. + let (rx, mut tx) = sys_util::pipe(true).unwrap(); + + // Set up the TLS for the uring_executor by creating one. + let ex = URingExecutor::new().unwrap(); + + // Register the receive side of the pipe with the executor. + let registered_source = ex.register_source(&rx).expect("register source failed"); + + // Submit the op to the kernel. Next, test that the source keeps its Arc open for the duration + // of the op. + let pending_op = registered_source + .start_read_to_mem(0, Arc::clone(&bm), &[MemRegion { offset: 0, len: 8 }]) + .expect("failed to start read to mem"); + + // Here the Arc count must be two, one for `bm` and one to signify that the kernel has a + // reference while the op is active. + assert_eq!(Arc::strong_count(&bm), 2); + + // Dropping the operation shouldn't reduce the Arc count, as the kernel could still be using + // it. + drop(pending_op); + assert_eq!(Arc::strong_count(&bm), 2); + + // Finishing the operation should put the Arc count back to 1. + // write to the pipe to wake the read pipe and then wait for the uring result in the + // executor. + tx.write_all(&[0u8; 8]).expect("write failed"); + ex.run_until(UringQueueEmpty { ex: &ex }) + .expect("Failed to wait for read pipe ready"); + assert_eq!(Arc::strong_count(&bm), 1); + } + + #[test] + fn dont_drop_backing_mem_write() { + if !use_uring() { + return; + } + + // Create a backing memory wrapped in an Arc and check that the drop isn't called while the + // op is pending. + let bm = + Arc::new(VecIoWrapper::from(vec![0u8; 4096])) as Arc; + + // Use pipes to create a future that will block forever. + let (mut rx, tx) = sys_util::new_pipe_full().expect("Pipe failed"); + + // Set up the TLS for the uring_executor by creating one. + let ex = URingExecutor::new().unwrap(); + + // Register the receive side of the pipe with the executor. + let registered_source = ex.register_source(&tx).expect("register source failed"); + + // Submit the op to the kernel. Next, test that the source keeps its Arc open for the duration + // of the op. + let pending_op = registered_source + .start_write_from_mem(0, Arc::clone(&bm), &[MemRegion { offset: 0, len: 8 }]) + .expect("failed to start write to mem"); + + // Here the Arc count must be two, one for `bm` and one to signify that the kernel has a + // reference while the op is active. + assert_eq!(Arc::strong_count(&bm), 2); + + // Dropping the operation shouldn't reduce the Arc count, as the kernel could still be using + // it. + drop(pending_op); + assert_eq!(Arc::strong_count(&bm), 2); + + // Finishing the operation should put the Arc count back to 1. + // write to the pipe to wake the read pipe and then wait for the uring result in the + // executor. + let mut buf = vec![0u8; sys_util::round_up_to_page_size(1)]; + rx.read_exact(&mut buf).expect("read to empty failed"); + ex.run_until(UringQueueEmpty { ex: &ex }) + .expect("Failed to wait for write pipe ready"); + assert_eq!(Arc::strong_count(&bm), 1); + } + + #[test] + fn canceled_before_completion() { + if !use_uring() { + return; + } + + async fn cancel_io(op: PendingOperation) { + mem::drop(op); + } + + async fn check_result(op: PendingOperation, expected: u32) { + let actual = op.await.expect("operation failed to complete"); + assert_eq!(expected, actual); + } + + let bm = + Arc::new(VecIoWrapper::from(vec![0u8; 16])) as Arc; + + let (rx, tx) = sys_util::pipe(true).expect("Pipe failed"); + + let ex = URingExecutor::new().unwrap(); + + let rx_source = ex.register_source(&rx).expect("register source failed"); + let tx_source = ex.register_source(&tx).expect("register source failed"); + + let read_task = rx_source + .start_read_to_mem(0, Arc::clone(&bm), &[MemRegion { offset: 0, len: 8 }]) + .expect("failed to start read to mem"); + + ex.spawn_local(cancel_io(read_task)).detach(); + + // Write to the pipe so that the kernel operation will complete. + let buf = + Arc::new(VecIoWrapper::from(vec![0xc2u8; 16])) as Arc; + let write_task = tx_source + .start_write_from_mem(0, Arc::clone(&buf), &[MemRegion { offset: 0, len: 8 }]) + .expect("failed to start write from mem"); + + ex.run_until(check_result(write_task, 8)) + .expect("Failed to run executor"); + } + + #[test] + fn drop_before_completion() { + if !use_uring() { + return; + } + + const VALUE: u64 = 0xef6c_a8df_b842_eb9c; + + async fn check_op(op: PendingOperation) { + let err = op.await.expect_err("Op completed successfully"); + match err { + Error::ExecutorGone => {} + e => panic!("Unexpected error from op: {}", e), + } + } + + let (mut rx, mut tx) = sys_util::pipe(true).expect("Pipe failed"); + + let ex = URingExecutor::new().unwrap(); + + let tx_source = ex.register_source(&tx).expect("Failed to register source"); + let bm = Arc::new(VecIoWrapper::from(VALUE.to_ne_bytes().to_vec())); + let op = tx_source + .start_write_from_mem( + 0, + bm, + &[MemRegion { + offset: 0, + len: mem::size_of::(), + }], + ) + .expect("Failed to start write from mem"); + + ex.spawn_local(check_op(op)).detach(); + + // Now drop the executor. It shouldn't run the write operation. + mem::drop(ex); + + // Make sure the executor did not complete the uring operation. + let new_val = [0x2e; 8]; + tx.write_all(&new_val).unwrap(); + + let mut buf = 0u64.to_ne_bytes(); + rx.read_exact(&mut buf[..]) + .expect("Failed to read from pipe"); + + assert_eq!(buf, new_val); + } + + #[test] + fn drop_on_different_thread() { + if !use_uring() { + return; + } + + let ex = URingExecutor::new().unwrap(); + + let ex2 = ex.clone(); + let t = thread::spawn(move || ex2.run_until(async {})); + + t.join().unwrap().unwrap(); + + // Leave an uncompleted operation in the queue so that the drop impl will try to drive it to + // completion. + let (_rx, tx) = sys_util::pipe(true).expect("Pipe failed"); + let tx = ex.register_source(&tx).expect("Failed to register source"); + let bm = Arc::new(VecIoWrapper::from(0xf2e96u64.to_ne_bytes().to_vec())); + let op = tx + .start_write_from_mem( + 0, + bm, + &[MemRegion { + offset: 0, + len: mem::size_of::(), + }], + ) + .expect("Failed to start write from mem"); + + mem::drop(ex); + + match block_on(op).expect_err("Pending operation completed after executor was dropped") { + Error::ExecutorGone => {} + e => panic!("Unexpected error after dropping executor: {}", e), + } + } +} diff --git a/cros_async/src/uring_source.rs b/cros_async/src/uring_source.rs new file mode 100644 index 0000000000..c5d44452c8 --- /dev/null +++ b/cros_async/src/uring_source.rs @@ -0,0 +1,643 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::{ + convert::TryInto, + io, + ops::{Deref, DerefMut}, + os::unix::io::AsRawFd, + sync::Arc, +}; + +use async_trait::async_trait; + +use super::{ + mem::{BackingMemory, MemRegion, VecIoWrapper}, + uring_executor::{Error, RegisteredSource, Result, URingExecutor}, + AsyncError, AsyncResult, +}; + +/// `UringSource` wraps FD backed IO sources for use with io_uring. It is a thin wrapper around +/// registering an IO source with the uring that provides an `IoSource` implementation. +/// Most useful functions are provided by 'IoSourceExt'. +pub struct UringSource { + registered_source: RegisteredSource, + source: F, +} + +impl UringSource { + /// Creates a new `UringSource` that wraps the given `io_source` object. + pub fn new(io_source: F, ex: &URingExecutor) -> Result> { + let r = ex.register_source(&io_source)?; + Ok(UringSource { + registered_source: r, + source: io_source, + }) + } + + /// Consume `self` and return the object used to create it. + pub fn into_source(self) -> F { + self.source + } +} + +#[async_trait(?Send)] +impl super::ReadAsync for UringSource { + /// Reads from the iosource at `file_offset` and fill the given `vec`. + async fn read_to_vec<'a>( + &'a self, + file_offset: Option, + vec: Vec, + ) -> AsyncResult<(usize, Vec)> { + let buf = Arc::new(VecIoWrapper::from(vec)); + let op = self.registered_source.start_read_to_mem( + file_offset.unwrap_or(0), + buf.clone(), + &[MemRegion { + offset: 0, + len: buf.len(), + }], + )?; + let len = op.await?; + let bytes = if let Ok(v) = Arc::try_unwrap(buf) { + v.into() + } else { + panic!("too many refs on buf"); + }; + + Ok((len as usize, bytes)) + } + + /// Wait for the FD of `self` to be readable. + async fn wait_readable(&self) -> AsyncResult<()> { + let op = self.registered_source.poll_fd_readable()?; + op.await?; + Ok(()) + } + + /// Reads a single u64 (e.g. from an eventfd). + async fn read_u64(&self) -> AsyncResult { + // This doesn't just forward to read_to_vec to avoid an unnecessary extra allocation from + // async-trait. + let buf = Arc::new(VecIoWrapper::from(0u64.to_ne_bytes().to_vec())); + let op = self.registered_source.start_read_to_mem( + 0, + buf.clone(), + &[MemRegion { + offset: 0, + len: buf.len(), + }], + )?; + let len = op.await?; + if len != buf.len() as u32 { + Err(AsyncError::Uring(Error::Io(io::Error::new( + io::ErrorKind::Other, + format!("expected to read {} bytes, but read {}", buf.len(), len), + )))) + } else { + let bytes: Vec = if let Ok(v) = Arc::try_unwrap(buf) { + v.into() + } else { + panic!("too many refs on buf"); + }; + + // Will never panic because bytes is of the appropriate size. + Ok(u64::from_ne_bytes(bytes[..].try_into().unwrap())) + } + } + + /// Reads to the given `mem` at the given offsets from the file starting at `file_offset`. + async fn read_to_mem<'a>( + &'a self, + file_offset: Option, + mem: Arc, + mem_offsets: &'a [MemRegion], + ) -> AsyncResult { + let op = + self.registered_source + .start_read_to_mem(file_offset.unwrap_or(0), mem, mem_offsets)?; + let len = op.await?; + Ok(len as usize) + } +} + +#[async_trait(?Send)] +impl super::WriteAsync for UringSource { + /// Writes from the given `vec` to the file starting at `file_offset`. + async fn write_from_vec<'a>( + &'a self, + file_offset: Option, + vec: Vec, + ) -> AsyncResult<(usize, Vec)> { + let buf = Arc::new(VecIoWrapper::from(vec)); + let op = self.registered_source.start_write_from_mem( + file_offset.unwrap_or(0), + buf.clone(), + &[MemRegion { + offset: 0, + len: buf.len(), + }], + )?; + let len = op.await?; + let bytes = if let Ok(v) = Arc::try_unwrap(buf) { + v.into() + } else { + panic!("too many refs on buf"); + }; + + Ok((len as usize, bytes)) + } + + /// Writes from the given `mem` from the given offsets to the file starting at `file_offset`. + async fn write_from_mem<'a>( + &'a self, + file_offset: Option, + mem: Arc, + mem_offsets: &'a [MemRegion], + ) -> AsyncResult { + let op = self.registered_source.start_write_from_mem( + file_offset.unwrap_or(0), + mem, + mem_offsets, + )?; + let len = op.await?; + Ok(len as usize) + } + + /// See `fallocate(2)`. Note this op is synchronous when using the Polled backend. + async fn fallocate(&self, file_offset: u64, len: u64, mode: u32) -> AsyncResult<()> { + let op = self + .registered_source + .start_fallocate(file_offset, len, mode)?; + let _ = op.await?; + Ok(()) + } + + /// Sync all completed write operations to the backing storage. + async fn fsync(&self) -> AsyncResult<()> { + let op = self.registered_source.start_fsync()?; + let _ = op.await?; + Ok(()) + } +} + +#[async_trait(?Send)] +impl super::IoSourceExt for UringSource { + /// Yields the underlying IO source. + fn into_source(self: Box) -> F { + self.source + } + + /// Provides a mutable ref to the underlying IO source. + fn as_source(&self) -> &F { + &self.source + } + + /// Provides a ref to the underlying IO source. + fn as_source_mut(&mut self) -> &mut F { + &mut self.source + } +} + +impl Deref for UringSource { + type Target = F; + + fn deref(&self) -> &Self::Target { + &self.source + } +} + +impl DerefMut for UringSource { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.source + } +} + +#[cfg(test)] +mod tests { + use std::{ + fs::{File, OpenOptions}, + os::unix::io::AsRawFd, + path::PathBuf, + }; + + use super::super::{ + io_ext::{ReadAsync, WriteAsync}, + uring_executor::use_uring, + UringSource, + }; + + use super::*; + + #[test] + fn read_to_mem() { + if !use_uring() { + return; + } + + use super::super::mem::VecIoWrapper; + use std::io::Write; + use tempfile::tempfile; + + let ex = URingExecutor::new().unwrap(); + // Use guest memory as a test file, it implements AsRawFd. + let mut source = tempfile().unwrap(); + let data = vec![0x55; 8192]; + source.write_all(&data).unwrap(); + + let io_obj = UringSource::new(source, &ex).unwrap(); + + // Start with memory filled with 0x44s. + let buf: Arc = Arc::new(VecIoWrapper::from(vec![0x44; 8192])); + + let fut = io_obj.read_to_mem( + None, + Arc::::clone(&buf), + &[MemRegion { + offset: 0, + len: 8192, + }], + ); + assert_eq!(8192, ex.run_until(fut).unwrap().unwrap()); + let vec: Vec = match Arc::try_unwrap(buf) { + Ok(v) => v.into(), + Err(_) => panic!("Too many vec refs"), + }; + assert!(vec.iter().all(|&b| b == 0x55)); + } + + #[test] + fn readvec() { + if !use_uring() { + return; + } + + async fn go(ex: &URingExecutor) { + let f = File::open("/dev/zero").unwrap(); + let source = UringSource::new(f, ex).unwrap(); + let v = vec![0x55u8; 32]; + let v_ptr = v.as_ptr(); + let ret = source.read_to_vec(None, v).await.unwrap(); + assert_eq!(ret.0, 32); + let ret_v = ret.1; + assert_eq!(v_ptr, ret_v.as_ptr()); + assert!(ret_v.iter().all(|&b| b == 0)); + } + + let ex = URingExecutor::new().unwrap(); + ex.run_until(go(&ex)).unwrap(); + } + + #[test] + fn readmulti() { + if !use_uring() { + return; + } + + async fn go(ex: &URingExecutor) { + let f = File::open("/dev/zero").unwrap(); + let source = UringSource::new(f, ex).unwrap(); + let v = vec![0x55u8; 32]; + let v2 = vec![0x55u8; 32]; + let (ret, ret2) = futures::future::join( + source.read_to_vec(None, v), + source.read_to_vec(Some(32), v2), + ) + .await; + + assert!(ret.unwrap().1.iter().all(|&b| b == 0)); + assert!(ret2.unwrap().1.iter().all(|&b| b == 0)); + } + + let ex = URingExecutor::new().unwrap(); + ex.run_until(go(&ex)).unwrap(); + } + + async fn read_u64(source: &UringSource) -> u64 { + // Init a vec that translates to u64::max; + let u64_mem = vec![0xffu8; std::mem::size_of::()]; + let (ret, u64_mem) = source.read_to_vec(None, u64_mem).await.unwrap(); + assert_eq!(ret as usize, std::mem::size_of::()); + let mut val = 0u64.to_ne_bytes(); + val.copy_from_slice(&u64_mem); + u64::from_ne_bytes(val) + } + + #[test] + fn u64_from_file() { + if !use_uring() { + return; + } + + let f = File::open("/dev/zero").unwrap(); + let ex = URingExecutor::new().unwrap(); + let source = UringSource::new(f, &ex).unwrap(); + + assert_eq!(0u64, ex.run_until(read_u64(&source)).unwrap()); + } + + #[test] + fn event() { + if !use_uring() { + return; + } + + use sys_util::EventFd; + + async fn write_event(ev: EventFd, wait: EventFd, ex: &URingExecutor) { + let wait = UringSource::new(wait, ex).unwrap(); + ev.write(55).unwrap(); + read_u64(&wait).await; + ev.write(66).unwrap(); + read_u64(&wait).await; + ev.write(77).unwrap(); + read_u64(&wait).await; + } + + async fn read_events(ev: EventFd, signal: EventFd, ex: &URingExecutor) { + let source = UringSource::new(ev, ex).unwrap(); + assert_eq!(read_u64(&source).await, 55); + signal.write(1).unwrap(); + assert_eq!(read_u64(&source).await, 66); + signal.write(1).unwrap(); + assert_eq!(read_u64(&source).await, 77); + signal.write(1).unwrap(); + } + + let event = EventFd::new().unwrap(); + let signal_wait = EventFd::new().unwrap(); + let ex = URingExecutor::new().unwrap(); + let write_task = write_event( + event.try_clone().unwrap(), + signal_wait.try_clone().unwrap(), + &ex, + ); + let read_task = read_events(event, signal_wait, &ex); + ex.run_until(futures::future::join(read_task, write_task)) + .unwrap(); + } + + #[test] + fn pend_on_pipe() { + if !use_uring() { + return; + } + + use std::io::Write; + + use futures::future::Either; + + async fn do_test(ex: &URingExecutor) { + let (read_source, mut w) = sys_util::pipe(true).unwrap(); + let source = UringSource::new(read_source, ex).unwrap(); + let done = Box::pin(async { 5usize }); + let pending = Box::pin(read_u64(&source)); + match futures::future::select(pending, done).await { + Either::Right((5, pending)) => { + // Write to the pipe so that the kernel will release the memory associated with + // the uring read operation. + w.write_all(&[0]).expect("failed to write to pipe"); + ::std::mem::drop(pending); + } + _ => panic!("unexpected select result"), + }; + } + + let ex = URingExecutor::new().unwrap(); + ex.run_until(do_test(&ex)).unwrap(); + } + + #[test] + fn readmem() { + if !use_uring() { + return; + } + + async fn go(ex: &URingExecutor) { + let f = File::open("/dev/zero").unwrap(); + let source = UringSource::new(f, ex).unwrap(); + let v = vec![0x55u8; 64]; + let vw = Arc::new(VecIoWrapper::from(v)); + let ret = source + .read_to_mem( + None, + Arc::::clone(&vw), + &[MemRegion { offset: 0, len: 32 }], + ) + .await + .unwrap(); + assert_eq!(32, ret); + let vec: Vec = match Arc::try_unwrap(vw) { + Ok(v) => v.into(), + Err(_) => panic!("Too many vec refs"), + }; + assert!(vec.iter().take(32).all(|&b| b == 0)); + assert!(vec.iter().skip(32).all(|&b| b == 0x55)); + + // test second half of memory too. + let v = vec![0x55u8; 64]; + let vw = Arc::new(VecIoWrapper::from(v)); + let ret = source + .read_to_mem( + None, + Arc::::clone(&vw), + &[MemRegion { + offset: 32, + len: 32, + }], + ) + .await + .unwrap(); + assert_eq!(32, ret); + let v: Vec = match Arc::try_unwrap(vw) { + Ok(v) => v.into(), + Err(_) => panic!("Too many vec refs"), + }; + assert!(v.iter().take(32).all(|&b| b == 0x55)); + assert!(v.iter().skip(32).all(|&b| b == 0)); + } + + let ex = URingExecutor::new().unwrap(); + ex.run_until(go(&ex)).unwrap(); + } + + #[test] + fn range_error() { + if !use_uring() { + return; + } + + async fn go(ex: &URingExecutor) { + let f = File::open("/dev/zero").unwrap(); + let source = UringSource::new(f, ex).unwrap(); + let v = vec![0x55u8; 64]; + let vw = Arc::new(VecIoWrapper::from(v)); + let ret = source + .read_to_mem( + None, + Arc::::clone(&vw), + &[MemRegion { + offset: 32, + len: 33, + }], + ) + .await; + assert!(ret.is_err()); + } + + let ex = URingExecutor::new().unwrap(); + ex.run_until(go(&ex)).unwrap(); + } + + #[test] + fn fallocate() { + if !use_uring() { + return; + } + + async fn go(ex: &URingExecutor) { + let dir = tempfile::TempDir::new().unwrap(); + let mut file_path = PathBuf::from(dir.path()); + file_path.push("test"); + + let f = OpenOptions::new() + .create(true) + .write(true) + .open(&file_path) + .unwrap(); + let source = UringSource::new(f, ex).unwrap(); + if let Err(e) = source.fallocate(0, 4096, 0).await { + match e { + super::super::io_ext::Error::Uring( + super::super::uring_executor::Error::Io(io_err), + ) => { + if io_err.kind() == std::io::ErrorKind::InvalidInput { + // Skip the test on kernels before fallocate support. + return; + } + } + _ => panic!("Unexpected uring error on fallocate: {}", e), + } + } + + let meta_data = std::fs::metadata(&file_path).unwrap(); + assert_eq!(meta_data.len(), 4096); + } + + let ex = URingExecutor::new().unwrap(); + ex.run_until(go(&ex)).unwrap(); + } + + #[test] + fn fsync() { + if !use_uring() { + return; + } + + async fn go(ex: &URingExecutor) { + let f = tempfile::tempfile().unwrap(); + let source = UringSource::new(f, ex).unwrap(); + source.fsync().await.unwrap(); + } + + let ex = URingExecutor::new().unwrap(); + ex.run_until(go(&ex)).unwrap(); + } + + #[test] + fn wait_read() { + if !use_uring() { + return; + } + + async fn go(ex: &URingExecutor) { + let f = File::open("/dev/zero").unwrap(); + let source = UringSource::new(f, ex).unwrap(); + source.wait_readable().await.unwrap(); + } + + let ex = URingExecutor::new().unwrap(); + ex.run_until(go(&ex)).unwrap(); + } + + #[test] + fn writemem() { + if !use_uring() { + return; + } + + async fn go(ex: &URingExecutor) { + let f = OpenOptions::new() + .create(true) + .write(true) + .open("/tmp/write_from_vec") + .unwrap(); + let source = UringSource::new(f, ex).unwrap(); + let v = vec![0x55u8; 64]; + let vw = Arc::new(super::super::mem::VecIoWrapper::from(v)); + let ret = source + .write_from_mem(None, vw, &[MemRegion { offset: 0, len: 32 }]) + .await + .unwrap(); + assert_eq!(32, ret); + } + + let ex = URingExecutor::new().unwrap(); + ex.run_until(go(&ex)).unwrap(); + } + + #[test] + fn writevec() { + if !use_uring() { + return; + } + + async fn go(ex: &URingExecutor) { + let f = OpenOptions::new() + .create(true) + .truncate(true) + .write(true) + .open("/tmp/write_from_vec") + .unwrap(); + let source = UringSource::new(f, ex).unwrap(); + let v = vec![0x55u8; 32]; + let v_ptr = v.as_ptr(); + let (ret, ret_v) = source.write_from_vec(None, v).await.unwrap(); + assert_eq!(32, ret); + assert_eq!(v_ptr, ret_v.as_ptr()); + } + + let ex = URingExecutor::new().unwrap(); + ex.run_until(go(&ex)).unwrap(); + } + + #[test] + fn writemulti() { + if !use_uring() { + return; + } + + async fn go(ex: &URingExecutor) { + let f = OpenOptions::new() + .create(true) + .truncate(true) + .write(true) + .open("/tmp/write_from_vec") + .unwrap(); + let source = UringSource::new(f, ex).unwrap(); + let v = vec![0x55u8; 32]; + let v2 = vec![0x55u8; 32]; + let (r, r2) = futures::future::join( + source.write_from_vec(None, v), + source.write_from_vec(Some(32), v2), + ) + .await; + assert_eq!(32, r.unwrap().0); + assert_eq!(32, r2.unwrap().0); + } + + let ex = URingExecutor::new().unwrap(); + ex.run_until(go(&ex)).unwrap(); + } +} diff --git a/cros_async/src/waker.rs b/cros_async/src/waker.rs new file mode 100644 index 0000000000..5daeb57710 --- /dev/null +++ b/cros_async/src/waker.rs @@ -0,0 +1,70 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::{ + mem::{drop, ManuallyDrop}, + sync::Weak, + task::{RawWaker, RawWakerVTable, Waker}, +}; + +/// Wrapper around a usize used as a token to uniquely identify a pending waker. +#[derive(Debug)] +pub(crate) struct WakerToken(pub(crate) usize); + +/// Like `futures::task::ArcWake` but uses `Weak` instead of `Arc`. +pub(crate) trait WeakWake: Send + Sync { + fn wake_by_ref(weak_self: &Weak); + + fn wake(weak_self: Weak) { + Self::wake_by_ref(&weak_self) + } +} + +fn waker_vtable() -> &'static RawWakerVTable { + &RawWakerVTable::new( + clone_weak_raw::, + wake_weak_raw::, + wake_by_ref_weak_raw::, + drop_weak_raw::, + ) +} + +unsafe fn clone_weak_raw(data: *const ()) -> RawWaker { + // Get a handle to the Weak but wrap it in a ManuallyDrop so that we don't reduce the + // refcount at the end of this function. + let weak = ManuallyDrop::new(Weak::::from_raw(data as *const W)); + + // Now increase the weak count and keep it in a ManuallyDrop so that it doesn't get decreased + // at the end of this function. + let _weak_clone: ManuallyDrop<_> = weak.clone(); + + RawWaker::new(data, waker_vtable::()) +} + +unsafe fn wake_weak_raw(data: *const ()) { + let weak: Weak = Weak::from_raw(data as *const W); + + WeakWake::wake(weak) +} + +unsafe fn wake_by_ref_weak_raw(data: *const ()) { + // Get a handle to the Weak but wrap it in a ManuallyDrop so that we don't reduce the + // refcount at the end of this function. + let weak = ManuallyDrop::new(Weak::::from_raw(data as *const W)); + + WeakWake::wake_by_ref(&weak) +} + +unsafe fn drop_weak_raw(data: *const ()) { + drop(Weak::from_raw(data as *const W)) +} + +pub(crate) fn new_waker(w: Weak) -> Waker { + unsafe { + Waker::from_raw(RawWaker::new( + w.into_raw() as *const (), + waker_vtable::(), + )) + } +} diff --git a/devices/Cargo.toml b/devices/Cargo.toml index 5ebd84658e..7594fa725a 100644 --- a/devices/Cargo.toml +++ b/devices/Cargo.toml @@ -28,7 +28,7 @@ audio_streams = "*" balloon_control = { path = "../common/balloon_control" } base = { path = "../base" } bit_field = { path = "../bit_field" } -cros_async = { path = "../common/cros_async" } +cros_async = { path = "../cros_async" } data_model = { path = "../common/data_model" } dbus = { version = "0.9", optional = true } disk = { path = "../disk" } diff --git a/disk/Cargo.toml b/disk/Cargo.toml index a72d42a49f..941e694524 100644 --- a/disk/Cargo.toml +++ b/disk/Cargo.toml @@ -20,7 +20,7 @@ remain = "*" tempfile = "3" thiserror = "*" uuid = { version = "0.8.2", features = ["v4"], optional = true } -cros_async = { path = "../common/cros_async" } +cros_async = { path = "../cros_async" } data_model = { path = "../common/data_model" } protos = { path = "../protos", features = ["composite-disk"], optional = true } vm_memory = { path = "../vm_memory" } diff --git a/net_util/Cargo.toml b/net_util/Cargo.toml index 62f7972a99..972500edf4 100644 --- a/net_util/Cargo.toml +++ b/net_util/Cargo.toml @@ -9,6 +9,6 @@ libc = "*" data_model = { path = "../common/data_model" } net_sys = { path = "../net_sys" } base = { path = "../base" } -cros_async = { path = "../common/cros_async" } +cros_async = { path = "../cros_async" } remain = "*" thiserror = "*" diff --git a/vm_memory/Cargo.toml b/vm_memory/Cargo.toml index e4c2ec7acf..fc32d67261 100644 --- a/vm_memory/Cargo.toml +++ b/vm_memory/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" include = ["src/**/*", "Cargo.toml"] [dependencies] -cros_async = { path = "../common/cros_async" } +cros_async = { path = "../cros_async" } data_model = { path = "../common/data_model" } libc = "*" base = { path = "../base" }