Clean up tasks properly when dropping a FakeLanguageServer

* Make sure the fake's IO tasks are stopped
* Ensure that the fake's stdout is closed, so that the corresponding language
  server's IO tasks are woken up and halted.
This commit is contained in:
Max Brunsfeld 2022-03-01 13:26:59 -08:00
parent 0e6686916c
commit 74469a46ba

View file

@ -476,17 +476,21 @@ impl Drop for Subscription {
#[cfg(any(test, feature = "test-support"))] #[cfg(any(test, feature = "test-support"))]
pub struct FakeLanguageServer { pub struct FakeLanguageServer {
handlers: Arc< handlers: FakeLanguageServerHandlers,
outgoing_tx: futures::channel::mpsc::UnboundedSender<Vec<u8>>,
incoming_rx: futures::channel::mpsc::UnboundedReceiver<Vec<u8>>,
_input_task: Task<Result<()>>,
_output_task: Task<Result<()>>,
}
type FakeLanguageServerHandlers = Arc<
Mutex< Mutex<
HashMap< HashMap<
&'static str, &'static str,
Box<dyn Send + FnMut(usize, &[u8], gpui::AsyncAppContext) -> Vec<u8>>, Box<dyn Send + FnMut(usize, &[u8], gpui::AsyncAppContext) -> Vec<u8>>,
>, >,
>, >,
>, >;
outgoing_tx: futures::channel::mpsc::UnboundedSender<Vec<u8>>,
incoming_rx: futures::channel::mpsc::UnboundedReceiver<Vec<u8>>,
}
#[cfg(any(test, feature = "test-support"))] #[cfg(any(test, feature = "test-support"))]
impl LanguageServer { impl LanguageServer {
@ -533,15 +537,12 @@ impl FakeLanguageServer {
let (incoming_tx, incoming_rx) = futures::channel::mpsc::unbounded(); let (incoming_tx, incoming_rx) = futures::channel::mpsc::unbounded();
let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded(); let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
let this = Self { let handlers = FakeLanguageServerHandlers::default();
outgoing_tx: outgoing_tx.clone(),
incoming_rx,
handlers: Default::default(),
};
// Receive incoming messages let input_task = cx.spawn(|cx| {
let handlers = this.handlers.clone(); let handlers = handlers.clone();
cx.spawn(|cx| async move { let outgoing_tx = outgoing_tx.clone();
async move {
let mut buffer = Vec::new(); let mut buffer = Vec::new();
let mut stdin = smol::io::BufReader::new(stdin); let mut stdin = smol::io::BufReader::new(stdin);
while Self::receive(&mut stdin, &mut buffer).await.is_ok() { while Self::receive(&mut stdin, &mut buffer).await.is_ok() {
@ -549,43 +550,56 @@ impl FakeLanguageServer {
if let Ok(request) = serde_json::from_slice::<AnyRequest>(&buffer) { if let Ok(request) = serde_json::from_slice::<AnyRequest>(&buffer) {
assert_eq!(request.jsonrpc, JSON_RPC_VERSION); assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
let response;
if let Some(handler) = handlers.lock().get_mut(request.method) { if let Some(handler) = handlers.lock().get_mut(request.method) {
let response = response =
handler(request.id, request.params.get().as_bytes(), cx.clone()); handler(request.id, request.params.get().as_bytes(), cx.clone());
log::debug!("handled lsp request. method:{}", request.method); log::debug!("handled lsp request. method:{}", request.method);
outgoing_tx.unbounded_send(response)?;
} else { } else {
log::debug!("unhandled lsp request. method:{}", request.method); response = serde_json::to_vec(&AnyResponse {
outgoing_tx.unbounded_send(
serde_json::to_vec(&AnyResponse {
id: request.id, id: request.id,
error: Some(Error { error: Some(Error {
message: "no handler".to_string(), message: "no handler".to_string(),
}), }),
result: None, result: None,
}) })
.unwrap(), .unwrap();
)?; log::debug!("unhandled lsp request. method:{}", request.method);
} }
outgoing_tx.unbounded_send(response)?;
} else { } else {
incoming_tx.unbounded_send(buffer.clone())?; incoming_tx.unbounded_send(buffer.clone())?;
} }
} }
Ok::<_, anyhow::Error>(()) Ok::<_, anyhow::Error>(())
})
.detach();
// Send outgoing messages
cx.background()
.spawn(async move {
let mut stdout = smol::io::BufWriter::new(stdout);
while let Some(notification) = outgoing_rx.next().await {
Self::send(&mut stdout, &notification).await;
} }
}) });
.detach();
this let output_task = cx.background().spawn(async move {
let mut stdout = smol::io::BufWriter::new(PipeWriterCloseOnDrop(stdout));
while let Some(message) = outgoing_rx.next().await {
stdout
.write_all(CONTENT_LEN_HEADER.as_bytes())
.await
.unwrap();
stdout
.write_all((format!("{}", message.len())).as_bytes())
.await
.unwrap();
stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
stdout.write_all(&message).await.unwrap();
stdout.flush().await.unwrap();
}
Ok(())
});
Self {
outgoing_tx,
incoming_rx,
handlers,
_input_task: input_task,
_output_task: output_task,
}
} }
pub async fn notify<T: notification::Notification>(&mut self, params: T::Params) { pub async fn notify<T: notification::Notification>(&mut self, params: T::Params) {
@ -665,20 +679,6 @@ impl FakeLanguageServer {
.await; .await;
} }
async fn send(stdout: &mut smol::io::BufWriter<async_pipe::PipeWriter>, message: &[u8]) {
stdout
.write_all(CONTENT_LEN_HEADER.as_bytes())
.await
.unwrap();
stdout
.write_all((format!("{}", message.len())).as_bytes())
.await
.unwrap();
stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
stdout.write_all(&message).await.unwrap();
stdout.flush().await.unwrap();
}
async fn receive( async fn receive(
stdin: &mut smol::io::BufReader<async_pipe::PipeReader>, stdin: &mut smol::io::BufReader<async_pipe::PipeReader>,
buffer: &mut Vec<u8>, buffer: &mut Vec<u8>,
@ -699,6 +699,44 @@ impl FakeLanguageServer {
} }
} }
struct PipeWriterCloseOnDrop(async_pipe::PipeWriter);
impl Drop for PipeWriterCloseOnDrop {
fn drop(&mut self) {
self.0.close().ok();
}
}
impl AsyncWrite for PipeWriterCloseOnDrop {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
let pipe = &mut self.0;
smol::pin!(pipe);
pipe.poll_write(cx, buf)
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let pipe = &mut self.0;
smol::pin!(pipe);
pipe.poll_flush(cx)
}
fn poll_close(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let pipe = &mut self.0;
smol::pin!(pipe);
pipe.poll_close(cx)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;