AI models
Some checks are pending
trivy / Check (push) Waiting to run

This commit is contained in:
mdecimus 2024-10-05 19:05:04 +02:00
parent d5c2dcb817
commit d0ce2b1a96
41 changed files with 1171 additions and 406 deletions

372
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -4,14 +4,19 @@
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-SEL
*/
use std::sync::Arc;
use std::{str::FromStr, sync::Arc};
use arc_swap::ArcSwap;
use base64::{engine::general_purpose, Engine};
use directory::{Directories, Directory};
use hyper::{
header::{HeaderName, HeaderValue, AUTHORIZATION},
HeaderMap,
};
use ring::signature::{EcdsaKeyPair, RsaKeyPair};
use store::{BlobBackend, BlobStore, FtsStore, LookupStore, Store, Stores};
use telemetry::Metrics;
use utils::config::Config;
use utils::config::{utils::AsKey, Config};
use crate::{
auth::oauth::config::OAuthConfig, expr::*, listener::tls::AcmeProviders,
@ -229,3 +234,52 @@ pub fn build_ecdsa_pem(
Ok(None) => Err("No ECDSA key found in PEM".to_string()),
}
}
pub(crate) fn parse_http_headers(config: &mut Config, prefix: impl AsKey) -> HeaderMap {
let prefix = prefix.as_key();
let mut headers = HeaderMap::new();
for (header, value) in config
.values((&prefix, "headers"))
.map(|(_, v)| {
if let Some((k, v)) = v.split_once(':') {
Ok((
HeaderName::from_str(k.trim()).map_err(|err| {
format!("Invalid header found in property \"{prefix}.headers\": {err}",)
})?,
HeaderValue::from_str(v.trim()).map_err(|err| {
format!("Invalid header found in property \"{prefix}.headers\": {err}",)
})?,
))
} else {
Err(format!(
"Invalid header found in property \"{prefix}.headers\": {v}",
))
}
})
.collect::<Result<Vec<(HeaderName, HeaderValue)>, String>>()
.map_err(|e| config.new_parse_error((&prefix, "headers"), e))
.unwrap_or_default()
{
headers.insert(header, value);
}
if let (Some(name), Some(secret)) = (
config.value((&prefix, "auth.username")),
config.value((&prefix, "auth.secret")),
) {
headers.insert(
AUTHORIZATION,
format!(
"Basic {}",
general_purpose::STANDARD.encode(format!("{}:{}", name, secret))
)
.parse()
.unwrap(),
);
} else if let Some(token) = config.value((&prefix, "auth.token")) {
headers.insert(AUTHORIZATION, format!("Bearer {}", token).parse().unwrap());
}
headers
}

View file

