use anyhow::{anyhow, Result}; use async_task::Runnable; pub use async_task::Task; use parking_lot::Mutex; use rand::prelude::*; use smol::{channel, prelude::*, Executor}; use std::{ marker::PhantomData, mem, pin::Pin, rc::Rc, sync::{mpsc::SyncSender, Arc}, thread, }; use crate::platform; pub enum Foreground { Platform { dispatcher: Arc, _not_send_or_sync: PhantomData>, }, Test(smol::LocalExecutor<'static>), Deterministic(Arc), } pub enum Background { Deterministic(Arc), Production { executor: Arc>, threads: usize, _stop: channel::Sender<()>, }, } pub struct Deterministic { seed: u64, runnables: Arc, Option>)>>, } impl Deterministic { fn new(seed: u64) -> Self { Self { seed, runnables: Default::default(), } } pub fn spawn_local(&self, future: F) -> Task where T: 'static, F: Future + 'static, { let runnables = self.runnables.clone(); let (runnable, task) = async_task::spawn_local(future, move |runnable| { let mut runnables = runnables.lock(); runnables.0.push(runnable); if let Some(wake_tx) = runnables.1.as_ref() { wake_tx.send(()).ok(); } }); runnable.schedule(); task } pub fn spawn(&self, future: F) -> Task where T: 'static + Send, F: 'static + Send + Future, { let runnables = self.runnables.clone(); let (runnable, task) = async_task::spawn(future, move |runnable| { let mut runnables = runnables.lock(); runnables.0.push(runnable); if let Some(wake_tx) = runnables.1.as_ref() { wake_tx.send(()).ok(); } }); runnable.schedule(); task } pub fn run(&self, future: F) -> T where T: 'static, F: Future + 'static, { let (wake_tx, wake_rx) = std::sync::mpsc::sync_channel(32); let runnables = self.runnables.clone(); runnables.lock().1 = Some(wake_tx); let (output_tx, output_rx) = std::sync::mpsc::channel(); self.spawn_local(async move { let output = future.await; output_tx.send(output).unwrap(); }) .detach(); let mut rng = StdRng::seed_from_u64(self.seed); loop { if let Ok(value) = output_rx.try_recv() { runnables.lock().1 = None; return value; } wake_rx.recv().unwrap(); let runnable = { let mut runnables = runnables.lock(); let runnables = &mut runnables.0; let index = rng.gen_range(0..runnables.len()); runnables.remove(index) }; runnable.run(); } } } impl Foreground { pub fn platform(dispatcher: Arc) -> Result { if dispatcher.is_main_thread() { Ok(Self::Platform { dispatcher, _not_send_or_sync: PhantomData, }) } else { Err(anyhow!("must be constructed on main thread")) } } pub fn test() -> Self { Self::Test(smol::LocalExecutor::new()) } pub fn spawn(&self, future: impl Future + 'static) -> Task { match self { Self::Platform { dispatcher, .. } => { let dispatcher = dispatcher.clone(); let schedule = move |runnable: Runnable| dispatcher.run_on_main_thread(runnable); let (runnable, task) = async_task::spawn_local(future, schedule); runnable.schedule(); task } Self::Test(executor) => executor.spawn(future), Self::Deterministic(executor) => executor.spawn_local(future), } } pub fn run(&self, future: impl 'static + Future) -> T { match self { Self::Platform { .. } => panic!("you can't call run on a platform foreground executor"), Self::Test(executor) => smol::block_on(executor.run(future)), Self::Deterministic(executor) => executor.run(future), } } } impl Background { pub fn new() -> Self { let executor = Arc::new(Executor::new()); let stop = channel::unbounded::<()>(); let threads = num_cpus::get(); for i in 0..threads { let executor = executor.clone(); let stop = stop.1.clone(); thread::Builder::new() .name(format!("background-executor-{}", i)) .spawn(move || smol::block_on(executor.run(stop.recv()))) .unwrap(); } Self::Production { executor, threads, _stop: stop.0, } } pub fn threads(&self) -> usize { match self { Self::Deterministic(_) => 1, Self::Production { threads, .. } => *threads, } } pub fn spawn(&self, future: F) -> Task where T: 'static + Send, F: Send + Future + 'static, { match self { Self::Production { executor, .. } => executor.spawn(future), Self::Deterministic(executor) => executor.spawn(future), } } pub async fn scoped<'scope, F>(&self, scheduler: F) where F: FnOnce(&mut Scope<'scope>), { let mut scope = Scope { futures: Default::default(), _phantom: PhantomData, }; (scheduler)(&mut scope); let spawned = scope .futures .into_iter() .map(|f| self.spawn(f)) .collect::>(); for task in spawned { task.await; } } } pub struct Scope<'a> { futures: Vec + Send + 'static>>>, _phantom: PhantomData<&'a ()>, } impl<'a> Scope<'a> { pub fn spawn(&mut self, f: F) where F: Future + Send + 'a, { let f = unsafe { mem::transmute::< Pin + Send + 'a>>, Pin + Send + 'static>>, >(Box::pin(f)) }; self.futures.push(f); } } pub fn deterministic(seed: u64) -> (Rc, Arc) { let executor = Arc::new(Deterministic::new(seed)); ( Rc::new(Foreground::Deterministic(executor.clone())), Arc::new(Background::Deterministic(executor)), ) }