use anyhow::{anyhow, Context, Result}; use gpui::{executor, AppContext, Task}; use parking_lot::Mutex; use postage::{barrier, prelude::Stream}; use serde::{Deserialize, Serialize}; use serde_json::value::RawValue; use smol::{ channel, io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}, process::Command, }; use std::{ collections::HashMap, future::Future, io::Write, sync::{ atomic::{AtomicUsize, Ordering::SeqCst}, Arc, }, }; use std::{path::Path, process::Stdio}; use util::TryFutureExt; const JSON_RPC_VERSION: &'static str = "2.0"; const CONTENT_LEN_HEADER: &'static str = "Content-Length: "; pub struct LanguageServer { next_id: AtomicUsize, outbound_tx: channel::Sender>, response_handlers: Arc>>, _input_task: Task>, _output_task: Task>, initialized: barrier::Receiver, } type ResponseHandler = Box)>; #[derive(Serialize)] struct Request { jsonrpc: &'static str, id: usize, method: &'static str, params: T, } #[derive(Deserialize)] struct Response<'a> { id: usize, #[serde(default)] error: Option, #[serde(default, borrow)] result: Option<&'a RawValue>, } #[derive(Serialize)] struct OutboundNotification { jsonrpc: &'static str, method: &'static str, params: T, } #[derive(Deserialize)] struct InboundNotification<'a> { #[serde(borrow)] method: &'a str, #[serde(borrow)] params: &'a RawValue, } #[derive(Deserialize)] struct Error { message: String, } impl LanguageServer { pub fn rust(cx: &AppContext) -> Result> { const BUNDLE: Option<&'static str> = option_env!("BUNDLE"); const TARGET: &'static str = env!("TARGET"); let rust_analyzer_name = format!("rust-analyzer-{}", TARGET); if BUNDLE.map_or(Ok(false), |b| b.parse())? { let rust_analyzer_path = cx .platform() .path_for_resource(Some(&rust_analyzer_name), None)?; Self::new(&rust_analyzer_path, cx.background()) } else { Self::new(Path::new(&rust_analyzer_name), cx.background()) } } pub fn new(path: &Path, background: &executor::Background) -> Result> { let mut server = Command::new(path) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::inherit()) .spawn()?; let mut stdin = server.stdin.take().unwrap(); let mut stdout = BufReader::new(server.stdout.take().unwrap()); let (outbound_tx, outbound_rx) = channel::unbounded::>(); let response_handlers = Arc::new(Mutex::new(HashMap::::new())); let _input_task = background.spawn( { let response_handlers = response_handlers.clone(); async move { let mut buffer = Vec::new(); loop { buffer.clear(); stdout.read_until(b'\n', &mut buffer).await?; stdout.read_until(b'\n', &mut buffer).await?; let message_len: usize = std::str::from_utf8(&buffer)? .strip_prefix(CONTENT_LEN_HEADER) .ok_or_else(|| anyhow!("invalid header"))? .trim_end() .parse()?; buffer.resize(message_len, 0); stdout.read_exact(&mut buffer).await?; if let Ok(InboundNotification { .. }) = serde_json::from_slice(&buffer) { } else if let Ok(Response { id, error, result }) = serde_json::from_slice(&buffer) { if let Some(handler) = response_handlers.lock().remove(&id) { if let Some(result) = result { handler(Ok(result.get())); } else if let Some(error) = error { handler(Err(error)); } } } else { return Err(anyhow!( "failed to deserialize message:\n{}", std::str::from_utf8(&buffer)? )); } } } } .log_err(), ); let _output_task = background.spawn( async move { let mut content_len_buffer = Vec::new(); loop { let message = outbound_rx.recv().await?; write!(content_len_buffer, "{}", message.len()).unwrap(); stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?; stdin.write_all(&content_len_buffer).await?; stdin.write_all("\r\n\r\n".as_bytes()).await?; stdin.write_all(&message).await?; } } .log_err(), ); let (initialized_tx, initialized_rx) = barrier::channel(); let this = Arc::new(Self { response_handlers, next_id: Default::default(), outbound_tx, _input_task, _output_task, initialized: initialized_rx, }); background .spawn({ let this = this.clone(); async move { this.init().log_err().await; drop(initialized_tx); } }) .detach(); Ok(this) } async fn init(self: Arc) -> Result<()> { let res = self .request_internal::( lsp_types::InitializeParams { process_id: Default::default(), root_path: Default::default(), root_uri: Default::default(), initialization_options: Default::default(), capabilities: Default::default(), trace: Default::default(), workspace_folders: Default::default(), client_info: Default::default(), locale: Default::default(), }, false, ) .await?; self.notify_internal::( lsp_types::InitializedParams {}, false, ) .await?; Ok(()) } pub fn request( self: &Arc, params: T::Params, ) -> impl Future> where T::Result: 'static + Send, { self.request_internal::(params, true) } fn request_internal( self: &Arc, params: T::Params, wait_for_initialization: bool, ) -> impl Future> where T::Result: 'static + Send, { let id = self.next_id.fetch_add(1, SeqCst); let message = serde_json::to_vec(&Request { jsonrpc: JSON_RPC_VERSION, id, method: T::METHOD, params, }) .unwrap(); let mut response_handlers = self.response_handlers.lock(); let (tx, rx) = smol::channel::bounded(1); response_handlers.insert( id, Box::new(move |result| { let response = match result { Ok(response) => { serde_json::from_str(response).context("failed to deserialize response") } Err(error) => Err(anyhow!("{}", error.message)), }; let _ = smol::block_on(tx.send(response)); }), ); let this = self.clone(); async move { if wait_for_initialization { this.initialized.clone().recv().await; } this.outbound_tx.send(message).await?; rx.recv().await? } } pub fn notify( self: &Arc, params: T::Params, ) -> impl Future> { self.notify_internal::(params, true) } fn notify_internal( self: &Arc, params: T::Params, wait_for_initialization: bool, ) -> impl Future> { let message = serde_json::to_vec(&OutboundNotification { jsonrpc: JSON_RPC_VERSION, method: T::METHOD, params, }) .unwrap(); let this = self.clone(); async move { if wait_for_initialization { this.initialized.clone().recv().await; } this.outbound_tx.send(message).await?; Ok(()) } } } #[cfg(test)] mod tests { use super::*; use gpui::TestAppContext; #[gpui::test] async fn test_basic(cx: TestAppContext) { let server = cx.read(|cx| LanguageServer::rust(cx).unwrap()); } }