use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; use futures::{channel::mpsc, SinkExt as _, Stream, StreamExt as _}; use std::{io, task::Poll}; pub struct Connection { pub(crate) tx: Box>, pub(crate) rx: Box< dyn 'static + Send + Unpin + futures::Stream>, >, } impl Connection { pub fn new(stream: S) -> Self where S: 'static + Send + Unpin + futures::Sink + futures::Stream>, { let (tx, rx) = stream.split(); Self { tx: Box::new(tx), rx: Box::new(rx), } } pub async fn send(&mut self, message: WebSocketMessage) -> Result<(), WebSocketError> { self.tx.send(message).await } #[cfg(any(test, feature = "test-support"))] pub fn in_memory() -> (Self, Self, postage::watch::Sender>) { let (kill_tx, mut kill_rx) = postage::watch::channel_with(None); postage::stream::Stream::try_recv(&mut kill_rx).unwrap(); let (a_tx, a_rx) = Self::channel(kill_rx.clone()); let (b_tx, b_rx) = Self::channel(kill_rx); ( Self { tx: a_tx, rx: b_rx }, Self { tx: b_tx, rx: a_rx }, kill_tx, ) } #[cfg(any(test, feature = "test-support"))] fn channel( kill_rx: postage::watch::Receiver>, ) -> ( Box>, Box>>, ) { use futures::{future, SinkExt as _}; use io::{Error, ErrorKind}; let (tx, rx) = mpsc::unbounded::(); let tx = tx .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e))) .with({ let kill_rx = kill_rx.clone(); move |msg| { if kill_rx.borrow().is_none() { future::ready(Ok(msg)) } else { future::ready(Err(Error::new(ErrorKind::Other, "connection killed").into())) } } }); let rx = KillableReceiver { kill_rx, rx }; (Box::new(tx), Box::new(rx)) } } struct KillableReceiver { rx: mpsc::UnboundedReceiver, kill_rx: postage::watch::Receiver>, } impl Stream for KillableReceiver { type Item = Result; fn poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { if let Poll::Ready(Some(Some(()))) = self.kill_rx.poll_next_unpin(cx) { Poll::Ready(Some(Err(io::Error::new( io::ErrorKind::Other, "connection killed", ) .into()))) } else { self.rx.poll_next_unpin(cx).map(|value| value.map(Ok)) } } }