diff --git a/zrpc/src/conn.rs b/zrpc/src/conn.rs index 4b2356c872..e67b4fa587 100644 --- a/zrpc/src/conn.rs +++ b/zrpc/src/conn.rs @@ -1,5 +1,6 @@ use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; -use futures::{SinkExt as _, StreamExt as _}; +use futures::{channel::mpsc, SinkExt as _, Stream, StreamExt as _}; +use std::{io, task::Poll}; pub struct Conn { pub(crate) tx: @@ -53,10 +54,10 @@ impl Conn { Box>, Box>>, ) { - use futures::{future, stream, SinkExt as _, StreamExt as _}; - use std::io::{Error, ErrorKind}; + use futures::{future, SinkExt as _}; + use io::{Error, ErrorKind}; - let (tx, rx) = futures::channel::mpsc::unbounded::(); + let (tx, rx) = mpsc::unbounded::(); let tx = tx .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e))) .with({ @@ -69,19 +70,32 @@ impl Conn { } } }); - let rx = stream::select( - rx.map(Ok), - kill_rx.filter_map(|kill| { - if kill.is_none() { - future::ready(None) - } else { - future::ready(Some(Err( - Error::new(ErrorKind::Other, "connection killed").into() - ))) - } - }), - ); + let rx = KillableReceiver { kill_rx, rx }; (Box::new(tx), Box::new(rx)) } } + +struct KillableReceiver { + rx: mpsc::UnboundedReceiver, + kill_rx: postage::watch::Receiver>, +} + +impl Stream for KillableReceiver { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + if let Poll::Ready(Some(Some(()))) = self.kill_rx.poll_next_unpin(cx) { + Poll::Ready(Some(Err(io::Error::new( + io::ErrorKind::Other, + "connection killed", + ) + .into()))) + } else { + self.rx.poll_next_unpin(cx).map(|value| value.map(Ok)) + } + } +}