diff --git a/crates/lsp/src/lsp.rs b/crates/lsp/src/lsp.rs index ed3502208e..0510388381 100644 --- a/crates/lsp/src/lsp.rs +++ b/crates/lsp/src/lsp.rs @@ -476,18 +476,22 @@ impl Drop for Subscription { #[cfg(any(test, feature = "test-support"))] pub struct FakeLanguageServer { - handlers: Arc< - Mutex< - HashMap< - &'static str, - Box Vec>, - >, - >, - >, + handlers: FakeLanguageServerHandlers, outgoing_tx: futures::channel::mpsc::UnboundedSender>, incoming_rx: futures::channel::mpsc::UnboundedReceiver>, + _input_task: Task>, + _output_task: Task>, } +type FakeLanguageServerHandlers = Arc< + Mutex< + HashMap< + &'static str, + Box Vec>, + >, + >, +>; + #[cfg(any(test, feature = "test-support"))] impl LanguageServer { pub fn fake(cx: &mut gpui::MutableAppContext) -> (Arc, FakeLanguageServer) { @@ -533,59 +537,69 @@ impl FakeLanguageServer { let (incoming_tx, incoming_rx) = futures::channel::mpsc::unbounded(); let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded(); - let this = Self { - outgoing_tx: outgoing_tx.clone(), - incoming_rx, - handlers: Default::default(), - }; + let handlers = FakeLanguageServerHandlers::default(); - // Receive incoming messages - let handlers = this.handlers.clone(); - cx.spawn(|cx| async move { - let mut buffer = Vec::new(); - let mut stdin = smol::io::BufReader::new(stdin); - while Self::receive(&mut stdin, &mut buffer).await.is_ok() { - cx.background().simulate_random_delay().await; - if let Ok(request) = serde_json::from_slice::(&buffer) { - assert_eq!(request.jsonrpc, JSON_RPC_VERSION); + let input_task = cx.spawn(|cx| { + let handlers = handlers.clone(); + let outgoing_tx = outgoing_tx.clone(); + async move { + let mut buffer = Vec::new(); + let mut stdin = smol::io::BufReader::new(stdin); + while Self::receive(&mut stdin, &mut buffer).await.is_ok() { + cx.background().simulate_random_delay().await; + if let Ok(request) = serde_json::from_slice::(&buffer) { + assert_eq!(request.jsonrpc, JSON_RPC_VERSION); - if let Some(handler) = handlers.lock().get_mut(request.method) { - let response = - handler(request.id, request.params.get().as_bytes(), cx.clone()); - log::debug!("handled lsp request. method:{}", request.method); - outgoing_tx.unbounded_send(response)?; - } else { - log::debug!("unhandled lsp request. method:{}", request.method); - outgoing_tx.unbounded_send( - serde_json::to_vec(&AnyResponse { + let response; + if let Some(handler) = handlers.lock().get_mut(request.method) { + response = + handler(request.id, request.params.get().as_bytes(), cx.clone()); + log::debug!("handled lsp request. method:{}", request.method); + } else { + response = serde_json::to_vec(&AnyResponse { id: request.id, error: Some(Error { message: "no handler".to_string(), }), result: None, }) - .unwrap(), - )?; + .unwrap(); + log::debug!("unhandled lsp request. method:{}", request.method); + } + outgoing_tx.unbounded_send(response)?; + } else { + incoming_tx.unbounded_send(buffer.clone())?; } - } else { - incoming_tx.unbounded_send(buffer.clone())?; } + Ok::<_, anyhow::Error>(()) } - Ok::<_, anyhow::Error>(()) - }) - .detach(); + }); - // Send outgoing messages - cx.background() - .spawn(async move { - let mut stdout = smol::io::BufWriter::new(stdout); - while let Some(notification) = outgoing_rx.next().await { - Self::send(&mut stdout, ¬ification).await; - } - }) - .detach(); + let output_task = cx.background().spawn(async move { + let mut stdout = smol::io::BufWriter::new(PipeWriterCloseOnDrop(stdout)); + while let Some(message) = outgoing_rx.next().await { + stdout + .write_all(CONTENT_LEN_HEADER.as_bytes()) + .await + .unwrap(); + stdout + .write_all((format!("{}", message.len())).as_bytes()) + .await + .unwrap(); + stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap(); + stdout.write_all(&message).await.unwrap(); + stdout.flush().await.unwrap(); + } + Ok(()) + }); - this + Self { + outgoing_tx, + incoming_rx, + handlers, + _input_task: input_task, + _output_task: output_task, + } } pub async fn notify(&mut self, params: T::Params) { @@ -665,20 +679,6 @@ impl FakeLanguageServer { .await; } - async fn send(stdout: &mut smol::io::BufWriter, message: &[u8]) { - stdout - .write_all(CONTENT_LEN_HEADER.as_bytes()) - .await - .unwrap(); - stdout - .write_all((format!("{}", message.len())).as_bytes()) - .await - .unwrap(); - stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap(); - stdout.write_all(&message).await.unwrap(); - stdout.flush().await.unwrap(); - } - async fn receive( stdin: &mut smol::io::BufReader, buffer: &mut Vec, @@ -699,6 +699,44 @@ impl FakeLanguageServer { } } +struct PipeWriterCloseOnDrop(async_pipe::PipeWriter); + +impl Drop for PipeWriterCloseOnDrop { + fn drop(&mut self) { + self.0.close().ok(); + } +} + +impl AsyncWrite for PipeWriterCloseOnDrop { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + let pipe = &mut self.0; + smol::pin!(pipe); + pipe.poll_write(cx, buf) + } + + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let pipe = &mut self.0; + smol::pin!(pipe); + pipe.poll_flush(cx) + } + + fn poll_close( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let pipe = &mut self.0; + smol::pin!(pipe); + pipe.poll_close(cx) + } +} + #[cfg(test)] mod tests { use super::*;