diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 04e8764f7b..e73424f0cd 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -4,7 +4,7 @@ mod sign_in; use anyhow::{anyhow, Context, Result}; use async_compression::futures::bufread::GzipDecoder; use async_tar::Archive; -use collections::HashMap; +use collections::{HashMap, HashSet}; use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt}; use gpui::{ actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle, @@ -127,7 +127,7 @@ impl CopilotServer { struct RunningCopilotServer { lsp: Arc, sign_in_status: SignInStatus, - registered_buffers: HashMap, + registered_buffers: HashMap, } #[derive(Clone, Debug)] @@ -163,7 +163,6 @@ impl Status { } struct RegisteredBuffer { - id: u64, uri: lsp::Url, language_id: String, snapshot: BufferSnapshot, @@ -178,13 +177,13 @@ impl RegisteredBuffer { buffer: &ModelHandle, cx: &mut ModelContext, ) -> oneshot::Receiver<(i32, BufferSnapshot)> { - let id = self.id; let (done_tx, done_rx) = oneshot::channel(); if buffer.read(cx).version() == self.snapshot.version { let _ = done_tx.send((self.snapshot_version, self.snapshot.clone())); } else { let buffer = buffer.downgrade(); + let id = buffer.id(); let prev_pending_change = mem::replace(&mut self.pending_buffer_change, Task::ready(None)); self.pending_buffer_change = cx.spawn_weak(|copilot, mut cx| async move { @@ -268,7 +267,7 @@ pub struct Copilot { http: Arc, node_runtime: Arc, server: CopilotServer, - buffers: HashMap>, + buffers: HashSet>, } impl Entity for Copilot { @@ -559,8 +558,8 @@ impl Copilot { } pub fn register_buffer(&mut self, buffer: &ModelHandle, cx: &mut ModelContext) { - let buffer_id = buffer.read(cx).remote_id(); - self.buffers.insert(buffer_id, buffer.downgrade()); + let weak_buffer = buffer.downgrade(); + self.buffers.insert(weak_buffer.clone()); if let CopilotServer::Running(RunningCopilotServer { lsp: server, @@ -573,8 +572,7 @@ impl Copilot { return; } - let buffer_id = buffer.read(cx).remote_id(); - registered_buffers.entry(buffer_id).or_insert_with(|| { + registered_buffers.entry(buffer.id()).or_insert_with(|| { let uri: lsp::Url = uri_for_buffer(buffer, cx); let language_id = id_for_language(buffer.read(cx).language()); let snapshot = buffer.read(cx).snapshot(); @@ -592,7 +590,6 @@ impl Copilot { .log_err(); RegisteredBuffer { - id: buffer_id, uri, language_id, snapshot, @@ -603,8 +600,8 @@ impl Copilot { this.handle_buffer_event(buffer, event, cx).log_err(); }), cx.observe_release(buffer, move |this, _buffer, _cx| { - this.buffers.remove(&buffer_id); - this.unregister_buffer(buffer_id); + this.buffers.remove(&weak_buffer); + this.unregister_buffer(&weak_buffer); }), ], } @@ -619,8 +616,7 @@ impl Copilot { cx: &mut ModelContext, ) -> Result<()> { if let Ok(server) = self.server.as_running() { - let buffer_id = buffer.read(cx).remote_id(); - if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer_id) { + if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.id()) { match event { language::Event::Edited => { let _ = registered_buffer.report_changes(&buffer, cx); @@ -674,9 +670,9 @@ impl Copilot { Ok(()) } - fn unregister_buffer(&mut self, buffer_id: u64) { + fn unregister_buffer(&mut self, buffer: &WeakModelHandle) { if let Ok(server) = self.server.as_running() { - if let Some(buffer) = server.registered_buffers.remove(&buffer_id) { + if let Some(buffer) = server.registered_buffers.remove(&buffer.id()) { server .lsp .notify::( @@ -779,8 +775,7 @@ impl Copilot { Err(error) => return Task::ready(Err(error)), }; let lsp = server.lsp.clone(); - let buffer_id = buffer.read(cx).remote_id(); - let registered_buffer = server.registered_buffers.get_mut(&buffer_id).unwrap(); + let registered_buffer = server.registered_buffers.get_mut(&buffer.id()).unwrap(); let snapshot = registered_buffer.report_changes(buffer, cx); let buffer = buffer.read(cx); let uri = registered_buffer.uri.clone(); @@ -850,7 +845,7 @@ impl Copilot { lsp_status: request::SignInStatus, cx: &mut ModelContext, ) { - self.buffers.retain(|_, buffer| buffer.is_upgradable(cx)); + self.buffers.retain(|buffer| buffer.is_upgradable(cx)); if let Ok(server) = self.server.as_running() { match lsp_status { @@ -858,7 +853,7 @@ impl Copilot { | request::SignInStatus::MaybeOk { .. } | request::SignInStatus::AlreadySignedIn { .. } => { server.sign_in_status = SignInStatus::Authorized; - for buffer in self.buffers.values().cloned().collect::>() { + for buffer in self.buffers.iter().cloned().collect::>() { if let Some(buffer) = buffer.upgrade(cx) { self.register_buffer(&buffer, cx); } @@ -866,14 +861,14 @@ impl Copilot { } request::SignInStatus::NotAuthorized { .. } => { server.sign_in_status = SignInStatus::Unauthorized; - for buffer_id in self.buffers.keys().copied().collect::>() { - self.unregister_buffer(buffer_id); + for buffer in self.buffers.iter().copied().collect::>() { + self.unregister_buffer(&buffer); } } request::SignInStatus::NotSignedIn => { server.sign_in_status = SignInStatus::SignedOut; - for buffer_id in self.buffers.keys().copied().collect::>() { - self.unregister_buffer(buffer_id); + for buffer in self.buffers.iter().copied().collect::>() { + self.unregister_buffer(&buffer); } } } @@ -896,9 +891,7 @@ fn uri_for_buffer(buffer: &ModelHandle, cx: &AppContext) -> lsp::Url { if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) { lsp::Url::from_file_path(file.abs_path(cx)).unwrap() } else { - format!("buffer://{}", buffer.read(cx).remote_id()) - .parse() - .unwrap() + format!("buffer://{}", buffer.id()).parse().unwrap() } }