diff --git a/crates/lsp/src/lsp.rs b/crates/lsp/src/lsp.rs index c563c23311..2e2efdb28c 100644 --- a/crates/lsp/src/lsp.rs +++ b/crates/lsp/src/lsp.rs @@ -33,7 +33,7 @@ type ResponseHandler = Box)>; pub struct LanguageServer { next_id: AtomicUsize, - outbound_tx: RwLock>>>, + outbound_tx: channel::Sender>, capabilities: watch::Receiver>, notification_handlers: Arc>>, response_handlers: Arc>>, @@ -213,7 +213,7 @@ impl LanguageServer { response_handlers, capabilities: capabilities_rx, next_id: Default::default(), - outbound_tx: RwLock::new(Some(outbound_tx)), + outbound_tx, executor: executor.clone(), io_tasks: Mutex::new(Some((input_task, output_task))), initialized: initialized_rx, @@ -296,37 +296,41 @@ impl LanguageServer { let request = Self::request_internal::( &this.next_id, &this.response_handlers, - this.outbound_tx.read().as_ref(), + &this.outbound_tx, params, ); let response = request.await?; Self::notify_internal::( - this.outbound_tx.read().as_ref(), + &this.outbound_tx, InitializedParams {}, )?; Ok(response.capabilities) } - pub fn shutdown(&self) -> Option>> { + pub fn shutdown(&self) -> Option>> { if let Some(tasks) = self.io_tasks.lock().take() { let response_handlers = self.response_handlers.clone(); - let outbound_tx = self.outbound_tx.write().take(); let next_id = AtomicUsize::new(self.next_id.load(SeqCst)); + let outbound_tx = self.outbound_tx.clone(); let mut output_done = self.output_done_rx.lock().take().unwrap(); - Some(async move { - Self::request_internal::( - &next_id, - &response_handlers, - outbound_tx.as_ref(), - (), - ) - .await?; - Self::notify_internal::(outbound_tx.as_ref(), ())?; - drop(outbound_tx); - output_done.recv().await; - drop(tasks); - Ok(()) - }) + let shutdown_request = Self::request_internal::( + &next_id, + &response_handlers, + &outbound_tx, + (), + ); + let exit = Self::notify_internal::(&outbound_tx, ()); + outbound_tx.close(); + Some( + async move { + shutdown_request.await?; + exit?; + output_done.recv().await; + drop(tasks); + Ok(()) + } + .log_err(), + ) } else { None } @@ -375,7 +379,7 @@ impl LanguageServer { Self::request_internal::( &this.next_id, &this.response_handlers, - this.outbound_tx.read().as_ref(), + &this.outbound_tx, params, ) .await @@ -385,7 +389,7 @@ impl LanguageServer { fn request_internal( next_id: &AtomicUsize, response_handlers: &Mutex>, - outbound_tx: Option<&channel::Sender>>, + outbound_tx: &channel::Sender>, params: T::Params, ) -> impl 'static + Future> where @@ -415,16 +419,8 @@ impl LanguageServer { ); let send = outbound_tx - .as_ref() - .ok_or_else(|| { - anyhow!("tried to send a request to a language server that has been shut down") - }) - .and_then(|outbound_tx| { - outbound_tx - .try_send(message) - .context("failed to write to language server's stdin")?; - Ok(()) - }); + .try_send(message) + .context("failed to write to language server's stdin"); async move { send?; rx.recv().await.unwrap() @@ -438,13 +434,13 @@ impl LanguageServer { let this = self.clone(); async move { this.initialized.clone().recv().await; - Self::notify_internal::(this.outbound_tx.read().as_ref(), params)?; + Self::notify_internal::(&this.outbound_tx, params)?; Ok(()) } } fn notify_internal( - outbound_tx: Option<&channel::Sender>>, + outbound_tx: &channel::Sender>, params: T::Params, ) -> Result<()> { let message = serde_json::to_vec(&Notification { @@ -453,9 +449,6 @@ impl LanguageServer { params, }) .unwrap(); - let outbound_tx = outbound_tx - .as_ref() - .ok_or_else(|| anyhow!("tried to notify a language server that has been shut down"))?; outbound_tx.try_send(message)?; Ok(()) }