diff --git a/Cargo.lock b/Cargo.lock index 18539ffe0a..8f884a185c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2572,6 +2572,7 @@ dependencies = [ "clock", "collab_ui", "collections", + "context_servers", "ctor", "dashmap 6.0.1", "derive_more", @@ -2818,7 +2819,6 @@ name = "context_servers" version = "0.1.0" dependencies = [ "anyhow", - "async-trait", "collections", "command_palette_hooks", "futures 0.3.30", @@ -4205,7 +4205,6 @@ dependencies = [ "assistant_slash_command", "async-compression", "async-tar", - "async-trait", "client", "collections", "context_servers", @@ -4222,6 +4221,7 @@ dependencies = [ "http_client", "indexed_docs", "language", + "log", "lsp", "node_runtime", "num-format", diff --git a/crates/assistant/src/context_store.rs b/crates/assistant/src/context_store.rs index b82e252bc2..568b04e492 100644 --- a/crates/assistant/src/context_store.rs +++ b/crates/assistant/src/context_store.rs @@ -8,9 +8,8 @@ use anyhow::{anyhow, Context as _, Result}; use client::{proto, telemetry::Telemetry, Client, TypedEnvelope}; use clock::ReplicaId; use collections::HashMap; -use command_palette_hooks::CommandPaletteFilter; -use context_servers::manager::{ContextServerManager, ContextServerSettings}; -use context_servers::{ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE}; +use context_servers::manager::ContextServerManager; +use context_servers::ContextServerFactoryRegistry; use fs::Fs; use futures::StreamExt; use fuzzy::StringMatchCandidate; @@ -22,7 +21,6 @@ use paths::contexts_dir; use project::Project; use regex::Regex; use rpc::AnyProtoClient; -use settings::{Settings as _, SettingsStore}; use std::{ cmp::Reverse, ffi::OsStr, @@ -111,7 +109,11 @@ impl ContextStore { let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await; let this = cx.new_model(|cx: &mut ModelContext| { - let context_server_manager = cx.new_model(|_cx| ContextServerManager::new()); + let context_server_factory_registry = + ContextServerFactoryRegistry::default_global(cx); + let context_server_manager = cx.new_model(|cx| { + ContextServerManager::new(context_server_factory_registry, project.clone(), cx) + }); let mut this = Self { contexts: Vec::new(), contexts_metadata: Vec::new(), @@ -148,91 +150,16 @@ impl ContextStore { this.handle_project_changed(project.clone(), cx); this.synchronize_contexts(cx); this.register_context_server_handlers(cx); - - if project.read(cx).is_local() { - // TODO: At the time when we construct the `ContextStore` we may not have yet initialized the extensions. - // In order to register the context servers when the extension is loaded, we're periodically looping to - // see if there are context servers to register. - // - // I tried doing this in a subscription on the `ExtensionStore`, but it never seemed to fire. - // - // We should find a more elegant way to do this. - let context_server_factory_registry = - ContextServerFactoryRegistry::default_global(cx); - cx.spawn(|context_store, mut cx| async move { - loop { - let mut servers_to_register = Vec::new(); - for (_id, factory) in - context_server_factory_registry.context_server_factories() - { - if let Some(server) = factory(project.clone(), &cx).await.log_err() - { - servers_to_register.push(server); - } - } - - let Some(_) = context_store - .update(&mut cx, |this, cx| { - this.context_server_manager.update(cx, |this, cx| { - for server in servers_to_register { - this.add_server(server, cx).detach_and_log_err(cx); - } - }) - }) - .log_err() - else { - break; - }; - - smol::Timer::after(Duration::from_millis(100)).await; - } - - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - } - this })?; this.update(&mut cx, |this, cx| this.reload(cx))? .await .log_err(); - this.update(&mut cx, |this, cx| { - this.watch_context_server_settings(cx); - }) - .log_err(); - Ok(this) }) } - fn watch_context_server_settings(&self, cx: &mut ModelContext) { - cx.observe_global::(move |this, cx| { - this.context_server_manager.update(cx, |manager, cx| { - let location = this.project.read(cx).worktrees(cx).next().map(|worktree| { - settings::SettingsLocation { - worktree_id: worktree.read(cx).id(), - path: Path::new(""), - } - }); - let settings = ContextServerSettings::get(location, cx); - - manager.maintain_servers(settings, cx); - - let has_any_context_servers = !manager.servers().is_empty(); - CommandPaletteFilter::update_global(cx, |filter, _cx| { - if has_any_context_servers { - filter.show_namespace(CONTEXT_SERVERS_NAMESPACE); - } else { - filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE); - } - }); - }) - }) - .detach(); - } - async fn handle_advertise_contexts( this: Model, envelope: TypedEnvelope, diff --git a/crates/assistant/src/slash_command/context_server_command.rs b/crates/assistant/src/slash_command/context_server_command.rs index 843c2081a7..672af37115 100644 --- a/crates/assistant/src/slash_command/context_server_command.rs +++ b/crates/assistant/src/slash_command/context_server_command.rs @@ -27,7 +27,7 @@ pub struct ContextServerSlashCommand { impl ContextServerSlashCommand { pub fn new( server_manager: Model, - server: &Arc, + server: &Arc, prompt: Prompt, ) -> Self { Self { diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 417353e39d..c5282689f4 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -78,6 +78,7 @@ uuid.workspace = true [dev-dependencies] assistant = { workspace = true, features = ["test-support"] } +context_servers.workspace = true async-trait.workspace = true audio.workspace = true call = { workspace = true, features = ["test-support"] } diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index 751469a700..5ec9a574a1 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -6486,6 +6486,8 @@ async fn test_context_collaboration_with_reconnect( assert_eq!(project.collaborators().len(), 1); }); + cx_a.update(context_servers::init); + cx_b.update(context_servers::init); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let context_store_a = cx_a .update(|cx| { diff --git a/crates/command_palette_hooks/src/command_palette_hooks.rs b/crates/command_palette_hooks/src/command_palette_hooks.rs index 108708c258..c1a61e287c 100644 --- a/crates/command_palette_hooks/src/command_palette_hooks.rs +++ b/crates/command_palette_hooks/src/command_palette_hooks.rs @@ -39,11 +39,13 @@ impl CommandPaletteFilter { } /// Updates the global [`CommandPaletteFilter`] using the given closure. - pub fn update_global(cx: &mut AppContext, update: F) -> R + pub fn update_global(cx: &mut AppContext, update: F) where - F: FnOnce(&mut Self, &mut AppContext) -> R, + F: FnOnce(&mut Self, &mut AppContext), { - cx.update_global(|this: &mut GlobalCommandPaletteFilter, cx| update(&mut this.0, cx)) + if cx.has_global::() { + cx.update_global(|this: &mut GlobalCommandPaletteFilter, cx| update(&mut this.0, cx)) + } } /// Returns whether the given [`Action`] is hidden by the filter. diff --git a/crates/context_servers/Cargo.toml b/crates/context_servers/Cargo.toml index c2453748c3..de1e991887 100644 --- a/crates/context_servers/Cargo.toml +++ b/crates/context_servers/Cargo.toml @@ -13,7 +13,6 @@ path = "src/context_servers.rs" [dependencies] anyhow.workspace = true -async-trait.workspace = true collections.workspace = true command_palette_hooks.workspace = true futures.workspace = true diff --git a/crates/context_servers/src/context_servers.rs b/crates/context_servers/src/context_servers.rs index af7d7f6218..87a98ca14f 100644 --- a/crates/context_servers/src/context_servers.rs +++ b/crates/context_servers/src/context_servers.rs @@ -8,7 +8,6 @@ use command_palette_hooks::CommandPaletteFilter; use gpui::{actions, AppContext}; use settings::Settings; -pub use crate::manager::ContextServer; use crate::manager::ContextServerSettings; pub use crate::registry::ContextServerFactoryRegistry; diff --git a/crates/context_servers/src/manager.rs b/crates/context_servers/src/manager.rs index ec1891463c..fc0c77e821 100644 --- a/crates/context_servers/src/manager.rs +++ b/crates/context_servers/src/manager.rs @@ -15,23 +15,23 @@ //! and react to changes in settings. use std::path::Path; -use std::pin::Pin; use std::sync::Arc; use anyhow::{bail, Result}; -use async_trait::async_trait; -use collections::{HashMap, HashSet}; -use futures::{Future, FutureExt}; -use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task}; +use collections::HashMap; +use command_palette_hooks::CommandPaletteFilter; +use gpui::{AsyncAppContext, EventEmitter, Model, ModelContext, Subscription, Task, WeakModel}; use log; use parking_lot::RwLock; +use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; +use settings::{Settings, SettingsSources, SettingsStore}; +use util::ResultExt as _; use crate::{ client::{self, Client}, - types, + types, ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE, }; #[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)] @@ -66,25 +66,13 @@ impl Settings for ContextServerSettings { } } -#[async_trait(?Send)] -pub trait ContextServer: Send + Sync + 'static { - fn id(&self) -> Arc; - fn config(&self) -> Arc; - fn client(&self) -> Option>; - fn start<'a>( - self: Arc, - cx: &'a AsyncAppContext, - ) -> Pin>>>; - fn stop(&self) -> Result<()>; -} - -pub struct NativeContextServer { +pub struct ContextServer { pub id: Arc, pub config: Arc, pub client: RwLock>>, } -impl NativeContextServer { +impl ContextServer { pub fn new(id: Arc, config: Arc) -> Self { Self { id, @@ -92,61 +80,52 @@ impl NativeContextServer { client: RwLock::new(None), } } -} -#[async_trait(?Send)] -impl ContextServer for NativeContextServer { - fn id(&self) -> Arc { + pub fn id(&self) -> Arc { self.id.clone() } - fn config(&self) -> Arc { + pub fn config(&self) -> Arc { self.config.clone() } - fn client(&self) -> Option> { + pub fn client(&self) -> Option> { self.client.read().clone() } - fn start<'a>( - self: Arc, - cx: &'a AsyncAppContext, - ) -> Pin>>> { - async move { - log::info!("starting context server {}", self.id); - let Some(command) = &self.config.command else { - bail!("no command specified for server {}", self.id); - }; - let client = Client::new( - client::ContextServerId(self.id.clone()), - client::ModelContextServerBinary { - executable: Path::new(&command.path).to_path_buf(), - args: command.args.clone(), - env: command.env.clone(), - }, - cx.clone(), - )?; + pub async fn start(self: Arc, cx: &AsyncAppContext) -> Result<()> { + log::info!("starting context server {}", self.id); + let Some(command) = &self.config.command else { + bail!("no command specified for server {}", self.id); + }; + let client = Client::new( + client::ContextServerId(self.id.clone()), + client::ModelContextServerBinary { + executable: Path::new(&command.path).to_path_buf(), + args: command.args.clone(), + env: command.env.clone(), + }, + cx.clone(), + )?; - let protocol = crate::protocol::ModelContextProtocol::new(client); - let client_info = types::Implementation { - name: "Zed".to_string(), - version: env!("CARGO_PKG_VERSION").to_string(), - }; - let initialized_protocol = protocol.initialize(client_info).await?; + let protocol = crate::protocol::ModelContextProtocol::new(client); + let client_info = types::Implementation { + name: "Zed".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + }; + let initialized_protocol = protocol.initialize(client_info).await?; - log::debug!( - "context server {} initialized: {:?}", - self.id, - initialized_protocol.initialize, - ); + log::debug!( + "context server {} initialized: {:?}", + self.id, + initialized_protocol.initialize, + ); - *self.client.write() = Some(Arc::new(initialized_protocol)); - Ok(()) - } - .boxed_local() + *self.client.write() = Some(Arc::new(initialized_protocol)); + Ok(()) } - fn stop(&self) -> Result<()> { + pub fn stop(&self) -> Result<()> { let mut client = self.client.write(); if let Some(protocol) = client.take() { drop(protocol); @@ -155,13 +134,13 @@ impl ContextServer for NativeContextServer { } } -/// A Context server manager manages the starting and stopping -/// of all servers. To obtain a server to interact with, a crate -/// must go through the `GlobalContextServerManager` which holds -/// a model to the ContextServerManager. pub struct ContextServerManager { - servers: HashMap, Arc>, - pending_servers: HashSet>, + servers: HashMap, Arc>, + project: Model, + registry: Model, + update_servers_task: Option>>, + needs_server_update: bool, + _subscriptions: Vec, } pub enum Event { @@ -171,74 +150,66 @@ pub enum Event { impl EventEmitter for ContextServerManager {} -impl Default for ContextServerManager { - fn default() -> Self { - Self::new() - } -} - impl ContextServerManager { - pub fn new() -> Self { - Self { + pub fn new( + registry: Model, + project: Model, + cx: &mut ModelContext, + ) -> Self { + let mut this = Self { + _subscriptions: vec![ + cx.observe(®istry, |this, _registry, cx| { + this.available_context_servers_changed(cx); + }), + cx.observe_global::(|this, cx| { + this.available_context_servers_changed(cx); + }), + ], + project, + registry, + needs_server_update: false, servers: HashMap::default(), - pending_servers: HashSet::default(), - } - } - - pub fn add_server( - &mut self, - server: Arc, - cx: &ModelContext, - ) -> Task> { - let server_id = server.id(); - - if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) { - return Task::ready(Ok(())); - } - - let task = { - let server_id = server_id.clone(); - cx.spawn(|this, mut cx| async move { - server.clone().start(&cx).await?; - this.update(&mut cx, |this, cx| { - this.servers.insert(server_id.clone(), server); - this.pending_servers.remove(&server_id); - cx.emit(Event::ServerStarted { - server_id: server_id.clone(), - }); - })?; - Ok(()) - }) + update_servers_task: None, }; - - self.pending_servers.insert(server_id); - task + this.available_context_servers_changed(cx); + this } - pub fn get_server(&self, id: &str) -> Option> { - self.servers.get(id).cloned() + fn available_context_servers_changed(&mut self, cx: &mut ModelContext) { + if self.update_servers_task.is_some() { + self.needs_server_update = true; + } else { + self.update_servers_task = Some(cx.spawn(|this, mut cx| async move { + this.update(&mut cx, |this, _| { + this.needs_server_update = false; + })?; + + Self::maintain_servers(this.clone(), cx.clone()).await?; + + this.update(&mut cx, |this, cx| { + let has_any_context_servers = !this.servers().is_empty(); + if has_any_context_servers { + CommandPaletteFilter::update_global(cx, |filter, _cx| { + filter.show_namespace(CONTEXT_SERVERS_NAMESPACE); + }); + } + + this.update_servers_task.take(); + if this.needs_server_update { + this.available_context_servers_changed(cx); + } + })?; + + Ok(()) + })); + } } - pub fn remove_server( - &mut self, - id: &Arc, - cx: &ModelContext, - ) -> Task> { - let id = id.clone(); - cx.spawn(|this, mut cx| async move { - if let Some(server) = - this.update(&mut cx, |this, _cx| this.servers.remove(id.as_ref()))? - { - server.stop()?; - } - this.update(&mut cx, |this, cx| { - this.pending_servers.remove(id.as_ref()); - cx.emit(Event::ServerStopped { - server_id: id.clone(), - }) - })?; - Ok(()) - }) + pub fn get_server(&self, id: &str) -> Option> { + self.servers + .get(id) + .filter(|server| server.client().is_some()) + .cloned() } pub fn restart_server( @@ -251,7 +222,7 @@ impl ContextServerManager { if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? { server.stop()?; let config = server.config(); - let new_server = Arc::new(NativeContextServer::new(id.clone(), config)); + let new_server = Arc::new(ContextServer::new(id.clone(), config)); new_server.clone().start(&cx).await?; this.update(&mut cx, |this, cx| { this.servers.insert(id.clone(), new_server); @@ -267,45 +238,83 @@ impl ContextServerManager { }) } - pub fn servers(&self) -> Vec> { - self.servers.values().cloned().collect() + pub fn servers(&self) -> Vec> { + self.servers + .values() + .filter(|server| server.client().is_some()) + .cloned() + .collect() } - pub fn maintain_servers(&mut self, settings: &ContextServerSettings, cx: &ModelContext) { - let current_servers = self - .servers() - .into_iter() - .map(|server| (server.id(), server.config())) - .collect::>(); + async fn maintain_servers(this: WeakModel, mut cx: AsyncAppContext) -> Result<()> { + let mut desired_servers = HashMap::default(); - let new_servers = settings - .context_servers - .iter() - .map(|(id, config)| (id.clone(), config.clone())) - .collect::>(); + let (registry, project) = this.update(&mut cx, |this, cx| { + let location = this.project.read(cx).worktrees(cx).next().map(|worktree| { + settings::SettingsLocation { + worktree_id: worktree.read(cx).id(), + path: Path::new(""), + } + }); + let settings = ContextServerSettings::get(location, cx); + desired_servers = settings.context_servers.clone(); - let servers_to_add = new_servers - .iter() - .filter(|(id, _)| !current_servers.contains_key(id.as_ref())) - .map(|(id, config)| (id.clone(), config.clone())) - .collect::>(); + (this.registry.clone(), this.project.clone()) + })?; - let servers_to_remove = current_servers - .keys() - .filter(|id| !new_servers.contains_key(id.as_ref())) - .cloned() - .collect::>(); - - log::trace!("servers_to_add={:?}", servers_to_add); - for (id, config) in servers_to_add { - if config.command.is_some() { - let server = Arc::new(NativeContextServer::new(id, Arc::new(config))); - self.add_server(server, cx).detach_and_log_err(cx); + for (id, factory) in + registry.read_with(&cx, |registry, _| registry.context_server_factories())? + { + let config = desired_servers.entry(id).or_default(); + if config.command.is_none() { + if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() { + config.command = Some(extension_command); + } } } - for id in servers_to_remove { - self.remove_server(&id, cx).detach_and_log_err(cx); + let mut servers_to_start = HashMap::default(); + let mut servers_to_stop = HashMap::default(); + + this.update(&mut cx, |this, _cx| { + this.servers.retain(|id, server| { + if desired_servers.contains_key(id) { + true + } else { + servers_to_stop.insert(id.clone(), server.clone()); + false + } + }); + + for (id, config) in desired_servers { + let existing_config = this.servers.get(&id).map(|server| server.config()); + if existing_config.as_deref() != Some(&config) { + let config = Arc::new(config); + let server = Arc::new(ContextServer::new(id.clone(), config)); + servers_to_start.insert(id.clone(), server.clone()); + let old_server = this.servers.insert(id.clone(), server); + if let Some(old_server) = old_server { + servers_to_stop.insert(id, old_server); + } + } + } + })?; + + for (id, server) in servers_to_stop { + server.stop().log_err(); + this.update(&mut cx, |_, cx| { + cx.emit(Event::ServerStopped { server_id: id }) + })?; } + + for (id, server) in servers_to_start { + if server.start(&cx).await.log_err().is_some() { + this.update(&mut cx, |_, cx| { + cx.emit(Event::ServerStarted { server_id: id }) + })?; + } + } + + Ok(()) } } diff --git a/crates/context_servers/src/registry.rs b/crates/context_servers/src/registry.rs index ae27b6d42f..c17c65370a 100644 --- a/crates/context_servers/src/registry.rs +++ b/crates/context_servers/src/registry.rs @@ -2,75 +2,61 @@ use std::sync::Arc; use anyhow::Result; use collections::HashMap; -use gpui::{AppContext, AsyncAppContext, Global, Model, ReadGlobal, Task}; -use parking_lot::RwLock; +use gpui::{AppContext, AsyncAppContext, Context, Global, Model, ReadGlobal, Task}; use project::Project; -use crate::ContextServer; +use crate::manager::ServerCommand; pub type ContextServerFactory = Arc< - dyn Fn(Model, &AsyncAppContext) -> Task>> - + Send - + Sync - + 'static, + dyn Fn(Model, &AsyncAppContext) -> Task> + Send + Sync + 'static, >; -#[derive(Default)] -struct GlobalContextServerFactoryRegistry(Arc); +struct GlobalContextServerFactoryRegistry(Model); impl Global for GlobalContextServerFactoryRegistry {} -#[derive(Default)] -struct ContextServerFactoryRegistryState { - context_servers: HashMap, ContextServerFactory>, -} - #[derive(Default)] pub struct ContextServerFactoryRegistry { - state: RwLock, + context_servers: HashMap, ContextServerFactory>, } impl ContextServerFactoryRegistry { /// Returns the global [`ContextServerFactoryRegistry`]. - pub fn global(cx: &AppContext) -> Arc { + pub fn global(cx: &AppContext) -> Model { GlobalContextServerFactoryRegistry::global(cx).0.clone() } /// Returns the global [`ContextServerFactoryRegistry`]. /// /// Inserts a default [`ContextServerFactoryRegistry`] if one does not yet exist. - pub fn default_global(cx: &mut AppContext) -> Arc { - cx.default_global::() - .0 - .clone() + pub fn default_global(cx: &mut AppContext) -> Model { + if !cx.has_global::() { + let registry = cx.new_model(|_| Self::new()); + cx.set_global(GlobalContextServerFactoryRegistry(registry)); + } + cx.global::().0.clone() } - pub fn new() -> Arc { - Arc::new(Self { - state: RwLock::new(ContextServerFactoryRegistryState { - context_servers: HashMap::default(), - }), - }) + pub fn new() -> Self { + Self { + context_servers: HashMap::default(), + } } pub fn context_server_factories(&self) -> Vec<(Arc, ContextServerFactory)> { - self.state - .read() - .context_servers + self.context_servers .iter() .map(|(id, factory)| (id.clone(), factory.clone())) .collect() } /// Registers the provided [`ContextServerFactory`]. - pub fn register_server_factory(&self, id: Arc, factory: ContextServerFactory) { - let mut state = self.state.write(); - state.context_servers.insert(id, factory); + pub fn register_server_factory(&mut self, id: Arc, factory: ContextServerFactory) { + self.context_servers.insert(id, factory); } /// Unregisters the [`ContextServerFactory`] for the server with the given ID. - pub fn unregister_server_factory_by_id(&self, server_id: &str) { - let mut state = self.state.write(); - state.context_servers.remove(server_id); + pub fn unregister_server_factory_by_id(&mut self, server_id: &str) { + self.context_servers.remove(server_id); } } diff --git a/crates/extension_host/src/extension_host.rs b/crates/extension_host/src/extension_host.rs index 1aed15a05c..7537734eed 100644 --- a/crates/extension_host/src/extension_host.rs +++ b/crates/extension_host/src/extension_host.rs @@ -141,7 +141,7 @@ pub trait ExtensionRegistrationHooks: Send + Sync + 'static { &self, _id: Arc, _extension: WasmExtension, - _host: Arc, + _cx: &mut AppContext, ) { } @@ -1266,7 +1266,7 @@ impl ExtensionStore { this.registration_hooks.register_context_server( id.clone(), wasm_extension.clone(), - this.wasm_host.clone(), + cx, ); } diff --git a/crates/extensions_ui/Cargo.toml b/crates/extensions_ui/Cargo.toml index 817b86bc3f..458925e87e 100644 --- a/crates/extensions_ui/Cargo.toml +++ b/crates/extensions_ui/Cargo.toml @@ -17,7 +17,6 @@ test-support = [] [dependencies] anyhow.workspace = true assistant_slash_command.workspace = true -async-trait.workspace = true client.workspace = true collections.workspace = true context_servers.workspace = true @@ -31,6 +30,7 @@ fuzzy.workspace = true gpui.workspace = true indexed_docs.workspace = true language.workspace = true +log.workspace = true lsp.workspace = true num-format.workspace = true picker.workspace = true diff --git a/crates/extensions_ui/src/extension_context_server.rs b/crates/extensions_ui/src/extension_context_server.rs deleted file mode 100644 index 21185b358f..0000000000 --- a/crates/extensions_ui/src/extension_context_server.rs +++ /dev/null @@ -1,97 +0,0 @@ -use std::pin::Pin; -use std::sync::Arc; - -use anyhow::{anyhow, Result}; -use async_trait::async_trait; -use context_servers::manager::{NativeContextServer, ServerCommand, ServerConfig}; -use context_servers::protocol::InitializedContextServerProtocol; -use context_servers::ContextServer; -use extension_host::wasm_host::{ExtensionProject, WasmExtension, WasmHost}; -use futures::{Future, FutureExt}; -use gpui::{AsyncAppContext, Model}; -use project::Project; -use wasmtime_wasi::WasiView as _; - -pub struct ExtensionContextServer { - #[allow(unused)] - pub(crate) extension: WasmExtension, - #[allow(unused)] - pub(crate) host: Arc, - id: Arc, - context_server: Arc, -} - -impl ExtensionContextServer { - pub async fn new( - extension: WasmExtension, - host: Arc, - id: Arc, - project: Model, - mut cx: AsyncAppContext, - ) -> Result { - let extension_project = project.update(&mut cx, |project, cx| ExtensionProject { - worktree_ids: project - .visible_worktrees(cx) - .map(|worktree| worktree.read(cx).id().to_proto()) - .collect(), - })?; - let command = extension - .call({ - let id = id.clone(); - |extension, store| { - async move { - let project = store.data_mut().table().push(extension_project)?; - let command = extension - .call_context_server_command(store, id.clone(), project) - .await? - .map_err(|e| anyhow!("{}", e))?; - anyhow::Ok(command) - } - .boxed() - } - }) - .await?; - - let config = Arc::new(ServerConfig { - settings: None, - command: Some(ServerCommand { - path: command.command, - args: command.args, - env: Some(command.env.into_iter().collect()), - }), - }); - - anyhow::Ok(Self { - extension, - host, - id: id.clone(), - context_server: Arc::new(NativeContextServer::new(id, config)), - }) - } -} - -#[async_trait(?Send)] -impl ContextServer for ExtensionContextServer { - fn id(&self) -> Arc { - self.id.clone() - } - - fn config(&self) -> Arc { - self.context_server.config() - } - - fn client(&self) -> Option> { - self.context_server.client() - } - - fn start<'a>( - self: Arc, - cx: &'a AsyncAppContext, - ) -> Pin>>> { - self.context_server.clone().start(cx) - } - - fn stop(&self) -> Result<()> { - self.context_server.stop() - } -} diff --git a/crates/extensions_ui/src/extension_registration_hooks.rs b/crates/extensions_ui/src/extension_registration_hooks.rs index bab8d90c56..f8cd9a3429 100644 --- a/crates/extensions_ui/src/extension_registration_hooks.rs +++ b/crates/extensions_ui/src/extension_registration_hooks.rs @@ -1,19 +1,21 @@ use std::{path::PathBuf, sync::Arc}; -use anyhow::Result; +use anyhow::{anyhow, Result}; use assistant_slash_command::{ExtensionSlashCommand, SlashCommandRegistry}; +use context_servers::manager::ServerCommand; use context_servers::ContextServerFactoryRegistry; +use db::smol::future::FutureExt as _; use extension::Extension; +use extension_host::wasm_host::ExtensionProject; use extension_host::{extension_lsp_adapter::ExtensionLspAdapter, wasm_host}; use fs::Fs; -use gpui::{AppContext, BackgroundExecutor, Task}; +use gpui::{AppContext, BackgroundExecutor, Model, Task}; use indexed_docs::{ExtensionIndexedDocsProvider, IndexedDocsRegistry, ProviderId}; use language::{LanguageRegistry, LanguageServerBinaryStatus, LoadedLanguage}; use snippet_provider::SnippetRegistry; use theme::{ThemeRegistry, ThemeSettings}; use ui::SharedString; - -use crate::extension_context_server::ExtensionContextServer; +use wasmtime_wasi::WasiView as _; pub struct ConcreteExtensionRegistrationHooks { slash_command_registry: Arc, @@ -21,7 +23,7 @@ pub struct ConcreteExtensionRegistrationHooks { indexed_docs_registry: Arc, snippet_registry: Arc, language_registry: Arc, - context_server_factory_registry: Arc, + context_server_factory_registry: Model, executor: BackgroundExecutor, } @@ -32,7 +34,7 @@ impl ConcreteExtensionRegistrationHooks { indexed_docs_registry: Arc, snippet_registry: Arc, language_registry: Arc, - context_server_factory_registry: Arc, + context_server_factory_registry: Model, cx: &AppContext, ) -> Arc { Arc::new(Self { @@ -71,25 +73,66 @@ impl extension_host::ExtensionRegistrationHooks for ConcreteExtensionRegistratio &self, id: Arc, extension: wasm_host::WasmExtension, - host: Arc, + cx: &mut AppContext, ) { self.context_server_factory_registry - .register_server_factory( - id.clone(), - Arc::new({ - move |project, cx| { - let id = id.clone(); - let extension = extension.clone(); - let host = host.clone(); - cx.spawn(|cx| async move { - let context_server = - ExtensionContextServer::new(extension, host, id, project, cx) + .update(cx, |registry, _| { + registry.register_server_factory( + id.clone(), + Arc::new({ + move |project, cx| { + log::info!( + "loading command for context server {id} from extension {}", + extension.manifest.id + ); + + let id = id.clone(); + let extension = extension.clone(); + cx.spawn(|mut cx| async move { + let extension_project = + project.update(&mut cx, |project, cx| ExtensionProject { + worktree_ids: project + .visible_worktrees(cx) + .map(|worktree| worktree.read(cx).id().to_proto()) + .collect(), + })?; + + let command = extension + .call({ + let id = id.clone(); + |extension, store| { + async move { + let project = store + .data_mut() + .table() + .push(extension_project)?; + let command = extension + .call_context_server_command( + store, + id.clone(), + project, + ) + .await? + .map_err(|e| anyhow!("{}", e))?; + anyhow::Ok(command) + } + .boxed() + } + }) .await?; - anyhow::Ok(Arc::new(context_server) as _) - }) - } - }), - ); + + log::info!("loaded command for context server {id}: {command:?}"); + + Ok(ServerCommand { + path: command.command, + args: command.args, + env: Some(command.env.into_iter().collect()), + }) + }) + } + }), + ) + }); } fn register_docs_provider(&self, extension: Arc, provider_id: Arc) { diff --git a/crates/extensions_ui/src/extension_store_test.rs b/crates/extensions_ui/src/extension_store_test.rs index 9c5140272d..90b1f7ebc7 100644 --- a/crates/extensions_ui/src/extension_store_test.rs +++ b/crates/extensions_ui/src/extension_store_test.rs @@ -268,7 +268,7 @@ async fn test_extension_store(cx: &mut TestAppContext) { let slash_command_registry = SlashCommandRegistry::new(); let indexed_docs_registry = Arc::new(IndexedDocsRegistry::new(cx.executor())); let snippet_registry = Arc::new(SnippetRegistry::new()); - let context_server_factory_registry = ContextServerFactoryRegistry::new(); + let context_server_factory_registry = cx.new_model(|_| ContextServerFactoryRegistry::new()); let node_runtime = NodeRuntime::unavailable(); let store = cx.new_model(|cx| { @@ -508,7 +508,7 @@ async fn test_extension_store_with_test_extension(cx: &mut TestAppContext) { let slash_command_registry = SlashCommandRegistry::new(); let indexed_docs_registry = Arc::new(IndexedDocsRegistry::new(cx.executor())); let snippet_registry = Arc::new(SnippetRegistry::new()); - let context_server_factory_registry = ContextServerFactoryRegistry::new(); + let context_server_factory_registry = cx.new_model(|_| ContextServerFactoryRegistry::new()); let node_runtime = NodeRuntime::unavailable(); let mut status_updates = language_registry.language_server_binary_statuses(); diff --git a/crates/extensions_ui/src/extensions_ui.rs b/crates/extensions_ui/src/extensions_ui.rs index cd0fe94fa3..d91e49533f 100644 --- a/crates/extensions_ui/src/extensions_ui.rs +++ b/crates/extensions_ui/src/extensions_ui.rs @@ -1,5 +1,4 @@ mod components; -mod extension_context_server; mod extension_registration_hooks; mod extension_suggest; mod extension_version_selector;