diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index eafc4d2d98..9a24139ad6 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -4,7 +4,7 @@ use anyhow::{anyhow, Result}; use async_compression::futures::bufread::GzipDecoder; use client::Client; use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task}; -use language::{point_to_lsp, Buffer, ToPointUtf16}; +use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16}; use lsp::LanguageServer; use settings::Settings; use smol::{fs, io::BufReader, stream::StreamExt}; @@ -77,6 +77,12 @@ impl Status { } } +#[derive(Debug)] +pub struct Completion { + pub position: Anchor, + pub text: String, +} + struct Copilot { server: CopilotServer, } @@ -186,12 +192,12 @@ impl Copilot { } } - pub fn completions( + pub fn completion( &self, buffer: &ModelHandle, position: T, cx: &mut ModelContext, - ) -> Task> + ) -> Task>> where T: ToPointUtf16, { @@ -201,43 +207,45 @@ impl Copilot { }; let buffer = buffer.read(cx).snapshot(); - let position = position.to_point_utf16(&buffer); - let language_name = buffer.language_at(position).map(|language| language.name()); - let language_name = language_name.as_deref(); + let request = server + .request::(build_completion_params(&buffer, position, cx)); + cx.background().spawn(async move { + let result = request.await?; + let completion = result + .completions + .into_iter() + .next() + .map(|completion| completion_from_lsp(completion, &buffer)); + anyhow::Ok(completion) + }) + } - let path; - let relative_path; - if let Some(file) = buffer.file() { - if let Some(file) = file.as_local() { - path = file.abs_path(cx); - } else { - path = file.full_path(cx); - } - relative_path = file.path().to_path_buf(); - } else { - path = PathBuf::from("/untitled"); - relative_path = PathBuf::from("untitled"); - } + pub fn completions_cycling( + &self, + buffer: &ModelHandle, + position: T, + cx: &mut ModelContext, + ) -> Task>> + where + T: ToPointUtf16, + { + let server = match self.authenticated_server() { + Ok(server) => server, + Err(error) => return Task::ready(Err(error)), + }; - let settings = cx.global::(); - let request = server.request::(request::GetCompletionsParams { - doc: request::GetCompletionsDocument { - source: buffer.text(), - tab_size: settings.tab_size(language_name).into(), - indent_size: 1, - insert_spaces: !settings.hard_tabs(language_name), - uri: lsp::Url::from_file_path(&path).unwrap(), - path: path.to_string_lossy().into(), - relative_path: relative_path.to_string_lossy().into(), - language_id: "csharp".into(), - position: point_to_lsp(position), - version: 0, - }, - }); - cx.spawn(|this, cx| async move { - dbg!(request.await?); - - anyhow::Ok(()) + let buffer = buffer.read(cx).snapshot(); + let request = server.request::(build_completion_params( + &buffer, position, cx, + )); + cx.background().spawn(async move { + let result = request.await?; + let completions = result + .completions + .into_iter() + .map(|completion| completion_from_lsp(completion, &buffer)) + .collect(); + anyhow::Ok(completions) }) } @@ -290,6 +298,62 @@ impl Copilot { } } +fn build_completion_params( + buffer: &BufferSnapshot, + position: T, + cx: &AppContext, +) -> request::GetCompletionsParams +where + T: ToPointUtf16, +{ + let position = position.to_point_utf16(&buffer); + let language_name = buffer.language_at(position).map(|language| language.name()); + let language_name = language_name.as_deref(); + + let path; + let relative_path; + if let Some(file) = buffer.file() { + if let Some(file) = file.as_local() { + path = file.abs_path(cx); + } else { + path = file.full_path(cx); + } + relative_path = file.path().to_path_buf(); + } else { + path = PathBuf::from("/untitled"); + relative_path = PathBuf::from("untitled"); + } + + let settings = cx.global::(); + let language_id = match language_name { + Some("Plain Text") => "plaintext".to_string(), + Some(language_name) => language_name.to_lowercase(), + None => "plaintext".to_string(), + }; + request::GetCompletionsParams { + doc: request::GetCompletionsDocument { + source: buffer.text(), + tab_size: settings.tab_size(language_name).into(), + indent_size: 1, + insert_spaces: !settings.hard_tabs(language_name), + uri: lsp::Url::from_file_path(&path).unwrap(), + path: path.to_string_lossy().into(), + relative_path: relative_path.to_string_lossy().into(), + language_id, + position: point_to_lsp(position), + version: 0, + }, + } +} + +fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion { + let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left); + Completion { + position: buffer.anchor_before(position), + text: completion.display_text, + } +} + async fn get_lsp_binary(http: Arc) -> anyhow::Result { ///Check for the latest copilot language server and download it if we haven't already async fn fetch_latest(http: Arc) -> anyhow::Result { @@ -354,17 +418,22 @@ mod tests { Settings::test_async(cx); let http = http::client(); let copilot = cx.add_model(|cx| Copilot::start(http, cx)); - smol::Timer::after(std::time::Duration::from_secs(5)).await; + smol::Timer::after(std::time::Duration::from_secs(2)).await; copilot .update(cx, |copilot, cx| copilot.sign_in(cx)) .await .unwrap(); dbg!(copilot.read_with(cx, |copilot, _| copilot.status())); - let buffer = cx.add_model(|cx| language::Buffer::new(0, "Lorem ipsum dol", cx)); - copilot - .update(cx, |copilot, cx| copilot.completions(&buffer, 15, cx)) + let buffer = cx.add_model(|cx| language::Buffer::new(0, "fn foo() -> ", cx)); + dbg!(copilot + .update(cx, |copilot, cx| copilot.completion(&buffer, 12, cx)) .await - .unwrap(); + .unwrap()); + dbg!(copilot + .update(cx, |copilot, cx| copilot + .completions_cycling(&buffer, 12, cx)) + .await + .unwrap()); } } diff --git a/crates/copilot/src/request.rs b/crates/copilot/src/request.rs index 3fe04532e1..f3a86698e1 100644 --- a/crates/copilot/src/request.rs +++ b/crates/copilot/src/request.rs @@ -114,17 +114,17 @@ pub struct GetCompletionsDocument { #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GetCompletionsResult { - completions: Vec, + pub completions: Vec, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Completion { - text: String, - position: lsp::Position, - uuid: String, - range: lsp::Range, - display_text: String, + pub text: String, + pub position: lsp::Position, + pub uuid: String, + pub range: lsp::Range, + pub display_text: String, } impl lsp::request::Request for GetCompletions { @@ -132,3 +132,11 @@ impl lsp::request::Request for GetCompletions { type Result = GetCompletionsResult; const METHOD: &'static str = "getCompletions"; } + +pub enum GetCompletionsCycling {} + +impl lsp::request::Request for GetCompletionsCycling { + type Params = GetCompletionsParams; + type Result = GetCompletionsResult; + const METHOD: &'static str = "getCompletionsCycling"; +}