Extract completion provider crate (#14823)

We will soon need `semantic_index` to be able to use
`CompletionProvider`. This is currently impossible due to a cyclic crate
dependency, because `CompletionProvider` lives in the `assistant` crate,
which depends on `semantic_index`.

This PR breaks the dependency cycle by extracting two crates out of
`assistant`: `language_model` and `completion`.

Only one piece of logic changed: [this
code](922fcaf5a6 (diff-3857b3707687a4d585f1200eec4c34a7a079eae8d303b4ce5b4fce46234ace9fR61-R69)).
* As of https://github.com/zed-industries/zed/pull/13276, whenever we
ask a given completion provider for its available models, OpenAI
providers would go and ask the global assistant settings whether the
user had configured an `available_models` setting, and if so, return
that.
* This PR changes it so that instead of eagerly asking the assistant
settings for this info (the new crate must not depend on `assistant`, or
else the dependency cycle would be back), OpenAI completion providers
now store the user-configured settings as part of their struct, and
whenever the settings change, we update the provider.

In theory, this change should not change user-visible behavior...but
since it's the only change in this large PR that's more than just moving
code around, I'm mentioning it here in case there's an unexpected
regression in practice! (cc @amtoaer in case you'd like to try out this
branch and verify that the feature is still working the way you expect.)

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
Richard Feldman 2024-07-19 13:35:34 -04:00 committed by GitHub
parent b9a53ffa0b
commit ec487d8f64
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 820 additions and 610 deletions

64
Cargo.lock generated
View file

@ -382,6 +382,7 @@ dependencies = [
"clock",
"collections",
"command_palette_hooks",
"completion",
"ctor",
"editor",
"env_logger",
@ -396,6 +397,7 @@ dependencies = [
"indexed_docs",
"indoc",
"language",
"language_model",
"log",
"menu",
"multi_buffer",
@ -418,13 +420,11 @@ dependencies = [
"settings",
"similar",
"smol",
"strum",
"telemetry_events",
"terminal",
"terminal_view",
"text",
"theme",
"tiktoken-rs",
"toml 0.8.10",
"ui",
"unindent",
@ -2491,6 +2491,7 @@ dependencies = [
"clock",
"collab_ui",
"collections",
"completion",
"ctor",
"dashmap",
"dev_server_projects",
@ -2673,6 +2674,42 @@ dependencies = [
"gpui",
]
[[package]]
name = "completion"
version = "0.1.0"
dependencies = [
"anthropic",
"anyhow",
"client",
"collections",
"ctor",
"editor",
"env_logger",
"futures 0.3.28",
"gpui",
"http 0.1.0",
"language",
"language_model",
"log",
"menu",
"ollama",
"open_ai",
"parking_lot",
"project",
"rand 0.8.5",
"serde",
"serde_json",
"settings",
"smol",
"strum",
"text",
"theme",
"tiktoken-rs",
"ui",
"unindent",
"util",
]
[[package]]
name = "concurrent-queue"
version = "2.2.0"
@ -5996,6 +6033,28 @@ dependencies = [
"util",
]
[[package]]
name = "language_model"
version = "0.1.0"
dependencies = [
"anthropic",
"ctor",
"editor",
"env_logger",
"language",
"log",
"ollama",
"open_ai",
"project",
"proto",
"rand 0.8.5",
"schemars",
"serde",
"strum",
"text",
"unindent",
]
[[package]]
name = "language_selector"
version = "0.1.0"
@ -9510,6 +9569,7 @@ dependencies = [
"client",
"clock",
"collections",
"completion",
"env_logger",
"fs",
"futures 0.3.28",

View file

@ -19,6 +19,7 @@ members = [
"crates/collections",
"crates/command_palette",
"crates/command_palette_hooks",
"crates/completion",
"crates/copilot",
"crates/db",
"crates/dev_server_projects",
@ -50,6 +51,7 @@ members = [
"crates/install_cli",
"crates/journal",
"crates/language",
"crates/language_model",
"crates/language_selector",
"crates/language_tools",
"crates/languages",
@ -176,6 +178,7 @@ collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections" }
command_palette = { path = "crates/command_palette" }
command_palette_hooks = { path = "crates/command_palette_hooks" }
completion = { path = "crates/completion" }
copilot = { path = "crates/copilot" }
db = { path = "crates/db" }
dev_server_projects = { path = "crates/dev_server_projects" }
@ -205,6 +208,7 @@ inline_completion_button = { path = "crates/inline_completion_button" }
install_cli = { path = "crates/install_cli" }
journal = { path = "crates/journal" }
language = { path = "crates/language" }
language_model = { path = "crates/language_model" }
language_selector = { path = "crates/language_selector" }
language_tools = { path = "crates/language_tools" }
languages = { path = "crates/languages" }

View file

@ -33,6 +33,7 @@ client.workspace = true
clock.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
completion.workspace = true
editor.workspace = true
feature_flags.workspace = true
fs.workspace = true
@ -45,6 +46,7 @@ http.workspace = true
indexed_docs.workspace = true
indoc.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
menu.workspace = true
multi_buffer.workspace = true
@ -64,12 +66,10 @@ serde_json.workspace = true
settings.workspace = true
similar.workspace = true
smol.workspace = true
strum.workspace = true
telemetry_events.workspace = true
terminal.workspace = true
terminal_view.workspace = true
theme.workspace = true
tiktoken-rs.workspace = true
toml.workspace = true
ui.workspace = true
util.workspace = true
@ -79,6 +79,7 @@ picker.workspace = true
roxmltree = "0.20.0"
[dev-dependencies]
completion = { workspace = true, features = ["test-support"] }
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true

View file

@ -1,6 +1,5 @@
pub mod assistant_panel;
pub mod assistant_settings;
mod completion_provider;
mod context;
pub mod context_store;
mod inline_assistant;
@ -12,17 +11,20 @@ mod streaming_diff;
mod terminal_inline_assistant;
pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
use assistant_settings::AssistantSettings;
use assistant_slash_command::SlashCommandRegistry;
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
pub use completion_provider::*;
use completion::CompletionProvider;
pub use context::*;
pub use context_store::*;
use fs::Fs;
use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
use gpui::{
actions, impl_actions, AppContext, BorrowAppContext, Global, SharedString, UpdateGlobal,
};
use indexed_docs::IndexedDocsRegistry;
pub(crate) use inline_assistant::*;
use language_model::LanguageModelResponseMessage;
pub(crate) use model_selector::*;
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
use serde::{Deserialize, Serialize};
@ -32,10 +34,7 @@ use slash_command::{
file_command, now_command, project_command, prompt_command, search_command, symbols_command,
tabs_command, term_command,
};
use std::{
fmt::{self, Display},
sync::Arc,
};
use std::sync::Arc;
pub(crate) use streaming_diff::*;
actions!(
@ -73,166 +72,6 @@ impl MessageId {
}
}
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn from_proto(role: i32) -> Role {
match proto::LanguageModelRole::from_i32(role) {
Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
None => Role::User,
}
}
pub fn to_proto(&self) -> proto::LanguageModelRole {
match self {
Role::User => proto::LanguageModelRole::LanguageModelUser,
Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
Role::System => proto::LanguageModelRole::LanguageModelSystem,
}
}
pub fn cycle(self) -> Role {
match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "user"),
Role::Assistant => write!(f, "assistant"),
Role::System => write!(f, "system"),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum LanguageModel {
Cloud(CloudModel),
OpenAi(OpenAiModel),
Anthropic(AnthropicModel),
Ollama(OllamaModel),
}
impl Default for LanguageModel {
fn default() -> Self {
LanguageModel::Cloud(CloudModel::default())
}
}
impl LanguageModel {
pub fn telemetry_id(&self) -> String {
match self {
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
}
}
pub fn display_name(&self) -> String {
match self {
LanguageModel::OpenAi(model) => model.display_name().into(),
LanguageModel::Anthropic(model) => model.display_name().into(),
LanguageModel::Cloud(model) => model.display_name().into(),
LanguageModel::Ollama(model) => model.display_name().into(),
}
}
pub fn max_token_count(&self) -> usize {
match self {
LanguageModel::OpenAi(model) => model.max_token_count(),
LanguageModel::Anthropic(model) => model.max_token_count(),
LanguageModel::Cloud(model) => model.max_token_count(),
LanguageModel::Ollama(model) => model.max_token_count(),
}
}
pub fn id(&self) -> &str {
match self {
LanguageModel::OpenAi(model) => model.id(),
LanguageModel::Anthropic(model) => model.id(),
LanguageModel::Cloud(model) => model.id(),
LanguageModel::Ollama(model) => model.id(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelRequestMessage {
pub role: Role,
pub content: String,
}
impl LanguageModelRequestMessage {
pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
proto::LanguageModelRequestMessage {
role: self.role.to_proto() as i32,
content: self.content.clone(),
tool_calls: Vec::new(),
tool_call_id: None,
}
}
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct LanguageModelRequest {
pub model: LanguageModel,
pub messages: Vec<LanguageModelRequestMessage>,
pub stop: Vec<String>,
pub temperature: f32,
}
impl LanguageModelRequest {
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
proto::CompleteWithLanguageModel {
model: self.model.id().to_string(),
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
stop: self.stop.clone(),
temperature: self.temperature,
tool_choice: None,
tools: Vec::new(),
}
}
/// Before we send the request to the server, we can perform fixups on it appropriate to the model.
pub fn preprocess(&mut self) {
match &self.model {
LanguageModel::OpenAi(_) => {}
LanguageModel::Anthropic(_) => {}
LanguageModel::Ollama(_) => {}
LanguageModel::Cloud(model) => match model {
CloudModel::Claude3Opus
| CloudModel::Claude3Sonnet
| CloudModel::Claude3Haiku
| CloudModel::Claude3_5Sonnet => {
preprocess_anthropic_request(self);
}
_ => {}
},
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct LanguageModelUsage {
pub prompt_tokens: u32,
@ -343,7 +182,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
context_store::init(&client);
prompt_library::init(cx);
completion_provider::init(client.clone(), cx);
init_completion_provider(Arc::clone(&client), cx);
assistant_slash_command::init(cx);
register_slash_commands(cx);
assistant_panel::init(cx);
@ -368,6 +207,20 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
.detach();
}
fn init_completion_provider(client: Arc<Client>, cx: &mut AppContext) {
let provider = assistant_settings::create_provider_from_settings(client.clone(), 0, cx);
cx.set_global(CompletionProvider::new(provider, Some(client)));
let mut settings_version = 0;
cx.observe_global::<SettingsStore>(move |cx| {
settings_version += 1;
cx.update_global::<CompletionProvider, _>(|provider, cx| {
assistant_settings::update_completion_provider_settings(provider, settings_version, cx);
})
})
.detach();
}
fn register_slash_commands(cx: &mut AppContext) {
let slash_command_registry = SlashCommandRegistry::global(cx);
slash_command_registry.register_command(file_command::FileSlashCommand, true);

View file

@ -8,18 +8,18 @@ use crate::{
SlashCommandCompletionProvider, SlashCommandRegistry,
},
terminal_inline_assistant::TerminalInlineAssistant,
Assist, CompletionProvider, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore,
CycleMessageRole, DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep,
EditStepOperations, EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant,
InsertIntoEditor, MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus,
QuoteSelection, RemoteContextMetadata, ResetKey, Role, SavedContextMetadata, Split,
ToggleFocus, ToggleModelSelector,
Assist, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, CycleMessageRole,
DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep, EditStepOperations,
EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant, InsertIntoEditor,
MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus, QuoteSelection,
RemoteContextMetadata, ResetKey, SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector,
};
use anyhow::{anyhow, Result};
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
use breadcrumbs::Breadcrumbs;
use client::proto;
use collections::{BTreeSet, HashMap, HashSet};
use completion::CompletionProvider;
use editor::{
actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
display_map::{
@ -43,6 +43,7 @@ use language::{
language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point,
ToOffset,
};
use language_model::Role;
use multi_buffer::MultiBufferRow;
use picker::{Picker, PickerDelegate};
use project::{Project, ProjectLspAdapterDelegate};

View file

@ -1,166 +1,19 @@
use std::fmt;
use std::{sync::Arc, time::Duration};
use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
pub use anthropic::Model as AnthropicModel;
use gpui::Pixels;
pub use ollama::Model as OllamaModel;
pub use open_ai::Model as OpenAiModel;
use schemars::{
schema::{InstanceType, Metadata, Schema, SchemaObject},
JsonSchema,
};
use serde::{
de::{self, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
use anthropic::Model as AnthropicModel;
use client::Client;
use completion::{
AnthropicCompletionProvider, CloudCompletionProvider, CompletionProvider,
LanguageModelCompletionProvider, OllamaCompletionProvider, OpenAiCompletionProvider,
};
use gpui::{AppContext, Pixels};
use language_model::{CloudModel, LanguageModel};
use ollama::Model as OllamaModel;
use open_ai::Model as OpenAiModel;
use parking_lot::RwLock;
use schemars::{schema::Schema, JsonSchema};
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources};
use strum::{EnumIter, IntoEnumIterator};
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
pub enum CloudModel {
Gpt3Point5Turbo,
Gpt4,
Gpt4Turbo,
#[default]
Gpt4Omni,
Gpt4OmniMini,
Claude3_5Sonnet,
Claude3Opus,
Claude3Sonnet,
Claude3Haiku,
Gemini15Pro,
Gemini15Flash,
Custom(String),
}
impl Serialize for CloudModel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.id())
}
}
impl<'de> Deserialize<'de> for CloudModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ZedDotDevModelVisitor;
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
type Value = CloudModel;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
let model = CloudModel::iter()
.find(|model| model.id() == value)
.unwrap_or_else(|| CloudModel::Custom(value.to_string()));
Ok(model)
}
}
deserializer.deserialize_str(ZedDotDevModelVisitor)
}
}
impl JsonSchema for CloudModel {
fn schema_name() -> String {
"ZedDotDevModel".to_owned()
}
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
let variants = CloudModel::iter()
.filter_map(|model| {
let id = model.id();
if id.is_empty() {
None
} else {
Some(id.to_string())
}
})
.collect::<Vec<_>>();
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
metadata: Some(Box::new(Metadata {
title: Some("ZedDotDevModel".to_owned()),
default: Some(CloudModel::default().id().into()),
examples: variants.into_iter().map(Into::into).collect(),
..Default::default()
})),
..Default::default()
})
}
}
impl CloudModel {
pub fn id(&self) -> &str {
match self {
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
Self::Gpt4 => "gpt-4",
Self::Gpt4Turbo => "gpt-4-turbo-preview",
Self::Gpt4Omni => "gpt-4o",
Self::Gpt4OmniMini => "gpt-4o-mini",
Self::Claude3_5Sonnet => "claude-3-5-sonnet",
Self::Claude3Opus => "claude-3-opus",
Self::Claude3Sonnet => "claude-3-sonnet",
Self::Claude3Haiku => "claude-3-haiku",
Self::Gemini15Pro => "gemini-1.5-pro",
Self::Gemini15Flash => "gemini-1.5-flash",
Self::Custom(id) => id,
}
}
pub fn display_name(&self) -> &str {
match self {
Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
Self::Gpt4 => "GPT 4",
Self::Gpt4Turbo => "GPT 4 Turbo",
Self::Gpt4Omni => "GPT 4 Omni",
Self::Gpt4OmniMini => "GPT 4 Omni Mini",
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Gemini15Pro => "Gemini 1.5 Pro",
Self::Gemini15Flash => "Gemini 1.5 Flash",
Self::Custom(id) => id.as_str(),
}
}
pub fn max_token_count(&self) -> usize {
match self {
Self::Gpt3Point5Turbo => 2048,
Self::Gpt4 => 4096,
Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
Self::Gpt4OmniMini => 128000,
Self::Claude3_5Sonnet
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3Haiku => 200000,
Self::Gemini15Pro => 128000,
Self::Gemini15Flash => 32000,
Self::Custom(_) => 4096, // TODO: Make this configurable
}
}
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
match self {
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
preprocess_anthropic_request(request)
}
_ => {}
}
}
}
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
@ -620,6 +473,124 @@ fn merge<T>(target: &mut T, value: Option<T>) {
}
}
pub fn update_completion_provider_settings(
provider: &mut CompletionProvider,
version: usize,
cx: &mut AppContext,
) {
let updated = match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { model } => provider
.update_current_as::<_, CloudCompletionProvider>(|provider| {
provider.update(model.clone(), version);
}),
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
provider.update(
choose_openai_model(&model, &available_models),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
);
}),
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
} => provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.update(
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
);
}),
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
} => provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
provider.update(
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
cx,
);
}),
};
// Previously configured provider was changed to another one
if updated.is_none() {
provider.update_provider(|client| create_provider_from_settings(client, version, cx));
}
}
pub(crate) fn create_provider_from_settings(
client: Arc<Client>,
settings_version: usize,
cx: &mut AppContext,
) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
)),
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
choose_openai_model(&model, &available_models),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
available_models.clone(),
))),
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
} => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
))),
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
} => Arc::new(RwLock::new(OllamaCompletionProvider::new(
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
cx,
))),
}
}
/// Choose which model to use for openai provider.
/// If the model is not available, try to use the first available model, or fallback to the original model.
fn choose_openai_model(
model: &::open_ai::Model,
available_models: &[::open_ai::Model],
) -> ::open_ai::Model {
available_models
.iter()
.find(|&m| m == model)
.or_else(|| available_models.first())
.unwrap_or_else(|| model)
.clone()
}
#[cfg(test)]
mod tests {
use gpui::{AppContext, UpdateGlobal};

View file

@ -1,12 +1,12 @@
use crate::{
prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider,
LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageStatus, Role,
prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider, MessageId,
MessageStatus,
};
use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
};
use client::{proto, telemetry::Telemetry};
use client::{self, proto, telemetry::Telemetry};
use clock::ReplicaId;
use collections::{HashMap, HashSet};
use fs::Fs;
@ -18,6 +18,8 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
use language::{
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequest, Role};
use open_ai::Model as OpenAiModel;
use paths::contexts_dir;
use project::Project;
@ -2477,9 +2479,10 @@ mod tests {
use crate::{
assistant_panel, prompt_library,
slash_command::{active_command, file_command},
FakeCompletionProvider, MessageId,
MessageId,
};
use assistant_slash_command::{ArgumentCompletion, SlashCommand};
use completion::FakeCompletionProvider;
use fs::FakeFs;
use gpui::{AppContext, TestAppContext, WeakView};
use indoc::indoc;

View file

@ -1,7 +1,6 @@
use crate::{
assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt,
AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, LanguageModelRequest,
LanguageModelRequestMessage, Role, StreamingDiff,
AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, StreamingDiff,
};
use anyhow::{anyhow, Context as _, Result};
use client::telemetry::Telemetry;
@ -28,6 +27,7 @@ use gpui::{
WhiteSpace, WindowContext,
};
use language::{Buffer, Point, Selection, TransactionId};
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
use rope::Rope;
@ -1432,8 +1432,7 @@ impl Render for PromptEditor {
PopoverMenu::new("model-switcher")
.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models(cx)
{
for model in CompletionProvider::global(cx).available_models() {
menu = menu.custom_entry(
{
let model = model.clone();
@ -2606,7 +2605,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
#[cfg(test)]
mod tests {
use super::*;
use crate::FakeCompletionProvider;
use completion::FakeCompletionProvider;
use futures::stream::{self};
use gpui::{Context, TestAppContext};
use indoc::indoc;

View file

@ -23,7 +23,7 @@ impl RenderOnce for ModelSelector {
.with_handle(self.handle)
.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models(cx) {
for model in CompletionProvider::global(cx).available_models() {
menu = menu.custom_entry(
{
let model = model.clone();

View file

@ -1,6 +1,6 @@
use crate::{
slash_command::SlashCommandCompletionProvider, AssistantPanel, CompletionProvider,
InlineAssist, InlineAssistant, LanguageModelRequest, LanguageModelRequestMessage, Role,
InlineAssist, InlineAssistant,
};
use anyhow::{anyhow, Result};
use assets::Assets;
@ -19,6 +19,7 @@ use gpui::{
};
use heed::{types::SerdeBincode, Database, RoTxn};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use parking_lot::RwLock;
use picker::{Picker, PickerDelegate};
use rope::Rope;

View file

@ -1,7 +1,7 @@
use crate::{
assistant_settings::AssistantSettings, humanize_token_count,
prompts::generate_terminal_assistant_prompt, AssistantPanel, AssistantPanelEvent,
CompletionProvider, LanguageModelRequest, LanguageModelRequestMessage, Role,
CompletionProvider,
};
use anyhow::{Context as _, Result};
use client::telemetry::Telemetry;
@ -17,6 +17,7 @@ use gpui::{
Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View, WeakView, WhiteSpace,
};
use language::Buffer;
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use settings::{update_settings_file, Settings};
use std::{
cmp,
@ -558,8 +559,7 @@ impl Render for PromptEditor {
PopoverMenu::new("model-switcher")
.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models(cx)
{
for model in CompletionProvider::global(cx).available_models() {
menu = menu.custom_entry(
{
let model = model.clone();

View file

@ -30,6 +30,7 @@ chrono.workspace = true
clock.workspace = true
clickhouse.workspace = true
collections.workspace = true
completion.workspace = true
dashmap = "5.4"
envy = "0.4.2"
futures.workspace = true
@ -79,6 +80,7 @@ channel.workspace = true
client = { workspace = true, features = ["test-support"] }
collab_ui = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }
completion = { workspace = true, features = ["test-support"] }
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true

View file

@ -295,7 +295,7 @@ impl TestServer {
menu::init();
dev_server_projects::init(client.clone(), cx);
settings::KeymapFile::load_asset(os_keymap, cx).unwrap();
assistant::FakeCompletionProvider::setup_test(cx);
completion::FakeCompletionProvider::setup_test(cx);
assistant::context_store::init(&client);
});

View file

@ -0,0 +1,56 @@
[package]
name = "completion"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/completion.rs"
doctest = false
[features]
test-support = [
"editor/test-support",
"language/test-support",
"project/test-support",
"text/test-support",
]
[dependencies]
anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
client.workspace = true
collections.workspace = true
editor.workspace = true
futures.workspace = true
gpui.workspace = true
http.workspace = true
language_model.workspace = true
log.workspace = true
menu.workspace = true
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
parking_lot.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
strum.workspace = true
theme.workspace = true
tiktoken-rs.workspace = true
ui.workspace = true
util.workspace = true
[dev-dependencies]
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
language = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
text = { workspace = true, features = ["test-support"] }
unindent.workspace = true

View file

@ -0,0 +1 @@
../../LICENSE-GPL

View file

@ -1,14 +1,12 @@
use crate::{
assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
Role,
};
use crate::{count_open_ai_tokens, LanguageModelCompletionProvider, LanguageModelRequestMessage};
use anthropic::{stream_completion, Request, RequestMessage};
use crate::{count_open_ai_tokens, LanguageModelCompletionProvider};
use crate::{CompletionProvider, LanguageModel, LanguageModelRequest};
use anthropic::{stream_completion, Model as AnthropicModel, Request, RequestMessage};
use anyhow::{anyhow, Result};
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
use http::HttpClient;
use language_model::Role;
use settings::Settings;
use std::time::Duration;
use std::{env, sync::Arc};
@ -27,7 +25,7 @@ pub struct AnthropicCompletionProvider {
}
impl LanguageModelCompletionProvider for AnthropicCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
fn available_models(&self) -> Vec<LanguageModel> {
AnthropicModel::iter()
.map(LanguageModel::Anthropic)
.collect()
@ -176,7 +174,7 @@ impl AnthropicCompletionProvider {
}
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
preprocess_anthropic_request(&mut request);
request.preprocess_anthropic();
let model = match request.model {
LanguageModel::Anthropic(model) => model,
@ -213,49 +211,6 @@ impl AnthropicCompletionProvider {
}
}
pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
let mut system_message = String::new();
for message in request.messages.drain(..) {
if message.content.is_empty() {
continue;
}
match message.role {
Role::User | Role::Assistant => {
if let Some(last_message) = new_messages.last_mut() {
if last_message.role == message.role {
last_message.content.push_str("\n\n");
last_message.content.push_str(&message.content);
continue;
}
}
new_messages.push(message);
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
}
}
}
if !system_message.is_empty() {
new_messages.insert(
0,
LanguageModelRequestMessage {
role: Role::System,
content: system_message,
},
);
}
request.messages = new_messages;
}
struct AuthenticationPrompt {
api_key: View<Editor>,
api_url: String,

View file

@ -1,11 +1,12 @@
use crate::{
assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
LanguageModelCompletionProvider, LanguageModelRequest,
count_open_ai_tokens, CompletionProvider, LanguageModel, LanguageModelCompletionProvider,
LanguageModelRequest,
};
use anyhow::{anyhow, Result};
use client::{proto, Client};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
use gpui::{AnyView, AppContext, Task};
use language_model::CloudModel;
use std::{future, sync::Arc};
use strum::IntoEnumIterator;
use ui::prelude::*;
@ -52,7 +53,7 @@ impl CloudCompletionProvider {
}
impl LanguageModelCompletionProvider for CloudCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
fn available_models(&self) -> Vec<LanguageModel> {
let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
Some(custom_model)
} else {

View file

@ -6,52 +6,19 @@ mod ollama;
mod open_ai;
pub use anthropic::*;
use anyhow::Result;
use client::Client;
pub use cloud::*;
#[cfg(any(test, feature = "test-support"))]
pub use fake::*;
use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
use gpui::{AnyView, AppContext, Task, WindowContext};
use language_model::{LanguageModel, LanguageModelRequest};
pub use ollama::*;
pub use open_ai::*;
use parking_lot::RwLock;
use smol::lock::{Semaphore, SemaphoreGuardArc};
use crate::{
assistant_settings::{AssistantProvider, AssistantSettings},
LanguageModel, LanguageModelRequest,
};
use anyhow::Result;
use client::Client;
use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
use settings::{Settings, SettingsStore};
use std::{any::Any, pin::Pin, sync::Arc, task::Poll, time::Duration};
/// Choose which model to use for openai provider.
/// If the model is not available, try to use the first available model, or fallback to the original model.
fn choose_openai_model(
model: &::open_ai::Model,
available_models: &[::open_ai::Model],
) -> ::open_ai::Model {
available_models
.iter()
.find(|&m| m == model)
.or_else(|| available_models.first())
.unwrap_or_else(|| model)
.clone()
}
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let provider = create_provider_from_settings(client.clone(), 0, cx);
cx.set_global(CompletionProvider::new(provider, Some(client)));
let mut settings_version = 0;
cx.observe_global::<SettingsStore>(move |cx| {
settings_version += 1;
cx.update_global::<CompletionProvider, _>(|provider, cx| {
provider.update_settings(settings_version, cx);
})
})
.detach();
}
use std::{any::Any, pin::Pin, sync::Arc, task::Poll};
pub struct CompletionResponse {
inner: BoxStream<'static, Result<String>>,
@ -70,7 +37,7 @@ impl futures::Stream for CompletionResponse {
}
pub trait LanguageModelCompletionProvider: Send + Sync {
fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel>;
fn available_models(&self) -> Vec<LanguageModel>;
fn settings_version(&self) -> usize;
fn is_authenticated(&self) -> bool;
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
@ -110,8 +77,8 @@ impl CompletionProvider {
}
}
pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
self.provider.read().available_models(cx)
pub fn available_models(&self) -> Vec<LanguageModel> {
self.provider.read().available_models()
}
pub fn settings_version(&self) -> usize {
@ -176,6 +143,17 @@ impl CompletionProvider {
Ok(completion)
})
}
pub fn update_provider(
&mut self,
get_provider: impl FnOnce(Arc<Client>) -> Arc<RwLock<dyn LanguageModelCompletionProvider>>,
) {
if let Some(client) = &self.client {
self.provider = get_provider(Arc::clone(client));
} else {
log::warn!("completion provider cannot be updated because its client was not set");
}
}
}
impl gpui::Global for CompletionProvider {}
@ -196,109 +174,6 @@ impl CompletionProvider {
None
}
}
pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) {
let updated = match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { model } => self
.update_current_as::<_, CloudCompletionProvider>(|provider| {
provider.update(model.clone(), version);
}),
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
provider.update(
choose_openai_model(&model, &available_models),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
);
}),
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
} => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.update(
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
);
}),
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
} => self.update_current_as::<_, OllamaCompletionProvider>(|provider| {
provider.update(
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
cx,
);
}),
};
// Previously configured provider was changed to another one
if updated.is_none() {
if let Some(client) = self.client.clone() {
self.provider = create_provider_from_settings(client, version, cx);
} else {
log::warn!("completion provider cannot be created because client is not set");
}
}
}
}
fn create_provider_from_settings(
client: Arc<Client>,
settings_version: usize,
cx: &mut AppContext,
) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
)),
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
choose_openai_model(&model, &available_models),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
))),
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
} => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
))),
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
} => Arc::new(RwLock::new(OllamaCompletionProvider::new(
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
cx,
))),
}
}
#[cfg(test)]
@ -311,8 +186,8 @@ mod tests {
use smol::stream::StreamExt;
use crate::{
completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider,
FakeCompletionProvider, LanguageModelRequest,
CompletionProvider, FakeCompletionProvider, LanguageModelRequest,
MAX_CONCURRENT_COMPLETION_REQUESTS,
};
#[gpui::test]

View file

@ -62,7 +62,7 @@ impl FakeCompletionProvider {
}
impl LanguageModelCompletionProvider for FakeCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
fn available_models(&self) -> Vec<LanguageModel> {
vec![LanguageModel::default()]
}

View file

@ -1,15 +1,14 @@
use crate::LanguageModelCompletionProvider;
use crate::{
assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
};
use crate::{CompletionProvider, LanguageModel, LanguageModelRequest};
use anyhow::Result;
use futures::StreamExt as _;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
use gpui::{AnyView, AppContext, Task};
use http::HttpClient;
use language_model::Role;
use ollama::Model as OllamaModel;
use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
Role as OllamaRole,
};
use std::sync::Arc;
use std::time::Duration;
@ -28,7 +27,7 @@ pub struct OllamaCompletionProvider {
}
impl LanguageModelCompletionProvider for OllamaCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
fn available_models(&self) -> Vec<LanguageModel> {
self.available_models
.iter()
.map(|m| LanguageModel::Ollama(m.clone()))
@ -262,16 +261,6 @@ impl OllamaCompletionProvider {
}
}
impl From<Role> for ollama::Role {
fn from(val: Role) -> Self {
match val {
Role::User => OllamaRole::User,
Role::Assistant => OllamaRole::Assistant,
Role::System => OllamaRole::System,
}
}
}
struct DownloadOllamaMessage {
retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
}

