mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-28 09:54:33 +00:00
Add support for interacting with Claude in the assistant panel (#11798)
Release Notes: - Added support for interacting with Claude in the assistant panel. You can enable it by adding the following to your `settings.json`: ```json "assistant": { "version": "1", "provider": { "name": "anthropic" } } ```
This commit is contained in:
parent
019d98898e
commit
5944caaa90
12 changed files with 446 additions and 21 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -225,6 +225,8 @@ dependencies = [
|
|||
"anyhow",
|
||||
"futures 0.3.28",
|
||||
"http 0.1.0",
|
||||
"isahc",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
|
@ -332,6 +334,7 @@ dependencies = [
|
|||
name = "assistant"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anthropic",
|
||||
"anyhow",
|
||||
"chrono",
|
||||
"client",
|
||||
|
|
|
@ -5,6 +5,10 @@ edition = "2021"
|
|||
publish = false
|
||||
license = "AGPL-3.0-or-later"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
schemars = ["dep:schemars"]
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
|
@ -15,6 +19,8 @@ path = "src/anthropic.rs"
|
|||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
http.workspace = true
|
||||
isahc.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
|
|
|
@ -1,17 +1,21 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
|
||||
use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use isahc::config::Configurable;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{convert::TryFrom, sync::Arc};
|
||||
use std::{convert::TryFrom, time::Duration};
|
||||
|
||||
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub enum Model {
|
||||
#[default]
|
||||
#[serde(rename = "claude-3-opus-20240229")]
|
||||
#[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")]
|
||||
Claude3Opus,
|
||||
#[serde(rename = "claude-3-sonnet-20240229")]
|
||||
#[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")]
|
||||
Claude3Sonnet,
|
||||
#[serde(rename = "claude-3-haiku-20240307")]
|
||||
#[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")]
|
||||
Claude3Haiku,
|
||||
}
|
||||
|
||||
|
@ -28,6 +32,14 @@ impl Model {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &'static str {
|
||||
match self {
|
||||
Model::Claude3Opus => "claude-3-opus-20240229",
|
||||
Model::Claude3Sonnet => "claude-3-sonnet-20240229",
|
||||
Model::Claude3Haiku => "claude-3-opus-20240307",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Claude3Opus => "Claude 3 Opus",
|
||||
|
@ -141,20 +153,24 @@ pub enum TextDelta {
|
|||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
client: Arc<dyn HttpClient>,
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
|
||||
let uri = format!("{api_url}/v1/messages");
|
||||
let request = HttpRequest::builder()
|
||||
let mut request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Anthropic-Version", "2023-06-01")
|
||||
.header("Anthropic-Beta", "messages-2023-12-15")
|
||||
.header("Anthropic-Beta", "tools-2024-04-04")
|
||||
.header("X-Api-Key", api_key)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||
.header("Content-Type", "application/json");
|
||||
if let Some(low_speed_timeout) = low_speed_timeout {
|
||||
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
|
||||
}
|
||||
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||
let mut response = client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
let reader = BufReader::new(response.into_body());
|
||||
|
|
|
@ -11,6 +11,7 @@ doctest = false
|
|||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
anthropic = { workspace = true, features = ["schemars"] }
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
|
|
|
@ -7,7 +7,7 @@ mod saved_conversation;
|
|||
mod streaming_diff;
|
||||
|
||||
pub use assistant_panel::AssistantPanel;
|
||||
use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
|
||||
use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
|
||||
use client::{proto, Client};
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
pub(crate) use completion_provider::*;
|
||||
|
@ -72,6 +72,7 @@ impl Display for Role {
|
|||
pub enum LanguageModel {
|
||||
ZedDotDev(ZedDotDevModel),
|
||||
OpenAi(OpenAiModel),
|
||||
Anthropic(AnthropicModel),
|
||||
}
|
||||
|
||||
impl Default for LanguageModel {
|
||||
|
@ -84,6 +85,7 @@ impl LanguageModel {
|
|||
pub fn telemetry_id(&self) -> String {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
|
||||
LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
|
||||
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
|
||||
}
|
||||
}
|
||||
|
@ -91,6 +93,7 @@ impl LanguageModel {
|
|||
pub fn display_name(&self) -> String {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => model.display_name().into(),
|
||||
LanguageModel::Anthropic(model) => model.display_name().into(),
|
||||
LanguageModel::ZedDotDev(model) => model.display_name().into(),
|
||||
}
|
||||
}
|
||||
|
@ -98,6 +101,7 @@ impl LanguageModel {
|
|||
pub fn max_token_count(&self) -> usize {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => model.max_token_count(),
|
||||
LanguageModel::Anthropic(model) => model.max_token_count(),
|
||||
LanguageModel::ZedDotDev(model) => model.max_token_count(),
|
||||
}
|
||||
}
|
||||
|
@ -105,6 +109,7 @@ impl LanguageModel {
|
|||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => model.id(),
|
||||
LanguageModel::Anthropic(model) => model.id(),
|
||||
LanguageModel::ZedDotDev(model) => model.id(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -800,6 +800,11 @@ impl AssistantPanel {
|
|||
open_ai::Model::FourTurbo => open_ai::Model::FourOmni,
|
||||
open_ai::Model::FourOmni => open_ai::Model::ThreePointFiveTurbo,
|
||||
}),
|
||||
LanguageModel::Anthropic(model) => LanguageModel::Anthropic(match &model {
|
||||
anthropic::Model::Claude3Opus => anthropic::Model::Claude3Sonnet,
|
||||
anthropic::Model::Claude3Sonnet => anthropic::Model::Claude3Haiku,
|
||||
anthropic::Model::Claude3Haiku => anthropic::Model::Claude3Opus,
|
||||
}),
|
||||
LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
|
||||
ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4,
|
||||
ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use std::fmt;
|
||||
|
||||
pub use anthropic::Model as AnthropicModel;
|
||||
use gpui::Pixels;
|
||||
pub use open_ai::Model as OpenAiModel;
|
||||
use schemars::{
|
||||
|
@ -161,6 +162,15 @@ pub enum AssistantProvider {
|
|||
#[serde(default)]
|
||||
low_speed_timeout_in_seconds: Option<u64>,
|
||||
},
|
||||
#[serde(rename = "anthropic")]
|
||||
Anthropic {
|
||||
#[serde(default)]
|
||||
default_model: AnthropicModel,
|
||||
#[serde(default = "anthropic_api_url")]
|
||||
api_url: String,
|
||||
#[serde(default)]
|
||||
low_speed_timeout_in_seconds: Option<u64>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Default for AssistantProvider {
|
||||
|
@ -172,7 +182,11 @@ impl Default for AssistantProvider {
|
|||
}
|
||||
|
||||
fn open_ai_url() -> String {
|
||||
"https://api.openai.com/v1".into()
|
||||
open_ai::OPEN_AI_API_URL.to_string()
|
||||
}
|
||||
|
||||
fn anthropic_api_url() -> String {
|
||||
anthropic::ANTHROPIC_API_URL.to_string()
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Deserialize, Serialize)]
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
mod anthropic;
|
||||
#[cfg(test)]
|
||||
mod fake;
|
||||
mod open_ai;
|
||||
mod zed;
|
||||
|
||||
pub use anthropic::*;
|
||||
#[cfg(test)]
|
||||
pub use fake::*;
|
||||
pub use open_ai::*;
|
||||
|
@ -42,6 +44,17 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
|||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
)),
|
||||
AssistantProvider::Anthropic {
|
||||
default_model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
} => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
|
||||
default_model.clone(),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
)),
|
||||
};
|
||||
cx.set_global(provider);
|
||||
|
||||
|
@ -64,13 +77,28 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
|||
settings_version,
|
||||
);
|
||||
}
|
||||
(
|
||||
CompletionProvider::Anthropic(provider),
|
||||
AssistantProvider::Anthropic {
|
||||
default_model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
},
|
||||
) => {
|
||||
provider.update(
|
||||
default_model.clone(),
|
||||
api_url.clone(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
);
|
||||
}
|
||||
(
|
||||
CompletionProvider::ZedDotDev(provider),
|
||||
AssistantProvider::ZedDotDev { default_model },
|
||||
) => {
|
||||
provider.update(default_model.clone(), settings_version);
|
||||
}
|
||||
(CompletionProvider::OpenAi(_), AssistantProvider::ZedDotDev { default_model }) => {
|
||||
(_, AssistantProvider::ZedDotDev { default_model }) => {
|
||||
*provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
|
||||
default_model.clone(),
|
||||
client.clone(),
|
||||
|
@ -79,7 +107,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
|||
));
|
||||
}
|
||||
(
|
||||
CompletionProvider::ZedDotDev(_),
|
||||
_,
|
||||
AssistantProvider::OpenAi {
|
||||
default_model,
|
||||
api_url,
|
||||
|
@ -94,8 +122,22 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
|||
settings_version,
|
||||
));
|
||||
}
|
||||
#[cfg(test)]
|
||||
(CompletionProvider::Fake(_), _) => unimplemented!(),
|
||||
(
|
||||
_,
|
||||
AssistantProvider::Anthropic {
|
||||
default_model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
},
|
||||
) => {
|
||||
*provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
|
||||
default_model.clone(),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
));
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
@ -104,6 +146,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
|||
|
||||
pub enum CompletionProvider {
|
||||
OpenAi(OpenAiCompletionProvider),
|
||||
Anthropic(AnthropicCompletionProvider),
|
||||
ZedDotDev(ZedDotDevCompletionProvider),
|
||||
#[cfg(test)]
|
||||
Fake(FakeCompletionProvider),
|
||||
|
@ -119,6 +162,7 @@ impl CompletionProvider {
|
|||
pub fn settings_version(&self) -> usize {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.settings_version(),
|
||||
CompletionProvider::Anthropic(provider) => provider.settings_version(),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => unimplemented!(),
|
||||
|
@ -128,6 +172,7 @@ impl CompletionProvider {
|
|||
pub fn is_authenticated(&self) -> bool {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
|
||||
CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => true,
|
||||
|
@ -137,6 +182,7 @@ impl CompletionProvider {
|
|||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
|
||||
CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => Task::ready(Ok(())),
|
||||
|
@ -146,6 +192,7 @@ impl CompletionProvider {
|
|||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
|
||||
CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => unimplemented!(),
|
||||
|
@ -155,6 +202,7 @@ impl CompletionProvider {
|
|||
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
|
||||
CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
|
||||
CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => Task::ready(Ok(())),
|
||||
|
@ -164,6 +212,9 @@ impl CompletionProvider {
|
|||
pub fn default_model(&self) -> LanguageModel {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
|
||||
CompletionProvider::Anthropic(provider) => {
|
||||
LanguageModel::Anthropic(provider.default_model())
|
||||
}
|
||||
CompletionProvider::ZedDotDev(provider) => {
|
||||
LanguageModel::ZedDotDev(provider.default_model())
|
||||
}
|
||||
|
@ -179,6 +230,7 @@ impl CompletionProvider {
|
|||
) -> BoxFuture<'static, Result<usize>> {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
|
||||
CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => unimplemented!(),
|
||||
|
@ -191,6 +243,7 @@ impl CompletionProvider {
|
|||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.complete(request),
|
||||
CompletionProvider::Anthropic(provider) => provider.complete(request),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.complete(request),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(provider) => provider.complete(),
|
||||
|
|
317
crates/assistant/src/completion_provider/anthropic.rs
Normal file
317
crates/assistant/src/completion_provider/anthropic.rs
Normal file
|
@ -0,0 +1,317 @@
|
|||
use crate::count_open_ai_tokens;
|
||||
use crate::{
|
||||
assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
|
||||
Role,
|
||||
};
|
||||
use anthropic::{stream_completion, Request, RequestMessage, Role as AnthropicRole};
|
||||
use anyhow::{anyhow, Result};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
|
||||
use http::HttpClient;
|
||||
use settings::Settings;
|
||||
use std::time::Duration;
|
||||
use std::{env, sync::Arc};
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
|
||||
pub struct AnthropicCompletionProvider {
|
||||
api_key: Option<String>,
|
||||
api_url: String,
|
||||
default_model: AnthropicModel,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
}
|
||||
|
||||
impl AnthropicCompletionProvider {
|
||||
pub fn new(
|
||||
default_model: AnthropicModel,
|
||||
api_url: String,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
api_key: None,
|
||||
api_url,
|
||||
default_model,
|
||||
http_client,
|
||||
low_speed_timeout,
|
||||
settings_version,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(
|
||||
&mut self,
|
||||
default_model: AnthropicModel,
|
||||
api_url: String,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
) {
|
||||
self.default_model = default_model;
|
||||
self.api_url = api_url;
|
||||
self.low_speed_timeout = low_speed_timeout;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
|
||||
pub fn settings_version(&self) -> usize {
|
||||
self.settings_version
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated() {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
let api_url = self.api_url.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
let api_key = if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") {
|
||||
api_key
|
||||
} else {
|
||||
let (_, api_key) = cx
|
||||
.update(|cx| cx.read_credentials(&api_url))?
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("credentials not found"))?;
|
||||
String::from_utf8(api_key)?
|
||||
};
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
if let CompletionProvider::Anthropic(provider) = provider {
|
||||
provider.api_key = Some(api_key);
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let delete_credentials = cx.delete_credentials(&self.api_url);
|
||||
cx.spawn(|mut cx| async move {
|
||||
delete_credentials.await.log_err();
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
if let CompletionProvider::Anthropic(provider) = provider {
|
||||
provider.api_key = None;
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
pub fn default_model(&self) -> AnthropicModel {
|
||||
self.default_model.clone()
|
||||
}
|
||||
|
||||
pub fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
|
||||
pub fn complete(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = self.to_anthropic_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
let low_speed_timeout = self.low_speed_timeout;
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let request = stream_completion(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
low_speed_timeout,
|
||||
);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(response) => match response {
|
||||
anthropic::ResponseEvent::ContentBlockStart {
|
||||
content_block, ..
|
||||
} => match content_block {
|
||||
anthropic::ContentBlock::Text { text } => Some(Ok(text)),
|
||||
},
|
||||
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
|
||||
match delta {
|
||||
anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn to_anthropic_request(&self, request: LanguageModelRequest) -> Request {
|
||||
let model = match request.model {
|
||||
LanguageModel::Anthropic(model) => model,
|
||||
_ => self.default_model(),
|
||||
};
|
||||
|
||||
let mut system_message = String::new();
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.filter_map(|message| {
|
||||
match message.role {
|
||||
Role::User => Some(RequestMessage {
|
||||
role: AnthropicRole::User,
|
||||
content: message.content,
|
||||
}),
|
||||
Role::Assistant => Some(RequestMessage {
|
||||
role: AnthropicRole::Assistant,
|
||||
content: message.content,
|
||||
}),
|
||||
// Anthropic's API breaks system instructions out as a separate field rather
|
||||
// than having a system message role.
|
||||
Role::System => {
|
||||
if !system_message.is_empty() {
|
||||
system_message.push_str("\n\n");
|
||||
}
|
||||
system_message.push_str(&message.content);
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Request {
|
||||
model,
|
||||
messages,
|
||||
stream: true,
|
||||
system: system_message,
|
||||
max_tokens: 4092,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt {
|
||||
api_key: View<Editor>,
|
||||
api_url: String,
|
||||
}
|
||||
|
||||
impl AuthenticationPrompt {
|
||||
fn new(api_url: String, cx: &mut WindowContext) -> Self {
|
||||
Self {
|
||||
api_key: cx.new_view(|cx| {
|
||||
let mut editor = Editor::single_line(cx);
|
||||
editor.set_placeholder_text(
|
||||
"sk-000000000000000000000000000000000000000000000000",
|
||||
cx,
|
||||
);
|
||||
editor
|
||||
}),
|
||||
api_url,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||
let api_key = self.api_key.read(cx).text(cx);
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
write_credentials.await?;
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
if let CompletionProvider::Anthropic(provider) = provider {
|
||||
provider.api_key = Some(api_key);
|
||||
}
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: FontWeight::NORMAL,
|
||||
font_style: FontStyle::Normal,
|
||||
line_height: relative(1.3),
|
||||
background_color: None,
|
||||
underline: None,
|
||||
strikethrough: None,
|
||||
white_space: WhiteSpace::Normal,
|
||||
};
|
||||
EditorElement::new(
|
||||
&self.api_key,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const INSTRUCTIONS: [&str; 4] = [
|
||||
"To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
|
||||
"You can create an API key at: https://console.anthropic.com/settings/keys",
|
||||
"",
|
||||
"Paste your Anthropic API key below and hit enter to use the assistant:",
|
||||
];
|
||||
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.children(
|
||||
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.my_2()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_md()
|
||||
.child(self.render_api_key_editor(cx)),
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
"You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
|
||||
)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Label::new("Click on").size(LabelSize::Small))
|
||||
.child(Icon::new(IconName::Ai).size(IconSize::XSmall))
|
||||
.child(
|
||||
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
}
|
|
@ -151,8 +151,8 @@ impl OpenAiCompletionProvider {
|
|||
|
||||
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
|
||||
let model = match request.model {
|
||||
LanguageModel::ZedDotDev(_) => self.default_model(),
|
||||
LanguageModel::OpenAi(model) => model,
|
||||
_ => self.default_model(),
|
||||
};
|
||||
|
||||
Request {
|
||||
|
@ -205,8 +205,12 @@ pub fn count_open_ai_tokens(
|
|||
|
||||
match request.model {
|
||||
LanguageModel::OpenAi(OpenAiModel::FourOmni)
|
||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni) => {
|
||||
// Tiktoken doesn't yet support gpt-4o, so we manually use the
|
||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni)
|
||||
| LanguageModel::Anthropic(_)
|
||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Opus)
|
||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Sonnet)
|
||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Haiku) => {
|
||||
// Tiktoken doesn't yet support these models, so we manually use the
|
||||
// same tokenizer as GPT-4.
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
|
||||
}
|
||||
|
|
|
@ -78,7 +78,6 @@ impl ZedDotDevCompletionProvider {
|
|||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
match request.model {
|
||||
LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
|
||||
LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4)
|
||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo)
|
||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni)
|
||||
|
@ -108,6 +107,7 @@ impl ZedDotDevCompletionProvider {
|
|||
}
|
||||
.boxed()
|
||||
}
|
||||
_ => future::ready(Err(anyhow!("invalid model"))).boxed(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -4489,8 +4489,8 @@ async fn complete_with_anthropic(
|
|||
.collect();
|
||||
|
||||
let mut stream = anthropic::stream_completion(
|
||||
session.http_client.clone(),
|
||||
"https://api.anthropic.com",
|
||||
session.http_client.as_ref(),
|
||||
anthropic::ANTHROPIC_API_URL,
|
||||
&api_key,
|
||||
anthropic::Request {
|
||||
model,
|
||||
|
@ -4499,6 +4499,7 @@ async fn complete_with_anthropic(
|
|||
system: system_message,
|
||||
max_tokens: 4092,
|
||||
},
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
|
Loading…
Reference in a new issue