Sanitize messages before sending them to Anthropic (#11810)

Release Notes:

- N/A

Co-authored-by: Nathan <nathan@zed.dev>
Co-authored-by: David <davidsp@anthropic.com>
This commit is contained in:
Antonio Scandurra 2024-05-14 17:47:33 +02:00 committed by GitHub
parent 69f9489aa9
commit de09409f01
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 36 additions and 26 deletions

View file

@ -11,11 +11,11 @@ pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub enum Model {
#[default]
#[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")]
#[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]
Claude3Opus,
#[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")]
#[serde(alias = "claude-3-sonnet", rename = "claude-3-sonnet-20240229")]
Claude3Sonnet,
#[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")]
#[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
Claude3Haiku,
}

View file

@ -169,32 +169,42 @@ impl AnthropicCompletionProvider {
};
let mut system_message = String::new();
let messages = request
.messages
.into_iter()
.filter_map(|message| {
match message.role {
Role::User => Some(RequestMessage {
role: AnthropicRole::User,
content: message.content,
}),
Role::Assistant => Some(RequestMessage {
role: AnthropicRole::Assistant,
content: message.content,
}),
// Anthropic's API breaks system instructions out as a separate field rather
// than having a system message role.
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
None
let mut messages: Vec<RequestMessage> = Vec::new();
for message in request.messages {
if message.content.is_empty() {
continue;
}
match message.role {
Role::User | Role::Assistant => {
let role = match message.role {
Role::User => AnthropicRole::User,
Role::Assistant => AnthropicRole::Assistant,
_ => unreachable!(),
};
if let Some(last_message) = messages.last_mut() {
if last_message.role == role {
last_message.content.push_str("\n\n");
last_message.content.push_str(&message.content);
continue;
}
}
messages.push(RequestMessage {
role,
content: message.content,
});
}
})
.collect();
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
}
}
}
Request {
model,