Use local ids, not remote ids, to identify buffers to copilot

This commit is contained in:
Max Brunsfeld 2023-06-05 14:12:19 -07:00
parent 6b89243902
commit 12dd91c89c

View file

@ -4,7 +4,7 @@ mod sign_in;
use anyhow::{anyhow, Context, Result};
use async_compression::futures::bufread::GzipDecoder;
use async_tar::Archive;
use collections::HashMap;
use collections::{HashMap, HashSet};
use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
use gpui::{
actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle,
@ -127,7 +127,7 @@ impl CopilotServer {
struct RunningCopilotServer {
lsp: Arc<LanguageServer>,
sign_in_status: SignInStatus,
registered_buffers: HashMap<u64, RegisteredBuffer>,
registered_buffers: HashMap<usize, RegisteredBuffer>,
}
#[derive(Clone, Debug)]
@ -163,7 +163,6 @@ impl Status {
}
struct RegisteredBuffer {
id: u64,
uri: lsp::Url,
language_id: String,
snapshot: BufferSnapshot,
@ -178,13 +177,13 @@ impl RegisteredBuffer {
buffer: &ModelHandle<Buffer>,
cx: &mut ModelContext<Copilot>,
) -> oneshot::Receiver<(i32, BufferSnapshot)> {
let id = self.id;
let (done_tx, done_rx) = oneshot::channel();
if buffer.read(cx).version() == self.snapshot.version {
let _ = done_tx.send((self.snapshot_version, self.snapshot.clone()));
} else {
let buffer = buffer.downgrade();
let id = buffer.id();
let prev_pending_change =
mem::replace(&mut self.pending_buffer_change, Task::ready(None));
self.pending_buffer_change = cx.spawn_weak(|copilot, mut cx| async move {
@ -268,7 +267,7 @@ pub struct Copilot {
http: Arc<dyn HttpClient>,
node_runtime: Arc<NodeRuntime>,
server: CopilotServer,
buffers: HashMap<u64, WeakModelHandle<Buffer>>,
buffers: HashSet<WeakModelHandle<Buffer>>,
}
impl Entity for Copilot {
@ -559,8 +558,8 @@ impl Copilot {
}
pub fn register_buffer(&mut self, buffer: &ModelHandle<Buffer>, cx: &mut ModelContext<Self>) {
let buffer_id = buffer.read(cx).remote_id();
self.buffers.insert(buffer_id, buffer.downgrade());
let weak_buffer = buffer.downgrade();
self.buffers.insert(weak_buffer.clone());
if let CopilotServer::Running(RunningCopilotServer {
lsp: server,
@ -573,8 +572,7 @@ impl Copilot {
return;
}
let buffer_id = buffer.read(cx).remote_id();
registered_buffers.entry(buffer_id).or_insert_with(|| {
registered_buffers.entry(buffer.id()).or_insert_with(|| {
let uri: lsp::Url = uri_for_buffer(buffer, cx);
let language_id = id_for_language(buffer.read(cx).language());
let snapshot = buffer.read(cx).snapshot();
@ -592,7 +590,6 @@ impl Copilot {
.log_err();
RegisteredBuffer {
id: buffer_id,
uri,
language_id,
snapshot,
@ -603,8 +600,8 @@ impl Copilot {
this.handle_buffer_event(buffer, event, cx).log_err();
}),
cx.observe_release(buffer, move |this, _buffer, _cx| {
this.buffers.remove(&buffer_id);
this.unregister_buffer(buffer_id);
this.buffers.remove(&weak_buffer);
this.unregister_buffer(&weak_buffer);
}),
],
}
@ -619,8 +616,7 @@ impl Copilot {
cx: &mut ModelContext<Self>,
) -> Result<()> {
if let Ok(server) = self.server.as_running() {
let buffer_id = buffer.read(cx).remote_id();
if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer_id) {
if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.id()) {
match event {
language::Event::Edited => {
let _ = registered_buffer.report_changes(&buffer, cx);
@ -674,9 +670,9 @@ impl Copilot {
Ok(())
}
fn unregister_buffer(&mut self, buffer_id: u64) {
fn unregister_buffer(&mut self, buffer: &WeakModelHandle<Buffer>) {
if let Ok(server) = self.server.as_running() {
if let Some(buffer) = server.registered_buffers.remove(&buffer_id) {
if let Some(buffer) = server.registered_buffers.remove(&buffer.id()) {
server
.lsp
.notify::<lsp::notification::DidCloseTextDocument>(
@ -779,8 +775,7 @@ impl Copilot {
Err(error) => return Task::ready(Err(error)),
};
let lsp = server.lsp.clone();
let buffer_id = buffer.read(cx).remote_id();
let registered_buffer = server.registered_buffers.get_mut(&buffer_id).unwrap();
let registered_buffer = server.registered_buffers.get_mut(&buffer.id()).unwrap();
let snapshot = registered_buffer.report_changes(buffer, cx);
let buffer = buffer.read(cx);
let uri = registered_buffer.uri.clone();
@ -850,7 +845,7 @@ impl Copilot {
lsp_status: request::SignInStatus,
cx: &mut ModelContext<Self>,
) {
self.buffers.retain(|_, buffer| buffer.is_upgradable(cx));
self.buffers.retain(|buffer| buffer.is_upgradable(cx));
if let Ok(server) = self.server.as_running() {
match lsp_status {
@ -858,7 +853,7 @@ impl Copilot {
| request::SignInStatus::MaybeOk { .. }
| request::SignInStatus::AlreadySignedIn { .. } => {
server.sign_in_status = SignInStatus::Authorized;
for buffer in self.buffers.values().cloned().collect::<Vec<_>>() {
for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
if let Some(buffer) = buffer.upgrade(cx) {
self.register_buffer(&buffer, cx);
}
@ -866,14 +861,14 @@ impl Copilot {
}
request::SignInStatus::NotAuthorized { .. } => {
server.sign_in_status = SignInStatus::Unauthorized;
for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
self.unregister_buffer(buffer_id);
for buffer in self.buffers.iter().copied().collect::<Vec<_>>() {
self.unregister_buffer(&buffer);
}
}
request::SignInStatus::NotSignedIn => {
server.sign_in_status = SignInStatus::SignedOut;
for buffer_id in self.buffers.keys().copied().collect::<Vec<_>>() {
self.unregister_buffer(buffer_id);
for buffer in self.buffers.iter().copied().collect::<Vec<_>>() {
self.unregister_buffer(&buffer);
}
}
}
@ -896,9 +891,7 @@ fn uri_for_buffer(buffer: &ModelHandle<Buffer>, cx: &AppContext) -> lsp::Url {
if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
lsp::Url::from_file_path(file.abs_path(cx)).unwrap()
} else {
format!("buffer://{}", buffer.read(cx).remote_id())
.parse()
.unwrap()
format!("buffer://{}", buffer.id()).parse().unwrap()
}
}