@ -15,7 +15,10 @@ use sieve::{compiler::grammar::Capability, Compiler, Runtime, Sieve};
use store::Stores;
use utils::config::Config;
use crate::scripts::{functions::register_functions, plugins::RegisterSievePlugins};
use crate::scripts::{
functions::{register_functions_trusted, register_functions_untrusted},
plugins::RegisterSievePlugins,
};
use super::{if_block::IfBlock, smtp::SMTP_RCPT_TO_VARS, tokenizer::TokenMap};
@ -40,6 +43,7 @@ pub struct RemoteList {
impl Scripting {
pub async fn parse(config: &mut Config, stores: &Stores) -> Self {
// Parse untrusted compiler
let mut fnc_map_untrusted = register_functions_untrusted().register_plugins_untrusted();
let untrusted_compiler = Compiler::new()
.with_max_script_size(
config
@ -90,10 +94,12 @@ impl Scripting {
config
.property("sieve.untrusted.limits.includes")
.unwrap_or(3),
);
)
.register_functions(&mut fnc_map_untrusted);
// Parse untrusted runtime
let untrusted_runtime = Runtime::new()
.with_functions(&mut fnc_map_untrusted)
.with_max_nested_includes(
config
.property("sieve.untrusted.limits.nested-includes")
@ -141,6 +147,7 @@ impl Scripting {
.unwrap_or(Duration::from_secs(7 * 86400))
.as_secs(),
)
.with_capability(Capability::Expressions)
.without_capabilities(
config
.values("sieve.untrusted.disable-capabilities")
@ -191,7 +198,7 @@ impl Scripting {
.with_env_variable("phase", "during");
// Parse trusted compiler and runtime
let mut fnc_map = register_functions().register_plugins();
let mut fnc_map_trusted = register_functions_trusted().register_plugins_trusted();
// Allocate compiler and runtime
let trusted_compiler = Compiler::new()
@ -208,7 +215,7 @@ impl Scripting {
.property_or_default("sieve.trusted.no-capability-check", "true")
.unwrap_or(true),
)
.register_functions(&mut fnc_map);
.register_functions(&mut fnc_map_trusted);
let mut trusted_runtime = Runtime::new()
.without_capabilities([
@ -233,7 +240,7 @@ impl Scripting {
.with_max_header_size(10240)
.with_valid_notification_uri("mailto")
.with_valid_ext_lists(stores.lookup_stores.keys().map(|k| k.to_string()))
.with_functions(&mut fnc_map)
.with_functions(&mut fnc_map_trusted)
.with_max_redirects(
config
.property_or_default("sieve.trusted.limits.redirects", "3")

View file

@ -8,10 +8,7 @@ use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration};
use ahash::{AHashMap, AHashSet};
use base64::{engine::general_purpose::STANDARD, Engine};
use hyper::{
header::{HeaderName, HeaderValue, AUTHORIZATION, CONTENT_TYPE},
HeaderMap,
};
use hyper::{header::CONTENT_TYPE, HeaderMap};
use opentelemetry::{InstrumentationLibrary, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use opentelemetry_sdk::{
@ -27,6 +24,8 @@ use store::Stores;
use trc::{ipc::subscriber::Interests, EventType, Level, TelemetryEvent};
use utils::config::{utils::ParseValue, Config};
use super::parse_http_headers;
#[derive(Debug)]
pub struct TelemetrySubscriber {
pub id: String,
@ -753,45 +752,8 @@ fn parse_webhook(
id: &str,
global_interests: &mut Interests,
) -> Option<TelemetrySubscriber> {
let mut headers = HeaderMap::new();
for (header, value) in config
.values(("webhook", id, "headers"))
.map(|(_, v)| {
if let Some((k, v)) = v.split_once(':') {
Ok((
HeaderName::from_str(k.trim()).map_err(|err| {
format!("Invalid header found in property \"webhook.{id}.headers\": {err}",)
})?,
HeaderValue::from_str(v.trim()).map_err(|err| {
format!("Invalid header found in property \"webhook.{id}.headers\": {err}",)
})?,
))
} else {
Err(format!(
"Invalid header found in property \"webhook.{id}.headers\": {v}",
))
}
})
.collect::<Result<Vec<(HeaderName, HeaderValue)>, String>>()
.map_err(|e| config.new_parse_error(("webhook", id, "headers"), e))
.unwrap_or_default()
{
headers.insert(header, value);
}
let mut headers = parse_http_headers(config, ("webhook", id));
headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
if let (Some(name), Some(secret)) = (
config.value(("webhook", id, "auth.username")),
config.value(("webhook", id, "auth.secret")),
) {
headers.insert(
AUTHORIZATION,
format!("Basic {}", STANDARD.encode(format!("{}:{}", name, secret)))
.parse()
.unwrap(),
);
}
// Build tracer
let mut tracer = TelemetrySubscriber {

View file

@ -10,6 +10,7 @@
use std::time::Duration;
use ahash::AHashMap;
use directory::{backend::internal::manage::ManageDirectory, Type};
use store::{Store, Stores};
use trc::{EventType, MetricType, TOTAL_EVENT_COUNT};
@ -22,8 +23,8 @@ use utils::config::{
use crate::expr::{tokenizer::TokenMap, Expression};
use super::{
license::LicenseValidator, AlertContent, AlertContentToken, AlertMethod, Enterprise,
MetricAlert, MetricStore, TraceStore, Undelete,
license::LicenseValidator, llm::AiApiConfig, AlertContent, AlertContentToken, AlertMethod,
Enterprise, MetricAlert, MetricStore, TraceStore, Undelete,
};
impl Enterprise {
@ -111,6 +112,18 @@ impl Enterprise {
None
};
// Parse AI APIs
let mut ai_apis = AHashMap::new();
for id in config
.sub_keys("enterprise.ai", ".url")
.map(|s| s.to_string())
.collect::<Vec<_>>()
{
if let Some(api) = AiApiConfig::parse(config, &id) {
ai_apis.insert(id, api);
}
}
Some(Enterprise {
license,
undelete: config
@ -121,6 +134,7 @@ impl Enterprise {
trace_store,
metrics_store,
metrics_alerts: parse_metric_alerts(config),
ai_apis,
})
}
}

View file

@ -0,0 +1,234 @@
/*
* SPDX-FileCopyrightText: 2020 Stalwart Labs Ltd <hello@stalw.art>
*
* SPDX-License-Identifier: LicenseRef-SEL
*
* This file is subject to the Stalwart Enterprise License Agreement (SEL) and
* is NOT open source software.
*
*/
use std::time::Duration;
use hyper::{header::CONTENT_TYPE, HeaderMap};
use serde::{Deserialize, Serialize};
use utils::config::Config;
use crate::config::parse_http_headers;
#[derive(Clone)]
pub struct AiApiConfig {
pub id: String,
pub api_type: ApiType,
pub url: String,
pub model: String,
pub timeout: Duration,
pub headers: HeaderMap,
pub tls_allow_invalid_certs: bool,
pub default_temperature: f64,
}
#[derive(Clone, Copy)]
pub enum ApiType {
ChatCompletion,
TextCompletion,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<Message>,
pub temperature: f64,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Message {
pub role: String,
pub content: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ChatCompletionResponse {
pub created: i64,
pub object: String,
pub id: String,
pub model: String,
pub choices: Vec<ChatCompletionChoice>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ChatCompletionChoice {
pub index: i32,
pub finish_reason: String,
pub message: Message,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct TextCompletionRequest {
pub model: String,
pub prompt: String,
pub temperature: f64,
}
#[derive(Deserialize, Debug)]
pub struct TextCompletionResponse {
pub created: i64,
pub object: String,
pub id: String,
pub model: String,
pub choices: Vec<TextCompletionChoice>,
}
#[derive(Deserialize, Debug)]
pub struct TextCompletionChoice {
pub index: i32,
pub finish_reason: String,
pub text: String,
}
impl AiApiConfig {
pub async fn send_request(
&self,
prompt: impl Into<String>,
temperature: Option<f64>,
) -> trc::Result<String> {
self.post_api(prompt, temperature).await.map_err(|err| {
trc::Error::new(trc::EventType::Ai(trc::AiEvent::ApiError))
.id(self.id.clone())
.details("OpenAPI request failed")
.reason(err)
})
}
async fn post_api(
&self,
prompt: impl Into<String>,
temperature: Option<f64>,
) -> Result<String, String> {
// Serialize body
let body = match self.api_type {
ApiType::ChatCompletion => serde_json::to_string(&ChatCompletionRequest {
model: self.model.to_string(),
messages: vec![Message {
role: "user".to_string(),
content: prompt.into(),
}],
temperature: temperature.unwrap_or(self.default_temperature),
})
.map_err(|err| format!("Failed to serialize request: {}", err))?,
ApiType::TextCompletion => serde_json::to_string(&TextCompletionRequest {
model: self.model.to_string(),
prompt: prompt.into(),
temperature: temperature.unwrap_or(self.default_temperature),
})
.map_err(|err| format!("Failed to serialize request: {}", err))?,
};
// Send request
let response = reqwest::Client::builder()
.timeout(self.timeout)
.danger_accept_invalid_certs(self.tls_allow_invalid_certs)
.build()
.map_err(|err| format!("Failed to create HTTP client: {}", err))?
.post(&self.url)
.headers(self.headers.clone())
.body(body)
.send()
.await
.map_err(|err| format!("API request to {} failed: {err}", self.url))?;
if response.status().is_success() {
let bytes = response.bytes().await.map_err(|err| {
format!("Failed to read response body from {}: {}", self.url, err)
})?;
match self.api_type {
ApiType::ChatCompletion => {
let response = serde_json::from_slice::<ChatCompletionResponse>(&bytes)
.map_err(|err| {
format!(
"Failed to chat completion parse response from {}: {}",
self.url, err
)
})?;
response
.choices
.into_iter()
.next()
.map(|choice| choice.message.content)
.filter(|text| !text.is_empty())
.ok_or_else(|| {
format!(
"Chat completion response from {} did not contain any choices: {}",
self.url,
std::str::from_utf8(&bytes).unwrap_or_default()
)
})
}
ApiType::TextCompletion => {
let response = serde_json::from_slice::<TextCompletionResponse>(&bytes)
.map_err(|err| {
format!(
"Failed to parse text completion response from {}: {}",
self.url, err
)
})?;
response
.choices
.into_iter()
.next()
.map(|choice| choice.text)
.filter(|text| !text.is_empty())
.ok_or_else(|| {
format!(
"Text completion response from {} did not contain any choices: {}",
self.url,
std::str::from_utf8(&bytes).unwrap_or_default()
)
})
}
}
} else {
Err(format!(
"OpenAPI request to {} failed with code {}: {}",
self.url,
response.status().as_u16(),
response.status().canonical_reason().unwrap_or("Unknown")
))
}
}
pub fn parse(config: &mut Config, id: &str) -> Option<Self> {
let url = config.value(("enterprise.ai", id, "endpoint"))?.to_string();
let api_type = match config.value(("enterprise.ai", id, "type"))? {
"chat" => ApiType::ChatCompletion,
"text" => ApiType::TextCompletion,
_ => {
config.new_build_error(("enterprise.ai", id, "type"), "Invalid API type");
return None;
}
};
let mut headers = parse_http_headers(config, ("enterprise.ai", id));
headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
Some(AiApiConfig {
id: id.to_string(),
api_type,
url,
headers,
model: config
.value_require(("enterprise.ai", id, "model"))?
.to_string(),
timeout: config
.property_or_default(("enterprise.ai", id, "timeout"), "2m")
.unwrap_or_else(|| Duration::from_secs(120)),
tls_allow_invalid_certs: config
.property_or_default(("enterprise.ai", id, "allow-invalid-certs"), "false")
.unwrap_or_default(),
default_temperature: config
.property_or_default(("enterprise.ai", id, "default-temperature"), "0.7")
.unwrap_or(0.7),
})
}
}

View file

@ -11,15 +11,18 @@
pub mod alerts;
pub mod config;
pub mod license;
pub mod llm;
pub mod undelete;
use std::time::Duration;
use ahash::AHashMap;
use directory::{
backend::internal::{lookup::DirectoryStore, PrincipalField},
QueryBy, Type,
};
use license::LicenseKey;
use llm::AiApiConfig;
use mail_parser::DateTime;
use store::Store;
use trc::{AddContext, EventType, MetricType};
@ -35,6 +38,7 @@ pub struct Enterprise {
pub trace_store: Option<TraceStore>,
pub metrics_store: Option<MetricStore>,
pub metrics_alerts: Vec<MetricAlert>,
pub ai_apis: AHashMap<String, AiApiConfig>,
}
#[derive(Clone)]

View file

@ -76,6 +76,7 @@ pub(crate) const FUNCTIONS: &[(&str, fn(Vec<Variable>) -> Variable, u32)] = &[
("rsplit", text::fn_rsplit, 2),
("split_once", text::fn_split_once, 2),
("rsplit_once", text::fn_rsplit_once, 2),
("split_n", text::fn_split_n, 3),
("split_words", text::fn_split_words, 1),
];

View file

@ -239,6 +239,36 @@ pub(crate) fn fn_rsplit(v: Vec<Variable>) -> Variable {
}
}
pub(crate) fn fn_split_n(v: Vec<Variable>) -> Variable {
let mut v = v.into_iter();
let value = v.next().unwrap().into_string();
let arg = v.next().unwrap().into_string();
let num = v.next().unwrap().to_integer().unwrap_or_default() as usize;
fn split_n<'x, 'y>(s: &'x str, arg: &'y str, num: usize, mut f: impl FnMut(&'x str)) {
let mut s = s;
for _ in 0..num {
if let Some((a, b)) = s.split_once(arg) {
f(a);
s = b;
} else {
break;
}
}
f(s);
}
let mut result = Vec::new();
match value {
Cow::Borrowed(s) => split_n(s, arg.as_ref(), num, |s| result.push(Variable::from(s))),
Cow::Owned(s) => split_n(&s, arg.as_ref(), num, |s| {
result.push(Variable::from(s.to_string()))
}),
}
result.into()
}
pub(crate) fn fn_split_once(v: Vec<Variable>) -> Variable {
let mut v = v.into_iter();
let value = v.next().unwrap().into_string();

View file

@ -94,7 +94,7 @@ pub fn fn_hash<'x>(_: &'x Context<'x>, v: Vec<Variable>) -> Variable {
})
}
pub fn fn_is_var_names<'x>(ctx: &'x Context<'x>, _: Vec<Variable>) -> Variable {
pub fn fn_get_var_names<'x>(ctx: &'x Context<'x>, _: Vec<Variable>) -> Variable {
Variable::Array(
ctx.global_variable_names()
.map(|v| Variable::from(v.to_uppercase()))

View file

@ -20,7 +20,7 @@ use self::{
array::*, email::*, header::*, html::*, image::*, misc::*, text::*, unicode::*, url::*,
};
pub fn register_functions() -> FunctionMap {
pub fn register_functions_trusted() -> FunctionMap {
FunctionMap::new()
.with_function("trim", fn_trim)
.with_function("trim_start", fn_trim_start)
@ -80,6 +80,7 @@ pub fn register_functions() -> FunctionMap {
.with_function_args("rsplit", fn_rsplit, 2)
.with_function_args("split_once", fn_split_once, 2)
.with_function_args("rsplit_once", fn_rsplit_once, 2)
.with_function_args("split_n", fn_split_n, 3)
.with_function_args("strip_prefix", fn_strip_prefix, 2)
.with_function_args("strip_suffix", fn_strip_suffix, 2)
.with_function_args("is_intersect", fn_is_intersect, 2)
@ -87,11 +88,58 @@ pub fn register_functions() -> FunctionMap {
.with_function_no_args("is_encoding_problem", fn_is_encoding_problem)
.with_function_no_args("is_attachment", fn_is_attachment)
.with_function_no_args("is_body", fn_is_body)
.with_function_no_args("var_names", fn_is_var_names)
.with_function_no_args("var_names", fn_get_var_names)
.with_function_no_args("attachment_name", fn_attachment_name)
.with_function_no_args("mime_part_len", fn_mime_part_len)
}
pub fn register_functions_untrusted() -> FunctionMap {
FunctionMap::new()
.with_function("trim", fn_trim)
.with_function("trim_start", fn_trim_start)
.with_function("trim_end", fn_trim_end)
.with_function("len", fn_len)
.with_function("count", fn_count)
.with_function("is_empty", fn_is_empty)
.with_function("is_number", fn_is_number)
.with_function("is_ascii", fn_is_ascii)
.with_function("to_lowercase", fn_to_lowercase)
.with_function("to_uppercase", fn_to_uppercase)
.with_function("is_email", fn_is_email)
.with_function("thread_name", fn_thread_name)
.with_function("html_to_text", fn_html_to_text)
.with_function("is_uppercase", fn_is_uppercase)
.with_function("is_lowercase", fn_is_lowercase)
.with_function("has_digits", fn_has_digits)
.with_function("count_spaces", fn_count_spaces)
.with_function("count_uppercase", fn_count_uppercase)
.with_function("count_lowercase", fn_count_lowercase)
.with_function("count_chars", fn_count_chars)
.with_function("dedup", fn_dedup)
.with_function("lines", fn_lines)
.with_function("is_ip_addr", fn_is_ip_addr)
.with_function("is_ipv4_addr", fn_is_ipv4_addr)
.with_function("is_ipv6_addr", fn_is_ipv6_addr)
.with_function("winnow", fn_winnow)
.with_function_args("sort", fn_sort, 2)
.with_function_args("email_part", fn_email_part, 2)
.with_function_args("eq_ignore_case", fn_eq_ignore_case, 2)
.with_function_args("contains", fn_contains, 2)
.with_function_args("contains_ignore_case", fn_contains_ignore_case, 2)
.with_function_args("starts_with", fn_starts_with, 2)
.with_function_args("ends_with", fn_ends_with, 2)
.with_function_args("uri_part", fn_uri_part, 2)
.with_function_args("substring", fn_substring, 3)
.with_function_args("split", fn_split, 2)
.with_function_args("rsplit", fn_rsplit, 2)
.with_function_args("split_once", fn_split_once, 2)
.with_function_args("rsplit_once", fn_rsplit_once, 2)
.with_function_args("split_n", fn_split_n, 3)
.with_function_args("strip_prefix", fn_strip_prefix, 2)
.with_function_args("strip_suffix", fn_strip_suffix, 2)
.with_function_args("is_intersect", fn_is_intersect, 2)
}
pub trait ApplyString<'x> {
fn transform(&self, f: impl Fn(&'_ str) -> Variable) -> Variable;
}

View file

@ -191,6 +191,25 @@ pub fn fn_rsplit<'x>(_: &'x Context<'x>, v: Vec<Variable>) -> Variable {
.into()
}
pub fn fn_split_n<'x>(_: &'x Context<'x>, v: Vec<Variable>) -> Variable {
let value = v[0].to_string();
let arg = v[1].to_string();
let num = v[2].to_integer() as usize;
let mut result = Vec::new();
let mut s = value.as_ref();
for _ in 0..num {
if let Some((a, b)) = s.split_once(arg.as_ref()) {
result.push(Variable::from(a.to_string()));
s = b;
} else {
break;
}
}
result.push(Variable::from(s.to_string()));
result.into()
}
pub fn fn_split_once<'x>(_: &'x Context<'x>, v: Vec<Variable>) -> Variable {
v[0].to_string()
.split_once(v[1].to_string().as_ref())

View file

@ -0,0 +1,80 @@
/*
* SPDX-FileCopyrightText: 2020 Stalwart Labs Ltd <hello@stalw.art>
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-SEL
*/
use std::time::Instant;
use directory::Permission;
use sieve::{runtime::Variable, FunctionMap};
use trc::{AiEvent, SecurityEvent};
use super::PluginContext;
pub fn register(plugin_id: u32, fnc_map: &mut FunctionMap) {
fnc_map.set_external_function("llm_prompt", plugin_id, 2);
}
pub async fn exec(ctx: PluginContext<'_>) -> trc::Result<Variable> {
// SPDX-SnippetBegin
// SPDX-FileCopyrightText: 2020 Stalwart Labs Ltd <hello@stalw.art>
// SPDX-License-Identifier: LicenseRef-SEL
#[cfg(feature = "enterprise")]
if let (Variable::String(name), Variable::String(prompt)) =
(&ctx.arguments[0], &ctx.arguments[1])
{
#[cfg(feature = "test_mode")]
if name.as_ref() == "echo-test" {
return Ok(prompt.to_string().into());
}
if let Some(ai_api) = ctx.server.core.enterprise.as_ref().and_then(|e| {
if ctx.access_token.map_or(true, |token| {
if token.has_permission(Permission::AiModelInteract) {
true
} else {
trc::event!(
Security(SecurityEvent::Unauthorized),
AccountId = token.primary_id(),
Details = Permission::AiModelInteract.name(),
SpanId = ctx.session_id,
);
false
}
}) {
if e.ai_apis.len() == 1 && name.is_empty() {
e.ai_apis.values().next()
} else {
e.ai_apis.get(name.as_ref())
}
} else {
None
}
}) {
let time = Instant::now();
match ai_api.send_request(prompt.as_ref(), None).await {
Ok(response) => {
trc::event!(
Ai(AiEvent::LlmResponse),
Id = ai_api.id.clone(),
Value = prompt.to_string(),
Details = response.clone(),
Elapsed = time.elapsed(),
SpanId = ctx.session_id,
);
return Ok(response.into());
}
Err(err) => {
trc::error!(err.span_id(ctx.session_id));
}
}
}
}
// SPDX-SnippetEnd
Ok(false.into())
}

View file

@ -9,6 +9,7 @@ pub mod dns;
pub mod exec;
pub mod headers;
pub mod http;
pub mod llm_prompt;
pub mod lookup;
pub mod pyzor;
pub mod query;
@ -17,7 +18,7 @@ pub mod text;
use mail_parser::Message;
use sieve::{runtime::Variable, FunctionMap, Input};
use crate::{Core, Server};
use crate::{auth::AccessToken, Core, Server};
use super::ScriptModification;
@ -25,13 +26,14 @@ type RegisterPluginFnc = fn(u32, &mut FunctionMap) -> ();
pub struct PluginContext<'x> {
pub session_id: u64,
pub access_token: Option<&'x AccessToken>,
pub server: &'x Server,
pub message: &'x Message<'x>,
pub modifications: &'x mut Vec<ScriptModification>,
pub arguments: Vec<Variable>,
}
const PLUGINS_REGISTER: [RegisterPluginFnc; 18] = [
const PLUGINS_REGISTER: [RegisterPluginFnc; 19] = [
query::register,
exec::register,
lookup::register,
@ -50,14 +52,16 @@ const PLUGINS_REGISTER: [RegisterPluginFnc; 18] = [
headers::register,
text::register_tokenize,
text::register_domain_part,
llm_prompt::register,
];
pub trait RegisterSievePlugins {
fn register_plugins(self) -> Self;
fn register_plugins_trusted(self) -> Self;
fn register_plugins_untrusted(self) -> Self;
}
impl RegisterSievePlugins for FunctionMap {
fn register_plugins(mut self) -> Self {
fn register_plugins_trusted(mut self) -> Self {
#[cfg(feature = "test_mode")]
{
self.set_external_function("print", PLUGINS_REGISTER.len() as u32, 1)
@ -68,6 +72,11 @@ impl RegisterSievePlugins for FunctionMap {
}
self
}
fn register_plugins_untrusted(mut self) -> Self {
llm_prompt::register(18, &mut self);
self
}
}
impl Core {
@ -97,6 +106,7 @@ impl Core {
15 => headers::exec(ctx),
16 => text::exec_tokenize(ctx),
17 => text::exec_domain_part(ctx),
18 => llm_prompt::exec(ctx).await,
_ => unreachable!(),
};

View file

@ -197,6 +197,7 @@ impl Permission {
Permission::OauthClientCreate => "Create new OAuth clients",
Permission::OauthClientUpdate => "Modify OAuth clients",
Permission::OauthClientDelete => "Remove OAuth clients",
Permission::AiModelInteract => "Interact with AI models",
}
}
}

View file

@ -262,6 +262,8 @@ pub enum Permission {
// OAuth client registration
OauthClientRegistration,
OauthClientOverride,
AiModelInteract,
// WARNING: add new ids at the end (TODO: use static ids)
}

View file

@ -1,6 +1,6 @@
[package]
name = "jmap_proto"
version = "0.1.0"
version = "0.10.2"
edition = "2021"
resolver = "2"

View file

@ -111,7 +111,10 @@ impl SieveHandler for Server {
}
// Run script
let result = match self.run_script(script_id, script, params, 0).await {
let result = match self
.run_script(script_id, script, params.with_access_token(access_token))
.await
{
ScriptResult::Accept { modifications } => Response::Accept { modifications },
ScriptResult::Replace {
message,

View file

@ -6,7 +6,9 @@
use std::borrow::Cow;
use common::{auth::AccessToken, listener::stream::NullIo, Server};
use common::{
auth::AccessToken, listener::stream::NullIo, scripts::plugins::PluginContext, Server,
};
use directory::{backend::internal::PrincipalField, QueryBy};
use jmap_proto::types::{collection::Collection, id::Id, keyword::Keyword, property::Property};
use mail_parser::MessageParser;
@ -398,12 +400,27 @@ impl SieveScriptIngest for Server {
}
}
Event::ListContains { .. }
| Event::Function { .. }
| Event::Notify { .. }
| Event::SetEnvelope { .. } => {
// Not allowed
input = false.into();
}
Event::Function { id, arguments } => {
input = self
.core
.run_plugin(
id,
PluginContext {
session_id,
server: self,
message: instance.message(),
modifications: &mut Vec::new(),
access_token: access_token.into(),
arguments,
},
)
.await;
}
Event::CreatedMessage { message, .. } => {
messages.push(SieveMessage {
raw_message: message.into(),

View file

@ -31,7 +31,6 @@ pub trait RunScript: Sync + Send {
script_id: String,
script: Arc<Sieve>,
params: ScriptParameters<'_>,
session_id: u64,
) -> impl Future<Output = ScriptResult> + Send;
}
@ -41,7 +40,6 @@ impl RunScript for Server {
script_id: String,
script: Arc<Sieve>,
params: ScriptParameters<'_>,
session_id: u64,
) -> ScriptResult {
// Create filter instance
let time = Instant::now();
@ -56,6 +54,7 @@ impl RunScript for Server {
.with_user_full_name(&params.from_name);
let mut input = Input::script("__script", script);
let mut messages: Vec<Vec<u8>> = Vec::new();
let session_id = params.session_id;
let mut reject_reason = None;
let mut modifications = vec![];
@ -124,6 +123,7 @@ impl RunScript for Server {
server: self,
message: instance.message(),
modifications: &mut modifications,
access_token: params.access_token,
arguments,
},
)

View file

@ -132,9 +132,9 @@ impl<T: SessionStream> Session<T> {
script_id,
script,
params
.with_session_id(self.data.session_id)
.with_envelope(&self.server, self, self.data.session_id)
.await,
self.data.session_id,
)
.await
}

View file

@ -7,7 +7,9 @@
use std::borrow::Cow;
use ahash::AHashMap;
use common::{expr::functions::ResolveVariable, scripts::ScriptModification, Server};
use common::{
auth::AccessToken, expr::functions::ResolveVariable, scripts::ScriptModification, Server,
};
use sieve::{runtime::Variable, Envelope};
pub mod envelope;
@ -38,6 +40,8 @@ pub struct ScriptParameters<'x> {
sign: Vec<String>,
#[cfg(feature = "test_mode")]
expected_variables: Option<AHashMap<String, Variable>>,
access_token: Option<&'x AccessToken>,
session_id: u64,
}
impl<'x> ScriptParameters<'x> {
@ -53,6 +57,8 @@ impl<'x> ScriptParameters<'x> {
from_name: Default::default(),
return_path: Default::default(),
sign: Default::default(),
access_token: None,
session_id: Default::default(),
}
}
@ -108,6 +114,16 @@ impl<'x> ScriptParameters<'x> {
self
}
pub fn with_access_token(mut self, access_token: &'x AccessToken) -> Self {
self.access_token = Some(access_token);
self
}
pub fn with_session_id(mut self, session_id: u64) -> Self {
self.session_id = session_id;
self
}
#[cfg(feature = "test_mode")]
pub fn with_expected_variables(
mut self,

View file

@ -51,6 +51,7 @@ impl EventType {
EventType::Telemetry(event) => event.description(),
EventType::MessageIngest(event) => event.description(),
EventType::Security(event) => event.description(),
EventType::Ai(event) => event.description(),
}
}
@ -98,6 +99,7 @@ impl EventType {
EventType::Telemetry(event) => event.explain(),
EventType::MessageIngest(event) => event.explain(),
EventType::Security(event) => event.explain(),
EventType::Ai(event) => event.explain(),
}
}
}
@ -1806,3 +1808,19 @@ impl SecurityEvent {
}
}
}
impl AiEvent {
pub fn description(&self) -> &'static str {
match self {
AiEvent::LlmResponse => "LLM response",
AiEvent::ApiError => "AI API error",
}
}
pub fn explain(&self) -> &'static str {
match self {
AiEvent::LlmResponse => "An LLM response has been received",
AiEvent::ApiError => "An AI API error occurred",
}
}
}

View file

@ -533,6 +533,10 @@ impl EventType {
MessageIngestEvent::Error => Level::Error,
},
EventType::Security(_) => Level::Info,
EventType::Ai(event) => match event {
AiEvent::LlmResponse => Level::Trace,
AiEvent::ApiError => Level::Warn,
},
}
}
}

View file

@ -183,6 +183,7 @@ pub enum EventType {
OutgoingReport(OutgoingReportEvent),
Telemetry(TelemetryEvent),
Security(SecurityEvent),
Ai(AiEvent),
}
#[event_type]
@ -939,6 +940,12 @@ pub enum ResourceEvent {
WebadminUnpacked,
}
#[event_type]
pub enum AiEvent {
LlmResponse,
ApiError,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MetricType {
ServerMemory,

View file

@ -861,6 +861,8 @@ impl EventType {
EventType::Limit(LimitEvent::TenantQuota) => 553,
EventType::Auth(AuthEvent::TokenExpired) => 554,
EventType::Auth(AuthEvent::ClientRegistration) => 555,
EventType::Ai(AiEvent::LlmResponse) => 556,
EventType::Ai(AiEvent::ApiError) => 557,
}
}
@ -1462,6 +1464,8 @@ impl EventType {
553 => Some(EventType::Limit(LimitEvent::TenantQuota)),
554 => Some(EventType::Auth(AuthEvent::TokenExpired)),
555 => Some(EventType::Auth(AuthEvent::ClientRegistration)),
556 => Some(EventType::Ai(AiEvent::LlmResponse)),
557 => Some(EventType::Ai(AiEvent::ApiError)),
_ => None,
}
}

View file

@ -31,6 +31,7 @@ scripts = {
"url.sieve",
"rbl.sieve",
"pyzor.sieve",
"llm.sieve",
"composites.sieve",
"scores.sieve",
"reputation.sieve",
@ -80,7 +81,7 @@ def read_file(file):
return f.read() + "\n"
def build_spam_filters(scripts):
spam_filter = "[version]\nspam-filter = \"1.1\"\n\n"
spam_filter = "[version]\nspam-filter = \"1.2\"\n\n"
for script_name, file_list in scripts.items():
script_content = read_and_concatenate(file_list).replace("'''", "\\'\\'\\'")
script_description = script_names[script_name]

View file

@ -361,4 +361,16 @@ spam-scores = {"ABUSE_SURBL" = "5.0",
"SHORT_PART_BAD_HEADERS" = "7.0",
"MISSING_ESSENTIAL_HEADERS" = "7.0",
"SINGLE_SHORT_PART" = "0.0",
"COMPLETELY_EMPTY" = "7.0"}
"COMPLETELY_EMPTY" = "7.0",
"LLM_UNSOLICITED_HIGH" = "3.0",
"LLM_UNSOLICITED_MEDIUM" = "2.0",
"LLM_UNSOLICITED_LOW" = "0.5",
"LLM_COMMERCIAL_HIGH" = "3.0",
"LLM_COMMERCIAL_MEDIUM" = "2.0",
"LLM_COMMERCIAL_LOW" = "0.5",
"LLM_HARMFUL_HIGH" = "3.0",
"LLM_HARMFUL_MEDIUM" = "2.0",
"LLM_HARMFUL_LOW" = "0.5",
"LLM_LEGITIMATE_HIGH" = "-3.0",
"LLM_LEGITIMATE_MEDIUM" = "-2.0",
"LLM_LEGITIMATE_LOW" = "-0.5"}

View file

@ -10,5 +10,30 @@ spam-config = {
"threshold-discard" = "0.0",
"threshold-reject" = "0.0",
"directory" = "",
"lookup" = ""
"lookup" = "",
"llm-model" = "",
"llm-prompt" = "You are an AI assistant specialized in analyzing email content to detect unsolicited, commercial, or harmful messages. Your task is to examine the provided email, including its subject line, and determine if it falls into any of these categories. Please follow these steps:
- Carefully read the entire email content, including the subject line.
- Look for indicators of unsolicited messages, such as:
* Lack of prior relationship or consent
* Mass-mailing characteristics
* Vague or misleading sender information
- Identify commercial content by checking for:
* Promotional language
* Product or service offerings
* Call-to-action for purchases
- Detect potentially harmful content by searching for:
* Phishing attempts (requests for personal information, suspicious links)
* Malware indicators (suspicious attachments, urgent calls to action)
* Scams or fraudulent schemes
- Analyze the overall tone, intent, and legitimacy of the email.
- Determine the most appropriate single category for the email: Unsolicited, Commercial, Harmful, or Legitimate.
- Assess your confidence level in this determination: High, Medium, or Low.
- Provide a brief explanation for your determination.
- Format your response as follows, separated by commas: Category,Confidence,Explanation
* Example: Unsolicited,High,The email contains mass-mailing characteristics without any prior relationship context.
Here's the email to analyze, please provide your analysis based on the above instructions, ensuring your response is in the specified comma-separated format:",
"add-llm-result" = true
}

View file

@ -33,3 +33,12 @@ let "DOMAIN_DIRECTORY" "key_get('spam-config', 'directory')";
# Store to use for Bayes tokens and ids (leave empty for default)
let "SPAM_DB" "key_get('spam-config', 'lookup')";
# LLM model to use for spam classification
let "LLM_MODEL" "key_get('spam-config', 'llm-model')";
# LLM prompt to use for spam classification
let "LLM_PROMPT_TEXT" "key_get('spam-config', 'llm-prompt')";
# Whether to add an X-Spam-Llm-Result header
let "ADD_HEADER_LLM" "key_get('spam-config', 'add-llm-result')";

View file

@ -0,0 +1,41 @@
if eval "LLM_MODEL && LLM_PROMPT_TEXT" {
let "llm_result" "trim(split_n(llm_prompt(LLM_MODEL, LLM_PROMPT_TEXT + '\n\nSubject: ' + subject_clean + '\n\n' + text_body), ',', 3))";
if eval "eq_ignore_case(llm_result[0], 'Unsolicited')" {
if eval "eq_ignore_case(llm_result[1], 'High')" {
let "t.LLM_UNSOLICITED_HIGH" "1";
} elsif eval "eq_ignore_case(llm_result[1], 'Medium')" {
let "t.LLM_UNSOLICITED_MEDIUM" "1";
} else {
let "t.LLM_UNSOLICITED_LOW" "1";
}
} elsif eval "eq_ignore_case(llm_result[0], 'Commercial')" {
if eval "eq_ignore_case(llm_result[1], 'High')" {
let "t.LLM_COMMERCIAL_HIGH" "1";
} elsif eval "eq_ignore_case(llm_result[1], 'Medium')" {
let "t.LLM_COMMERCIAL_MEDIUM" "1";
} else {
let "t.LLM_COMMERCIAL_LOW" "1";
}
} elsif eval "eq_ignore_case(llm_result[0], 'Harmful')" {
if eval "eq_ignore_case(llm_result[1], 'High')" {
let "t.LLM_HARMFUL_HIGH" "1";
} elsif eval "eq_ignore_case(llm_result[1], 'Medium')" {
let "t.LLM_HARMFUL_MEDIUM" "1";
} else {
let "t.LLM_HARMFUL_LOW" "1";
}
} elsif eval "eq_ignore_case(llm_result[0], 'Legitimate')" {
if eval "eq_ignore_case(llm_result[1], 'High')" {
let "t.LLM_LEGITIMATE_HIGH" "1";
} elsif eval "eq_ignore_case(llm_result[1], 'Medium')" {
let "t.LLM_LEGITIMATE_MEDIUM" "1";
} else {
let "t.LLM_LEGITIMATE_LOW" "1";
}
}
if eval "ADD_HEADER_LLM && count(llm_result) > 2" {
eval "add_header('X-Spam-Llm-Result', 'Category=' + llm_result[0] + '; Confidence=' + llm_result[1] + '; Explanation=' + llm_result[2])";
}
}

View file

@ -1,4 +1,4 @@
require ["fileinto", "mailbox", "mailboxid", "special-use", "ihave", "imap4flags"];
require ["fileinto", "mailbox", "mailboxid", "special-use", "ihave", "imap4flags", "vnd.stalwart.expressions"];
# SpecialUse extension tests
if not specialuse_exists ["inbox", "trash"] {
@ -61,3 +61,6 @@ if not mailboxexists "My" {
error "'My' not found.";
}
if eval "llm_prompt('echo-test', 'hello world') != 'hello world'" {
error "llm_prompt is unavailable.";
}

View file

@ -0,0 +1,72 @@
expect LLM_UNSOLICITED_HIGH
Subject: Unsolicited,High,Test
Test
<!-- NEXT TEST -->
expect LLM_COMMERCIAL_HIGH
Subject: Commercial,High,Test
Test
<!-- NEXT TEST -->
expect LLM_HARMFUL_HIGH
Subject: Harmful,High,Test
Test
<!-- NEXT TEST -->
expect LLM_LEGITIMATE_HIGH
Subject: Legitimate,High,Test
Test
<!-- NEXT TEST -->
expect LLM_UNSOLICITED_MEDIUM
Subject: Unsolicited,Medium,Test
Test
<!-- NEXT TEST -->
expect LLM_COMMERCIAL_MEDIUM
Subject: Commercial,Medium,Test
Test
<!-- NEXT TEST -->
expect LLM_HARMFUL_MEDIUM
Subject: Harmful,Medium,Test
Test
<!-- NEXT TEST -->
expect LLM_LEGITIMATE_MEDIUM
Subject: Legitimate,Medium,Test
Test
<!-- NEXT TEST -->
expect LLM_UNSOLICITED_LOW
Subject: Unsolicited,Low,Test
Test
<!-- NEXT TEST -->
expect LLM_COMMERCIAL_LOW
Subject: Commercial,Low,Test
Test
<!-- NEXT TEST -->
expect LLM_HARMFUL_LOW
Subject: Harmful,Low,Test
Test
<!-- NEXT TEST -->
expect LLM_LEGITIMATE_LOW
Subject: Legitimate,Low,Test
Test

View file

@ -1,10 +0,0 @@
co.uk
org.uk
com
net
org
info
biz
*.wildcard
!test.wildcard
disposable.org

View file

@ -10,23 +10,18 @@
use std::sync::Arc;
use ahash::AHashMap;
use base64::{engine::general_purpose, Engine};
use common::{config::server::Listeners, listener::SessionData, Core, Data, Inner};
use directory::{backend::internal::PrincipalField, QueryBy};
use hyper::{body, server::conn::http1, service::service_fn, Method, StatusCode, Uri};
use hyper_util::rt::TokioIo;
use jmap::api::{
http::{fetch_body, ToHttpResponse},
HttpResponse, JsonResponse,
};
use hyper::{Method, StatusCode};
use jmap::api::{http::ToHttpResponse, JsonResponse};
use mail_send::Credentials;
use serde_json::json;
use tokio::sync::watch;
use trc::{AuthEvent, EventType};
use utils::config::Config;
use crate::{add_test_certs, directory::DirectoryTest, AssertConfig};
use crate::{
directory::DirectoryTest,
http_server::{spawn_mock_http_server, HttpMessage},
};
static TEST_TOKEN: &str = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ";
@ -143,121 +138,3 @@ async fn oidc_directory() {
assert_eq!(principal.description(), Some("John Doe"));
}
}
const MOCK_HTTP_SERVER: &str = r#"
[server]
hostname = "'oidc.example.org'"
http.url = "'https://127.0.0.1:9090'"
[server.listener.jmap]
bind = ['127.0.0.1:9090']
protocol = 'http'
tls.implicit = true
[server.socket]
reuse-addr = true
[certificate.default]
cert = '%{file:{CERT}}%'
private-key = '%{file:{PK}}%'
default = true
"#;
#[derive(Clone)]
pub struct HttpSessionManager {
inner: HttpRequestHandler,
}
pub type HttpRequestHandler = Arc<dyn Fn(HttpMessage) -> HttpResponse + Sync + Send>;
#[derive(Debug)]
pub struct HttpMessage {
method: Method,
headers: AHashMap<String, String>,
uri: Uri,
body: Option<Vec<u8>>,
}
impl HttpMessage {
pub fn get_url_encoded(&self, key: &str) -> Option<String> {
form_urlencoded::parse(self.body.as_ref()?.as_slice())
.find(|(k, _)| k == key)
.map(|(_, v)| v.into_owned())
}
}
pub async fn spawn_mock_http_server(
handler: HttpRequestHandler,
) -> (watch::Sender<bool>, watch::Receiver<bool>) {
// Start mock push server
let mut settings = Config::new(add_test_certs(MOCK_HTTP_SERVER)).unwrap();
settings.resolve_all_macros().await;
let mock_inner = Arc::new(Inner {
shared_core: Core::parse(&mut settings, Default::default(), Default::default())
.await
.into_shared(),
data: Data::parse(&mut settings),
..Default::default()
});
settings.errors.clear();
settings.warnings.clear();
let mut servers = Listeners::parse(&mut settings);
servers.parse_tcp_acceptors(&mut settings, mock_inner.clone());
// Start JMAP server
servers.bind_and_drop_priv(&mut settings);
settings.assert_no_errors();
servers.spawn(|server, acceptor, shutdown_rx| {
server.spawn(
HttpSessionManager {
inner: handler.clone(),
},
mock_inner.clone(),
acceptor,
shutdown_rx,
);
})
}
impl common::listener::SessionManager for HttpSessionManager {
#[allow(clippy::manual_async_fn)]
fn handle<T: common::listener::SessionStream>(
self,
session: SessionData<T>,
) -> impl std::future::Future<Output = ()> + Send {
async move {
let sender = self.inner;
let _ = http1::Builder::new()
.keep_alive(false)
.serve_connection(
TokioIo::new(session.stream),
service_fn(|mut req: hyper::Request<body::Incoming>| {
let sender = sender.clone();
async move {
let response = sender(HttpMessage {
method: req.method().clone(),
uri: req.uri().clone(),
headers: req
.headers()
.iter()
.map(|(k, v)| {
(k.as_str().to_lowercase(), v.to_str().unwrap().to_string())
})
.collect(),
body: fetch_body(&mut req, 1024 * 1024, 0).await,
});
Ok::<_, hyper::Error>(response.build())
}
}),
)
.await;
}
}
#[allow(clippy::manual_async_fn)]
fn shutdown(&self) -> impl std::future::Future<Output = ()> + Send {
async {}
}
}

135
tests/src/http_server.rs Normal file
View file

@ -0,0 +1,135 @@
/*
* SPDX-FileCopyrightText: 2020 Stalwart Labs Ltd <hello@stalw.art>
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-SEL
*/
use std::sync::Arc;
use ahash::AHashMap;
use common::{config::server::Listeners, listener::SessionData, Core, Data, Inner};
use hyper::{body, server::conn::http1, service::service_fn, Method, Uri};
use hyper_util::rt::TokioIo;
use jmap::api::{http::fetch_body, HttpResponse};
use tokio::sync::watch;
use utils::config::Config;
use crate::{add_test_certs, AssertConfig};
const MOCK_HTTP_SERVER: &str = r#"
[server]
hostname = "'oidc.example.org'"
http.url = "'https://127.0.0.1:9090'"
[server.listener.jmap]
bind = ['127.0.0.1:9090']
protocol = 'http'
tls.implicit = true
[server.socket]
reuse-addr = true
[certificate.default]
cert = '%{file:{CERT}}%'
private-key = '%{file:{PK}}%'
default = true
"#;
#[derive(Clone)]
pub struct HttpSessionManager {
inner: HttpRequestHandler,
}
pub type HttpRequestHandler = Arc<dyn Fn(HttpMessage) -> HttpResponse + Sync + Send>;
#[derive(Debug)]
pub struct HttpMessage {
pub method: Method,
pub headers: AHashMap<String, String>,
pub uri: Uri,
pub body: Option<Vec<u8>>,
}
impl HttpMessage {
pub fn get_url_encoded(&self, key: &str) -> Option<String> {
form_urlencoded::parse(self.body.as_ref()?.as_slice())
.find(|(k, _)| k == key)
.map(|(_, v)| v.into_owned())
}
}
pub async fn spawn_mock_http_server(
handler: HttpRequestHandler,
) -> (watch::Sender<bool>, watch::Receiver<bool>) {
// Start mock push server
let mut settings = Config::new(add_test_certs(MOCK_HTTP_SERVER)).unwrap();
settings.resolve_all_macros().await;
let mock_inner = Arc::new(Inner {
shared_core: Core::parse(&mut settings, Default::default(), Default::default())
.await
.into_shared(),
data: Data::parse(&mut settings),
..Default::default()
});
settings.errors.clear();
settings.warnings.clear();
let mut servers = Listeners::parse(&mut settings);
servers.parse_tcp_acceptors(&mut settings, mock_inner.clone());
// Start JMAP server
servers.bind_and_drop_priv(&mut settings);
settings.assert_no_errors();
servers.spawn(|server, acceptor, shutdown_rx| {
server.spawn(
HttpSessionManager {
inner: handler.clone(),
},
mock_inner.clone(),
acceptor,
shutdown_rx,
);
})
}
impl common::listener::SessionManager for HttpSessionManager {
#[allow(clippy::manual_async_fn)]
fn handle<T: common::listener::SessionStream>(
self,
session: SessionData<T>,
) -> impl std::future::Future<Output = ()> + Send {
async move {
let sender = self.inner;
let _ = http1::Builder::new()
.keep_alive(false)
.serve_connection(
TokioIo::new(session.stream),
service_fn(|mut req: hyper::Request<body::Incoming>| {
let sender = sender.clone();
async move {
let response = sender(HttpMessage {
method: req.method().clone(),
uri: req.uri().clone(),
headers: req
.headers()
.iter()
.map(|(k, v)| {
(k.as_str().to_lowercase(), v.to_str().unwrap().to_string())
})
.collect(),
body: fetch_body(&mut req, 1024 * 1024, 0).await,
});
Ok::<_, hyper::Error>(response.build())
}
}),
)
.await;
}
}
#[allow(clippy::manual_async_fn)]
fn shutdown(&self) -> impl std::future::Future<Output = ()> + Send {
async {}
}
}

View file

@ -106,6 +106,7 @@ pub async fn test(params: &mut JMAPTest) {
.into(),
metrics_alerts: parse_metric_alerts(&mut config),
logo_url: None,
ai_apis: Default::default(),
}
.into();
config.assert_no_errors();
@ -170,6 +171,7 @@ impl EnterpriseCore for Core {
metrics_store: None,
metrics_alerts: vec![],
logo_url: None,
ai_apis: Default::default(),
}
.into();
self

View file

@ -384,12 +384,12 @@ pub async fn jmap_tests() {
mailbox::test(&mut params).await;
delivery::test(&mut params).await;
auth_acl::test(&mut params).await;
auth_limits::test(&mut params).await;*/
auth_limits::test(&mut params).await;
auth_oauth::test(&mut params).await;
/*event_source::test(&mut params).await;
push_subscription::test(&mut params).await;
event_source::test(&mut params).await;
push_subscription::test(&mut params).await;*/
sieve_script::test(&mut params).await;
vacation_response::test(&mut params).await;
/*vacation_response::test(&mut params).await;
email_submission::test(&mut params).await;
websocket::test(&mut params).await;
quota::test(&mut params).await;

View file

@ -18,6 +18,8 @@ static GLOBAL: Jemalloc = Jemalloc;
#[cfg(test)]
pub mod directory;
#[cfg(test)]
pub mod http_server;
#[cfg(test)]
pub mod imap;
#[cfg(test)]
pub mod jmap;

View file

@ -10,12 +10,17 @@ use std::{
use ahash::AHashMap;
use common::{
auth::AccessToken,
enterprise::llm::{
AiApiConfig, ChatCompletionChoice, ChatCompletionRequest, ChatCompletionResponse, Message,
},
scripts::{
functions::html::{get_attribute, html_attr_tokens, html_img_area, html_to_tokens},
ScriptModification,
},
Core,
};
use hyper::Method;
use jmap::api::{http::ToHttpResponse, JsonResponse};
use mail_auth::{dmarc::Policy, DkimResult, DmarcResult, IprevResult, SpfResult, MX};
use sieve::runtime::Variable;
use smtp::{
@ -26,7 +31,11 @@ use smtp::{
use store::Stores;
use utils::config::Config;
use crate::smtp::{session::TestSession, TempDir, TestSMTP};
use crate::{
http_server::{spawn_mock_http_server, HttpMessage},
jmap::enterprise::EnterpriseCore,
smtp::{session::TestSession, TempDir, TestSMTP},
};
const CONFIG: &str = r#"
[spam.header]
@ -46,6 +55,10 @@ threshold-discard = 0
threshold-reject = 0
directory = ""
lookup = ""
llm-model = "dummy"
llm-prompt = "You are an AI assistant specialized in analyzing email content to detect unsolicited, commercial, or harmful messages. Format your response as follows, separated by commas: Category,Confidence,Explanation
Here's the email to analyze, please provide your analysis based on the above instructions, ensuring your response is in the specified comma-separated format:"
add-llm-result = false
[session.rcpt]
relay = true
@ -70,6 +83,11 @@ data = "spamdb"
lookup = "spamdb"
blob = "spamdb"
fts = "spamdb"
directory = "spamdb"
[directory."spamdb"]
type = "internal"
store = "spamdb"
[store."spamdb"]
type = "sqlite"
@ -79,6 +97,12 @@ path = "{PATH}/test_antispam.db"
#type = "redis"
#url = "redis://127.0.0.1"
[enterprise.ai.dummy]
endpoint = "https://127.0.0.1:9090/v1/chat/completions"
type = "chat"
model = "gpt-dummy"
allow-invalid-certs = true
[lookup]
"spam-free" = {"gmail.com", "googlemail.com", "yahoomail.com", "*freemail.org"}
"spam-disposable" = {"guerrillamail.com", "*disposable.org"}
@ -94,9 +118,6 @@ path = "{PATH}/test_antispam.db"
"spam-trap" = {"spamtrap@*"}
"spam-allow" = {"stalw.art"}
[resolver]
public-suffix = "file://{LIST_PATH}/public-suffix.dat"
[sieve.trusted.scripts]
"#;
@ -129,6 +150,7 @@ async fn antispam() {
"bayes_classify",
"reputation",
"pyzor",
"llm",
];
let tmp_dir = TempDir::new("smtp_antispam_test", true);
let base_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
@ -192,8 +214,14 @@ async fn antispam() {
let mut config = Config::new(&config).unwrap();
config.resolve_all_macros().await;
let stores = Stores::parse_all(&mut config).await;
let core = Core::parse(&mut config, stores, Default::default()).await;
//config.assert_no_errors();
let mut core = Core::parse(&mut config, stores, Default::default())
.await
.enable_enterprise();
core.enterprise.as_mut().unwrap().ai_apis.insert(
"dummy".to_string(),
AiApiConfig::parse(&mut config, "dummy").unwrap(),
);
crate::AssertConfig::assert_no_errors(config);
// Add mock DNS entries
for (domain, ip) in [
@ -252,6 +280,34 @@ async fn antispam() {
let server = TestSMTP::from_core(core).server;
// Spawn mock OpenAI server
let _tx = spawn_mock_http_server(Arc::new(|req: HttpMessage| {
assert_eq!(req.uri.path(), "/v1/chat/completions");
assert_eq!(req.method, Method::POST);
let req =
serde_json::from_slice::<ChatCompletionRequest>(req.body.as_ref().unwrap()).unwrap();
assert_eq!(req.model, "gpt-dummy");
let message = &req.messages[0].content;
assert!(message.contains("You are an AI assistant specialized in analyzing email"));
JsonResponse::new(&ChatCompletionResponse {
created: 0,
object: String::new(),
id: String::new(),
model: req.model,
choices: vec![ChatCompletionChoice {
index: 0,
finish_reason: "stop".to_string(),
message: Message {
role: "assistant".to_string(),
content: message.split_once("Subject: ").unwrap().1.to_string(),
},
}],
})
.into_http_response()
}))
.await;
// Run tests
let base_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("resources")
@ -420,10 +476,7 @@ async fn antispam() {
// Run script
let server_ = server.clone();
let script = script.clone();
match server_
.run_script("test".to_string(), script, params, 0)
.await
{
match server_.run_script("test".to_string(), script, params).await {
ScriptResult::Accept { modifications } => {
if modifications.len() != expected_headers.len() {
panic!(

View file

@ -177,7 +177,7 @@ async fn sieve_scripts() {
.await;
match test
.server
.run_script(name.to_string(), script, params, 0)
.run_script(name.to_string(), script, params)
.await
{
ScriptResult::Accept { .. } => (),