View file

@ -1,15 +1,13 @@
use crate::assistant_settings::CloudModel;
use crate::assistant_settings::{AssistantProvider, AssistantSettings};
use crate::CompletionProvider;
use crate::LanguageModelCompletionProvider;
use crate::{
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
};
use anyhow::{anyhow, Result};
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
use http::HttpClient;
use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
use language_model::{CloudModel, LanguageModel, LanguageModelRequest, Role};
use open_ai::Model as OpenAiModel;
use open_ai::{stream_completion, Request, RequestMessage};
use settings::Settings;
use std::time::Duration;
use std::{env, sync::Arc};
@ -25,6 +23,7 @@ pub struct OpenAiCompletionProvider {
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
available_models_from_settings: Vec<OpenAiModel>,
}
impl OpenAiCompletionProvider {
@ -34,6 +33,7 @@ impl OpenAiCompletionProvider {
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
available_models_from_settings: Vec<OpenAiModel>,
) -> Self {
Self {
api_key: None,
@ -42,6 +42,7 @@ impl OpenAiCompletionProvider {
http_client,
low_speed_timeout,
settings_version,
available_models_from_settings,
}
}
@ -92,19 +93,8 @@ impl OpenAiCompletionProvider {
}
impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
if let AssistantProvider::OpenAi {
available_models, ..
} = &AssistantSettings::get_global(cx).provider
{
if !available_models.is_empty() {
return available_models
.iter()
.cloned()
.map(LanguageModel::OpenAi)
.collect();
}
}
fn available_models(&self) -> Vec<LanguageModel> {
if self.available_models_from_settings.is_empty() {
let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
vec![self.model.clone()]
} else {
@ -116,6 +106,13 @@ impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
.into_iter()
.map(LanguageModel::OpenAi)
.collect()
} else {
self.available_models_from_settings
.iter()
.cloned()
.map(LanguageModel::OpenAi)
.collect()
}
}
fn settings_version(&self) -> usize {
@ -255,16 +252,6 @@ pub fn count_open_ai_tokens(
.boxed()
}
impl From<Role> for open_ai::Role {
fn from(val: Role) -> Self {
match val {
Role::User => OpenAiRole::User,
Role::Assistant => OpenAiRole::Assistant,
Role::System => OpenAiRole::System,
}
}
}
struct AuthenticationPrompt {
api_key: View<Editor>,
api_url: String,

View file

@ -0,0 +1,41 @@
[package]
name = "language_model"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/language_model.rs"
doctest = false
[features]
test-support = [
"editor/test-support",
"language/test-support",
"project/test-support",
"text/test-support",
]
[dependencies]
anthropic = { workspace = true, features = ["schemars"] }
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
schemars.workspace = true
serde.workspace = true
strum.workspace = true
proto = { workspace = true, features = ["test-support"] }
[dev-dependencies]
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
language = { workspace = true, features = ["test-support"] }
log.workspace = true
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
text = { workspace = true, features = ["test-support"] }
unindent.workspace = true

View file

@ -0,0 +1 @@
../../LICENSE-GPL

View file

@ -0,0 +1,7 @@
mod model;
mod request;
mod role;
pub use model::*;
pub use request::*;
pub use role::*;

View file

@ -0,0 +1,160 @@
use crate::LanguageModelRequest;
pub use anthropic::Model as AnthropicModel;
pub use ollama::Model as OllamaModel;
pub use open_ai::Model as OpenAiModel;
use schemars::{
schema::{InstanceType, Metadata, Schema, SchemaObject},
JsonSchema,
};
use serde::{
de::{self, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
};
use std::fmt;
use strum::{EnumIter, IntoEnumIterator};
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
pub enum CloudModel {
Gpt3Point5Turbo,
Gpt4,
Gpt4Turbo,
#[default]
Gpt4Omni,
Gpt4OmniMini,
Claude3_5Sonnet,
Claude3Opus,
Claude3Sonnet,
Claude3Haiku,
Gemini15Pro,
Gemini15Flash,
Custom(String),
}
impl Serialize for CloudModel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.id())
}
}
impl<'de> Deserialize<'de> for CloudModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ZedDotDevModelVisitor;
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
type Value = CloudModel;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
let model = CloudModel::iter()
.find(|model| model.id() == value)
.unwrap_or_else(|| CloudModel::Custom(value.to_string()));
Ok(model)
}
}
deserializer.deserialize_str(ZedDotDevModelVisitor)
}
}
impl JsonSchema for CloudModel {
fn schema_name() -> String {
"ZedDotDevModel".to_owned()
}
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
let variants = CloudModel::iter()
.filter_map(|model| {
let id = model.id();
if id.is_empty() {
None
} else {
Some(id.to_string())
}
})
.collect::<Vec<_>>();
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
metadata: Some(Box::new(Metadata {
title: Some("ZedDotDevModel".to_owned()),
default: Some(CloudModel::default().id().into()),
examples: variants.into_iter().map(Into::into).collect(),
..Default::default()
})),
..Default::default()
})
}
}
impl CloudModel {
pub fn id(&self) -> &str {
match self {
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
Self::Gpt4 => "gpt-4",
Self::Gpt4Turbo => "gpt-4-turbo-preview",
Self::Gpt4Omni => "gpt-4o",
Self::Gpt4OmniMini => "gpt-4o-mini",
Self::Claude3_5Sonnet => "claude-3-5-sonnet",
Self::Claude3Opus => "claude-3-opus",
Self::Claude3Sonnet => "claude-3-sonnet",
Self::Claude3Haiku => "claude-3-haiku",
Self::Gemini15Pro => "gemini-1.5-pro",
Self::Gemini15Flash => "gemini-1.5-flash",
Self::Custom(id) => id,
}
}
pub fn display_name(&self) -> &str {
match self {
Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
Self::Gpt4 => "GPT 4",
Self::Gpt4Turbo => "GPT 4 Turbo",
Self::Gpt4Omni => "GPT 4 Omni",
Self::Gpt4OmniMini => "GPT 4 Omni Mini",
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Gemini15Pro => "Gemini 1.5 Pro",
Self::Gemini15Flash => "Gemini 1.5 Flash",
Self::Custom(id) => id.as_str(),
}
}
pub fn max_token_count(&self) -> usize {
match self {
Self::Gpt3Point5Turbo => 2048,
Self::Gpt4 => 4096,
Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
Self::Gpt4OmniMini => 128000,
Self::Claude3_5Sonnet
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3Haiku => 200000,
Self::Gemini15Pro => 128000,
Self::Gemini15Flash => 32000,
Self::Custom(_) => 4096, // TODO: Make this configurable
}
}
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
match self {
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
request.preprocess_anthropic()
}
_ => {}
}
}
}

