From 784ed62943e1037fd9b65c16075743ce061debfb Mon Sep 17 00:00:00 2001 From: Frederick Mayle Date: Mon, 13 May 2024 14:00:48 -0700 Subject: [PATCH] base: move cros_async's WaitForHandle into base This functionality will also be used by a pure Tokio library for crosvm. The API is now a single async function and is named to match the underlying Windows API. The unit tests have been rewritten to not require an async executor and should have a bit more test coverage. BUG=b:338274203 Change-Id: Iff65ca088c74ce5c5ce396fdf7dcd9f1550ea545 Reviewed-on: https://chromium-review.googlesource.com/c/crosvm/crosvm/+/5536313 Commit-Queue: Frederick Mayle Reviewed-by: Noah Gold --- Cargo.lock | 1 + base/Cargo.toml | 1 + .../windows/async_wait_for_single_object.rs | 123 ++++++++++-------- base/src/sys/windows/mod.rs | 2 + cros_async/src/sys/windows.rs | 2 - cros_async/src/sys/windows/handle_source.rs | 10 +- .../src/sys/windows/overlapped_source.rs | 6 +- cros_async/src/sys/windows/tokio_source.rs | 4 +- 8 files changed, 84 insertions(+), 65 deletions(-) rename cros_async/src/sys/windows/wait_for_handle.rs => base/src/sys/windows/async_wait_for_single_object.rs (73%) diff --git a/Cargo.lock b/Cargo.lock index d61f567966..f3a571ee88 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -263,6 +263,7 @@ dependencies = [ "cfg-if", "chrono", "env_logger", + "futures", "libc", "libtest-mimic", "log", diff --git a/base/Cargo.toml b/base/Cargo.toml index 1f4910e464..566965e510 100644 --- a/base/Cargo.toml +++ b/base/Cargo.toml @@ -44,6 +44,7 @@ tempfile = "3" minijail = "*" [target.'cfg(windows)'.dependencies] +futures = { version = "0.3" } protobuf = "3.2" rand = "0.8" winapi = "*" diff --git a/cros_async/src/sys/windows/wait_for_handle.rs b/base/src/sys/windows/async_wait_for_single_object.rs similarity index 73% rename from cros_async/src/sys/windows/wait_for_handle.rs rename to base/src/sys/windows/async_wait_for_single_object.rs index e0afa642d0..c4ea3d2425 100644 --- a/cros_async/src/sys/windows/wait_for_handle.rs +++ b/base/src/sys/windows/async_wait_for_single_object.rs @@ -1,9 +1,12 @@ -// Copyright 2021 The ChromiumOS Authors +// Copyright 2024 The ChromiumOS Authors // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. use std::ffi::c_void; use std::future::Future; +use std::io::Error; +use std::io::ErrorKind; +use std::io::Result; use std::marker::PhantomData; use std::marker::PhantomPinned; use std::pin::Pin; @@ -13,10 +16,6 @@ use std::task::Context; use std::task::Poll; use std::task::Waker; -use base::error; -use base::warn; -use base::AsRawDescriptor; -use base::Descriptor; use sync::Mutex; use winapi::shared::ntdef::FALSE; use winapi::um::handleapi::INVALID_HANDLE_VALUE; @@ -27,8 +26,19 @@ use winapi::um::winnt::BOOLEAN; use winapi::um::winnt::PVOID; use winapi::um::winnt::WT_EXECUTEONLYONCE; -use crate::sys::windows::handle_source::Error; -use crate::sys::windows::handle_source::Result; +use crate::error; +use crate::warn; +use crate::AsRawDescriptor; +use crate::Descriptor; + +/// Async wrapper around `RegisterWaitForSingleObject`. See the official documentation of that +/// function for a list of supported object types. +/// +/// The implementation is not tied to any specific async runtime. `Waker::wake` gets invoked from +/// an OS maintained thread pool when the object becomes readable. +pub async fn async_wait_for_single_object(source: &impl AsRawDescriptor) -> Result<()> { + WaitForHandle::new(source).await +} /// Inner state shared between the future struct & the kernel invoked waiter callback. struct WaitForHandleInner { @@ -57,8 +67,7 @@ enum WaitState { Failed, } -/// Waits for an object with a handle to be readable. -pub struct WaitForHandle<'a, T: AsRawDescriptor> { +struct WaitForHandle<'a, T: AsRawDescriptor> { handle: Descriptor, inner: Mutex, _marker: PhantomData<&'a T>, @@ -69,7 +78,7 @@ impl<'a, T> WaitForHandle<'a, T> where T: AsRawDescriptor, { - pub fn new(source: &'a T) -> WaitForHandle<'a, T> { + fn new(source: &'a T) -> WaitForHandle<'a, T> { WaitForHandle { handle: Descriptor(source.as_raw_descriptor()), inner: Mutex::new(WaitForHandleInner::new()), @@ -106,7 +115,7 @@ where ) }; if err == 0 { - return Poll::Ready(Err(Error::HandleWaitFailed(base::Error::last()))); + return Poll::Ready(Err(Error::last_os_error())); } inner.wait_state = WaitState::Sleeping; @@ -138,7 +147,9 @@ where Poll::Ready(Ok(())) } - WaitState::Aborted => Poll::Ready(Err(Error::OperationAborted)), + WaitState::Aborted => { + Poll::Ready(Err(Error::new(ErrorKind::Other, "operation aborted"))) + } WaitState::Finished => panic!("polled an already completed WaitForHandle future."), WaitState::Failed => { panic!("WaitForHandle future's waiter callback hit unexpected behavior.") @@ -237,77 +248,79 @@ unsafe fn unregister_wait(desc: Descriptor) { if UnregisterWaitEx(desc.0, INVALID_HANDLE_VALUE) == 0 { warn!( "WaitForHandle: failed to clean up RegisterWaitForSingleObject wait handle: {}", - base::Error::last() + Error::last_os_error() ) } } #[cfg(test)] mod tests { + use std::sync::mpsc; use std::sync::Arc; - use std::sync::Weak; use std::time::Duration; - use base::thread::spawn_with_timeout; - use base::Event; - use futures::pin_mut; - use super::*; - use crate::waker::new_waker; - use crate::waker::WeakWake; - use crate::EventAsync; - use crate::Executor; + use crate::Event; - struct FakeWaker {} - impl WeakWake for FakeWaker { - fn wake_by_ref(_weak_self: &Weak) { - // Do nothing. + struct SimpleWaker { + wake_tx: mpsc::Sender<()>, + } + + impl futures::task::ArcWake for SimpleWaker { + fn wake_by_ref(arc_self: &Arc) { + let _ = arc_self.wake_tx.send(()); } } #[test] fn test_unsignaled_event() { - async fn wait_on_unsignaled_event(evt: EventAsync) { - evt.next_val().await.unwrap(); - panic!("await should never terminate"); - } - - let fake_waker = Arc::new(FakeWaker {}); - let waker = new_waker(Arc::downgrade(&fake_waker)); + let (wake_tx, _wake_rx) = mpsc::channel(); + let waker = futures::task::waker(Arc::new(SimpleWaker { wake_tx })); let mut cx = Context::from_waker(&waker); - let ex = Executor::new().unwrap(); let evt = Event::new().unwrap(); - let async_evt = EventAsync::new(evt, &ex).unwrap(); - - let fut = wait_on_unsignaled_event(async_evt); - pin_mut!(fut); - + let mut fut = std::pin::pin!(async { async_wait_for_single_object(&evt).await.unwrap() }); // Assert we make it to the pending state. This means we've registered a wait. - assert_eq!(fut.poll(&mut cx), Poll::Pending); + assert_eq!(fut.as_mut().poll(&mut cx), Poll::Pending); // If this test doesn't crash trying to drop the future, it is considered successful. } #[test] fn test_signaled_event() { - let join_handle = spawn_with_timeout(|| { - async fn wait_on_signaled_event(evt: EventAsync) { - evt.next_val().await.unwrap(); - } + let (wake_tx, wake_rx) = mpsc::channel(); + let waker = futures::task::waker(Arc::new(SimpleWaker { wake_tx })); + let mut cx = Context::from_waker(&waker); - let ex = Executor::new().unwrap(); - let evt = Event::new().unwrap(); - evt.signal().unwrap(); - let async_evt = EventAsync::new(evt, &ex).unwrap(); + let evt = Event::new().unwrap(); + let mut fut = std::pin::pin!(async { async_wait_for_single_object(&evt).await.unwrap() }); + // Should be pending. + assert_eq!(fut.as_mut().poll(&mut cx), Poll::Pending); + // Should still be pending. + assert_eq!(fut.as_mut().poll(&mut cx), Poll::Pending); + // Signal, wait for wake, then the future should be ready. + evt.signal().unwrap(); + wake_rx + .recv_timeout(Duration::from_secs(5)) + .expect("timeout waiting for wake"); + assert_eq!(fut.as_mut().poll(&mut cx), Poll::Ready(())); + } - let fut = wait_on_signaled_event(async_evt); - pin_mut!(fut); + #[test] + fn test_already_signaled_event() { + let (wake_tx, wake_rx) = mpsc::channel(); + let waker = futures::task::waker(Arc::new(SimpleWaker { wake_tx })); + let mut cx = Context::from_waker(&waker); - ex.run_until(fut).unwrap(); - }); - join_handle - .try_join(Duration::from_secs(5)) - .expect("async wait never returned from signaled event."); + let evt = Event::new().unwrap(); + evt.signal().unwrap(); + let mut fut = std::pin::pin!(async { async_wait_for_single_object(&evt).await.unwrap() }); + // First call is always pending and registers the wait. + assert_eq!(fut.as_mut().poll(&mut cx), Poll::Pending); + // Wait for wake, then the future should be ready. + wake_rx + .recv_timeout(Duration::from_secs(5)) + .expect("timeout waiting for wake"); + assert_eq!(fut.as_mut().poll(&mut cx), Poll::Ready(())); } } diff --git a/base/src/sys/windows/mod.rs b/base/src/sys/windows/mod.rs index 1a1495c81f..6c0bdc9d4d 100644 --- a/base/src/sys/windows/mod.rs +++ b/base/src/sys/windows/mod.rs @@ -10,6 +10,7 @@ pub mod ioctl; #[macro_use] pub mod syslog; +mod async_wait_for_single_object; mod console; mod descriptor; mod event; @@ -42,6 +43,7 @@ pub mod thread; mod write_zeroes; +pub use async_wait_for_single_object::async_wait_for_single_object; pub use console::*; pub use descriptor::*; pub use event::*; diff --git a/cros_async/src/sys/windows.rs b/cros_async/src/sys/windows.rs index 315426f331..bb92ffca68 100644 --- a/cros_async/src/sys/windows.rs +++ b/cros_async/src/sys/windows.rs @@ -13,7 +13,6 @@ pub mod overlapped_source; mod timer; #[cfg(feature = "tokio")] pub mod tokio_source; -pub mod wait_for_handle; pub use error::AsyncErrorSys; pub use executor::ExecutorKindSys; @@ -21,7 +20,6 @@ pub use handle_executor::HandleReactor; pub use handle_source::HandleSource; pub use handle_source::HandleWrapper; pub use overlapped_source::OverlappedSource; -pub(crate) use wait_for_handle::WaitForHandle; use crate::Error; diff --git a/cros_async/src/sys/windows/handle_source.rs b/cros_async/src/sys/windows/handle_source.rs index 51c9a9e6a5..13553e5b5c 100644 --- a/cros_async/src/sys/windows/handle_source.rs +++ b/cros_async/src/sys/windows/handle_source.rs @@ -52,7 +52,7 @@ pub enum Error { #[error("An error occurred trying to duplicate source handles: {0}.")] HandleDuplicationFailed(io::Error), #[error("An error occurred trying to wait on source handles: {0}.")] - HandleWaitFailed(base::Error), + HandleWaitFailed(io::Error), #[error("An error occurred trying to get a VolatileSlice into BackingMemory: {0}.")] BackingMemoryVolatileSliceFetchFailed(crate::mem::Error), #[error("HandleSource is gone, so no handles are available to fulfill the IO request.")] @@ -74,7 +74,7 @@ impl From for io::Error { IoPunchHoleError(e) => e, IoWriteZeroesError(e) => e, HandleDuplicationFailed(e) => e, - HandleWaitFailed(e) => e.into(), + HandleWaitFailed(e) => e, BackingMemoryVolatileSliceFetchFailed(e) => io::Error::new(io::ErrorKind::Other, e), NoHandleSource => io::Error::new(io::ErrorKind::Other, NoHandleSource), OperationCancelled => io::Error::new(io::ErrorKind::Interrupted, OperationCancelled), @@ -385,8 +385,10 @@ impl HandleSource { /// If sources are not interchangeable, behavior is undefined. pub async fn wait_for_handle(&self) -> AsyncResult<()> { - let waiter = super::WaitForHandle::new(&self.source); - Ok(waiter.await?) + base::sys::windows::async_wait_for_single_object(&self.source) + .await + .map_err(Error::HandleWaitFailed)?; + Ok(()) } } diff --git a/cros_async/src/sys/windows/overlapped_source.rs b/cros_async/src/sys/windows/overlapped_source.rs index 7b046b2b86..6c7419cbfa 100644 --- a/cros_async/src/sys/windows/overlapped_source.rs +++ b/cros_async/src/sys/windows/overlapped_source.rs @@ -400,8 +400,10 @@ impl OverlappedSource { } pub async fn wait_for_handle(&self) -> AsyncResult<()> { - let waiter = super::WaitForHandle::new(&self.source); - Ok(waiter.await?) + base::sys::windows::async_wait_for_single_object(&self.source) + .await + .map_err(super::handle_source::Error::HandleWaitFailed)?; + Ok(()) } } diff --git a/cros_async/src/sys/windows/tokio_source.rs b/cros_async/src/sys/windows/tokio_source.rs index 676edb7c2a..7af52d6cf3 100644 --- a/cros_async/src/sys/windows/tokio_source.rs +++ b/cros_async/src/sys/windows/tokio_source.rs @@ -212,8 +212,8 @@ impl TokioSource { unimplemented!(); } pub async fn wait_for_handle(&self) -> AsyncResult<()> { - let waiter = super::wait_for_handle::WaitForHandle::new(self.source.as_ref().unwrap()); - Ok(waiter.await?) + base::sys::windows::async_wait_for_single_object(self.source.as_ref().unwrap()).await?; + Ok(()) } pub async fn write_from_mem( &self,