mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-28 01:26:48 +00:00
assistant: Fix issues when configuring different providers (#15072)
Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
ba6c36f370
commit
af4b9805c9
16 changed files with 225 additions and 148 deletions
|
@ -853,7 +853,17 @@
|
|||
}
|
||||
},
|
||||
// Different settings for specific language models.
|
||||
"language_models": {},
|
||||
"language_models": {
|
||||
"anthropic": {
|
||||
"api_url": "https://api.anthropic.com"
|
||||
},
|
||||
"openai": {
|
||||
"api_url": "https://api.openai.com/v1"
|
||||
},
|
||||
"ollama": {
|
||||
"api_url": "http://localhost:11434"
|
||||
}
|
||||
},
|
||||
// Zed's Prettier integration settings.
|
||||
// Allows to enable/disable formatting with Prettier
|
||||
// and configure default Prettier, used when no project-level Prettier installation is found.
|
||||
|
|
|
@ -23,7 +23,7 @@ use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal
|
|||
use indexed_docs::IndexedDocsRegistry;
|
||||
pub(crate) use inline_assistant::*;
|
||||
use language_model::{
|
||||
LanguageModelId, LanguageModelProviderName, LanguageModelRegistry, LanguageModelResponseMessage,
|
||||
LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
|
||||
};
|
||||
pub(crate) use model_selector::*;
|
||||
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
|
||||
|
@ -231,7 +231,7 @@ fn init_completion_provider(cx: &mut AppContext) {
|
|||
|
||||
fn update_active_language_model_from_settings(cx: &mut AppContext) {
|
||||
let settings = AssistantSettings::get_global(cx);
|
||||
let provider_name = LanguageModelProviderName::from(settings.default_model.provider.clone());
|
||||
let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
|
||||
let model_id = LanguageModelId::from(settings.default_model.model.clone());
|
||||
|
||||
let Some(provider) = LanguageModelRegistry::global(cx)
|
||||
|
|
|
@ -144,8 +144,8 @@ impl AssistantSettingsContent {
|
|||
fs,
|
||||
cx,
|
||||
move |content, _| {
|
||||
if content.open_ai.is_none() {
|
||||
content.open_ai =
|
||||
if content.openai.is_none() {
|
||||
content.openai =
|
||||
Some(language_model::settings::OpenAiSettingsContent {
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
|
@ -243,7 +243,7 @@ impl AssistantSettingsContent {
|
|||
|
||||
pub fn set_model(&mut self, language_model: Arc<dyn LanguageModel>) {
|
||||
let model = language_model.id().0.to_string();
|
||||
let provider = language_model.provider_name().0.to_string();
|
||||
let provider = language_model.provider_id().0.to_string();
|
||||
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match settings {
|
||||
|
|
|
@ -1438,7 +1438,7 @@ impl Render for PromptEditor {
|
|||
{
|
||||
let model_name = available_model.name().0.clone();
|
||||
let provider =
|
||||
available_model.provider_name().0.clone();
|
||||
available_model.provider_id().0.clone();
|
||||
move |_| {
|
||||
h_flex()
|
||||
.w_full()
|
||||
|
|
|
@ -565,7 +565,7 @@ impl Render for PromptEditor {
|
|||
{
|
||||
let model_name = available_model.name().0.clone();
|
||||
let provider =
|
||||
available_model.provider_name().0.clone();
|
||||
available_model.provider_id().0.clone();
|
||||
move |_| {
|
||||
h_flex()
|
||||
.w_full()
|
||||
|
|
|
@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
|
|||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AppContext, Global, Model, ModelContext, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelRegistry,
|
||||
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
|
||||
LanguageModelRequest,
|
||||
};
|
||||
use smol::lock::{Semaphore, SemaphoreGuardArc};
|
||||
|
@ -89,7 +89,7 @@ impl LanguageModelCompletionProvider {
|
|||
|
||||
pub fn set_active_provider(
|
||||
&mut self,
|
||||
provider_name: LanguageModelProviderName,
|
||||
provider_name: LanguageModelProviderId,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_name);
|
||||
|
@ -103,14 +103,19 @@ impl LanguageModelCompletionProvider {
|
|||
|
||||
pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
|
||||
if self.active_model.as_ref().map_or(false, |m| {
|
||||
m.id() == model.id() && m.provider_name() == model.provider_name()
|
||||
m.id() == model.id() && m.provider_id() == model.provider_id()
|
||||
}) {
|
||||
return;
|
||||
}
|
||||
|
||||
self.active_provider =
|
||||
LanguageModelRegistry::read_global(cx).provider(&model.provider_name());
|
||||
self.active_model = Some(model);
|
||||
LanguageModelRegistry::read_global(cx).provider(&model.provider_id());
|
||||
self.active_model = Some(model.clone());
|
||||
|
||||
if let Some(provider) = self.active_provider.as_ref() {
|
||||
provider.load_model(model, cx);
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
|||
pub trait LanguageModel: Send + Sync {
|
||||
fn id(&self) -> LanguageModelId;
|
||||
fn name(&self) -> LanguageModelName;
|
||||
fn provider_id(&self) -> LanguageModelProviderId;
|
||||
fn provider_name(&self) -> LanguageModelProviderName;
|
||||
fn telemetry_id(&self) -> String;
|
||||
|
||||
|
@ -44,8 +45,10 @@ pub trait LanguageModel: Send + Sync {
|
|||
}
|
||||
|
||||
pub trait LanguageModelProvider: 'static {
|
||||
fn id(&self) -> LanguageModelProviderId;
|
||||
fn name(&self) -> LanguageModelProviderName;
|
||||
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
|
||||
fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
|
||||
fn is_authenticated(&self, cx: &AppContext) -> bool;
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
|
||||
|
@ -62,6 +65,9 @@ pub struct LanguageModelId(pub SharedString);
|
|||
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
|
||||
pub struct LanguageModelName(pub SharedString);
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
|
||||
pub struct LanguageModelProviderId(pub SharedString);
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
|
||||
pub struct LanguageModelProviderName(pub SharedString);
|
||||
|
||||
|
@ -77,6 +83,12 @@ impl From<String> for LanguageModelName {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<String> for LanguageModelProviderId {
|
||||
fn from(value: String) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for LanguageModelProviderName {
|
||||
fn from(value: String) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
use anthropic::{stream_completion, Request, RequestMessage};
|
||||
use anyhow::{anyhow, Result};
|
||||
use collections::HashMap;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{
|
||||
|
@ -9,7 +8,7 @@ use gpui::{
|
|||
};
|
||||
use http_client::HttpClient;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use std::{collections::BTreeMap, sync::Arc, time::Duration};
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
|
@ -17,11 +16,12 @@ use util::ResultExt;
|
|||
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||
};
|
||||
|
||||
const PROVIDER_NAME: &str = "anthropic";
|
||||
const PROVIDER_ID: &str = "anthropic";
|
||||
const PROVIDER_NAME: &str = "Anthropic";
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct AnthropicSettings {
|
||||
|
@ -37,7 +37,6 @@ pub struct AnthropicLanguageModelProvider {
|
|||
|
||||
struct State {
|
||||
api_key: Option<String>,
|
||||
settings: AnthropicSettings,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
|
@ -45,9 +44,7 @@ impl AnthropicLanguageModelProvider {
|
|||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
|
||||
let state = cx.new_model(|cx| State {
|
||||
api_key: None,
|
||||
settings: AnthropicSettings::default(),
|
||||
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
this.settings = AllLanguageModelSettings::get_global(cx).anthropic.clone();
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
|
||||
cx.notify();
|
||||
}),
|
||||
});
|
||||
|
@ -64,12 +61,16 @@ impl LanguageModelProviderState for AnthropicLanguageModelProvider {
|
|||
}
|
||||
|
||||
impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||
let mut models = HashMap::default();
|
||||
let mut models = BTreeMap::default();
|
||||
|
||||
// Add base models from anthropic::Model::iter()
|
||||
for model in anthropic::Model::iter() {
|
||||
|
@ -79,7 +80,11 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
|||
}
|
||||
|
||||
// Override with available models from settings
|
||||
for model in &self.state.read(cx).settings.available_models {
|
||||
for model in AllLanguageModelSettings::get_global(cx)
|
||||
.anthropic
|
||||
.available_models
|
||||
.iter()
|
||||
{
|
||||
models.insert(model.id().to_string(), model.clone());
|
||||
}
|
||||
|
||||
|
@ -104,7 +109,10 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
|||
if self.is_authenticated(cx) {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
let api_url = self.state.read(cx).settings.api_url.clone();
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.anthropic
|
||||
.api_url
|
||||
.clone();
|
||||
let state = self.state.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
let api_key = if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
|
||||
|
@ -132,7 +140,8 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
|||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let state = self.state.clone();
|
||||
let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url);
|
||||
let delete_credentials =
|
||||
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
|
||||
cx.spawn(|mut cx| async move {
|
||||
delete_credentials.await.log_err();
|
||||
state.update(&mut cx, |this, cx| {
|
||||
|
@ -221,6 +230,10 @@ impl LanguageModel for AnthropicModel {
|
|||
LanguageModelName::from(self.model.display_name().to_string())
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
@ -249,11 +262,13 @@ impl LanguageModel for AnthropicModel {
|
|||
let request = self.to_anthropic_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
|
||||
|
||||
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
|
||||
(
|
||||
state.api_key.clone(),
|
||||
state.settings.api_url.clone(),
|
||||
state.settings.low_speed_timeout,
|
||||
settings.api_url.clone(),
|
||||
settings.low_speed_timeout,
|
||||
)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
|
@ -365,7 +380,10 @@ impl AuthenticationPrompt {
|
|||
}
|
||||
|
||||
let write_credentials = cx.write_credentials(
|
||||
&self.state.read(cx).settings.api_url,
|
||||
AllLanguageModelSettings::get_global(cx)
|
||||
.anthropic
|
||||
.api_url
|
||||
.as_str(),
|
||||
"Bearer",
|
||||
api_key.as_bytes(),
|
||||
);
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
use super::open_ai::count_open_ai_tokens;
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use client::Client;
|
||||
use collections::HashMap;
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::prelude::*;
|
||||
|
||||
|
@ -17,6 +17,7 @@ use crate::LanguageModelProvider;
|
|||
|
||||
use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request};
|
||||
|
||||
pub const PROVIDER_ID: &str = "zed.dev";
|
||||
pub const PROVIDER_NAME: &str = "zed.dev";
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
|
@ -33,7 +34,6 @@ pub struct CloudLanguageModelProvider {
|
|||
struct State {
|
||||
client: Arc<Client>,
|
||||
status: client::Status,
|
||||
settings: ZedDotDevSettings,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
|
@ -52,9 +52,7 @@ impl CloudLanguageModelProvider {
|
|||
let state = cx.new_model(|cx| State {
|
||||
client: client.clone(),
|
||||
status,
|
||||
settings: ZedDotDevSettings::default(),
|
||||
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
this.settings = AllLanguageModelSettings::get_global(cx).zed_dot_dev.clone();
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
|
||||
cx.notify();
|
||||
}),
|
||||
});
|
||||
|
@ -90,12 +88,16 @@ impl LanguageModelProviderState for CloudLanguageModelProvider {
|
|||
}
|
||||
|
||||
impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||
let mut models = HashMap::default();
|
||||
let mut models = BTreeMap::default();
|
||||
|
||||
// Add base models from CloudModel::iter()
|
||||
for model in CloudModel::iter() {
|
||||
|
@ -105,7 +107,10 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
}
|
||||
|
||||
// Override with available models from settings
|
||||
for model in &self.state.read(cx).settings.available_models {
|
||||
for model in &AllLanguageModelSettings::get_global(cx)
|
||||
.zed_dot_dev
|
||||
.available_models
|
||||
{
|
||||
models.insert(model.id().to_string(), model.clone());
|
||||
}
|
||||
|
||||
|
@ -156,6 +161,10 @@ impl LanguageModel for CloudLanguageModel {
|
|||
LanguageModelName::from(self.model.display_name().to_string())
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
@ -187,6 +196,9 @@ impl LanguageModel for CloudLanguageModel {
|
|||
| CloudModel::Claude3Opus
|
||||
| CloudModel::Claude3Sonnet
|
||||
| CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx),
|
||||
CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
|
||||
count_anthropic_tokens(request, cx)
|
||||
}
|
||||
_ => {
|
||||
let request = self.client.request(proto::CountTokensWithLanguageModel {
|
||||
model: self.model.id().to_string(),
|
||||
|
|
|
@ -5,7 +5,8 @@ use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, St
|
|||
|
||||
use crate::{
|
||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest,
|
||||
};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
|
||||
use http_client::Result;
|
||||
|
@ -19,8 +20,12 @@ pub fn language_model_name() -> LanguageModelName {
|
|||
LanguageModelName::from("Fake".to_string())
|
||||
}
|
||||
|
||||
pub fn provider_id() -> LanguageModelProviderId {
|
||||
LanguageModelProviderId::from("fake".to_string())
|
||||
}
|
||||
|
||||
pub fn provider_name() -> LanguageModelProviderName {
|
||||
LanguageModelProviderName::from("fake".to_string())
|
||||
LanguageModelProviderName::from("Fake".to_string())
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
|
@ -35,6 +40,10 @@ impl LanguageModelProviderState for FakeLanguageModelProvider {
|
|||
}
|
||||
|
||||
impl LanguageModelProvider for FakeLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
provider_id()
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
provider_name()
|
||||
}
|
||||
|
@ -125,6 +134,10 @@ impl LanguageModel for FakeLanguageModel {
|
|||
language_model_name()
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
provider_id()
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
provider_name()
|
||||
}
|
||||
|
|
|
@ -2,21 +2,24 @@ use anyhow::{anyhow, Result};
|
|||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
|
||||
use http_client::HttpClient;
|
||||
use ollama::{get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest};
|
||||
use ollama::{
|
||||
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
|
||||
};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use ui::{prelude::*, ButtonLike, ElevationIndex};
|
||||
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest, Role,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, Role,
|
||||
};
|
||||
|
||||
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
|
||||
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
|
||||
|
||||
const PROVIDER_NAME: &str = "ollama";
|
||||
const PROVIDER_ID: &str = "ollama";
|
||||
const PROVIDER_NAME: &str = "Ollama";
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq)]
|
||||
pub struct OllamaSettings {
|
||||
|
@ -32,14 +35,14 @@ pub struct OllamaLanguageModelProvider {
|
|||
struct State {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
available_models: Vec<ollama::Model>,
|
||||
settings: OllamaSettings,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn fetch_models(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
|
||||
fn fetch_models(&self, cx: &ModelContext<Self>) -> Task<Result<()>> {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = self.settings.api_url.clone();
|
||||
let api_url = settings.api_url.clone();
|
||||
|
||||
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
|
@ -66,23 +69,25 @@ impl State {
|
|||
|
||||
impl OllamaLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
|
||||
Self {
|
||||
let this = Self {
|
||||
http_client: http_client.clone(),
|
||||
state: cx.new_model(|cx| State {
|
||||
http_client,
|
||||
available_models: Default::default(),
|
||||
settings: OllamaSettings::default(),
|
||||
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
this.settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
|
||||
this.fetch_models(cx).detach_and_log_err(cx);
|
||||
cx.notify();
|
||||
}),
|
||||
}),
|
||||
}
|
||||
};
|
||||
this.fetch_models(cx).detach_and_log_err(cx);
|
||||
this
|
||||
}
|
||||
|
||||
fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = self.state.read(cx).settings.api_url.clone();
|
||||
let api_url = settings.api_url.clone();
|
||||
|
||||
let state = self.state.clone();
|
||||
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
|
||||
|
@ -117,6 +122,10 @@ impl LanguageModelProviderState for OllamaLanguageModelProvider {
|
|||
}
|
||||
|
||||
impl LanguageModelProvider for OllamaLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
@ -131,12 +140,20 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
|
|||
id: LanguageModelId::from(model.name.clone()),
|
||||
model: model.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
state: self.state.clone(),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = settings.api_url.clone();
|
||||
let id = model.id().0.to_string();
|
||||
cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn is_authenticated(&self, cx: &AppContext) -> bool {
|
||||
!self.state.read(cx).available_models.is_empty()
|
||||
}
|
||||
|
@ -167,7 +184,6 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
|
|||
pub struct OllamaLanguageModel {
|
||||
id: LanguageModelId,
|
||||
model: ollama::Model,
|
||||
state: gpui::Model<State>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
}
|
||||
|
||||
|
@ -211,6 +227,14 @@ impl LanguageModel for OllamaLanguageModel {
|
|||
LanguageModelName::from(self.model.display_name().to_string())
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
@ -219,10 +243,6 @@ impl LanguageModel for OllamaLanguageModel {
|
|||
format!("ollama/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
|
@ -248,11 +268,9 @@ impl LanguageModel for OllamaLanguageModel {
|
|||
let request = self.to_ollama_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
|
||||
(
|
||||
state.settings.api_url.clone(),
|
||||
state.settings.low_speed_timeout,
|
||||
)
|
||||
let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
(settings.api_url.clone(), settings.low_speed_timeout)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use collections::HashMap;
|
||||
use collections::BTreeMap;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, FutureExt, StreamExt};
|
||||
use gpui::{
|
||||
|
@ -17,11 +17,12 @@ use util::ResultExt;
|
|||
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest, Role,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, Role,
|
||||
};
|
||||
|
||||
const PROVIDER_NAME: &str = "openai";
|
||||
const PROVIDER_ID: &str = "openai";
|
||||
const PROVIDER_NAME: &str = "OpenAI";
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct OpenAiSettings {
|
||||
|
@ -37,7 +38,6 @@ pub struct OpenAiLanguageModelProvider {
|
|||
|
||||
struct State {
|
||||
api_key: Option<String>,
|
||||
settings: OpenAiSettings,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
|
@ -45,9 +45,7 @@ impl OpenAiLanguageModelProvider {
|
|||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
|
||||
let state = cx.new_model(|cx| State {
|
||||
api_key: None,
|
||||
settings: OpenAiSettings::default(),
|
||||
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
this.settings = AllLanguageModelSettings::get_global(cx).open_ai.clone();
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
|
||||
cx.notify();
|
||||
}),
|
||||
});
|
||||
|
@ -65,12 +63,16 @@ impl LanguageModelProviderState for OpenAiLanguageModelProvider {
|
|||
}
|
||||
|
||||
impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||
let mut models = HashMap::default();
|
||||
let mut models = BTreeMap::default();
|
||||
|
||||
// Add base models from open_ai::Model::iter()
|
||||
for model in open_ai::Model::iter() {
|
||||
|
@ -80,7 +82,10 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
|||
}
|
||||
|
||||
// Override with available models from settings
|
||||
for model in &self.state.read(cx).settings.available_models {
|
||||
for model in &AllLanguageModelSettings::get_global(cx)
|
||||
.openai
|
||||
.available_models
|
||||
{
|
||||
models.insert(model.id().to_string(), model.clone());
|
||||
}
|
||||
|
||||
|
@ -105,7 +110,10 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
|||
if self.is_authenticated(cx) {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
let api_url = self.state.read(cx).settings.api_url.clone();
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.openai
|
||||
.api_url
|
||||
.clone();
|
||||
let state = self.state.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
let api_key = if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
|
||||
|
@ -131,7 +139,8 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
|||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).openai;
|
||||
let delete_credentials = cx.delete_credentials(&settings.api_url);
|
||||
let state = self.state.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
delete_credentials.await.log_err();
|
||||
|
@ -188,6 +197,10 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||
LanguageModelName::from(self.model.display_name().to_string())
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
@ -216,11 +229,12 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||
let request = self.to_open_ai_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
|
||||
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).openai;
|
||||
(
|
||||
state.api_key.clone(),
|
||||
state.settings.api_url.clone(),
|
||||
state.settings.low_speed_timeout,
|
||||
settings.api_url.clone(),
|
||||
settings.low_speed_timeout,
|
||||
)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
|
@ -307,11 +321,9 @@ impl AuthenticationPrompt {
|
|||
return;
|
||||
}
|
||||
|
||||
let write_credentials = cx.write_credentials(
|
||||
&self.state.read(cx).settings.api_url,
|
||||
"Bearer",
|
||||
api_key.as_bytes(),
|
||||
);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).openai;
|
||||
let write_credentials =
|
||||
cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
|
||||
let state = self.state.clone();
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
write_credentials.await?;
|
||||
|
|
|
@ -9,7 +9,7 @@ use crate::{
|
|||
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
|
||||
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
|
||||
},
|
||||
LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
|
||||
};
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
|
@ -48,7 +48,7 @@ fn register_language_model_providers(
|
|||
registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
|
||||
} else {
|
||||
registry.unregister_provider(
|
||||
&LanguageModelProviderName::from(
|
||||
&LanguageModelProviderId::from(
|
||||
crate::provider::cloud::PROVIDER_NAME.to_string(),
|
||||
),
|
||||
cx,
|
||||
|
@ -65,7 +65,7 @@ impl Global for GlobalLanguageModelRegistry {}
|
|||
|
||||
#[derive(Default)]
|
||||
pub struct LanguageModelRegistry {
|
||||
providers: HashMap<LanguageModelProviderName, Arc<dyn LanguageModelProvider>>,
|
||||
providers: HashMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
|
||||
}
|
||||
|
||||
impl LanguageModelRegistry {
|
||||
|
@ -94,7 +94,7 @@ impl LanguageModelRegistry {
|
|||
provider: T,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
let name = provider.name();
|
||||
let name = provider.id();
|
||||
|
||||
if let Some(subscription) = provider.subscribe(cx) {
|
||||
subscription.detach();
|
||||
|
@ -106,7 +106,7 @@ impl LanguageModelRegistry {
|
|||
|
||||
pub fn unregister_provider(
|
||||
&mut self,
|
||||
name: &LanguageModelProviderName,
|
||||
name: &LanguageModelProviderId,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
if self.providers.remove(name).is_some() {
|
||||
|
@ -116,7 +116,7 @@ impl LanguageModelRegistry {
|
|||
|
||||
pub fn providers(
|
||||
&self,
|
||||
) -> impl Iterator<Item = (&LanguageModelProviderName, &Arc<dyn LanguageModelProvider>)> {
|
||||
) -> impl Iterator<Item = (&LanguageModelProviderId, &Arc<dyn LanguageModelProvider>)> {
|
||||
self.providers.iter()
|
||||
}
|
||||
|
||||
|
@ -130,7 +130,7 @@ impl LanguageModelRegistry {
|
|||
pub fn available_models_grouped_by_provider(
|
||||
&self,
|
||||
cx: &AppContext,
|
||||
) -> HashMap<LanguageModelProviderName, Vec<Arc<dyn LanguageModel>>> {
|
||||
) -> HashMap<LanguageModelProviderId, Vec<Arc<dyn LanguageModel>>> {
|
||||
self.providers
|
||||
.iter()
|
||||
.map(|(name, provider)| (name.clone(), provider.provided_models(cx)))
|
||||
|
@ -139,7 +139,7 @@ impl LanguageModelRegistry {
|
|||
|
||||
pub fn provider(
|
||||
&self,
|
||||
name: &LanguageModelProviderName,
|
||||
name: &LanguageModelProviderId,
|
||||
) -> Option<Arc<dyn LanguageModelProvider>> {
|
||||
self.providers.get(name).cloned()
|
||||
}
|
||||
|
@ -160,10 +160,10 @@ mod tests {
|
|||
|
||||
let providers = registry.read(cx).providers().collect::<Vec<_>>();
|
||||
assert_eq!(providers.len(), 1);
|
||||
assert_eq!(providers[0].0, &crate::provider::fake::provider_name());
|
||||
assert_eq!(providers[0].0, &crate::provider::fake::provider_id());
|
||||
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.unregister_provider(&crate::provider::fake::provider_name(), cx);
|
||||
registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
|
||||
});
|
||||
|
||||
let providers = registry.read(cx).providers().collect::<Vec<_>>();
|
||||
|
|
|
@ -21,9 +21,9 @@ pub fn init(cx: &mut AppContext) {
|
|||
|
||||
#[derive(Default)]
|
||||
pub struct AllLanguageModelSettings {
|
||||
pub open_ai: OpenAiSettings,
|
||||
pub anthropic: AnthropicSettings,
|
||||
pub ollama: OllamaSettings,
|
||||
pub openai: OpenAiSettings,
|
||||
pub zed_dot_dev: ZedDotDevSettings,
|
||||
}
|
||||
|
||||
|
@ -31,7 +31,7 @@ pub struct AllLanguageModelSettings {
|
|||
pub struct AllLanguageModelSettingsContent {
|
||||
pub anthropic: Option<AnthropicSettingsContent>,
|
||||
pub ollama: Option<OllamaSettingsContent>,
|
||||
pub open_ai: Option<OpenAiSettingsContent>,
|
||||
pub openai: Option<OpenAiSettingsContent>,
|
||||
#[serde(rename = "zed.dev")]
|
||||
pub zed_dot_dev: Option<ZedDotDevSettingsContent>,
|
||||
}
|
||||
|
@ -110,21 +110,21 @@ impl settings::Settings for AllLanguageModelSettings {
|
|||
}
|
||||
|
||||
merge(
|
||||
&mut settings.open_ai.api_url,
|
||||
value.open_ai.as_ref().and_then(|s| s.api_url.clone()),
|
||||
&mut settings.openai.api_url,
|
||||
value.openai.as_ref().and_then(|s| s.api_url.clone()),
|
||||
);
|
||||
if let Some(low_speed_timeout_in_seconds) = value
|
||||
.open_ai
|
||||
.openai
|
||||
.as_ref()
|
||||
.and_then(|s| s.low_speed_timeout_in_seconds)
|
||||
{
|
||||
settings.open_ai.low_speed_timeout =
|
||||
settings.openai.low_speed_timeout =
|
||||
Some(Duration::from_secs(low_speed_timeout_in_seconds));
|
||||
}
|
||||
merge(
|
||||
&mut settings.open_ai.available_models,
|
||||
&mut settings.openai.available_models,
|
||||
value
|
||||
.open_ai
|
||||
.openai
|
||||
.as_ref()
|
||||
.and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
|
|
|
@ -4,7 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
|||
use isahc::config::Configurable;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{convert::TryFrom, time::Duration};
|
||||
use std::{convert::TryFrom, sync::Arc, time::Duration};
|
||||
|
||||
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
|
||||
|
||||
|
@ -243,7 +243,7 @@ pub async fn get_models(
|
|||
}
|
||||
|
||||
/// Sends an empty request to Ollama to trigger loading the model
|
||||
pub async fn preload_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<()> {
|
||||
pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
|
||||
let uri = format!("{api_url}/api/generate");
|
||||
let request = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
|
|
|
@ -85,12 +85,8 @@ To do so, add the following to your Zed `settings.json`:
|
|||
|
||||
```json
|
||||
{
|
||||
"assistant": {
|
||||
"version": "1",
|
||||
"provider": {
|
||||
"name": "openai",
|
||||
"type": "openai",
|
||||
"default_model": "gpt-4-turbo-preview",
|
||||
"language_models": {
|
||||
"openai": {
|
||||
"api_url": "http://localhost:11434/v1"
|
||||
}
|
||||
}
|
||||
|
@ -103,51 +99,32 @@ The custom URL here is `http://localhost:11434/v1`.
|
|||
|
||||
You can use Ollama with the Zed assistant by making Ollama appear as an OpenAPI endpoint.
|
||||
|
||||
1. Add the following to your Zed `settings.json`:
|
||||
1. Download, for example, the `mistral` model with Ollama:
|
||||
```
|
||||
ollama pull mistral
|
||||
```
|
||||
2. Make sure that the Ollama server is running. You can start it either via running the Ollama app, or launching:
|
||||
```
|
||||
ollama serve
|
||||
```
|
||||
3. In the assistant panel, select one of the Ollama models using the model dropdown.
|
||||
4. (Optional) If you want to change the default url that is used to access the Ollama server, you can do so by adding the following settings:
|
||||
|
||||
```json
|
||||
{
|
||||
"assistant": {
|
||||
"version": "1",
|
||||
"provider": {
|
||||
"name": "openai",
|
||||
"type": "openai",
|
||||
"default_model": "gpt-4-turbo-preview",
|
||||
"api_url": "http://localhost:11434/v1"
|
||||
}
|
||||
```json
|
||||
{
|
||||
"language_models": {
|
||||
"ollama": {
|
||||
"api_url": "http://localhost:11434"
|
||||
}
|
||||
}
|
||||
```
|
||||
2. Download, for example, the `mistral` model with Ollama:
|
||||
```
|
||||
ollama run mistral
|
||||
```
|
||||
3. Copy the model and change its name to match the model in the Zed `settings.json`:
|
||||
```
|
||||
ollama cp mistral gpt-4-turbo-preview
|
||||
```
|
||||
4. Use `assistant: reset key` (see the [Setup](#setup) section above) and enter the following API key:
|
||||
```
|
||||
ollama
|
||||
```
|
||||
5. Restart Zed
|
||||
}
|
||||
```
|
||||
|
||||
### Using Claude 3.5 Sonnet
|
||||
|
||||
You can use Claude with the Zed assistant by adding the following settings:
|
||||
You can use Claude with the Zed assistant by choosing it via the model dropdown in the assistant panel.
|
||||
|
||||
```json
|
||||
"assistant": {
|
||||
"version": "1",
|
||||
"provider": {
|
||||
"default_model": "claude-3-5-sonnet",
|
||||
"name": "anthropic"
|
||||
}
|
||||
},
|
||||
```
|
||||
|
||||
When you save the settings, the assistant panel will open and ask you to add your Anthropic API key.
|
||||
You need can obtain this key [here](https://console.anthropic.com/settings/keys).
|
||||
You need can obtain an API key [here](https://console.anthropic.com/settings/keys).
|
||||
|
||||
Even if you pay for Claude Pro, you will still have to [pay for additional credits](https://console.anthropic.com/settings/plans) to use it via the API.
|
||||
|
||||
|
|
Loading…
Reference in a new issue