View file

@ -0,0 +1,60 @@
pub mod cloud_model;
pub use anthropic::Model as AnthropicModel;
pub use cloud_model::*;
pub use ollama::Model as OllamaModel;
pub use open_ai::Model as OpenAiModel;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum LanguageModel {
Cloud(CloudModel),
OpenAi(OpenAiModel),
Anthropic(AnthropicModel),
Ollama(OllamaModel),
}
impl Default for LanguageModel {
fn default() -> Self {
LanguageModel::Cloud(CloudModel::default())
}
}
impl LanguageModel {
pub fn telemetry_id(&self) -> String {
match self {
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
}
}
pub fn display_name(&self) -> String {
match self {
LanguageModel::OpenAi(model) => model.display_name().into(),
LanguageModel::Anthropic(model) => model.display_name().into(),
LanguageModel::Cloud(model) => model.display_name().into(),
LanguageModel::Ollama(model) => model.display_name().into(),
}
}
pub fn max_token_count(&self) -> usize {
match self {
LanguageModel::OpenAi(model) => model.max_token_count(),
LanguageModel::Anthropic(model) => model.max_token_count(),
LanguageModel::Cloud(model) => model.max_token_count(),
LanguageModel::Ollama(model) => model.max_token_count(),
}
}
pub fn id(&self) -> &str {
match self {
LanguageModel::OpenAi(model) => model.id(),
LanguageModel::Anthropic(model) => model.id(),
LanguageModel::Cloud(model) => model.id(),
LanguageModel::Ollama(model) => model.id(),
}
}
}

