assistant: Fix issues when configuring different providers (#15072)

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Bennet Bo Fenner 2024-07-24 11:21:31 +02:00 committed by GitHub
parent ba6c36f370
commit af4b9805c9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 225 additions and 148 deletions

View file

@ -853,7 +853,17 @@
}
},
// Different settings for specific language models.
"language_models": {},
"language_models": {
"anthropic": {
"api_url": "https://api.anthropic.com"
},
"openai": {
"api_url": "https://api.openai.com/v1"
},
"ollama": {
"api_url": "http://localhost:11434"
}
},
// Zed's Prettier integration settings.
// Allows to enable/disable formatting with Prettier
// and configure default Prettier, used when no project-level Prettier installation is found.

View file

@ -23,7 +23,7 @@ use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal
use indexed_docs::IndexedDocsRegistry;
pub(crate) use inline_assistant::*;
use language_model::{
LanguageModelId, LanguageModelProviderName, LanguageModelRegistry, LanguageModelResponseMessage,
LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
};
pub(crate) use model_selector::*;
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
@ -231,7 +231,7 @@ fn init_completion_provider(cx: &mut AppContext) {
fn update_active_language_model_from_settings(cx: &mut AppContext) {
let settings = AssistantSettings::get_global(cx);
let provider_name = LanguageModelProviderName::from(settings.default_model.provider.clone());
let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
let model_id = LanguageModelId::from(settings.default_model.model.clone());
let Some(provider) = LanguageModelRegistry::global(cx)

View file

@ -144,8 +144,8 @@ impl AssistantSettingsContent {
fs,
cx,
move |content, _| {
if content.open_ai.is_none() {
content.open_ai =
if content.openai.is_none() {
content.openai =
Some(language_model::settings::OpenAiSettingsContent {
api_url,
low_speed_timeout_in_seconds,
@ -243,7 +243,7 @@ impl AssistantSettingsContent {
pub fn set_model(&mut self, language_model: Arc<dyn LanguageModel>) {
let model = language_model.id().0.to_string();
let provider = language_model.provider_name().0.to_string();
let provider = language_model.provider_id().0.to_string();
match self {
AssistantSettingsContent::Versioned(settings) => match settings {

View file

@ -1438,7 +1438,7 @@ impl Render for PromptEditor {
{
let model_name = available_model.name().0.clone();
let provider =
available_model.provider_name().0.clone();
available_model.provider_id().0.clone();
move |_| {
h_flex()
.w_full()

View file

@ -565,7 +565,7 @@ impl Render for PromptEditor {
{
let model_name = available_model.name().0.clone();
let provider =
available_model.provider_name().0.clone();
available_model.provider_id().0.clone();
move |_| {
h_flex()
.w_full()

View file

@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AppContext, Global, Model, ModelContext, Task};
use language_model::{
LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelRegistry,
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
LanguageModelRequest,
};
use smol::lock::{Semaphore, SemaphoreGuardArc};
@ -89,7 +89,7 @@ impl LanguageModelCompletionProvider {
pub fn set_active_provider(
&mut self,
provider_name: LanguageModelProviderName,
provider_name: LanguageModelProviderId,
cx: &mut ModelContext<Self>,
) {
self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_name);
@ -103,14 +103,19 @@ impl LanguageModelCompletionProvider {
pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
if self.active_model.as_ref().map_or(false, |m| {
m.id() == model.id() && m.provider_name() == model.provider_name()
m.id() == model.id() && m.provider_id() == model.provider_id()
}) {
return;
}
self.active_provider =
LanguageModelRegistry::read_global(cx).provider(&model.provider_name());
self.active_model = Some(model);
LanguageModelRegistry::read_global(cx).provider(&model.provider_id());
self.active_model = Some(model.clone());
if let Some(provider) = self.active_provider.as_ref() {
provider.load_model(model, cx);
}
cx.notify();
}

View file

@ -25,6 +25,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
pub trait LanguageModel: Send + Sync {
fn id(&self) -> LanguageModelId;
fn name(&self) -> LanguageModelName;
fn provider_id(&self) -> LanguageModelProviderId;
fn provider_name(&self) -> LanguageModelProviderName;
fn telemetry_id(&self) -> String;
@ -44,8 +45,10 @@ pub trait LanguageModel: Send + Sync {
}
pub trait LanguageModelProvider: 'static {
fn id(&self) -> LanguageModelProviderId;
fn name(&self) -> LanguageModelProviderName;
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
fn is_authenticated(&self, cx: &AppContext) -> bool;
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
@ -62,6 +65,9 @@ pub struct LanguageModelId(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub struct LanguageModelName(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub struct LanguageModelProviderId(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub struct LanguageModelProviderName(pub SharedString);
@ -77,6 +83,12 @@ impl From<String> for LanguageModelName {
}
}
impl From<String> for LanguageModelProviderId {
fn from(value: String) -> Self {
Self(SharedString::from(value))
}
}
impl From<String> for LanguageModelProviderName {
fn from(value: String) -> Self {
Self(SharedString::from(value))

View file

@ -1,6 +1,5 @@
use anthropic::{stream_completion, Request, RequestMessage};
use anyhow::{anyhow, Result};
use collections::HashMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{
@ -9,7 +8,7 @@ use gpui::{
};
use http_client::HttpClient;
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use std::{collections::BTreeMap, sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::prelude::*;
@ -17,11 +16,12 @@ use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelRequestMessage, Role,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
const PROVIDER_NAME: &str = "anthropic";
const PROVIDER_ID: &str = "anthropic";
const PROVIDER_NAME: &str = "Anthropic";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct AnthropicSettings {
@ -37,7 +37,6 @@ pub struct AnthropicLanguageModelProvider {
struct State {
api_key: Option<String>,
settings: AnthropicSettings,
_subscription: Subscription,
}
@ -45,9 +44,7 @@ impl AnthropicLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
let state = cx.new_model(|cx| State {
api_key: None,
settings: AnthropicSettings::default(),
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
this.settings = AllLanguageModelSettings::get_global(cx).anthropic.clone();
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
});
@ -64,12 +61,16 @@ impl LanguageModelProviderState for AnthropicLanguageModelProvider {
}
impl LanguageModelProvider for AnthropicLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
let mut models = HashMap::default();
let mut models = BTreeMap::default();
// Add base models from anthropic::Model::iter()
for model in anthropic::Model::iter() {
@ -79,7 +80,11 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
}
// Override with available models from settings
for model in &self.state.read(cx).settings.available_models {
for model in AllLanguageModelSettings::get_global(cx)
.anthropic
.available_models
.iter()
{
models.insert(model.id().to_string(), model.clone());
}
@ -104,7 +109,10 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
let api_url = self.state.read(cx).settings.api_url.clone();
let api_url = AllLanguageModelSettings::get_global(cx)
.anthropic
.api_url
.clone();
let state = self.state.clone();
cx.spawn(|mut cx| async move {
let api_key = if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
@ -132,7 +140,8 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let state = self.state.clone();
let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url);
let delete_credentials =
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
cx.spawn(|mut cx| async move {
delete_credentials.await.log_err();
state.update(&mut cx, |this, cx| {
@ -221,6 +230,10 @@ impl LanguageModel for AnthropicModel {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
@ -249,11 +262,13 @@ impl LanguageModel for AnthropicModel {
let request = self.to_anthropic_request(request);
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
(
state.api_key.clone(),
state.settings.api_url.clone(),
state.settings.low_speed_timeout,
settings.api_url.clone(),
settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
@ -365,7 +380,10 @@ impl AuthenticationPrompt {
}
let write_credentials = cx.write_credentials(
&self.state.read(cx).settings.api_url,
AllLanguageModelSettings::get_global(cx)
.anthropic
.api_url
.as_str(),
"Bearer",
api_key.as_bytes(),
);

View file

@ -1,15 +1,15 @@
use super::open_ai::count_open_ai_tokens;
use crate::{
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
LanguageModelName, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
};
use anyhow::Result;
use client::Client;
use collections::HashMap;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
use std::{collections::BTreeMap, sync::Arc};
use strum::IntoEnumIterator;
use ui::prelude::*;
@ -17,6 +17,7 @@ use crate::LanguageModelProvider;
use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request};
pub const PROVIDER_ID: &str = "zed.dev";
pub const PROVIDER_NAME: &str = "zed.dev";
#[derive(Default, Clone, Debug, PartialEq)]
@ -33,7 +34,6 @@ pub struct CloudLanguageModelProvider {
struct State {
client: Arc<Client>,
status: client::Status,
settings: ZedDotDevSettings,
_subscription: Subscription,
}
@ -52,9 +52,7 @@ impl CloudLanguageModelProvider {
let state = cx.new_model(|cx| State {
client: client.clone(),
status,
settings: ZedDotDevSettings::default(),
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
this.settings = AllLanguageModelSettings::get_global(cx).zed_dot_dev.clone();
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
});
@ -90,12 +88,16 @@ impl LanguageModelProviderState for CloudLanguageModelProvider {
}
impl LanguageModelProvider for CloudLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
let mut models = HashMap::default();
let mut models = BTreeMap::default();
// Add base models from CloudModel::iter()
for model in CloudModel::iter() {
@ -105,7 +107,10 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
}
// Override with available models from settings
for model in &self.state.read(cx).settings.available_models {
for model in &AllLanguageModelSettings::get_global(cx)
.zed_dot_dev
.available_models
{
models.insert(model.id().to_string(), model.clone());
}
@ -156,6 +161,10 @@ impl LanguageModel for CloudLanguageModel {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
@ -187,6 +196,9 @@ impl LanguageModel for CloudLanguageModel {
| CloudModel::Claude3Opus
| CloudModel::Claude3Sonnet
| CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx),
CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
count_anthropic_tokens(request, cx)
}
_ => {
let request = self.client.request(proto::CountTokensWithLanguageModel {
model: self.model.id().to_string(),

View file

@ -5,7 +5,8 @@ use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, St
use crate::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest,
};
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
use http_client::Result;
@ -19,8 +20,12 @@ pub fn language_model_name() -> LanguageModelName {
LanguageModelName::from("Fake".to_string())
}
pub fn provider_id() -> LanguageModelProviderId {
LanguageModelProviderId::from("fake".to_string())
}
pub fn provider_name() -> LanguageModelProviderName {
LanguageModelProviderName::from("fake".to_string())
LanguageModelProviderName::from("Fake".to_string())
}
#[derive(Clone, Default)]
@ -35,6 +40,10 @@ impl LanguageModelProviderState for FakeLanguageModelProvider {
}
impl LanguageModelProvider for FakeLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
provider_id()
}
fn name(&self) -> LanguageModelProviderName {
provider_name()
}
@ -125,6 +134,10 @@ impl LanguageModel for FakeLanguageModel {
language_model_name()
}
fn provider_id(&self) -> LanguageModelProviderId {
provider_id()
}
fn provider_name(&self) -> LanguageModelProviderName {
provider_name()
}

View file

@ -2,21 +2,24 @@ use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
use http_client::HttpClient;
use ollama::{get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest};
use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
};
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, ElevationIndex};
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, Role,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
};
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
const PROVIDER_NAME: &str = "ollama";
const PROVIDER_ID: &str = "ollama";
const PROVIDER_NAME: &str = "Ollama";
#[derive(Default, Debug, Clone, PartialEq)]
pub struct OllamaSettings {
@ -32,14 +35,14 @@ pub struct OllamaLanguageModelProvider {
struct State {
http_client: Arc<dyn HttpClient>,
available_models: Vec<ollama::Model>,
settings: OllamaSettings,
_subscription: Subscription,
}
impl State {
fn fetch_models(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
fn fetch_models(&self, cx: &ModelContext<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = self.settings.api_url.clone();
let api_url = settings.api_url.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(|this, mut cx| async move {
@ -66,23 +69,25 @@ impl State {
impl OllamaLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
Self {
let this = Self {
http_client: http_client.clone(),
state: cx.new_model(|cx| State {
http_client,
available_models: Default::default(),
settings: OllamaSettings::default(),
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
this.settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
this.fetch_models(cx).detach_and_log_err(cx);
cx.notify();
}),
}),
}
};
this.fetch_models(cx).detach_and_log_err(cx);
this
}
fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = self.state.read(cx).settings.api_url.clone();
let api_url = settings.api_url.clone();
let state = self.state.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
@ -117,6 +122,10 @@ impl LanguageModelProviderState for OllamaLanguageModelProvider {
}
impl LanguageModelProvider for OllamaLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
@ -131,12 +140,20 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
id: LanguageModelId::from(model.name.clone()),
model: model.clone(),
http_client: self.http_client.clone(),
state: self.state.clone(),
}) as Arc<dyn LanguageModel>
})
.collect()
}
fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let id = model.id().0.to_string();
cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
.detach_and_log_err(cx);
}
fn is_authenticated(&self, cx: &AppContext) -> bool {
!self.state.read(cx).available_models.is_empty()
}
@ -167,7 +184,6 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
pub struct OllamaLanguageModel {
id: LanguageModelId,
model: ollama::Model,
state: gpui::Model<State>,
http_client: Arc<dyn HttpClient>,
}
@ -211,6 +227,14 @@ impl LanguageModel for OllamaLanguageModel {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}
@ -219,10 +243,6 @@ impl LanguageModel for OllamaLanguageModel {
format!("ollama/{}", self.model.id())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn count_tokens(
&self,
request: LanguageModelRequest,
@ -248,11 +268,9 @@ impl LanguageModel for OllamaLanguageModel {
let request = self.to_ollama_request(request);
let http_client = self.http_client.clone();
let Ok((api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
(
state.settings.api_url.clone(),
state.settings.low_speed_timeout,
)
let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
(settings.api_url.clone(), settings.low_speed_timeout)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};

View file

@ -1,5 +1,5 @@
use anyhow::{anyhow, Result};
use collections::HashMap;
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt};
use gpui::{
@ -17,11 +17,12 @@ use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, Role,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
};
const PROVIDER_NAME: &str = "openai";
const PROVIDER_ID: &str = "openai";
const PROVIDER_NAME: &str = "OpenAI";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenAiSettings {
@ -37,7 +38,6 @@ pub struct OpenAiLanguageModelProvider {
struct State {
api_key: Option<String>,
settings: OpenAiSettings,
_subscription: Subscription,
}
@ -45,9 +45,7 @@ impl OpenAiLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
let state = cx.new_model(|cx| State {
api_key: None,
settings: OpenAiSettings::default(),
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
this.settings = AllLanguageModelSettings::get_global(cx).open_ai.clone();
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
cx.notify();
}),
});
@ -65,12 +63,16 @@ impl LanguageModelProviderState for OpenAiLanguageModelProvider {
}
impl LanguageModelProvider for OpenAiLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
let mut models = HashMap::default();
let mut models = BTreeMap::default();
// Add base models from open_ai::Model::iter()
for model in open_ai::Model::iter() {
@ -80,7 +82,10 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
}
// Override with available models from settings
for model in &self.state.read(cx).settings.available_models {
for model in &AllLanguageModelSettings::get_global(cx)
.openai
.available_models
{
models.insert(model.id().to_string(), model.clone());
}
@ -105,7 +110,10 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
let api_url = self.state.read(cx).settings.api_url.clone();
let api_url = AllLanguageModelSettings::get_global(cx)
.openai
.api_url
.clone();
let state = self.state.clone();
cx.spawn(|mut cx| async move {
let api_key = if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
@ -131,7 +139,8 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url);
let settings = &AllLanguageModelSettings::get_global(cx).openai;
let delete_credentials = cx.delete_credentials(&settings.api_url);
let state = self.state.clone();
cx.spawn(|mut cx| async move {
delete_credentials.await.log_err();
@ -188,6 +197,10 @@ impl LanguageModel for OpenAiLanguageModel {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
@ -216,11 +229,12 @@ impl LanguageModel for OpenAiLanguageModel {
let request = self.to_open_ai_request(request);
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).openai;
(
state.api_key.clone(),
state.settings.api_url.clone(),
state.settings.low_speed_timeout,
settings.api_url.clone(),
settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
@ -307,11 +321,9 @@ impl AuthenticationPrompt {
return;
}
let write_credentials = cx.write_credentials(
&self.state.read(cx).settings.api_url,
"Bearer",
api_key.as_bytes(),
);
let settings = &AllLanguageModelSettings::get_global(cx).openai;
let write_credentials =
cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
let state = self.state.clone();
cx.spawn(|_, mut cx| async move {
write_credentials.await?;

View file

@ -9,7 +9,7 @@ use crate::{
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
},
LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
};
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
@ -48,7 +48,7 @@ fn register_language_model_providers(
registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
} else {
registry.unregister_provider(
&LanguageModelProviderName::from(
&LanguageModelProviderId::from(
crate::provider::cloud::PROVIDER_NAME.to_string(),
),
cx,
@ -65,7 +65,7 @@ impl Global for GlobalLanguageModelRegistry {}
#[derive(Default)]
pub struct LanguageModelRegistry {
providers: HashMap<LanguageModelProviderName, Arc<dyn LanguageModelProvider>>,
providers: HashMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
}
impl LanguageModelRegistry {
@ -94,7 +94,7 @@ impl LanguageModelRegistry {
provider: T,
cx: &mut ModelContext<Self>,
) {
let name = provider.name();
let name = provider.id();
if let Some(subscription) = provider.subscribe(cx) {
subscription.detach();
@ -106,7 +106,7 @@ impl LanguageModelRegistry {
pub fn unregister_provider(
&mut self,
name: &LanguageModelProviderName,
name: &LanguageModelProviderId,
cx: &mut ModelContext<Self>,
) {
if self.providers.remove(name).is_some() {
@ -116,7 +116,7 @@ impl LanguageModelRegistry {
pub fn providers(
&self,
) -> impl Iterator<Item = (&LanguageModelProviderName, &Arc<dyn LanguageModelProvider>)> {
) -> impl Iterator<Item = (&LanguageModelProviderId, &Arc<dyn LanguageModelProvider>)> {
self.providers.iter()
}
@ -130,7 +130,7 @@ impl LanguageModelRegistry {
pub fn available_models_grouped_by_provider(
&self,
cx: &AppContext,
) -> HashMap<LanguageModelProviderName, Vec<Arc<dyn LanguageModel>>> {
) -> HashMap<LanguageModelProviderId, Vec<Arc<dyn LanguageModel>>> {
self.providers
.iter()
.map(|(name, provider)| (name.clone(), provider.provided_models(cx)))
@ -139,7 +139,7 @@ impl LanguageModelRegistry {
pub fn provider(
&self,
name: &LanguageModelProviderName,
name: &LanguageModelProviderId,
) -> Option<Arc<dyn LanguageModelProvider>> {
self.providers.get(name).cloned()
}
@ -160,10 +160,10 @@ mod tests {
let providers = registry.read(cx).providers().collect::<Vec<_>>();
assert_eq!(providers.len(), 1);
assert_eq!(providers[0].0, &crate::provider::fake::provider_name());
assert_eq!(providers[0].0, &crate::provider::fake::provider_id());
registry.update(cx, |registry, cx| {
registry.unregister_provider(&crate::provider::fake::provider_name(), cx);
registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
});
let providers = registry.read(cx).providers().collect::<Vec<_>>();

View file

@ -21,9 +21,9 @@ pub fn init(cx: &mut AppContext) {
#[derive(Default)]
pub struct AllLanguageModelSettings {
pub open_ai: OpenAiSettings,
pub anthropic: AnthropicSettings,
pub ollama: OllamaSettings,
pub openai: OpenAiSettings,
pub zed_dot_dev: ZedDotDevSettings,
}
@ -31,7 +31,7 @@ pub struct AllLanguageModelSettings {
pub struct AllLanguageModelSettingsContent {
pub anthropic: Option<AnthropicSettingsContent>,
pub ollama: Option<OllamaSettingsContent>,
pub open_ai: Option<OpenAiSettingsContent>,
pub openai: Option<OpenAiSettingsContent>,
#[serde(rename = "zed.dev")]
pub zed_dot_dev: Option<ZedDotDevSettingsContent>,
}
@ -110,21 +110,21 @@ impl settings::Settings for AllLanguageModelSettings {
}
merge(
&mut settings.open_ai.api_url,
value.open_ai.as_ref().and_then(|s| s.api_url.clone()),
&mut settings.openai.api_url,
value.openai.as_ref().and_then(|s| s.api_url.clone()),
);
if let Some(low_speed_timeout_in_seconds) = value
.open_ai
.openai
.as_ref()
.and_then(|s| s.low_speed_timeout_in_seconds)
{
settings.open_ai.low_speed_timeout =
settings.openai.low_speed_timeout =
Some(Duration::from_secs(low_speed_timeout_in_seconds));
}
merge(
&mut settings.open_ai.available_models,
&mut settings.openai.available_models,
value
.open_ai
.openai
.as_ref()
.and_then(|s| s.available_models.clone()),
);

View file

@ -4,7 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{convert::TryFrom, time::Duration};
use std::{convert::TryFrom, sync::Arc, time::Duration};
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
@ -243,7 +243,7 @@ pub async fn get_models(
}
/// Sends an empty request to Ollama to trigger loading the model
pub async fn preload_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<()> {
pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
let uri = format!("{api_url}/api/generate");
let request = HttpRequest::builder()
.method(Method::POST)

View file

@ -85,12 +85,8 @@ To do so, add the following to your Zed `settings.json`:
```json
{
"assistant": {
"version": "1",
"provider": {
"name": "openai",
"type": "openai",
"default_model": "gpt-4-turbo-preview",
"language_models": {
"openai": {
"api_url": "http://localhost:11434/v1"
}
}
@ -103,51 +99,32 @@ The custom URL here is `http://localhost:11434/v1`.
You can use Ollama with the Zed assistant by making Ollama appear as an OpenAPI endpoint.
1. Add the following to your Zed `settings.json`:
1. Download, for example, the `mistral` model with Ollama:
```
ollama pull mistral
```
2. Make sure that the Ollama server is running. You can start it either via running the Ollama app, or launching:
```
ollama serve
```
3. In the assistant panel, select one of the Ollama models using the model dropdown.
4. (Optional) If you want to change the default url that is used to access the Ollama server, you can do so by adding the following settings:
```json
{
"assistant": {
"version": "1",
"provider": {
"name": "openai",
"type": "openai",
"default_model": "gpt-4-turbo-preview",
"api_url": "http://localhost:11434/v1"
}
```json
{
"language_models": {
"ollama": {
"api_url": "http://localhost:11434"
}
}
```
2. Download, for example, the `mistral` model with Ollama:
```
ollama run mistral
```
3. Copy the model and change its name to match the model in the Zed `settings.json`:
```
ollama cp mistral gpt-4-turbo-preview
```
4. Use `assistant: reset key` (see the [Setup](#setup) section above) and enter the following API key:
```
ollama
```
5. Restart Zed
}
```
### Using Claude 3.5 Sonnet
You can use Claude with the Zed assistant by adding the following settings:
You can use Claude with the Zed assistant by choosing it via the model dropdown in the assistant panel.
```json
"assistant": {
"version": "1",
"provider": {
"default_model": "claude-3-5-sonnet",
"name": "anthropic"
}
},
```
When you save the settings, the assistant panel will open and ask you to add your Anthropic API key.
You need can obtain this key [here](https://console.anthropic.com/settings/keys).
You need can obtain an API key [here](https://console.anthropic.com/settings/keys).
Even if you pay for Claude Pro, you will still have to [pay for additional credits](https://console.anthropic.com/settings/plans) to use it via the API.