diff --git a/cros_async/src/blocking/pool.rs b/cros_async/src/blocking/pool.rs index d63bf6661b..2109fd7911 100644 --- a/cros_async/src/blocking/pool.rs +++ b/cros_async/src/blocking/pool.rs @@ -5,26 +5,32 @@ use std::{ collections::VecDeque, mem, - sync::Arc, + sync::{ + mpsc::{channel, Receiver, Sender}, + Arc, + }, thread::{self, JoinHandle}, - time::Duration, + time::{Duration, Instant}, }; use async_task::{Runnable, Task}; use slab::Slab; use sync::{Condvar, Mutex}; +use sys_util::{error, warn}; + +const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10); -#[derive(Default)] struct State { tasks: VecDeque, num_threads: usize, num_idle: usize, worker_threads: Slab>, - last_exited_thread: Option>, + exited_threads: Option>, + exit: Sender, shutting_down: bool, } -fn run_blocking_thread(idx: usize, inner: Arc) { +fn run_blocking_thread(idx: usize, inner: Arc, exit: Sender) { let mut state = inner.state.lock(); while !state.shutting_down { if let Some(runnable) = state.tasks.pop_front() { @@ -56,9 +62,11 @@ fn run_blocking_thread(idx: usize, inner: Arc) { // If we're shutting down then the BlockingPool will take care of joining all the threads. // Otherwise, we need to join the last worker thread that exited here. - let last_exited_thread = if !state.shutting_down { - let this_thread = state.worker_threads.remove(idx); - mem::replace(&mut state.last_exited_thread, Some(this_thread)) + let last_exited_thread = if let Some(exited_threads) = state.exited_threads.as_mut() { + exited_threads + .try_recv() + .map(|idx| state.worker_threads.remove(idx)) + .ok() } else { None }; @@ -69,6 +77,10 @@ fn run_blocking_thread(idx: usize, inner: Arc) { if let Some(handle) = last_exited_thread { let _ = handle.join(); } + + if let Err(e) = exit.send(idx) { + error!("Failed to send thread exit event on channel: {}", e); + } } struct Inner { @@ -93,13 +105,14 @@ impl Inner { // There are no idle threads. Spawn a new one if possible. if state.num_threads < self.max_threads { state.num_threads += 1; + let exit = state.exit.clone(); let entry = state.worker_threads.vacant_entry(); let idx = entry.key(); let inner = self.clone(); entry.insert( thread::Builder::new() .name(format!("blockingPool{}", idx)) - .spawn(move || run_blocking_thread(idx, inner)) + .spawn(move || run_blocking_thread(idx, inner, exit)) .unwrap(), ); } @@ -111,6 +124,10 @@ impl Inner { } } +#[derive(Debug, thiserror::Error)] +#[error("{0} BlockingPool threads did not exit in time and will be detached")] +pub struct ShutdownTimedOut(usize); + /// A thread pool for running work that may block. /// /// It is generally discouraged to do any blocking work inside an async function. However, this is @@ -125,8 +142,9 @@ impl Inner { /// should just use `thread::spawn` directly. /// /// There is no way to cancel work once it has been picked up by one of the worker threads in the -/// `BlockingPool` and dropping or shutting down the pool will block until all worker threads finish -/// their current task. +/// `BlockingPool`. Dropping or shutting down the pool will block up to a timeout (default 10 +/// seconds) to wait for any active blocking work to finish. Any threads running tasks that have not +/// completed by that time will be detached. /// /// # Examples /// @@ -167,9 +185,18 @@ impl BlockingPool { /// `BlockingPool`. `keepalive` determines the idle duration after which the worker thread will /// exit. The default value is 10 seconds. pub fn new(max_threads: usize, keepalive: Duration) -> BlockingPool { + let (exit, exited_threads) = channel(); BlockingPool { inner: Arc::new(Inner { - state: Default::default(), + state: Mutex::new(State { + tasks: VecDeque::new(), + num_threads: 0, + num_idle: 0, + worker_threads: Slab::new(), + exited_threads: Some(exited_threads), + exit, + shutting_down: false, + }), condvar: Condvar::new(), max_threads, keepalive, @@ -179,6 +206,7 @@ impl BlockingPool { /// Like new but with pre-allocating capacity for up to `max_threads`. pub fn with_capacity(max_threads: usize, keepalive: Duration) -> BlockingPool { + let (exit, exited_threads) = channel(); BlockingPool { inner: Arc::new(Inner { state: Mutex::new(State { @@ -186,7 +214,8 @@ impl BlockingPool { num_threads: 0, num_idle: 0, worker_threads: Slab::with_capacity(max_threads), - last_exited_thread: None, + exited_threads: Some(exited_threads), + exit, shutting_down: false, }), condvar: Condvar::new(), @@ -224,19 +253,20 @@ impl BlockingPool { /// Shut down the `BlockingPool`. /// - /// This will block until all work that has been started by the worker threads is finished. Any - /// work that was added to the `BlockingPool` but not yet picked up by a worker thread will not - /// complete and `await`ing on the `Task` for that work will panic. - pub fn shutdown(&self) { + /// If `deadline` is provided then this will block until either all worker threads exit or the + /// deadline is exceeded. If `deadline` is not given then this will block indefinitely until all + /// worker threads exit. Any work that was added to the `BlockingPool` but not yet picked up by + /// a worker thread will not complete and `await`ing on the `Task` for that work will panic. + pub fn shutdown(&self, deadline: Option) -> Result<(), ShutdownTimedOut> { let mut state = self.inner.state.lock(); if state.shutting_down { // We've already shut down this BlockingPool. - return; + return Ok(()); } state.shutting_down = true; - let last_exited_thread = state.last_exited_thread.take(); + let exited_threads = state.exited_threads.take().expect("exited_threads missing"); let unfinished_tasks = mem::replace(&mut state.tasks, VecDeque::new()); let mut worker_threads = mem::replace(&mut state.worker_threads, Slab::new()); drop(state); @@ -247,12 +277,28 @@ impl BlockingPool { drop(unfinished_tasks); // Now wait for all worker threads to exit. - if let Some(handle) = last_exited_thread { - let _ = handle.join(); - } + if let Some(deadline) = deadline { + let mut now = Instant::now(); + while now < deadline && !worker_threads.is_empty() { + if let Ok(idx) = exited_threads.recv_timeout(deadline - now) { + let _ = worker_threads.remove(idx).join(); + } + now = Instant::now(); + } - for handle in worker_threads.drain() { - let _ = handle.join(); + // Any threads that have not yet joined will just be detached. + if !worker_threads.is_empty() { + return Err(ShutdownTimedOut(worker_threads.len())); + } + + Ok(()) + } else { + // Block indefinitely until all worker threads exit. + for handle in worker_threads.drain() { + let _ = handle.join(); + } + + Ok(()) } } } @@ -265,14 +311,16 @@ impl Default for BlockingPool { impl Drop for BlockingPool { fn drop(&mut self) { - self.shutdown() + if let Err(e) = self.shutdown(Some(Instant::now() + DEFAULT_SHUTDOWN_TIMEOUT)) { + warn!("{}", e); + } } } #[cfg(test)] mod test { use std::{ - sync::Arc, + sync::{Arc, Barrier}, thread, time::{Duration, Instant}, }; @@ -313,7 +361,8 @@ mod test { let results = block_on(stream.collect::>()); assert_eq!(results.len(), 19); - pool.shutdown(); + pool.shutdown(Some(Instant::now() + Duration::from_secs(10))) + .unwrap(); let state = pool.inner.state.lock(); assert_eq!(state.num_threads, 0); } @@ -382,9 +431,59 @@ mod test { *mu.lock() = true; cv.notify_all(); }); - pool.shutdown(); + pool.shutdown(None).unwrap(); // This should panic. assert_eq!(block_on(unfinished), 5); } + + #[test] + fn unfinished_worker_thread() { + let pool = BlockingPool::default(); + + let ready = Arc::new(Mutex::new(false)); + let cv = Arc::new(Condvar::new()); + let barrier = Arc::new(Barrier::new(2)); + + let thread_ready = ready.clone(); + let thread_barrier = barrier.clone(); + let thread_cv = cv.clone(); + + let task = pool.spawn(move || { + thread_barrier.wait(); + let mut ready = thread_ready.lock(); + while !*ready { + ready = thread_cv.wait(ready); + } + }); + + // Wait to shut down the pool until after the worker thread has started. + barrier.wait(); + pool.shutdown(Some(Instant::now() + Duration::from_millis(5))) + .unwrap_err(); + + let num_threads = pool.inner.state.lock().num_threads; + assert_eq!(num_threads, 1); + + // Now wake up the blocked task so we don't leak the thread. + *ready.lock() = true; + cv.notify_all(); + + block_on(task); + + let deadline = Instant::now() + Duration::from_secs(10); + while Instant::now() < deadline { + thread::sleep(Duration::from_millis(100)); + let state = pool.inner.state.lock(); + if state.num_threads == 0 { + break; + } + } + + { + let state = pool.inner.state.lock(); + assert_eq!(state.num_threads, 0); + assert_eq!(state.num_idle, 0); + } + } }