View file

@ -0,0 +1,110 @@
use crate::{
model::{CloudModel, LanguageModel},
role::Role,
};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelRequestMessage {
pub role: Role,
pub content: String,
}
impl LanguageModelRequestMessage {
pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
proto::LanguageModelRequestMessage {
role: self.role.to_proto() as i32,
content: self.content.clone(),
tool_calls: Vec::new(),
tool_call_id: None,
}
}
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct LanguageModelRequest {
pub model: LanguageModel,
pub messages: Vec<LanguageModelRequestMessage>,
pub stop: Vec<String>,
pub temperature: f32,
}
impl LanguageModelRequest {
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
proto::CompleteWithLanguageModel {
model: self.model.id().to_string(),
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
stop: self.stop.clone(),
temperature: self.temperature,
tool_choice: None,
tools: Vec::new(),
}
}
/// Before we send the request to the server, we can perform fixups on it appropriate to the model.
pub fn preprocess(&mut self) {
match &self.model {
LanguageModel::OpenAi(_) => {}
LanguageModel::Anthropic(_) => {}
LanguageModel::Ollama(_) => {}
LanguageModel::Cloud(model) => match model {
CloudModel::Claude3Opus
| CloudModel::Claude3Sonnet
| CloudModel::Claude3Haiku
| CloudModel::Claude3_5Sonnet => {
self.preprocess_anthropic();
}
_ => {}
},
}
}
pub fn preprocess_anthropic(&mut self) {
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
let mut system_message = String::new();
for message in self.messages.drain(..) {
if message.content.is_empty() {
continue;
}
match message.role {
Role::User | Role::Assistant => {
if let Some(last_message) = new_messages.last_mut() {
if last_message.role == message.role {
last_message.content.push_str("\n\n");
last_message.content.push_str(&message.content);
continue;
}
}
new_messages.push(message);
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
}
}
}
if !system_message.is_empty() {
new_messages.insert(
0,
LanguageModelRequestMessage {
role: Role::System,
content: system_message,
},
);
}
self.messages = new_messages;
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}

View file

@ -0,0 +1,68 @@
use serde::{Deserialize, Serialize};
use std::fmt::{self, Display};
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn from_proto(role: i32) -> Role {
match proto::LanguageModelRole::from_i32(role) {
Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
None => Role::User,
}
}
pub fn to_proto(&self) -> proto::LanguageModelRole {
match self {
Role::User => proto::LanguageModelRole::LanguageModelUser,
Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
Role::System => proto::LanguageModelRole::LanguageModelSystem,
}
}
pub fn cycle(self) -> Role {
match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "user"),
Role::Assistant => write!(f, "assistant"),
Role::System => write!(f, "system"),
}
}
}
impl From<Role> for ollama::Role {
fn from(val: Role) -> Self {
match val {
Role::User => ollama::Role::User,
Role::Assistant => ollama::Role::Assistant,
Role::System => ollama::Role::System,
}
}
}
impl From<Role> for open_ai::Role {
fn from(val: Role) -> Self {
match val {
Role::User => open_ai::Role::User,
Role::Assistant => open_ai::Role::Assistant,
Role::System => open_ai::Role::System,
}
}
}

View file

@ -22,6 +22,7 @@ anyhow.workspace = true
client.workspace = true
clock.workspace = true
collections.workspace = true
completion.workspace = true
fs.workspace = true
futures.workspace = true
futures-batch.workspace = true

View file

@ -1261,3 +1261,6 @@ mod tests {
);
}
}
// See https://github.com/zed-industries/zed/pull/14823#discussion_r1684616398 for why this is here and when it should be removed.
type _TODO = completion::CompletionProvider;