Distributed SMTP queues (untested)

This commit is contained in:
Mauro D 2024-02-08 20:03:57 -03:00
parent d15f598460
commit d16119f54b
60 changed files with 2990 additions and 3828 deletions

630
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -25,11 +25,14 @@ use std::{borrow::Borrow, io::Write};
use store::{
write::{DeserializeFrom, SerializeInto},
BlobClass, BlobHash,
BlobClass,
};
use utils::codec::{
use utils::{
codec::{
base32_custom::{Base32Reader, Base32Writer},
leb128::{Leb128Iterator, Leb128Writer},
},
BlobHash,
};
use crate::parser::{base32::JsonBase32Reader, json::Parser, JsonObjectParser};

View file

@ -35,7 +35,8 @@ use mail_parser::{
decoders::{base64::base64_decode, quoted_printable::quoted_printable_decode},
Encoding,
};
use store::{BlobClass, BlobHash};
use store::BlobClass;
use utils::BlobHash;
use crate::{auth::AccessToken, JMAP};

View file

@ -33,8 +33,9 @@ use jmap_proto::{
};
use store::{
write::{now, BatchBuilder, BlobOp},
BlobClass, BlobHash, Serialize,
BlobClass, Serialize,
};
use utils::BlobHash;
use crate::{auth::AccessToken, JMAP};

View file

@ -48,12 +48,12 @@ use jmap_proto::{
};
use mail_parser::{parsers::fields::thread::thread_name, HeaderName, HeaderValue};
use store::{
write::{BatchBuilder, ValueClass, F_BITMAP, F_VALUE},
write::{BatchBuilder, Bincode, ValueClass, F_BITMAP, F_VALUE},
BlobClass,
};
use utils::map::vec_map::VecMap;
use crate::{auth::AccessToken, mailbox::UidMailbox, services::housekeeper::Event, Bincode, JMAP};
use crate::{auth::AccessToken, mailbox::UidMailbox, services::housekeeper::Event, JMAP};
use super::{
index::{EmailIndexBuilder, TrimTextValue, VisitValues, MAX_ID_LENGTH, MAX_SORT_FIELD_LENGTH},

View file

@ -37,9 +37,9 @@ use jmap_proto::{
},
};
use mail_parser::HeaderName;
use store::BlobClass;
use store::{write::Bincode, BlobClass};
use crate::{auth::AccessToken, email::headers::HeaderToValue, mailbox::UidMailbox, Bincode, JMAP};
use crate::{auth::AccessToken, email::headers::HeaderToValue, mailbox::UidMailbox, JMAP};
use super::{
body::{ToBodyPart, TruncateBody},

View file

@ -35,12 +35,13 @@ use store::{
backend::MAX_TOKEN_LENGTH,
fts::{index::FtsDocument, Field},
write::{
BatchBuilder, BlobOp, DirectoryClass, IntoOperations, F_BITMAP, F_CLEAR, F_INDEX, F_VALUE,
BatchBuilder, Bincode, BlobOp, DirectoryClass, IntoOperations, F_BITMAP, F_CLEAR, F_INDEX,
F_VALUE,
},
BlobHash,
};
use utils::BlobHash;
use crate::{mailbox::UidMailbox, Bincode};
use crate::mailbox::UidMailbox;
use super::metadata::MessageMetadata;

View file

@ -32,7 +32,7 @@ use mail_parser::{
MessagePartId, MimeHeaders, PartType,
};
use serde::{Deserialize, Serialize};
use store::BlobHash;
use utils::BlobHash;
#[derive(Debug, Serialize, Deserialize)]
pub struct MessageMetadata<'x> {

View file

@ -53,15 +53,14 @@ use mail_parser::MessageParser;
use store::{
ahash::AHashSet,
write::{
assert::HashedValue, log::ChangeLogBuilder, BatchBuilder, DeserializeFrom, SerializeInto,
ToBitmaps, ValueClass, F_BITMAP, F_CLEAR, F_VALUE,
assert::HashedValue, log::ChangeLogBuilder, BatchBuilder, Bincode, DeserializeFrom,
SerializeInto, ToBitmaps, ValueClass, F_BITMAP, F_CLEAR, F_VALUE,
},
Serialize,
};
use crate::{
auth::AccessToken, mailbox::UidMailbox, services::housekeeper::Event, Bincode, IngestError,
JMAP,
auth::AccessToken, mailbox::UidMailbox, services::housekeeper::Event, IngestError, JMAP,
};
use super::{

View file

@ -31,9 +31,9 @@ use jmap_proto::{
};
use mail_parser::{decoders::html::html_to_text, GetHeader, HeaderName, PartType};
use nlp::language::{search_snippet::generate_snippet, stemmer::Stemmer, Language};
use store::backend::MAX_TOKEN_LENGTH;
use store::{backend::MAX_TOKEN_LENGTH, write::Bincode};
use crate::{auth::AccessToken, Bincode, JMAP};
use crate::{auth::AccessToken, JMAP};
use super::metadata::{MessageMetadata, MetadataPartType};

View file

@ -54,8 +54,8 @@ use store::{
fts::FtsFilter,
query::{sort::Pagination, Comparator, Filter, ResultSet, SortedResultSet},
roaring::RoaringBitmap,
write::{BatchBuilder, BitmapClass, DirectoryClass, TagValue, ToBitmaps, ValueClass},
BitmapKey, BlobStore, Deserialize, FtsStore, Serialize, Store, Stores, ValueKey,
write::{BatchBuilder, BitmapClass, DirectoryClass, TagValue, ValueClass},
BitmapKey, BlobStore, Deserialize, FtsStore, Store, Stores, ValueKey,
};
use tokio::sync::mpsc;
use utils::{
@ -171,10 +171,6 @@ pub struct Config {
pub capabilities: BaseCapabilities,
}
pub struct Bincode<T: serde::Serialize + serde::de::DeserializeOwned> {
pub inner: T,
}
#[derive(Debug)]
pub enum IngestError {
Temporary,
@ -759,56 +755,6 @@ impl JMAP {
}
}
impl<T: serde::Serialize + serde::de::DeserializeOwned> Bincode<T> {
pub fn new(inner: T) -> Self {
Self { inner }
}
}
impl<T: serde::Serialize + serde::de::DeserializeOwned> Serialize for &Bincode<T> {
fn serialize(self) -> Vec<u8> {
lz4_flex::compress_prepend_size(&bincode::serialize(&self.inner).unwrap_or_default())
}
}
impl<T: serde::Serialize + serde::de::DeserializeOwned> Serialize for Bincode<T> {
fn serialize(self) -> Vec<u8> {
lz4_flex::compress_prepend_size(&bincode::serialize(&self.inner).unwrap_or_default())
}
}
impl<T: serde::Serialize + serde::de::DeserializeOwned + Sized + Sync + Send> Deserialize
for Bincode<T>
{
fn deserialize(bytes: &[u8]) -> store::Result<Self> {
lz4_flex::decompress_size_prepended(bytes)
.map_err(|err| {
store::Error::InternalError(format!("Bincode decompression failed: {err:?}"))
})
.and_then(|result| {
bincode::deserialize(&result).map_err(|err| {
store::Error::InternalError(format!(
"Bincode deserialization failed (len {}): {err:?}",
result.len()
))
})
})
.map(|inner| Self { inner })
}
}
impl<T: serde::Serialize + serde::de::DeserializeOwned> ToBitmaps for Bincode<T> {
fn to_bitmaps(&self, _ops: &mut Vec<store::write::Operation>, _field: u8, _set: bool) {
unreachable!()
}
}
impl<T: serde::Serialize + serde::de::DeserializeOwned> ToBitmaps for &Bincode<T> {
fn to_bitmaps(&self, _ops: &mut Vec<store::write::Operation>, _field: u8, _set: bool) {
unreachable!()
}
}
trait UpdateResults: Sized {
fn update_results(&mut self, sorted_results: SortedResultSet) -> Result<(), MethodError>;
}

View file

@ -24,13 +24,13 @@
use jmap_proto::types::{collection::Collection, property::Property};
use store::{
fts::index::FtsDocument,
write::{key::DeserializeBigEndian, BatchBuilder, ValueClass},
write::{key::DeserializeBigEndian, BatchBuilder, Bincode, ValueClass},
Deserialize, IterateParams, ValueKey, U32_LEN, U64_LEN,
};
use crate::{
email::{index::IndexMessageText, metadata::MessageMetadata},
Bincode, JMAP,
JMAP,
};
use super::housekeeper::Event;

View file

@ -32,9 +32,13 @@ use crate::{email::ingest::IngestEmail, mailbox::INBOX_ID, IngestError, JMAP};
impl JMAP {
pub async fn deliver_message(&self, message: IngestMessage) -> Vec<DeliveryResult> {
// Read message
let raw_message = match message.read_message().await {
Ok(raw_message) => raw_message,
Err(_) => {
let raw_message = match self
.blob_store
.get_blob(message.message_blob.as_slice(), 0..u32::MAX)
.await
{
Ok(Some(raw_message)) => raw_message,
_ => {
return (0..message.recipients.len())
.map(|_| DeliveryResult::TemporaryFailure {
reason: "Temporary I/O error.".into(),

View file

@ -32,11 +32,11 @@ use jmap_proto::{
use sieve::Sieve;
use store::{
query::Filter,
write::{assert::HashedValue, BatchBuilder, BlobOp},
write::{assert::HashedValue, BatchBuilder, Bincode, BlobOp},
Deserialize, Serialize,
};
use crate::{sieve::SeenIds, Bincode, JMAP};
use crate::{sieve::SeenIds, JMAP};
use super::ActiveScript;

View file

@ -30,7 +30,7 @@ use sieve::{Envelope, Event, Input, Mailbox, Recipient};
use smtp::core::{Session, SessionAddress};
use store::{
ahash::AHashSet,
write::{now, BatchBuilder, F_VALUE},
write::{now, BatchBuilder, Bincode, F_VALUE},
};
use utils::listener::stream::NullIo;
@ -38,7 +38,7 @@ use crate::{
email::ingest::{IngestEmail, IngestedEmail},
mailbox::{INBOX_ID, TRASH_ID},
sieve::SeenIdHash,
Bincode, IngestError, JMAP,
IngestError, JMAP,
};
use super::ActiveScript;

View file

@ -27,8 +27,7 @@ use jmap_proto::{
object::Object,
types::{collection::Collection, property::Property, value::Value},
};
use smtp::{core::management::QueueRequest, queue};
use tokio::sync::oneshot;
use smtp::queue;
use crate::JMAP;
@ -97,25 +96,10 @@ impl JMAP {
};
// Obtain queueId
let mut queued_message = None;
let (result_tx, result_rx) = oneshot::channel();
if self
let queued_message = self
.smtp
.queue
.tx
.send(queue::Event::Manage(QueueRequest::Status {
queue_ids: vec![push.get(&Property::MessageId).as_uint().unwrap_or(u64::MAX)],
result_tx,
}))
.await
.is_ok()
{
queued_message = result_rx
.await
.ok()
.and_then(|mut result| result.pop())
.flatten();
}
.read_message(push.get(&Property::MessageId).as_uint().unwrap_or(u64::MAX))
.await;
let mut result = Object::with_capacity(properties.len());
for property in &properties {
@ -124,11 +108,7 @@ impl JMAP {
Property::DeliveryStatus => {
match (queued_message.as_ref(), push.remove(property)) {
(Some(message), Value::Object(mut status)) => {
for rcpt in message
.domains
.iter()
.flat_map(|rcpts| rcpts.recipients.iter())
{
for rcpt in &message.recipients {
status.set(
Property::_T(rcpt.address.clone()),
Object::with_capacity(3)
@ -146,10 +126,12 @@ impl JMAP {
.with_property(
Property::SmtpReply,
match &rcpt.status {
queue::Status::Completed(reply)
| queue::Status::TemporaryFailure(reply)
queue::Status::Completed(reply) => {
reply.response.message()
}
queue::Status::TemporaryFailure(reply)
| queue::Status::PermanentFailure(reply) => {
reply.as_str()
reply.response.message()
}
queue::Status::Scheduled => "250 2.1.5 Queued",
}

View file

@ -48,19 +48,15 @@ use jmap_proto::{
},
};
use mail_parser::{HeaderName, HeaderValue};
use smtp::{
core::{management::QueueRequest, Session, SessionData, State},
queue,
};
use smtp::core::{Session, SessionData, State};
use smtp_proto::{request::parser::Rfc5321Parser, MailFrom, RcptTo};
use store::write::{assert::HashedValue, log::ChangeLogBuilder, now, BatchBuilder};
use tokio::sync::oneshot;
use store::write::{assert::HashedValue, log::ChangeLogBuilder, now, BatchBuilder, Bincode};
use utils::{
listener::{stream::NullIo, ServerInstance},
map::vec_map::VecMap,
};
use crate::{email::metadata::MessageMetadata, identity::set::sanitize_email, Bincode, JMAP};
use crate::{email::metadata::MessageMetadata, identity::set::sanitize_email, JMAP};
pub static SCHEMA: &[IndexProperty] = &[
IndexProperty::new(Property::UndoStatus).index_as(IndexAs::Text {
@ -176,24 +172,11 @@ impl JMAP {
match undo_status {
Some(undo_status) if undo_status == "canceled" => {
let (result_tx, result_rx) = oneshot::channel();
if self
.smtp
.queue
.tx
.send(queue::Event::Manage(QueueRequest::Cancel {
queue_ids: vec![queue_id],
item: None,
result_tx,
}))
.await
.is_ok()
&& result_rx
.await
.ok()
.and_then(|mut r| r.pop())
.unwrap_or(false)
{
if let Some(queue_message) = self.smtp.read_message(queue_id).await {
// Delete message from queue
let message_due = queue_message.next_event().unwrap_or_default();
queue_message.remove(&self.smtp, message_due).await;
// Update record
let mut batch = BatchBuilder::new();
batch

View file

@ -32,7 +32,7 @@ jemallocator = "0.5.0"
[features]
#default = ["sqlite", "foundationdb", "postgres", "mysql", "rocks", "elastic", "s3", "redis"]
default = ["sqlite", "postgres", "mysql"]
default = ["sqlite", "postgres", "mysql", "redis"]
sqlite = ["store/sqlite"]
foundationdb = ["store/foundation"]
postgres = ["store/postgres"]

View file

@ -20,7 +20,7 @@ mail-auth = { version = "0.3" }
mail-send = { version = "0.4", default-features = false, features = ["cram-md5"] }
mail-parser = { version = "0.9", features = ["full_encoding", "ludicrous_mode"] }
mail-builder = { version = "0.3", features = ["ludicrous_mode"] }
smtp-proto = { version = "0.1" }
smtp-proto = { version = "0.1", features = ["serde_support"] }
sieve-rs = { version = "0.4" }
ahash = { version = "0.8" }
rustls = "0.22"

View file

@ -72,6 +72,17 @@ impl ConfigShared for Config {
)
})?
.clone(),
default_blob_store: self
.value_or_default("storage.blob", "storage.data")
.and_then(|id| ctx.stores.blob_stores.get(id))
.ok_or_else(|| {
format!(
"Lookup store {:?} not found for key \"storage.blob\".",
self.value_or_default("storage.blob", "storage.data")
.unwrap()
)
})?
.clone(),
})
}

View file

@ -3,16 +3,13 @@ use std::{borrow::Cow, net::IpAddr, sync::Arc, vec::IntoIter};
use directory::Directory;
use mail_auth::IpLookupStrategy;
use sieve::Sieve;
use store::{LookupKey, LookupStore, LookupValue};
use store::{Deserialize, LookupStore};
use utils::{
config::if_block::IfBlock,
expr::{Expression, Variable},
};
use crate::{
config::{ArcSealer, DkimSigner, RelayHost},
scripts::plugins::lookup::VariableExists,
};
use crate::config::{ArcSealer, DkimSigner, RelayHost};
use super::{ResolveVariable, SMTP};
@ -165,15 +162,9 @@ impl SMTP {
let key = params.next_as_string();
self.get_lookup_store(store.as_ref())
.key_get::<String>(LookupKey::Key(key.into_owned().into_bytes()))
.key_get::<VariableWrapper>(key.into_owned().into_bytes())
.await
.map(|value| {
if let LookupValue::Value { value, .. } = value {
Variable::from(value)
} else {
Variable::default()
}
})
.map(|value| value.map(|v| v.into_inner()).unwrap_or_default())
.unwrap_or_else(|err| {
tracing::warn!(
context = "eval_if",
@ -191,9 +182,8 @@ impl SMTP {
let key = params.next_as_string();
self.get_lookup_store(store.as_ref())
.key_get::<VariableExists>(LookupKey::Key(key.into_owned().into_bytes()))
.key_exists(key.into_owned().into_bytes())
.await
.map(|value| matches!(value, LookupValue::Value { .. }))
.unwrap_or_else(|err| {
tracing::warn!(
context = "eval_if",
@ -395,3 +385,30 @@ impl<'x> FncParams<'x> {
self.params.next().unwrap().into_string()
}
}
#[derive(Debug)]
struct VariableWrapper(Variable<'static>);
impl From<i64> for VariableWrapper {
fn from(value: i64) -> Self {
VariableWrapper(Variable::Integer(value))
}
}
impl Deserialize for VariableWrapper {
fn deserialize(bytes: &[u8]) -> store::Result<Self> {
String::deserialize(bytes).map(|v| VariableWrapper(Variable::String(v.into())))
}
}
impl From<store::Value<'static>> for VariableWrapper {
fn from(value: store::Value<'static>) -> Self {
VariableWrapper(value.into())
}
}
impl VariableWrapper {
pub fn into_inner(self) -> Variable<'static> {
self.0
}
}

View file

@ -21,7 +21,7 @@
* for more details.
*/
use std::{borrow::Cow, fmt::Display, net::IpAddr, sync::Arc, time::Instant};
use std::{borrow::Cow, net::IpAddr, sync::Arc};
use directory::{AuthResult, Type};
use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full};
@ -35,70 +35,24 @@ use hyper::{
use hyper_util::rt::TokioIo;
use mail_parser::{decoders::base64::base64_decode, DateTime};
use mail_send::Credentials;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tokio::sync::oneshot;
use serde::{Deserializer, Serializer};
use store::{
write::{key::DeserializeBigEndian, now, Bincode, QueueClass, ReportEvent, ValueClass},
Deserialize, IterateParams, ValueKey,
};
use utils::listener::{limiter::InFlight, SessionData, SessionManager, SessionStream};
use crate::{
queue::{self, instant_to_timestamp, InstantFromTimestamp, QueueId, Status},
reporting::{
self,
scheduler::{ReportKey, ReportPolicy, ReportType, ReportValue},
},
};
use crate::queue::{self, HostResponse, QueueId, Status};
use super::{SmtpAdminSessionManager, SMTP};
#[derive(Debug)]
pub enum QueueRequest {
List {
from: Option<String>,
to: Option<String>,
before: Option<Instant>,
after: Option<Instant>,
result_tx: oneshot::Sender<Vec<u64>>,
},
Status {
queue_ids: Vec<QueueId>,
result_tx: oneshot::Sender<Vec<Option<Message>>>,
},
Cancel {
queue_ids: Vec<QueueId>,
item: Option<String>,
result_tx: oneshot::Sender<Vec<bool>>,
},
Retry {
queue_ids: Vec<QueueId>,
item: Option<String>,
time: Instant,
result_tx: oneshot::Sender<Vec<bool>>,
},
}
#[derive(Debug)]
pub enum ReportRequest {
List {
type_: Option<ReportType<(), ()>>,
domain: Option<String>,
result_tx: oneshot::Sender<Vec<String>>,
},
Status {
report_ids: Vec<ReportKey>,
result_tx: oneshot::Sender<Vec<Option<Report>>>,
},
Cancel {
report_ids: Vec<ReportKey>,
result_tx: oneshot::Sender<Vec<bool>>,
},
}
#[derive(Debug, Serialize)]
#[derive(Debug, serde::Serialize)]
pub struct Response<T> {
data: T,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
pub struct Message {
pub return_path: String,
pub domains: Vec<Domain>,
@ -113,7 +67,7 @@ pub struct Message {
pub env_id: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
pub struct Domain {
pub name: String,
pub status: Status<String, String>,
@ -131,7 +85,7 @@ pub struct Domain {
pub expires: DateTime,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
pub struct Recipient {
pub address: String,
pub status: Status<String, String>,
@ -139,7 +93,7 @@ pub struct Recipient {
pub orcpt: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct Report {
pub domain: String,
#[serde(rename = "type")]
@ -362,18 +316,48 @@ impl SMTP {
match error {
None => {
let (result_tx, result_rx) = oneshot::channel();
self.send_queue_event(
QueueRequest::List {
from,
to,
before,
after,
result_tx,
let mut result = Vec::new();
let from_key = ValueKey::from(ValueClass::Queue(QueueClass::Message(0)));
let to_key =
ValueKey::from(ValueClass::Queue(QueueClass::Message(u64::MAX)));
let has_filters =
from.is_some() || to.is_some() || before.is_some() || after.is_some();
let _ =
self.shared
.default_data_store
.iterate(
IterateParams::new(from_key, to_key).ascending(),
|key, value| {
if has_filters {
let message =
Bincode::<queue::Message>::deserialize(value)?
.inner;
if from.as_ref().map_or(true, |from| {
message.return_path.contains(from)
}) && to.as_ref().map_or(true, |to| {
message
.recipients
.iter()
.any(|r| r.address_lcase.contains(to))
}) && before.as_ref().map_or(true, |before| {
message.next_delivery_event() < *before
}) && after.as_ref().map_or(true, |after| {
message.next_delivery_event() > *after
}) {
result.push(key.deserialize_be_u64(1)?);
}
} else {
result.push(key.deserialize_be_u64(1)?);
}
Ok(true)
},
result_rx,
)
.await
.await;
(
StatusCode::OK,
serde_json::to_string(&Response { data: result }).unwrap_or_default(),
)
}
Some(error) => error.into_bad_request(),
}
@ -404,22 +388,24 @@ impl SMTP {
match error {
None => {
let (result_tx, result_rx) = oneshot::channel();
self.send_queue_event(
QueueRequest::Status {
queue_ids,
result_tx,
},
result_rx,
let mut result = Vec::with_capacity(queue_ids.len());
for queue_id in queue_ids {
if let Some(message) = self.read_message(queue_id).await {
result.push(Message::from(&message));
}
}
(
StatusCode::OK,
serde_json::to_string(&Response { data: result }).unwrap_or_default(),
)
.await
}
Some(error) => error.into_bad_request(),
}
}
(&Method::GET, "queue", "retry") => {
let mut queue_ids = Vec::new();
let mut time = Instant::now();
let mut time = now();
let mut item = None;
let mut error = None;
@ -457,17 +443,49 @@ impl SMTP {
match error {
None => {
let (result_tx, result_rx) = oneshot::channel();
self.send_queue_event(
QueueRequest::Retry {
queue_ids,
item,
time,
result_tx,
},
result_rx,
let mut result = Vec::with_capacity(queue_ids.len());
for queue_id in queue_ids {
let mut found = false;
if let Some(mut message) = self.read_message(queue_id).await {
let prev_event = message.next_event().unwrap_or_default();
for domain in &mut message.domains {
if matches!(
domain.status,
Status::Scheduled | Status::TemporaryFailure(_)
) && item
.as_ref()
.map_or(true, |item| domain.domain.contains(item))
{
domain.retry.due = time;
if domain.expires > time {
domain.expires = time + 10;
}
found = true;
}
}
if found {
let next_event = message.next_event().unwrap_or_default();
message
.save_changes(self, prev_event.into(), next_event.into())
.await;
}
}
result.push(found);
}
if result.iter().any(|r| *r) {
let _ = self.queue.tx.send(queue::Event::Reload).await;
}
(
StatusCode::OK,
serde_json::to_string(&Response { data: result }).unwrap_or_default(),
)
.await
}
Some(error) => error.into_bad_request(),
}
@ -502,16 +520,93 @@ impl SMTP {
match error {
None => {
let (result_tx, result_rx) = oneshot::channel();
self.send_queue_event(
QueueRequest::Cancel {
queue_ids,
item,
result_tx,
let mut result = Vec::with_capacity(queue_ids.len());
for queue_id in queue_ids {
let mut found = false;
if let Some(mut message) = self.read_message(queue_id).await {
let prev_event = message.next_event().unwrap_or_default();
if let Some(item) = &item {
// Cancel delivery for all recipients that match
for rcpt in &mut message.recipients {
if rcpt.address_lcase.contains(item) {
rcpt.status = Status::Completed(HostResponse {
hostname: String::new(),
response: smtp_proto::Response {
code: 0,
esc: [0, 0, 0],
message: "Delivery canceled.".to_string(),
},
result_rx,
});
found = true;
}
}
if found {
// Mark as completed domains without any pending deliveries
for (domain_idx, domain) in
message.domains.iter_mut().enumerate()
{
if matches!(
domain.status,
Status::TemporaryFailure(_) | Status::Scheduled
) {
let mut total_rcpt = 0;
let mut total_completed = 0;
for rcpt in &message.recipients {
if rcpt.domain_idx == domain_idx {
total_rcpt += 1;
if matches!(
rcpt.status,
Status::PermanentFailure(_)
| Status::Completed(_)
) {
total_completed += 1;
}
}
}
if total_rcpt == total_completed {
domain.status = Status::Completed(());
}
}
}
// Delete message if there are no pending deliveries
if message.domains.iter().any(|domain| {
matches!(
domain.status,
Status::TemporaryFailure(_) | Status::Scheduled
)
}) {
let next_event =
message.next_event().unwrap_or_default();
message
.save_changes(
self,
next_event.into(),
prev_event.into(),
)
.await;
} else {
message.remove(self, prev_event).await;
}
}
} else {
message.remove(self, prev_event).await;
found = true;
}
}
result.push(found);
}
(
StatusCode::OK,
serde_json::to_string(&Response { data: result }).unwrap_or_default(),
)
.await
}
Some(error) => error.into_bad_request(),
}
@ -526,10 +621,10 @@ impl SMTP {
match key.as_ref() {
"type" => match value.as_ref() {
"dmarc" => {
type_ = ReportType::Dmarc(()).into();
type_ = 0u8.into();
}
"tls" => {
type_ = ReportType::Tls(()).into();
type_ = 1u8.into();
}
_ => {
error = format!("Invalid report type {value:?}.").into();
@ -549,16 +644,54 @@ impl SMTP {
match error {
None => {
let (result_tx, result_rx) = oneshot::channel();
self.send_report_event(
ReportRequest::List {
type_,
domain,
result_tx,
let mut result = Vec::new();
let from_key = ValueKey::from(ValueClass::Queue(
QueueClass::DmarcReportHeader(ReportEvent {
due: 0,
policy_hash: 0,
seq_id: 0,
domain: String::new(),
}),
));
let to_key = ValueKey::from(ValueClass::Queue(
QueueClass::TlsReportHeader(ReportEvent {
due: u64::MAX,
policy_hash: 0,
seq_id: 0,
domain: String::new(),
}),
));
let _ =
self.shared
.default_data_store
.iterate(
IterateParams::new(from_key, to_key).ascending().no_values(),
|key, _| {
if type_.map_or(true, |t| t == *key.last().unwrap()) {
let event = ReportEvent::deserialize(key)?;
if domain.as_ref().map_or(true, |d| {
d.eq_ignore_ascii_case(&event.domain)
}) {
result.push(
if *key.last().unwrap() == 0 {
QueueClass::DmarcReportHeader(event)
} else {
QueueClass::TlsReportHeader(event)
}
.queue_id(),
);
}
}
Ok(true)
},
result_rx,
)
.await
.await;
(
StatusCode::OK,
serde_json::to_string(&Response { data: result }).unwrap_or_default(),
)
}
Some(error) => error.into_bad_request(),
}
@ -588,17 +721,13 @@ impl SMTP {
}
match error {
None => {
let (result_tx, result_rx) = oneshot::channel();
self.send_report_event(
ReportRequest::Status {
report_ids,
result_tx,
},
result_rx,
)
.await
}
None => (
StatusCode::OK,
serde_json::to_string(&Response {
data: report_ids.into_iter().map(Report::from).collect::<Vec<_>>(),
})
.unwrap_or_default(),
),
Some(error) => error.into_bad_request(),
}
}
@ -628,15 +757,26 @@ impl SMTP {
match error {
None => {
let (result_tx, result_rx) = oneshot::channel();
self.send_report_event(
ReportRequest::Cancel {
report_ids,
result_tx,
},
result_rx,
let mut result = Vec::with_capacity(report_ids.len());
for report_id in report_ids {
match report_id {
QueueClass::DmarcReportHeader(event) => {
self.delete_dmarc_report(event).await;
}
QueueClass::TlsReportHeader(event) => {
self.delete_tls_report(vec![event]).await;
}
_ => (),
}
result.push(true);
}
(
StatusCode::OK,
serde_json::to_string(&Response { data: result }).unwrap_or_default(),
)
.await
}
Some(error) => error.into_bad_request(),
}
@ -660,85 +800,11 @@ impl SMTP {
)
.unwrap()
}
async fn send_queue_event<T: Serialize>(
&self,
request: QueueRequest,
rx: oneshot::Receiver<T>,
) -> (StatusCode, String) {
match self.queue.tx.send(queue::Event::Manage(request)).await {
Ok(_) => match rx.await {
Ok(result) => {
return (
StatusCode::OK,
serde_json::to_string(&Response { data: result }).unwrap_or_default(),
)
}
Err(_) => {
tracing::debug!(
context = "queue",
event = "recv-error",
reason = "Failed to receive manage request response."
);
}
},
Err(_) => {
tracing::debug!(
context = "queue",
event = "send-error",
reason = "Failed to send manage request event."
);
}
}
(
StatusCode::INTERNAL_SERVER_ERROR,
"{\"error\": \"internal-error\", \"details\": \"Resource unavailable, try again later.\"}"
.to_string(),
)
}
async fn send_report_event<T: Serialize>(
&self,
request: ReportRequest,
rx: oneshot::Receiver<T>,
) -> (StatusCode, String) {
match self.report.tx.send(reporting::Event::Manage(request)).await {
Ok(_) => match rx.await {
Ok(result) => {
return (
StatusCode::OK,
serde_json::to_string(&Response { data: result }).unwrap_or_default(),
)
}
Err(_) => {
tracing::debug!(
context = "queue",
event = "recv-error",
reason = "Failed to receive manage request response."
);
}
},
Err(_) => {
tracing::debug!(
context = "queue",
event = "send-error",
reason = "Failed to send manage request event."
);
}
}
(
StatusCode::INTERNAL_SERVER_ERROR,
"{\"error\": \"internal-error\", \"details\": \"Resource unavailable, try again later.\"}"
.to_string(),
)
}
}
impl From<&queue::Message> for Message {
fn from(message: &queue::Message) -> Self {
let now = Instant::now();
let now = now();
Message {
return_path: message.return_path.clone(),
@ -764,20 +830,12 @@ impl From<&queue::Message> for Message {
},
retry_num: domain.retry.inner,
next_retry: if domain.retry.due > now {
DateTime::from_timestamp(instant_to_timestamp(now, domain.retry.due) as i64)
.into()
DateTime::from_timestamp(domain.retry.due as i64).into()
} else {
None
},
next_notify: if domain.notify.due > now {
DateTime::from_timestamp(
instant_to_timestamp(
now,
domain.notify.due,
)
as i64,
)
.into()
DateTime::from_timestamp(domain.notify.due as i64).into()
} else {
None
},
@ -802,61 +860,64 @@ impl From<&queue::Message> for Message {
orcpt: rcpt.orcpt.clone(),
})
.collect(),
expires: DateTime::from_timestamp(
instant_to_timestamp(now, domain.expires) as i64
),
expires: DateTime::from_timestamp(domain.expires as i64),
})
.collect(),
}
}
}
impl From<(&ReportKey, &ReportValue)> for Report {
fn from((key, value): (&ReportKey, &ReportValue)) -> Self {
match (key, value) {
(ReportType::Dmarc(domain), ReportType::Dmarc(value)) => Report {
domain: domain.inner.clone(),
range_from: DateTime::from_timestamp(value.created as i64),
range_to: DateTime::from_timestamp(
(value.created + value.deliver_at.as_secs()) as i64,
),
size: value.size,
impl From<QueueClass> for Report {
fn from(value: QueueClass) -> Self {
match value {
QueueClass::DmarcReportHeader(event) => Report {
domain: event.domain,
type_: "dmarc".to_string(),
range_from: DateTime::from_timestamp(event.due as i64),
range_to: DateTime::from_timestamp(event.due as i64),
size: 0,
},
(ReportType::Tls(domain), ReportType::Tls(value)) => Report {
domain: domain.clone(),
range_from: DateTime::from_timestamp(value.created as i64),
range_to: DateTime::from_timestamp(
(value.created + value.deliver_at.as_secs()) as i64,
),
size: value.size,
QueueClass::TlsReportHeader(event) => Report {
domain: event.domain,
type_: "tls".to_string(),
range_from: DateTime::from_timestamp(event.due as i64),
range_to: DateTime::from_timestamp(event.due as i64),
size: 0,
},
_ => unreachable!(),
}
}
}
impl Display for ReportKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
trait GenerateQueueId {
fn queue_id(&self) -> String;
}
impl GenerateQueueId for QueueClass {
fn queue_id(&self) -> String {
match self {
ReportType::Dmarc(policy) => write!(f, "d!{}!{}", policy.inner, policy.policy),
ReportType::Tls(domain) => write!(f, "t!{domain}"),
QueueClass::DmarcReportHeader(h) => {
format!("d!{}!{}!{}!{}", h.domain, h.policy_hash, h.seq_id, h.due)
}
QueueClass::TlsReportHeader(h) => {
format!("t!{}!{}!{}!{}", h.domain, h.policy_hash, h.seq_id, h.due)
}
_ => unreachable!(),
}
}
}
trait ParseValues {
fn parse_timestamp(&self) -> Result<Instant, String>;
fn parse_timestamp(&self) -> Result<u64, String>;
fn parse_queue_ids(&self) -> Result<Vec<QueueId>, String>;
fn parse_report_ids(&self) -> Result<Vec<ReportKey>, String>;
fn parse_report_ids(&self) -> Result<Vec<QueueClass>, String>;
}
impl ParseValues for Cow<'_, str> {
fn parse_timestamp(&self) -> Result<Instant, String> {
fn parse_timestamp(&self) -> Result<u64, String> {
if let Some(dt) = DateTime::parse_rfc3339(self.as_ref()) {
let instant = (dt.to_timestamp() as u64).to_instant();
if instant >= Instant::now() {
let instant = dt.to_timestamp() as u64;
if instant >= now() {
return Ok(instant);
}
}
@ -881,31 +942,44 @@ impl ParseValues for Cow<'_, str> {
Ok(ids)
}
fn parse_report_ids(&self) -> Result<Vec<ReportKey>, String> {
fn parse_report_ids(&self) -> Result<Vec<QueueClass>, String> {
let mut ids = Vec::new();
for id in self.split(',') {
if !id.is_empty() {
let mut parts = id.split('!');
match (parts.next(), parts.next()) {
(Some("d"), Some(domain)) if !domain.is_empty() => {
if let Some(policy) = parts.next().and_then(|policy| policy.parse().ok()) {
ids.push(ReportType::Dmarc(ReportPolicy {
inner: domain.to_string(),
policy,
match (
parts.next(),
parts.next(),
parts.next().and_then(|p| p.parse::<u64>().ok()),
parts.next().and_then(|p| p.parse::<u64>().ok()),
parts.next().and_then(|p| p.parse::<u64>().ok()),
) {
(Some("d"), Some(domain), Some(policy), Some(seq_id), Some(due))
if !domain.is_empty() =>
{
ids.push(QueueClass::DmarcReportHeader(ReportEvent {
due,
policy_hash: policy,
seq_id,
domain: domain.to_string(),
}));
continue;
}
(Some("t"), Some(domain), Some(policy), Some(seq_id), Some(due))
if !domain.is_empty() =>
{
ids.push(QueueClass::TlsReportHeader(ReportEvent {
due,
policy_hash: policy,
seq_id,
domain: domain.to_string(),
}));
}
(Some("t"), Some(domain)) if !domain.is_empty() => {
ids.push(ReportType::Tls(domain.to_string()));
continue;
}
_ => (),
}
_ => {
return Err(format!("Failed to parse id {id:?}."));
}
}
}
}
Ok(ids)
}
}
@ -944,7 +1018,7 @@ fn deserialize_maybe_datetime<'de, D>(deserializer: D) -> Result<Option<DateTime
where
D: Deserializer<'de>,
{
if let Some(value) = Option::<&str>::deserialize(deserializer)? {
if let Some(value) = <Option<&str> as serde::Deserialize>::deserialize(deserializer)? {
if let Some(value) = DateTime::parse_rfc3339(value) {
Ok(Some(value))
} else {
@ -968,6 +1042,8 @@ fn deserialize_datetime<'de, D>(deserializer: D) -> Result<DateTime, D::Error>
where
D: Deserializer<'de>,
{
use serde::Deserialize;
if let Some(value) = DateTime::parse_rfc3339(<&str>::deserialize(deserializer)?) {
Ok(value)
} else {

View file

@ -24,7 +24,7 @@
use std::{
hash::Hash,
net::IpAddr,
sync::{atomic::AtomicU32, Arc},
sync::Arc,
time::{Duration, Instant},
};
@ -40,7 +40,7 @@ use smtp_proto::{
},
IntoString,
};
use store::{LookupStore, Store, Value};
use store::{BlobStore, LookupStore, Store, Value};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::mpsc,
@ -50,7 +50,12 @@ use tracing::Span;
use utils::{
expr,
ipc::DeliveryEvent,
listener::{limiter::InFlight, stream::NullIo, ServerInstance, TcpAcceptor},
listener::{
limiter::{ConcurrencyLimiter, InFlight},
stream::NullIo,
ServerInstance, TcpAcceptor,
},
snowflake::SnowflakeIdGenerator,
};
use crate::{
@ -63,11 +68,11 @@ use crate::{
dane::{DnssecResolver, Tlsa},
mta_sts,
},
queue::{self, DomainPart, QueueId, QuotaLimiter},
queue::{self, DomainPart, QueueId},
reporting,
};
use self::throttle::{Limiter, ThrottleKey, ThrottleKeyHasherBuilder};
use self::throttle::{ThrottleKey, ThrottleKeyHasherBuilder};
pub mod eval;
pub mod management;
@ -121,6 +126,7 @@ pub struct Shared {
// Default store and directory
pub default_directory: Arc<Directory>,
pub default_data_store: Store,
pub default_blob_store: BlobStore,
pub default_lookup_store: LookupStore,
}
@ -145,15 +151,14 @@ pub struct DnsCache {
pub struct SessionCore {
pub config: SessionConfig,
pub throttle: DashMap<ThrottleKey, Limiter, ThrottleKeyHasherBuilder>,
pub throttle: DashMap<ThrottleKey, ConcurrencyLimiter, ThrottleKeyHasherBuilder>,
}
pub struct QueueCore {
pub config: QueueConfig,
pub throttle: DashMap<ThrottleKey, Limiter, ThrottleKeyHasherBuilder>,
pub quota: DashMap<ThrottleKey, Arc<QuotaLimiter>, ThrottleKeyHasherBuilder>,
pub throttle: DashMap<ThrottleKey, ConcurrencyLimiter, ThrottleKeyHasherBuilder>,
pub tx: mpsc::Sender<queue::Event>,
pub id_seq: AtomicU32,
pub snowflake_id: SnowflakeIdGenerator,
pub connectors: TlsConnectors,
}

View file

@ -21,7 +21,7 @@
* for more details.
*/
use ::utils::listener::limiter::{ConcurrencyLimiter, RateLimiter};
use ::utils::listener::limiter::ConcurrencyLimiter;
use dashmap::mapref::entry::Entry;
use tokio::io::{AsyncRead, AsyncWrite};
use utils::config::Rate;
@ -32,12 +32,6 @@ use crate::config::*;
use super::{eval::*, ResolveVariable, Session};
#[derive(Debug)]
pub struct Limiter {
pub rate: Option<RateLimiter>,
pub concurrency: Option<ConcurrencyLimiter>,
}
#[derive(Debug, Clone, Eq)]
pub struct ThrottleKey {
hash: [u8; 32],
@ -55,6 +49,12 @@ impl Hash for ThrottleKey {
}
}
impl AsRef<[u8]> for ThrottleKey {
fn as_ref(&self) -> &[u8] {
&self.hash
}
}
#[derive(Default)]
pub struct ThrottleKeyHasher {
hash: u64,
@ -236,10 +236,36 @@ impl<T: AsyncRead + AsyncWrite> Session<T> {
}
// Build throttle key
match self.core.session.throttle.entry(t.new_key(self)) {
let key = t.new_key(self);
// Check rate
if let Some(rate) = &t.rate {
if self
.core
.shared
.default_lookup_store
.is_rate_allowed(key.hash.as_slice(), rate, false)
.await
.unwrap_or_default()
.is_some()
{
tracing::debug!(
parent: &self.span,
context = "throttle",
event = "rate-limit-exceeded",
max_requests = rate.requests,
max_interval = rate.period.as_secs(),
"Rate limit exceeded."
);
return false;
}
}
// Check concurrency
if let Some(concurrency) = &t.concurrency {
match self.core.session.throttle.entry(key) {
Entry::Occupied(mut e) => {
let limiter = e.get_mut();
if let Some(limiter) = &limiter.concurrency {
if let Some(inflight) = limiter.is_allowed() {
self.in_flight.push(inflight);
} else {
@ -253,35 +279,13 @@ impl<T: AsyncRead + AsyncWrite> Session<T> {
return false;
}
}
if let (Some(limiter), Some(rate)) = (&mut limiter.rate, &t.rate) {
if !limiter.is_allowed(rate) {
tracing::debug!(
parent: &self.span,
context = "throttle",
event = "rate-limit-exceeded",
max_requests = rate.requests,
max_interval = rate.period.as_secs(),
"Rate limit exceeded."
);
return false;
}
}
}
Entry::Vacant(e) => {
let concurrency = t.concurrency.map(|concurrency| {
let limiter = ConcurrencyLimiter::new(concurrency);
let limiter = ConcurrencyLimiter::new(*concurrency);
if let Some(inflight) = limiter.is_allowed() {
self.in_flight.push(inflight);
}
limiter
});
let rate = t.rate.as_ref().map(|rate| {
let r = RateLimiter::new(rate);
r.is_allowed(rate);
r
});
e.insert(Limiter { rate, concurrency });
e.insert(limiter);
}
}
}
}
@ -290,33 +294,19 @@ impl<T: AsyncRead + AsyncWrite> Session<T> {
true
}
pub fn throttle_rcpt(&self, rcpt: &str, rate: &Rate, ctx: &str) -> bool {
pub async fn throttle_rcpt(&self, rcpt: &str, rate: &Rate, ctx: &str) -> bool {
let mut hasher = blake3::Hasher::new();
hasher.update(rcpt.as_bytes());
hasher.update(ctx.as_bytes());
hasher.update(&rate.period.as_secs().to_ne_bytes()[..]);
hasher.update(&rate.requests.to_ne_bytes()[..]);
let key = ThrottleKey {
hash: hasher.finalize().into(),
};
match self.core.session.throttle.entry(key) {
Entry::Occupied(mut e) => {
if let Some(limiter) = &mut e.get_mut().rate {
limiter.is_allowed(rate)
} else {
false
}
}
Entry::Vacant(e) => {
let limiter = RateLimiter::new(rate);
limiter.is_allowed(rate);
e.insert(Limiter {
rate: limiter.into(),
concurrency: None,
});
true
}
}
self.core
.shared
.default_lookup_store
.is_rate_allowed(hasher.finalize().as_bytes(), rate, false)
.await
.unwrap_or_default()
.is_none()
}
}

View file

@ -54,16 +54,8 @@ impl SMTP {
fn cleanup(&self) {
for throttle in [&self.session.throttle, &self.queue.throttle] {
throttle.retain(|_, v| {
v.concurrency
.as_ref()
.map_or(false, |c| c.concurrent.load(Ordering::Relaxed) > 0)
|| v.rate.as_ref().map_or(false, |r| r.is_active())
});
throttle.retain(|_, v| v.concurrent.load(Ordering::Relaxed) > 0);
}
self.queue.quota.retain(|_, v| {
v.messages.load(Ordering::Relaxed) > 0 || v.size.load(Ordering::Relaxed) > 0
});
}
}

View file

@ -23,10 +23,9 @@
use std::{
borrow::Cow,
path::PathBuf,
process::Stdio,
sync::Arc,
time::{Duration, Instant, SystemTime},
time::{Duration, SystemTime},
};
use mail_auth::{
@ -38,6 +37,7 @@ use sieve::runtime::Variable;
use smtp_proto::{
MAIL_BY_RETURN, RCPT_NOTIFY_DELAY, RCPT_NOTIFY_FAILURE, RCPT_NOTIFY_NEVER, RCPT_NOTIFY_SUCCESS,
};
use store::write::now;
use tokio::{io::AsyncWriteExt, process::Command};
use utils::{config::Rate, listener::SessionStream};
@ -654,10 +654,8 @@ impl<T: SessionStream> Session<T> {
// Verify queue quota
if self.core.has_quota(&mut message).await {
let queue_id = message.id;
if self
.core
.queue
.queue_message(message, Some(&headers), &raw_message, &self.span)
if message
.queue(Some(&headers), &raw_message, &self.core, &self.span)
.await
{
self.state = State::Accepted(queue_id);
@ -682,14 +680,14 @@ impl<T: SessionStream> Session<T> {
&self,
mail_from: SessionAddress,
mut rcpt_to: Vec<SessionAddress>,
) -> Box<Message> {
) -> Message {
// Build message
let mut message = Box::new(Message {
id: self.core.queue.queue_id(),
path: PathBuf::new(),
created: SystemTime::now()
let created = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map_or(0, |d| d.as_secs()),
.map_or(0, |d| d.as_secs());
let mut message = Message {
id: self.core.queue.snowflake_id.generate().unwrap_or(created),
created,
return_path: mail_from.address,
return_path_lcase: mail_from.address_lcase,
return_path_domain: mail_from.domain,
@ -699,8 +697,9 @@ impl<T: SessionStream> Session<T> {
priority: self.data.priority,
size: 0,
env_id: mail_from.dsn_info,
queue_refs: Vec::with_capacity(0),
});
blob_hash: Default::default(),
quota_keys: Vec::new(),
};
// Add recipients
let future_release = Duration::from_secs(self.data.future_release);
@ -711,7 +710,7 @@ impl<T: SessionStream> Session<T> {
.last()
.map_or(true, |d| d.domain != rcpt.domain)
{
let envelope = SimpleEnvelope::new(message.as_ref(), &rcpt.domain);
let envelope = SimpleEnvelope::new(&message, &rcpt.domain);
// Set next retry time
let retry = if self.data.future_release == 0 {
@ -731,18 +730,19 @@ impl<T: SessionStream> Session<T> {
let (notify, expires) = if self.data.delivery_by == 0 {
(
queue::Schedule::later(future_release + next_notify),
Instant::now()
+ future_release
now()
+ future_release.as_secs()
+ self
.core
.eval_if(&config.expire, &envelope)
.await
.unwrap_or_else(|| Duration::from_secs(5 * 86400)),
.unwrap_or_else(|| Duration::from_secs(5 * 86400))
.as_secs(),
)
} else if (message.flags & MAIL_BY_RETURN) != 0 {
(
queue::Schedule::later(future_release + next_notify),
Instant::now() + Duration::from_secs(self.data.delivery_by as u64),
now() + self.data.delivery_by as u64,
)
} else {
let expire = self
@ -769,7 +769,7 @@ impl<T: SessionStream> Session<T> {
let mut notify = queue::Schedule::later(future_release + notify);
notify.inner = (num_intervals - 1) as u32; // Disable further notification attempts
(notify, Instant::now() + expire)
(notify, now() + expire_secs)
};
message.domains.push(queue::Domain {
@ -779,7 +779,6 @@ impl<T: SessionStream> Session<T> {
status: queue::Status::Scheduled,
domain: rcpt.domain,
disable_tls: false,
changed: false,
});
}

View file

@ -42,6 +42,7 @@ use store::Stores;
use tokio::sync::mpsc;
use utils::{
config::{Config, ServerProtocol, Servers},
snowflake::SnowflakeIdGenerator,
UnwrapFailure,
};
@ -129,15 +130,10 @@ impl SMTP {
.unwrap_or(32)
.next_power_of_two() as usize,
),
id_seq: 0.into(),
quota: DashMap::with_capacity_and_hasher_and_shard_amount(
config.property("global.shared-map.capacity")?.unwrap_or(2),
ThrottleKeyHasherBuilder::default(),
config
.property::<u64>("global.shared-map.shard")?
.unwrap_or(32)
.next_power_of_two() as usize,
),
snowflake_id: config
.property::<u64>("storage.cluster.node-id")?
.map(SnowflakeIdGenerator::with_node_id)
.unwrap_or_else(SnowflakeIdGenerator::new),
tx: queue_tx,
connectors: TlsConnectors {
pki_verify: build_tls_connector(false),
@ -156,10 +152,10 @@ impl SMTP {
});
// Spawn queue manager
queue_rx.spawn(core.clone(), core.queue.read_queue().await);
queue_rx.spawn(core.clone());
// Spawn report manager
report_rx.spawn(core.clone(), core.report.read_reports().await);
report_rx.spawn(core.clone());
Ok(core)
}

View file

@ -24,7 +24,7 @@
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::Arc,
time::{Duration, Instant},
time::Duration,
};
use mail_auth::{
@ -33,6 +33,7 @@ use mail_auth::{
};
use mail_send::SmtpClient;
use smtp_proto::MAIL_REQUIRETLS;
use store::write::now;
use utils::config::ServerProtocol;
use crate::{
@ -49,12 +50,12 @@ use super::{
NextHop,
};
use crate::queue::{
manager::Queue, throttle, DeliveryAttempt, Domain, Error, Event, OnHold, QueueEnvelope,
Schedule, Status, WorkerResult,
throttle, DeliveryAttempt, Domain, Error, Event, OnHold, QueueEnvelope, Status,
};
impl DeliveryAttempt {
pub async fn try_deliver(mut self, core: Arc<SMTP>, queue: &mut Queue) {
pub async fn try_deliver(mut self, core: Arc<SMTP>) {
tokio::spawn(async move {
// Check that the message still has recipients to be delivered
let has_pending_delivery = self.has_pending_delivery();
@ -64,56 +65,65 @@ impl DeliveryAttempt {
if has_pending_delivery {
// Re-queue the message if its not yet due for delivery
let due = self.message.next_delivery_event();
if due > Instant::now() {
// Save changes to disk
self.message.save_changes().await;
queue.schedule(Schedule {
due,
inner: self.message,
});
if due > now() {
// Save changes
self.message
.save_changes(&core, self.event.due.into(), due.into())
.await;
if core.queue.tx.send(Event::Reload).await.is_err() {
tracing::warn!("Channel closed while trying to notify queue manager.");
}
return;
}
} else {
// All message recipients expired, do not re-queue. (DSN has been already sent)
self.message.remove().await;
self.message.remove(&core, self.event.due).await;
if core.queue.tx.send(Event::Reload).await.is_err() {
tracing::warn!("Channel closed while trying to notify queue manager.");
}
return;
}
// Throttle sender
for throttle in &core.queue.config.throttle.sender {
if let Err(err) = core
.is_allowed(
throttle,
self.message.as_ref(),
&mut self.in_flight,
&self.span,
)
.is_allowed(throttle, &self.message, &mut self.in_flight, &self.span)
.await
{
// Save changes to disk
self.message.save_changes().await;
match err {
let event = match err {
throttle::Error::Concurrency { limiter } => {
queue.on_hold(OnHold {
next_due: self.message.next_event_after(Instant::now()),
// Save changes to disk
let next_due = self.message.next_event_after(now());
self.message.save_changes(&core, None, None).await;
Event::OnHold(OnHold {
next_due,
limiters: vec![limiter],
message: self.message,
});
message: self.event,
})
}
throttle::Error::Rate { retry_at } => {
queue.schedule(Schedule {
due: retry_at,
inner: self.message,
});
// Save changes to disk
let next_event = std::cmp::min(
retry_at,
self.message.next_event_after(now()).unwrap_or(u64::MAX),
);
self.message
.save_changes(&core, self.event.due.into(), next_event.into())
.await;
Event::Reload
}
};
if core.queue.tx.send(event).await.is_err() {
tracing::warn!("Channel closed while trying to notify queue manager.");
}
return;
}
}
tokio::spawn(async move {
let queue_config = &core.queue.config;
let mut on_hold = Vec::new();
let no_ip = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0));
@ -123,7 +133,7 @@ impl DeliveryAttempt {
'next_domain: for (domain_idx, domain) in domains.iter_mut().enumerate() {
// Only process domains due for delivery
if !matches!(&domain.status, Status::Scheduled | Status::TemporaryFailure(_)
if domain.retry.due <= Instant::now())
if domain.retry.due <= now())
{
continue;
}
@ -138,7 +148,7 @@ impl DeliveryAttempt {
// Build envelope
let mut envelope = QueueEnvelope {
message: self.message.as_ref(),
message: &self.message,
domain: &domain.domain,
mx: "",
remote_ip: no_ip,
@ -672,6 +682,7 @@ impl DeliveryAttempt {
.unwrap_or_else(|| "localhost".to_string());
let params = SessionParams {
span: &span,
core: &core,
credentials: remote_host.credentials(),
is_smtp: remote_host.is_smtp(),
hostname: envelope.mx,
@ -1018,11 +1029,9 @@ impl DeliveryAttempt {
// Notify queue manager
let span = self.span;
let result = if !on_hold.is_empty() {
// Release quota for completed deliveries
self.message.release_quota();
// Save changes to disk
self.message.save_changes().await;
let next_due = self.message.next_event_after(now());
self.message.save_changes(&core, None, None).await;
tracing::info!(
parent: &span,
@ -1032,17 +1041,16 @@ impl DeliveryAttempt {
"Too many outbound concurrent connections, message moved to on-hold queue."
);
WorkerResult::OnHold(OnHold {
next_due: self.message.next_event_after(Instant::now()),
Event::OnHold(OnHold {
next_due,
limiters: on_hold,
message: self.message,
message: self.event,
})
} else if let Some(due) = self.message.next_event() {
// Release quota for completed deliveries
self.message.release_quota();
// Save changes to disk
self.message.save_changes().await;
self.message
.save_changes(&core, self.event.due.into(), due.into())
.await;
tracing::info!(
parent: &span,
@ -1052,13 +1060,10 @@ impl DeliveryAttempt {
"Delivery was not possible, message re-queued for delivery."
);
WorkerResult::Retry(Schedule {
due,
inner: self.message,
})
Event::Reload
} else {
// Delete message from queue
self.message.remove().await;
self.message.remove(&core, self.event.due).await;
tracing::info!(
parent: &span,
@ -1067,9 +1072,9 @@ impl DeliveryAttempt {
"Delivery completed."
);
WorkerResult::Done
Event::Reload
};
if core.queue.tx.send(Event::Done(result)).await.is_err() {
if core.queue.tx.send(result).await.is_err() {
tracing::warn!(
parent: &span,
"Channel closed while trying to notify queue manager."
@ -1080,7 +1085,7 @@ impl DeliveryAttempt {
/// Marks as failed all domains that reached their expiration time
pub fn has_pending_delivery(&mut self) -> bool {
let now = Instant::now();
let now = now();
let mut has_pending_delivery = false;
let span = self.span.clone();
@ -1103,7 +1108,6 @@ impl DeliveryAttempt {
domain.status =
std::mem::replace(&mut domain.status, Status::Scheduled).into_permanent();
domain.changed = true;
}
Status::Scheduled if domain.expires <= now => {
tracing::info!(
@ -1123,7 +1127,6 @@ impl DeliveryAttempt {
domain.status = Status::PermanentFailure(Error::Io(
"Queue rate limit exceeded.".to_string(),
));
domain.changed = true;
}
Status::Completed(_) | Status::PermanentFailure(_) => (),
_ => {
@ -1139,7 +1142,6 @@ impl DeliveryAttempt {
impl Domain {
pub fn set_status(&mut self, status: impl Into<Status<(), Error>>, schedule: &[Duration]) {
self.status = status.into();
self.changed = true;
if matches!(
&self.status,
Status::TemporaryFailure(_) | Status::Scheduled
@ -1149,8 +1151,8 @@ impl Domain {
}
pub fn retry(&mut self, schedule: &[Duration]) {
self.retry.due =
Instant::now() + schedule[std::cmp::min(self.retry.inner as usize, schedule.len() - 1)];
self.retry.due = now()
+ schedule[std::cmp::min(self.retry.inner as usize, schedule.len() - 1)].as_secs();
self.retry.inner += 1;
}
}

View file

@ -63,7 +63,7 @@ impl Message {
message: IngestMessage {
sender_address: self.return_path_lcase.clone(),
recipients: recipient_addresses,
message_path: self.path.clone(),
message_blob: self.blob_hash.clone(),
message_size: self.size,
},
result_tx,

View file

@ -25,6 +25,7 @@ use std::borrow::Cow;
use mail_send::Credentials;
use smtp_proto::{Response, Severity};
use store::write::QueueEvent;
use utils::config::ServerProtocol;
use crate::{
@ -211,8 +212,8 @@ impl From<mta_sts::Error> for Status<(), Error> {
}
}
impl From<Box<Message>> for DeliveryAttempt {
fn from(message: Box<Message>) -> Self {
impl DeliveryAttempt {
pub fn new(message: Message, event: QueueEvent) -> Self {
DeliveryAttempt {
span: tracing::info_span!(
"delivery",
@ -227,6 +228,7 @@ impl From<Box<Message>> for DeliveryAttempt {
),
in_flight: Vec::new(),
message,
event,
}
}
}

View file

@ -30,7 +30,6 @@ use smtp_proto::{
use std::fmt::Write;
use std::time::Duration;
use tokio::{
fs,
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpStream,
};
@ -38,6 +37,7 @@ use tokio_rustls::{client::TlsStream, TlsConnector};
use crate::{
config::{RequireOptional, TlsStrategy},
core::SMTP,
queue::{ErrorDetails, HostResponse, RCPT_STATUS_CHANGED},
};
@ -45,6 +45,7 @@ use crate::queue::{Error, Message, Recipient, Status};
pub struct SessionParams<'x> {
pub span: &'x tracing::Span,
pub core: &'x SMTP,
pub hostname: &'x str,
pub credentials: Option<&'x Credentials<String>>,
pub is_smtp: bool,
@ -532,27 +533,14 @@ pub async fn send_message<T: AsyncRead + AsyncWrite + Unpin>(
bdat_cmd: &Option<String>,
params: &SessionParams<'_>,
) -> Result<(), Status<(), Error>> {
let mut raw_message = vec![0u8; message.size];
let mut file = fs::File::open(&message.path).await.map_err(|err| {
tracing::error!(parent: params.span,
context = "queue",
event = "error",
"Failed to open message file {}: {}",
message.path.display(),
err);
Status::TemporaryFailure(Error::Io("Queue system error.".to_string()))
})?;
file.read_exact(&mut raw_message).await.map_err(|err| {
tracing::error!(parent: params.span,
context = "queue",
event = "error",
"Failed to read {} bytes file {} from disk: {}",
message.size,
message.path.display(),
err);
Status::TemporaryFailure(Error::Io("Queue system error.".to_string()))
})?;
tokio::time::timeout(params.timeout_data, async {
match params
.core
.shared
.default_blob_store
.get_blob(message.blob_hash.as_slice(), 0..u32::MAX)
.await
{
Ok(Some(raw_message)) => tokio::time::timeout(params.timeout_data, async {
if let Some(bdat_cmd) = bdat_cmd {
write_chunks(smtp_client, &[bdat_cmd.as_bytes(), &raw_message]).await
} else {
@ -568,7 +556,30 @@ pub async fn send_message<T: AsyncRead + AsyncWrite + Unpin>(
.map_err(|_| Status::timeout(params.hostname, "sending message"))?
.map_err(|err| {
Status::from_smtp_error(params.hostname, bdat_cmd.as_deref().unwrap_or("DATA"), err)
})
}),
Ok(None) => {
tracing::error!(parent: params.span,
context = "queue",
event = "error",
"BlobHash {:?} does not exist.",
message.blob_hash,
);
Err(Status::TemporaryFailure(Error::Io(
"Queue system error.".to_string(),
)))
}
Err(err) => {
tracing::error!(parent: params.span,
context = "queue",
event = "error",
"Failed to fetch blobId {:?}: {}",
message.blob_hash,
err);
Err(Status::TemporaryFailure(Error::Io(
"Queue system error.".to_string(),
)))
}
}
}
pub async fn say_helo<T: AsyncRead + AsyncWrite + Unpin>(

View file

@ -30,22 +30,21 @@ use smtp_proto::{
Response, RCPT_NOTIFY_DELAY, RCPT_NOTIFY_FAILURE, RCPT_NOTIFY_NEVER, RCPT_NOTIFY_SUCCESS,
};
use std::fmt::Write;
use std::time::{Duration, Instant};
use tokio::fs::File;
use tokio::io::AsyncReadExt;
use std::time::Duration;
use store::write::now;
use crate::core::SMTP;
use super::{
instant_to_timestamp, DeliveryAttempt, Domain, Error, ErrorDetails, HostResponse, Message,
Recipient, SimpleEnvelope, Status, RCPT_DSN_SENT, RCPT_STATUS_CHANGED,
DeliveryAttempt, Domain, Error, ErrorDetails, HostResponse, Message, Recipient, SimpleEnvelope,
Status, RCPT_DSN_SENT, RCPT_STATUS_CHANGED,
};
impl SMTP {
pub async fn send_dsn(&self, attempt: &mut DeliveryAttempt) {
if !attempt.message.return_path.is_empty() {
if let Some(dsn) = attempt.build_dsn(self).await {
let mut dsn_message = Message::new_boxed("", "", "");
let mut dsn_message = self.queue.new_message("", "", "");
dsn_message
.add_recipient_parts(
&attempt.message.return_path,
@ -64,8 +63,8 @@ impl SMTP {
&attempt.span,
)
.await;
self.queue
.queue_message(dsn_message, signature.as_deref(), &dsn, &attempt.span)
dsn_message
.queue(signature.as_deref(), &dsn, self, &attempt.span)
.await;
}
} else {
@ -77,7 +76,7 @@ impl SMTP {
impl DeliveryAttempt {
pub async fn build_dsn(&mut self, core: &SMTP) -> Option<Vec<u8>> {
let config = &core.queue.config;
let now = Instant::now();
let now = now();
let mut txt_success = String::new();
let mut txt_delay = String::new();
@ -245,11 +244,10 @@ impl DeliveryAttempt {
})
{
domain.notify.inner += 1;
domain.notify.due = Instant::now() + next_notify;
domain.notify.due = now + next_notify.as_secs();
} else {
domain.notify.due = domain.expires + Duration::from_secs(10);
domain.notify.due = domain.expires + 10;
}
domain.changed = true;
}
}
self.message.domains = domains;
@ -257,15 +255,15 @@ impl DeliveryAttempt {
// Obtain hostname and sender addresses
let from_name = core
.eval_if(&config.dsn.name, self.message.as_ref())
.eval_if(&config.dsn.name, &self.message)
.await
.unwrap_or_else(|| String::from("Mail Delivery Subsystem"));
let from_addr = core
.eval_if(&config.dsn.address, self.message.as_ref())
.eval_if(&config.dsn.address, &self.message)
.await
.unwrap_or_else(|| String::from("MAILER-DAEMON@localhost"));
let reporting_mta = core
.eval_if(&config.hostname, self.message.as_ref())
.eval_if(&config.hostname, &self.message)
.await
.unwrap_or_else(|| String::from("localhost"));
@ -276,13 +274,15 @@ impl DeliveryAttempt {
let dsn = dsn_header + &dsn;
// Fetch up to 1024 bytes of message headers
let headers = match File::open(&self.message.path).await {
Ok(mut file) => {
let mut buf = vec![0u8; std::cmp::min(self.message.size, 1024)];
match file.read(&mut buf).await {
Ok(br) => {
let headers = match core
.shared
.default_blob_store
.get_blob(self.message.blob_hash.as_slice(), 0..1024)
.await
{
Ok(Some(mut buf)) => {
let mut prev_ch = 0;
let mut last_lf = br;
let mut last_lf = buf.len();
for (pos, &ch) in buf.iter().enumerate() {
match ch {
b'\n' => {
@ -305,26 +305,23 @@ impl DeliveryAttempt {
}
String::from_utf8(buf).unwrap_or_default()
}
Err(err) => {
Ok(None) => {
tracing::error!(
parent: &self.span,
context = "queue",
event = "error",
"Failed to read from {}: {}",
self.message.path.display(),
err
"Failed to open blob {:?}: not found",
self.message.blob_hash
);
String::new()
}
}
}
Err(err) => {
tracing::error!(
parent: &self.span,
context = "queue",
event = "error",
"Failed to open file {}: {}",
self.message.path.display(),
"Failed to open blob {:?}: {}",
self.message.blob_hash,
err
);
String::new()
@ -387,10 +384,10 @@ impl DeliveryAttempt {
}
}
let now = Instant::now();
let now = now();
for domain in &mut message.domains {
if domain.notify.due <= now {
domain.notify.due = domain.expires + Duration::from_secs(10);
domain.notify.due = domain.expires + 10;
}
}
@ -520,13 +517,10 @@ impl Recipient {
impl Domain {
fn write_dsn_will_retry_until(&self, dsn: &mut String) {
let now = Instant::now();
let now = now();
if self.expires > now {
dsn.push_str("Will-Retry-Until: ");
dsn.push_str(
&DateTime::from_timestamp(instant_to_timestamp(now, self.expires) as i64)
.to_rfc822(),
);
dsn.push_str(&DateTime::from_timestamp(self.expires as i64).to_rfc822());
dsn.push_str("\r\n");
}
}

View file

@ -22,311 +22,101 @@
*/
use std::{
collections::BinaryHeap,
sync::{atomic::Ordering, Arc},
time::{Duration, Instant},
time::Duration,
};
use ahash::AHashMap;
use smtp_proto::Response;
use store::write::{now, BatchBuilder, QueueClass, QueueEvent, ValueClass};
use tokio::sync::mpsc;
use crate::core::{
management::{self},
QueueCore, SMTP,
};
use crate::core::SMTP;
use super::{
DeliveryAttempt, Event, HostResponse, Message, OnHold, QueueId, Schedule, Status, WorkerResult,
RCPT_STATUS_CHANGED,
};
use super::{DeliveryAttempt, Event, Message, OnHold, Status};
pub(crate) const SHORT_WAIT: Duration = Duration::from_millis(1);
pub(crate) const LONG_WAIT: Duration = Duration::from_secs(86400 * 365);
#[derive(Debug)]
pub struct Queue {
short_wait: Duration,
long_wait: Duration,
pub scheduled: BinaryHeap<Schedule<QueueId>>,
pub on_hold: Vec<OnHold<QueueId>>,
pub messages: AHashMap<QueueId, Box<Message>>,
pub on_hold: Vec<OnHold<QueueEvent>>,
}
impl SpawnQueue for mpsc::Receiver<Event> {
fn spawn(mut self, core: Arc<SMTP>, mut queue: Queue) {
fn spawn(mut self, core: Arc<SMTP>) {
tokio::spawn(async move {
let mut queue = Queue::default();
let mut next_wake_up = SHORT_WAIT;
loop {
let result = tokio::time::timeout(queue.wake_up_time(), self.recv()).await;
let on_hold = match tokio::time::timeout(next_wake_up, self.recv()).await {
Ok(Some(Event::OnHold(on_hold))) => on_hold.into(),
Ok(Some(Event::Stop)) | Ok(None) => {
break;
}
_ => None,
};
// Deliver any concurrency limited messages
let mut delete_events = Vec::new();
while let Some(queue_event) = queue.next_on_hold() {
if let Some(message) = core.read_message(queue_event.queue_id).await {
DeliveryAttempt::new(message, queue_event)
.try_deliver(core.clone())
.await;
} else {
delete_events.push(queue_event);
}
}
// Deliver scheduled messages
while let Some(message) = queue.next_due() {
DeliveryAttempt::from(message)
.try_deliver(core.clone(), &mut queue)
.await;
}
match result {
Ok(Some(event)) => match event {
Event::Queue(item) => {
// Deliver any concurrency limited messages
while let Some(message) = queue.next_on_hold() {
DeliveryAttempt::from(message)
.try_deliver(core.clone(), &mut queue)
.await;
}
if item.due <= Instant::now() {
DeliveryAttempt::from(item.inner)
.try_deliver(core.clone(), &mut queue)
let now = now();
next_wake_up = LONG_WAIT;
for queue_event in core.next_event().await {
if queue_event.due <= now {
if let Some(message) = core.read_message(queue_event.queue_id).await {
DeliveryAttempt::new(message, queue_event)
.try_deliver(core.clone())
.await;
} else {
queue.schedule(item);
delete_events.push(queue_event);
}
} else {
next_wake_up = Duration::from_secs(queue_event.due - now);
}
}
Event::Done(result) => {
// A worker is done, try delivering concurrency limited messages
while let Some(message) = queue.next_on_hold() {
DeliveryAttempt::from(message)
.try_deliver(core.clone(), &mut queue)
.await;
// Delete unlinked events
if !delete_events.is_empty() {
let core = core.clone();
tokio::spawn(async move {
let mut batch = BatchBuilder::new();
for queue_event in delete_events {
batch.clear(ValueClass::Queue(QueueClass::MessageEvent(queue_event)));
}
match result {
WorkerResult::Done => (),
WorkerResult::Retry(schedule) => {
queue.schedule(schedule);
let _ = core.shared.default_data_store.write(batch.build()).await;
});
}
WorkerResult::OnHold(on_hold) => {
// Add message on hold
if let Some(on_hold) = on_hold {
queue.on_hold(on_hold);
}
}
}
Event::Manage(request) => match request {
management::QueueRequest::List {
from,
to,
before,
after,
result_tx,
} => {
let mut result = Vec::with_capacity(queue.messages.len());
for message in queue.messages.values() {
if from.as_ref().map_or(false, |from| {
!message.return_path_lcase.contains(from)
}) {
continue;
}
if to.as_ref().map_or(false, |to| {
!message
.recipients
.iter()
.any(|rcpt| rcpt.address_lcase.contains(to))
}) {
continue;
}
if (before.is_some() || after.is_some())
&& !message.domains.iter().any(|domain| {
matches!(
&domain.status,
Status::Scheduled | Status::TemporaryFailure(_)
) && match (&before, &after) {
(Some(before), Some(after)) => {
domain.retry.due.lt(before)
&& domain.retry.due.gt(after)
}
(Some(before), None) => domain.retry.due.lt(before),
(None, Some(after)) => domain.retry.due.gt(after),
(None, None) => false,
}
})
{
continue;
}
result.push(message.id);
}
result.sort_unstable_by_key(|id| *id & 0xFFFFFFFF);
let _ = result_tx.send(result);
}
management::QueueRequest::Status {
queue_ids,
result_tx,
} => {
let mut result = Vec::with_capacity(queue_ids.len());
for queue_id in queue_ids {
result.push(
queue
.messages
.get(&queue_id)
.map(|message| message.as_ref().into()),
);
}
let _ = result_tx.send(result);
}
management::QueueRequest::Cancel {
queue_ids,
item,
result_tx,
} => {
let mut result = Vec::with_capacity(queue_ids.len());
for queue_id in &queue_ids {
let mut found = false;
if let Some(item) = &item {
if let Some(message) = queue.messages.get_mut(queue_id) {
// Cancel delivery for all recipients that match
for rcpt in &mut message.recipients {
if rcpt.address_lcase.contains(item) {
rcpt.flags |= RCPT_STATUS_CHANGED;
rcpt.status = Status::Completed(HostResponse {
hostname: String::new(),
response: Response {
code: 0,
esc: [0, 0, 0],
message: "Delivery canceled."
.to_string(),
},
});
found = true;
}
}
if found {
// Mark as completed domains without any pending deliveries
for (domain_idx, domain) in
message.domains.iter_mut().enumerate()
{
if matches!(
domain.status,
Status::TemporaryFailure(_)
| Status::Scheduled
) {
let mut total_rcpt = 0;
let mut total_completed = 0;
for rcpt in &message.recipients {
if rcpt.domain_idx == domain_idx {
total_rcpt += 1;
if matches!(
rcpt.status,
Status::PermanentFailure(_)
| Status::Completed(_)
) {
total_completed += 1;
}
}
}
if total_rcpt == total_completed {
domain.status = Status::Completed(());
domain.changed = true;
}
}
}
// Delete message if there are no pending deliveries
if message.domains.iter().any(|domain| {
matches!(
domain.status,
Status::TemporaryFailure(_)
| Status::Scheduled
)
}) {
message.save_changes().await;
} else {
message.remove().await;
queue.messages.remove(queue_id);
}
}
}
} else if let Some(message) = queue.messages.remove(queue_id) {
message.remove().await;
found = true;
}
result.push(found);
}
let _ = result_tx.send(result);
}
management::QueueRequest::Retry {
queue_ids,
item,
time,
result_tx,
} => {
let mut result = Vec::with_capacity(queue_ids.len());
for queue_id in &queue_ids {
let mut found = false;
if let Some(message) = queue.messages.get_mut(queue_id) {
for domain in &mut message.domains {
if matches!(
domain.status,
Status::Scheduled | Status::TemporaryFailure(_)
) && item
.as_ref()
.map_or(true, |item| domain.domain.contains(item))
{
domain.retry.due = time;
if domain.expires > time {
domain.expires = time + Duration::from_secs(10);
}
domain.changed = true;
found = true;
}
}
if found {
queue.on_hold.retain(|oh| &oh.message != queue_id);
message.save_changes().await;
if let Some(next_event) = message.next_event() {
queue.scheduled.push(Schedule {
due: next_event,
inner: *queue_id,
});
}
}
}
result.push(found);
}
let _ = result_tx.send(result);
}
},
Event::Stop => break,
},
Ok(None) => break,
Err(_) => (),
}
}
});
}
}
impl Queue {
pub fn schedule(&mut self, message: Schedule<Box<Message>>) {
self.scheduled.push(Schedule {
due: message.due,
inner: message.inner.id,
});
self.messages.insert(message.inner.id, message.inner);
}
pub fn on_hold(&mut self, message: OnHold<Box<Message>>) {
pub fn on_hold(&mut self, message: OnHold<QueueEvent>) {
self.on_hold.push(OnHold {
next_due: message.next_due,
limiters: message.limiters,
message: message.message.id,
message: message.message,
});
self.messages.insert(message.message.id, message.message);
}
pub fn next_due(&mut self) -> Option<Box<Message>> {
let item = self.scheduled.peek()?;
if item.due <= Instant::now() {
self.scheduled
.pop()
.and_then(|i| self.messages.remove(&i.inner))
} else {
None
}
}
pub fn next_on_hold(&mut self) -> Option<Box<Message>> {
let now = Instant::now();
pub fn next_on_hold(&mut self) -> Option<QueueEvent> {
let now = now();
self.on_hold
.iter()
.position(|o| {
@ -335,24 +125,13 @@ impl Queue {
.any(|l| l.concurrent.load(Ordering::Relaxed) < l.max_concurrent)
|| o.next_due.map_or(false, |due| due <= now)
})
.and_then(|pos| self.messages.remove(&self.on_hold.remove(pos).message))
}
pub fn wake_up_time(&self) -> Duration {
self.scheduled
.peek()
.map(|item| {
item.due
.checked_duration_since(Instant::now())
.unwrap_or(self.short_wait)
})
.unwrap_or(self.long_wait)
.map(|pos| self.on_hold.remove(pos).message)
}
}
impl Message {
pub fn next_event(&self) -> Option<Instant> {
let mut next_event = Instant::now();
pub fn next_event(&self) -> Option<u64> {
let mut next_event = now();
let mut has_events = false;
for domain in &self.domains {
@ -380,8 +159,8 @@ impl Message {
}
}
pub fn next_delivery_event(&self) -> Instant {
let mut next_delivery = Instant::now();
pub fn next_delivery_event(&self) -> u64 {
let mut next_delivery = now();
for (pos, domain) in self
.domains
@ -397,7 +176,7 @@ impl Message {
next_delivery
}
pub fn next_event_after(&self, instant: Instant) -> Option<Instant> {
pub fn next_event_after(&self, instant: u64) -> Option<u64> {
let mut next_event = None;
for domain in &self.domains {
@ -431,129 +210,14 @@ impl Message {
}
}
impl QueueCore {
pub async fn read_queue(&self) -> Queue {
let mut queue = Queue::default();
let mut messages = Vec::new();
let mut dir = match tokio::fs::read_dir(&self.config.path).await {
Ok(dir) => dir,
Err(err) => {
tracing::warn!(
"Failed to read queue directory {}: {}",
self.config.path.display(),
err
);
return queue;
}
};
loop {
match dir.next_entry().await {
Ok(Some(file)) => {
let file = file.path();
if file.is_dir() {
match tokio::fs::read_dir(&file).await {
Ok(mut dir) => {
let file_ = file;
loop {
match dir.next_entry().await {
Ok(Some(file)) => {
let file = file.path();
if file.extension().map_or(false, |e| e == "msg") {
messages
.push(tokio::spawn(Message::from_path(file)));
}
}
Ok(None) => break,
Err(err) => {
tracing::warn!(
"Failed to read queue directory {}: {}",
file_.display(),
err
);
break;
}
}
}
}
Err(err) => {
tracing::warn!(
"Failed to read queue directory {}: {}",
file.display(),
err
)
}
};
} else if file.extension().map_or(false, |e| e == "msg") {
messages.push(tokio::spawn(Message::from_path(file)));
}
}
Ok(None) => {
break;
}
Err(err) => {
tracing::warn!(
"Failed to read queue directory {}: {}",
self.config.path.display(),
err
);
break;
}
}
}
// Join all futures
for message in messages {
match message.await {
Ok(Ok(mut message)) => {
// Reserve quota
let todo = true;
//self.has_quota(&mut message).await;
// Schedule message
queue.schedule(Schedule {
due: message.next_event().unwrap_or_else(|| {
tracing::warn!(
context = "queue",
event = "warn",
"No due events found for message {}",
message.path.display()
);
Instant::now()
}),
inner: Box::new(message),
});
}
Ok(Err(err)) => {
tracing::warn!(
context = "queue",
event = "error",
"Queue startup error: {}",
err
);
}
Err(err) => {
tracing::error!("Join error while starting queue: {}", err);
}
}
}
queue
}
}
impl Default for Queue {
fn default() -> Self {
Queue {
short_wait: Duration::from_millis(1),
long_wait: Duration::from_secs(86400 * 365),
scheduled: BinaryHeap::with_capacity(128),
on_hold: Vec::with_capacity(128),
messages: AHashMap::with_capacity(128),
}
}
}
pub trait SpawnQueue {
fn spawn(self, core: Arc<SMTP>, queue: Queue);
fn spawn(self, core: Arc<SMTP>);
}

View file

@ -24,21 +24,22 @@
use std::{
fmt::Display,
net::IpAddr,
path::PathBuf,
sync::{atomic::AtomicUsize, Arc},
time::{Duration, Instant, SystemTime},
};
use serde::{Deserialize, Serialize};
use smtp_proto::Response;
use utils::listener::limiter::{ConcurrencyLimiter, InFlight};
use store::write::{now, QueueEvent};
use utils::{
listener::limiter::{ConcurrencyLimiter, InFlight},
BlobHash,
};
use crate::core::{eval::*, management, ResolveVariable};
use crate::core::{eval::*, ResolveVariable};
pub mod dsn;
pub mod manager;
pub mod quota;
pub mod serialize;
pub mod spool;
pub mod throttle;
@ -46,37 +47,29 @@ pub type QueueId = u64;
#[derive(Debug)]
pub enum Event {
Queue(Schedule<Box<Message>>),
Manage(management::QueueRequest),
Done(WorkerResult),
Reload,
OnHold(OnHold<QueueEvent>),
Stop,
}
#[derive(Debug)]
pub enum WorkerResult {
Done,
Retry(Schedule<Box<Message>>),
OnHold(OnHold<Box<Message>>),
}
#[derive(Debug)]
pub struct OnHold<T> {
pub next_due: Option<Instant>,
pub next_due: Option<u64>,
pub limiters: Vec<ConcurrencyLimiter>,
pub message: T,
}
#[derive(Debug)]
#[derive(Debug, Serialize, Deserialize)]
pub struct Schedule<T> {
pub due: Instant,
pub due: u64,
pub inner: T,
}
#[derive(Debug)]
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct Message {
pub id: QueueId,
pub created: u64,
pub path: PathBuf,
pub blob_hash: BlobHash,
pub return_path: String,
pub return_path_lcase: String,
@ -89,21 +82,26 @@ pub struct Message {
pub priority: i16,
pub size: usize,
pub queue_refs: Vec<UsedQuota>,
pub quota_keys: Vec<QuotaKey>,
}
#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub enum QuotaKey {
Size { key: Vec<u8>, id: u64 },
Count { key: Vec<u8>, id: u64 },
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Domain {
pub domain: String,
pub retry: Schedule<u32>,
pub notify: Schedule<u32>,
pub expires: Instant,
pub expires: u64,
pub status: Status<(), Error>,
pub disable_tls: bool,
pub changed: bool,
}
#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Recipient {
pub domain_idx: usize,
pub address: String,
@ -128,13 +126,13 @@ pub enum Status<T, E> {
PermanentFailure(E),
}
#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct HostResponse<T> {
pub hostname: T,
pub response: Response<String>,
}
#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum Error {
DnsError(String),
UnexpectedResponse(HostResponse<ErrorDetails>),
@ -147,7 +145,7 @@ pub enum Error {
Io(String),
}
#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ErrorDetails {
pub entity: String,
pub details: String,
@ -156,32 +154,10 @@ pub struct ErrorDetails {
pub struct DeliveryAttempt {
pub span: tracing::Span,
pub in_flight: Vec<InFlight>,
pub message: Box<Message>,
pub message: Message,
pub event: QueueEvent,
}
#[derive(Debug)]
pub struct QuotaLimiter {
pub max_size: usize,
pub max_messages: usize,
pub size: AtomicUsize,
pub messages: AtomicUsize,
}
#[derive(Debug)]
pub struct UsedQuota {
id: u64,
size: usize,
limiter: Arc<QuotaLimiter>,
}
impl PartialEq for UsedQuota {
fn eq(&self, other: &Self) -> bool {
self.id == other.id && self.size == other.size
}
}
impl Eq for UsedQuota {}
impl<T> Ord for Schedule<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.due.cmp(&self.due)
@ -205,14 +181,14 @@ impl<T> Eq for Schedule<T> {}
impl<T: Default> Schedule<T> {
pub fn now() -> Self {
Schedule {
due: Instant::now(),
due: now(),
inner: T::default(),
}
}
pub fn later(duration: Duration) -> Self {
Schedule {
due: Instant::now() + duration,
due: now() + duration.as_secs(),
inner: T::default(),
}
}

View file

@ -21,25 +21,26 @@
* for more details.
*/
use std::sync::{atomic::Ordering, Arc};
use dashmap::mapref::entry::Entry;
use store::{
write::{BatchBuilder, QueueClass, ValueClass},
ValueKey,
};
use crate::{
config::QueueQuota,
core::{ResolveVariable, SMTP},
};
use super::{Message, QuotaLimiter, SimpleEnvelope, Status, UsedQuota};
use super::{Message, QuotaKey, SimpleEnvelope, Status};
impl SMTP {
pub async fn has_quota(&self, message: &mut Message) -> bool {
let mut queue_refs = Vec::new();
let mut quota_keys = Vec::new();
if !self.queue.config.quota.sender.is_empty() {
for quota in &self.queue.config.quota.sender {
if !self
.reserve_quota(quota, message, message.size, 0, &mut queue_refs)
.check_quota(quota, message, message.size, 0, &mut quota_keys)
.await
{
return false;
@ -50,12 +51,12 @@ impl SMTP {
for quota in &self.queue.config.quota.rcpt_domain {
for (pos, domain) in message.domains.iter().enumerate() {
if !self
.reserve_quota(
.check_quota(
quota,
&SimpleEnvelope::new(message, &domain.domain),
message.size,
((pos + 1) << 32) as u64,
&mut queue_refs,
&mut quota_keys,
)
.await
{
@ -67,7 +68,7 @@ impl SMTP {
for quota in &self.queue.config.quota.rcpt {
for (pos, rcpt) in message.recipients.iter().enumerate() {
if !self
.reserve_quota(
.check_quota(
quota,
&SimpleEnvelope::new_rcpt(
message,
@ -76,7 +77,7 @@ impl SMTP {
),
message.size,
(pos + 1) as u64,
&mut queue_refs,
&mut quota_keys,
)
.await
{
@ -85,48 +86,66 @@ impl SMTP {
}
}
message.queue_refs = queue_refs;
message.quota_keys = quota_keys;
true
}
async fn reserve_quota(
async fn check_quota(
&self,
quota: &QueueQuota,
envelope: &impl ResolveVariable,
size: usize,
id: u64,
refs: &mut Vec<UsedQuota>,
refs: &mut Vec<QuotaKey>,
) -> bool {
if !quota.expr.is_empty()
&& self
.eval_expr(&quota.expr, envelope, "reserve_quota")
.eval_expr(&quota.expr, envelope, "check_quota")
.await
.unwrap_or(false)
{
match self.queue.quota.entry(quota.new_key(envelope)) {
Entry::Occupied(e) => {
if let Some(qref) = e.get().is_allowed(id, size) {
refs.push(qref);
} else {
let key = quota.new_key(envelope);
if let Some(max_size) = quota.size {
if self
.shared
.default_data_store
.get_counter(ValueKey::from(ValueClass::Queue(QueueClass::QuotaSize(
key.as_ref().to_vec(),
))))
.await
.unwrap_or(0) as usize
+ size
> max_size
{
return false;
}
}
Entry::Vacant(e) => {
let limiter = Arc::new(QuotaLimiter {
max_size: quota.size.unwrap_or(0),
max_messages: quota.messages.unwrap_or(0),
size: 0.into(),
messages: 0.into(),
} else {
refs.push(QuotaKey::Size {
key: key.as_ref().to_vec(),
id,
});
if let Some(qref) = limiter.is_allowed(id, size) {
refs.push(qref);
e.insert(limiter);
} else {
return false;
}
}
if let Some(max_messages) = quota.messages {
if self
.shared
.default_data_store
.get_counter(ValueKey::from(ValueClass::Queue(QueueClass::QuotaCount(
key.as_ref().to_vec(),
))))
.await
.unwrap_or(0) as usize
+ 1
> max_messages
{
return false;
} else {
refs.push(QuotaKey::Count {
key: key.as_ref().to_vec(),
id,
});
}
}
}
true
@ -134,7 +153,10 @@ impl SMTP {
}
impl Message {
pub fn release_quota(&mut self) {
pub fn release_quota(&mut self, batch: &mut BatchBuilder) {
if self.quota_keys.is_empty() {
return;
}
let mut quota_ids = Vec::with_capacity(self.domains.len() + self.recipients.len());
for (pos, domain) in self.domains.iter().enumerate() {
if matches!(
@ -153,48 +175,21 @@ impl Message {
}
}
if !quota_ids.is_empty() {
self.queue_refs.retain(|q| !quota_ids.contains(&q.id));
}
}
}
trait QuotaLimiterAllowed {
fn is_allowed(&self, id: u64, size: usize) -> Option<UsedQuota>;
}
impl QuotaLimiterAllowed for Arc<QuotaLimiter> {
fn is_allowed(&self, id: u64, size: usize) -> Option<UsedQuota> {
if self.max_messages > 0 {
if self.messages.load(Ordering::Relaxed) < self.max_messages {
self.messages.fetch_add(1, Ordering::Relaxed);
} else {
return None;
}
}
if self.max_size > 0 {
if self.size.load(Ordering::Relaxed) + size < self.max_size {
self.size.fetch_add(size, Ordering::Relaxed);
} else {
return None;
}
}
Some(UsedQuota {
id,
size,
limiter: self.clone(),
})
}
}
impl Drop for UsedQuota {
fn drop(&mut self) {
if self.limiter.max_messages > 0 {
self.limiter.messages.fetch_sub(1, Ordering::Relaxed);
}
if self.limiter.max_size > 0 {
self.limiter.size.fetch_sub(self.size, Ordering::Relaxed);
let mut quota_keys = Vec::new();
for quota_key in std::mem::take(&mut self.quota_keys) {
match quota_key {
QuotaKey::Count { id, key } if quota_ids.contains(&id) => {
batch.clear(ValueClass::Queue(QueueClass::QuotaCount(key)));
}
QuotaKey::Size { id, key } if quota_ids.contains(&id) => {
batch.clear(ValueClass::Queue(QueueClass::QuotaSize(key)));
}
_ => {
quota_keys.push(quota_key);
}
}
}
self.quota_keys = quota_keys;
}
}
}

View file

@ -1,565 +0,0 @@
/*
* Copyright (c) 2023 Stalwart Labs Ltd.
*
* This file is part of Stalwart Mail Server.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of
* the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
* in the LICENSE file at the top-level directory of this distribution.
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
* You can be released from the requirements of the AGPLv3 license by
* purchasing a commercial license. Please contact licensing@stalw.art
* for more details.
*/
use mail_auth::common::base32::Base32Reader;
use smtp_proto::Response;
use std::io::SeekFrom;
use std::path::PathBuf;
use std::slice::Iter;
use std::{fmt::Write, time::Instant};
use tokio::fs;
use tokio::fs::File;
use tokio::io::{AsyncReadExt, AsyncSeekExt};
use super::{
instant_to_timestamp, Domain, DomainPart, Error, ErrorDetails, HostResponse,
InstantFromTimestamp, Message, Recipient, Schedule, Status, RCPT_STATUS_CHANGED,
};
pub trait QueueSerializer: Sized {
fn serialize(&self, buf: &mut String);
fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self>;
}
impl Message {
pub fn serialize(&self) -> Vec<u8> {
let mut buf = String::with_capacity(
self.return_path.len()
+ self.env_id.as_ref().map_or(0, |e| e.len())
+ (self.domains.len() * 64)
+ (self.recipients.len() * 64)
+ 50,
);
// Serialize message properties
(self.created as usize).serialize(&mut buf);
self.return_path.serialize(&mut buf);
(self.env_id.as_deref().unwrap_or_default()).serialize(&mut buf);
(self.flags as usize).serialize(&mut buf);
self.priority.serialize(&mut buf);
// Serialize domains
let now = Instant::now();
self.domains.len().serialize(&mut buf);
for domain in &self.domains {
domain.domain.serialize(&mut buf);
(instant_to_timestamp(now, domain.expires) as usize).serialize(&mut buf);
}
// Serialize recipients
self.recipients.len().serialize(&mut buf);
for rcpt in &self.recipients {
rcpt.domain_idx.serialize(&mut buf);
rcpt.address.serialize(&mut buf);
(rcpt.orcpt.as_deref().unwrap_or_default()).serialize(&mut buf);
}
// Serialize domain status
for (idx, domain) in self.domains.iter().enumerate() {
domain.serialize(idx, now, &mut buf);
}
// Serialize recipient status
for (idx, rcpt) in self.recipients.iter().enumerate() {
rcpt.serialize(idx, &mut buf);
}
buf.into_bytes()
}
pub fn serialize_changes(&mut self) -> Vec<u8> {
let now = Instant::now();
let mut buf = String::with_capacity(128);
for (idx, domain) in self.domains.iter_mut().enumerate() {
if domain.changed {
domain.changed = false;
domain.serialize(idx, now, &mut buf);
}
}
for (idx, rcpt) in self.recipients.iter_mut().enumerate() {
if rcpt.has_flag(RCPT_STATUS_CHANGED) {
rcpt.flags &= !RCPT_STATUS_CHANGED;
rcpt.serialize(idx, &mut buf);
}
}
buf.into_bytes()
}
pub async fn from_path(path: PathBuf) -> Result<Self, String> {
let filename = path
.file_name()
.and_then(|f| f.to_str())
.and_then(|f| f.rsplit_once('.'))
.map(|(f, _)| f)
.ok_or_else(|| format!("Invalid queue file name {}", path.display()))?;
// Decode file name
let mut id = [0u8; std::mem::size_of::<u64>()];
let mut size = [0u8; std::mem::size_of::<u32>()];
for (pos, byte) in Base32Reader::new(filename.as_bytes()).enumerate() {
match pos {
0..=7 => {
id[pos] = byte;
}
8..=11 => {
size[pos - 8] = byte;
}
_ => {
return Err(format!("Invalid queue file name {}", path.display()));
}
}
}
let id = u64::from_le_bytes(id);
let size = u32::from_le_bytes(size) as u64;
// Obtail file size
let file_size = fs::metadata(&path)
.await
.map_err(|err| {
format!(
"Failed to obtain file metadata for {}: {}",
path.display(),
err
)
})?
.len();
if size == 0 || size >= file_size {
return Err(format!(
"Invalid queue file name size {} for {}",
size,
path.display()
));
}
let mut buf = Vec::with_capacity((file_size - size) as usize);
let mut file = File::open(&path)
.await
.map_err(|err| format!("Failed to open queue file {}: {}", path.display(), err))?;
file.seek(SeekFrom::Start(size))
.await
.map_err(|err| format!("Failed to seek queue file {}: {}", path.display(), err))?;
file.read_to_end(&mut buf)
.await
.map_err(|err| format!("Failed to read queue file {}: {}", path.display(), err))?;
let mut message = Self::deserialize(&buf)
.ok_or_else(|| format!("Failed to deserialize metadata for file {}", path.display()))?;
message.path = path;
message.size = size as usize;
message.id = id;
Ok(message)
}
pub fn deserialize(bytes: &[u8]) -> Option<Self> {
let mut bytes = bytes.iter();
let created = usize::deserialize(&mut bytes)? as u64;
let return_path = String::deserialize(&mut bytes)?;
let return_path_lcase = return_path.to_lowercase();
let env_id = String::deserialize(&mut bytes)?;
let mut message = Message {
id: 0,
path: PathBuf::new(),
created,
return_path_domain: return_path_lcase.domain_part().to_string(),
return_path_lcase,
return_path,
env_id: if !env_id.is_empty() {
env_id.into()
} else {
None
},
flags: usize::deserialize(&mut bytes)? as u64,
priority: i16::deserialize(&mut bytes)?,
size: 0,
recipients: vec![],
domains: vec![],
queue_refs: vec![],
};
// Deserialize domains
let num_domains = usize::deserialize(&mut bytes)?;
message.domains = Vec::with_capacity(num_domains);
for _ in 0..num_domains {
message.domains.push(Domain {
domain: String::deserialize(&mut bytes)?,
expires: Instant::deserialize(&mut bytes)?,
retry: Schedule::now(),
notify: Schedule::now(),
status: Status::Scheduled,
disable_tls: false,
changed: false,
});
}
// Deserialize recipients
let num_recipients = usize::deserialize(&mut bytes)?;
message.recipients = Vec::with_capacity(num_recipients);
for _ in 0..num_recipients {
let domain_idx = usize::deserialize(&mut bytes)?;
let address = String::deserialize(&mut bytes)?;
let orcpt = String::deserialize(&mut bytes)?;
message.recipients.push(Recipient {
domain_idx,
address_lcase: address.to_lowercase(),
address,
status: Status::Scheduled,
flags: 0,
orcpt: if !orcpt.is_empty() {
orcpt.into()
} else {
None
},
});
}
// Deserialize status
while let Some((ch, idx)) = bytes
.next()
.and_then(|ch| (ch, usize::deserialize(&mut bytes)?).into())
{
match ch {
b'D' => {
if let (Some(domain), Some(retry), Some(notify), Some(status)) = (
message.domains.get_mut(idx),
Schedule::deserialize(&mut bytes),
Schedule::deserialize(&mut bytes),
Status::deserialize(&mut bytes),
) {
domain.retry = retry;
domain.notify = notify;
domain.status = status;
} else {
break;
}
}
b'R' => {
if let (Some(rcpt), Some(flags), Some(status)) = (
message.recipients.get_mut(idx),
usize::deserialize(&mut bytes),
Status::deserialize(&mut bytes),
) {
rcpt.flags = flags as u64;
rcpt.status = status;
} else {
break;
}
}
_ => break,
}
}
message.into()
}
}
impl<T: QueueSerializer, E: QueueSerializer> QueueSerializer for Status<T, E> {
fn serialize(&self, buf: &mut String) {
match self {
Status::Scheduled => buf.push('S'),
Status::Completed(s) => {
buf.push('C');
s.serialize(buf);
}
Status::TemporaryFailure(s) => {
buf.push('T');
s.serialize(buf);
}
Status::PermanentFailure(s) => {
buf.push('F');
s.serialize(buf);
}
}
}
fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> {
match bytes.next()? {
b'S' => Self::Scheduled.into(),
b'C' => Self::Completed(T::deserialize(bytes)?).into(),
b'T' => Self::TemporaryFailure(E::deserialize(bytes)?).into(),
b'F' => Self::PermanentFailure(E::deserialize(bytes)?).into(),
_ => None,
}
}
}
impl QueueSerializer for Response<String> {
fn serialize(&self, buf: &mut String) {
let _ = write!(
buf,
"{} {} {} {} {} {}",
self.code,
self.esc[0],
self.esc[1],
self.esc[2],
self.message.len(),
self.message
);
}
fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> {
Response {
code: usize::deserialize(bytes)? as u16,
esc: [
usize::deserialize(bytes)? as u8,
usize::deserialize(bytes)? as u8,
usize::deserialize(bytes)? as u8,
],
message: String::deserialize(bytes)?,
}
.into()
}
}
impl QueueSerializer for usize {
fn serialize(&self, buf: &mut String) {
let _ = write!(buf, "{self} ");
}
fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> {
let mut num = 0;
loop {
match bytes.next()? {
ch @ (b'0'..=b'9') => {
num = (num * 10) + (*ch - b'0') as usize;
}
b' ' => {
return num.into();
}
_ => {
return None;
}
}
}
}
}
impl QueueSerializer for i16 {
fn serialize(&self, buf: &mut String) {
let _ = write!(buf, "{self} ");
}
fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> {
let mut num = 0;
let mut mul = 1;
loop {
match bytes.next()? {
ch @ (b'0'..=b'9') => {
num = (num * 10) + (*ch - b'0') as i16;
}
b' ' => {
return (num * mul).into();
}
b'-' => {
mul = -1;
}
_ => {
return None;
}
}
}
}
}
impl QueueSerializer for ErrorDetails {
fn serialize(&self, buf: &mut String) {
self.entity.serialize(buf);
self.details.serialize(buf);
}
fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> {
ErrorDetails {
entity: String::deserialize(bytes)?,
details: String::deserialize(bytes)?,
}
.into()
}
}
impl<T: QueueSerializer> QueueSerializer for HostResponse<T> {
fn serialize(&self, buf: &mut String) {
self.hostname.serialize(buf);
self.response.serialize(buf);
}
fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> {
HostResponse {
hostname: T::deserialize(bytes)?,
response: Response::deserialize(bytes)?,
}
.into()
}
}
impl QueueSerializer for String {
fn serialize(&self, buf: &mut String) {
if !self.is_empty() {
let _ = write!(buf, "{} {}", self.len(), self);
} else {
buf.push_str("0 ");
}
}
fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> {
match usize::deserialize(bytes)? {
len @ (1..=4096) => {
String::from_utf8(bytes.take(len).copied().collect::<Vec<_>>()).ok()
}
0 => String::new().into(),
_ => None,
}
}
}
impl QueueSerializer for &str {
fn serialize(&self, buf: &mut String) {
if !self.is_empty() {
let _ = write!(buf, "{} {}", self.len(), self);
} else {
buf.push_str("0 ");
}
}
fn deserialize(_bytes: &mut Iter<'_, u8>) -> Option<Self> {
unimplemented!()
}
}
impl QueueSerializer for Instant {
fn serialize(&self, buf: &mut String) {
let _ = write!(buf, "{} ", instant_to_timestamp(Instant::now(), *self),);
}
fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> {
(usize::deserialize(bytes)? as u64).to_instant().into()
}
}
impl QueueSerializer for Schedule<u32> {
fn serialize(&self, buf: &mut String) {
let _ = write!(
buf,
"{} {} ",
self.inner,
instant_to_timestamp(Instant::now(), self.due),
);
}
fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> {
Schedule {
inner: usize::deserialize(bytes)? as u32,
due: Instant::deserialize(bytes)?,
}
.into()
}
}
impl QueueSerializer for Error {
fn serialize(&self, buf: &mut String) {
match self {
Error::DnsError(e) => {
buf.push('0');
e.serialize(buf);
}
Error::UnexpectedResponse(e) => {
buf.push('1');
e.serialize(buf);
}
Error::ConnectionError(e) => {
buf.push('2');
e.serialize(buf);
}
Error::TlsError(e) => {
buf.push('3');
e.serialize(buf);
}
Error::DaneError(e) => {
buf.push('4');
e.serialize(buf);
}
Error::MtaStsError(e) => {
buf.push('5');
e.serialize(buf);
}
Error::RateLimited => {
buf.push('6');
}
Error::ConcurrencyLimited => {
buf.push('7');
}
Error::Io(e) => {
buf.push('8');
e.serialize(buf);
}
}
}
fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> {
match bytes.next()? {
b'0' => Error::DnsError(String::deserialize(bytes)?).into(),
b'1' => Error::UnexpectedResponse(HostResponse::deserialize(bytes)?).into(),
b'2' => Error::ConnectionError(ErrorDetails::deserialize(bytes)?).into(),
b'3' => Error::TlsError(ErrorDetails::deserialize(bytes)?).into(),
b'4' => Error::DaneError(ErrorDetails::deserialize(bytes)?).into(),
b'5' => Error::MtaStsError(String::deserialize(bytes)?).into(),
b'6' => Error::RateLimited.into(),
b'7' => Error::ConcurrencyLimited.into(),
b'8' => Error::Io(String::deserialize(bytes)?).into(),
_ => None,
}
}
}
impl QueueSerializer for () {
fn serialize(&self, _buf: &mut String) {}
fn deserialize(_bytes: &mut Iter<'_, u8>) -> Option<Self> {
Some(())
}
}
impl Domain {
fn serialize(&self, idx: usize, now: Instant, buf: &mut String) {
let _ = write!(
buf,
"D{} {} {} {} {} ",
idx,
self.retry.inner,
instant_to_timestamp(now, self.retry.due),
self.notify.inner,
instant_to_timestamp(now, self.notify.due)
);
self.status.serialize(buf);
}
}
impl Recipient {
fn serialize(&self, idx: usize, buf: &mut String) {
let _ = write!(buf, "R{} {} ", idx, self.flags);
self.status.serialize(buf);
}
}

View file

@ -22,163 +22,32 @@
*/
use crate::queue::DomainPart;
use mail_auth::common::base32::Base32Writer;
use mail_auth::common::headers::Writer;
use std::path::PathBuf;
use std::sync::atomic::Ordering;
use std::time::Instant;
use std::borrow::Cow;
use std::time::{Duration, SystemTime};
use tokio::fs::OpenOptions;
use tokio::{fs, io::AsyncWriteExt};
use store::write::key::DeserializeBigEndian;
use store::write::{now, BatchBuilder, Bincode, BlobOp, QueueClass, QueueEvent, ValueClass};
use store::{IterateParams, Serialize, ValueKey, U64_LEN};
use utils::BlobHash;
use crate::core::{QueueCore, SMTP};
use super::{Domain, Event, Message, Recipient, Schedule, SimpleEnvelope, Status};
use super::{
Domain, Event, Message, QueueId, QuotaKey, Recipient, Schedule, SimpleEnvelope, Status,
};
impl QueueCore {
pub async fn queue_message(
pub fn new_message(
&self,
mut message: Box<Message>,
raw_headers: Option<&[u8]>,
raw_message: &[u8],
span: &tracing::Span,
) -> bool {
// Generate id
if message.id == 0 {
message.id = self.queue_id();
}
if message.size == 0 {
message.size = raw_message.len() + raw_headers.as_ref().map_or(0, |h| h.len());
}
// Build path
let todo = 1;
message.path = self.config.path.clone();
let hash = 1;
if hash > 0 {
message.path.push((message.id % hash).to_string());
}
let _ = fs::create_dir(&message.path).await;
// Encode file name
let mut encoder = Base32Writer::with_capacity(20);
encoder.write(&message.id.to_le_bytes()[..]);
encoder.write(&(message.size as u32).to_le_bytes()[..]);
let mut file = encoder.finalize();
file.push_str(".msg");
message.path.push(file);
// Serialize metadata
let metadata = message.serialize();
// Save message
let mut file = match fs::File::create(&message.path).await {
Ok(file) => file,
Err(err) => {
tracing::error!(
parent: span,
context = "queue",
event = "error",
"Failed to create file {}: {}",
message.path.display(),
err
);
return false;
}
};
let iter = if let Some(raw_headers) = raw_headers {
[raw_headers, raw_message, &metadata].into_iter()
} else {
[raw_message, &metadata, b""].into_iter()
};
for bytes in iter {
if !bytes.is_empty() {
if let Err(err) = file.write_all(bytes).await {
tracing::error!(
parent: span,
context = "queue",
event = "error",
"Failed to write to file {}: {}",
message.path.display(),
err
);
return false;
}
}
}
if let Err(err) = file.flush().await {
tracing::error!(
parent: span,
context = "queue",
event = "error",
"Failed to flush file {}: {}",
message.path.display(),
err
);
return false;
}
tracing::info!(
parent: span,
context = "queue",
event = "scheduled",
id = message.id,
from = if !message.return_path.is_empty() {
message.return_path.as_str()
} else {
"<>"
},
nrcpts = message.recipients.len(),
size = message.size,
"Message queued for delivery."
);
// Queue the message
if self
.tx
.send(Event::Queue(Schedule {
due: message.next_event().unwrap(),
inner: message,
}))
.await
.is_err()
{
tracing::warn!(
parent: span,
context = "queue",
event = "error",
"Queue channel closed: Message queued but won't be sent until next restart."
);
}
true
}
pub fn queue_id(&self) -> u64 {
(SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map_or(0, |d| d.as_secs())
.saturating_sub(946684800)
& 0xFFFFFFFF)
| (self.id_seq.fetch_add(1, Ordering::Relaxed) as u64) << 32
}
}
impl Message {
pub fn new_boxed(
return_path: impl Into<String>,
return_path_lcase: impl Into<String>,
return_path_domain: impl Into<String>,
) -> Box<Message> {
Box::new(Message {
id: 0,
path: PathBuf::new(),
created: SystemTime::now()
) -> Message {
let created = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
.map_or(0, |d| d.as_secs());
Message {
id: self.snowflake_id.generate().unwrap_or(created),
created,
return_path: return_path.into(),
return_path_lcase: return_path_lcase.into(),
return_path_domain: return_path_domain.into(),
@ -188,8 +57,189 @@ impl Message {
env_id: None,
priority: 0,
size: 0,
queue_refs: vec![],
})
blob_hash: Default::default(),
quota_keys: Vec::new(),
}
}
}
impl SMTP {
pub async fn next_event(&self) -> Vec<QueueEvent> {
let from_key = ValueKey::from(ValueClass::Queue(QueueClass::MessageEvent(QueueEvent {
due: 0,
queue_id: 0,
})));
let to_key = ValueKey::from(ValueClass::Queue(QueueClass::MessageEvent(QueueEvent {
due: u64::MAX,
queue_id: u64::MAX,
})));
let mut events = Vec::new();
let now = now();
let result = self
.shared
.default_data_store
.iterate(
IterateParams::new(from_key, to_key).ascending().no_values(),
|key, _| {
let event = QueueEvent {
due: key.deserialize_be_u64(1)?,
queue_id: key.deserialize_be_u64(U64_LEN + 1)?,
};
let do_continue = event.due <= now;
events.push(event);
Ok(do_continue)
},
)
.await;
if let Err(err) = result {
tracing::error!(
context = "queue",
event = "error",
"Failed to read from store: {}",
err
);
}
events
}
pub async fn read_message(&self, id: QueueId) -> Option<Message> {
match self
.shared
.default_data_store
.get_value::<Bincode<Message>>(ValueKey::from(ValueClass::Queue(QueueClass::Message(
id,
))))
.await
{
Ok(Some(message)) => Some(message.inner),
Ok(None) => None,
Err(err) => {
tracing::error!(
context = "queue",
event = "error",
"Failed to read message from store: {}",
err
);
None
}
}
}
}
impl Message {
pub async fn queue(
mut self,
raw_headers: Option<&[u8]>,
raw_message: &[u8],
core: &SMTP,
span: &tracing::Span,
) -> bool {
// Write blob
let message = if let Some(raw_headers) = raw_headers {
let mut message = Vec::with_capacity(raw_headers.len() + raw_message.len());
message.extend_from_slice(raw_headers);
message.extend_from_slice(raw_message);
Cow::Owned(message)
} else {
raw_message.into()
};
self.blob_hash = BlobHash::from(message.as_ref());
// Generate id
if self.size == 0 {
self.size = message.len();
}
// Reserve and write blob
let mut batch = BatchBuilder::new();
batch.with_account_id(u32::MAX).set(
BlobOp::Reserve {
hash: self.blob_hash.clone(),
until: self.next_delivery_event() + 3600,
},
0u32.serialize(),
);
if let Err(err) = core.shared.default_data_store.write(batch.build()).await {
tracing::error!(
parent: span,
context = "queue",
event = "error",
"Failed to write to data store: {}",
err
);
return false;
}
if let Err(err) = core
.shared
.default_blob_store
.put_blob(self.blob_hash.as_slice(), message.as_ref())
.await
{
tracing::error!(
parent: span,
context = "queue",
event = "error",
"Failed to write to blob store: {}",
err
);
return false;
}
tracing::info!(
parent: span,
context = "queue",
event = "scheduled",
id = self.id,
from = if !self.return_path.is_empty() {
self.return_path.as_str()
} else {
"<>"
},
nrcpts = self.recipients.len(),
size = self.size,
"Message queued for delivery."
);
// Write message to queue
let mut batch = BatchBuilder::new();
batch
.set(
ValueClass::Queue(QueueClass::MessageEvent(QueueEvent {
due: self.next_event().unwrap_or_default(),
queue_id: self.id,
})),
vec![],
)
.set(
ValueClass::Queue(QueueClass::Message(self.id)),
Bincode::new(self).serialize(),
);
if let Err(err) = core.shared.default_data_store.write(batch.build()).await {
tracing::error!(
parent: span,
context = "queue",
event = "error",
"Failed to write to store: {}",
err
);
return false;
}
// Queue the message
if core.queue.tx.send(Event::Reload).await.is_err() {
tracing::warn!(
parent: span,
context = "queue",
event = "error",
"Queue channel closed: Message queued but won't be sent until next restart."
);
}
true
}
pub async fn add_recipient_parts(
@ -216,10 +266,9 @@ impl Message {
domain: rcpt_domain,
retry: Schedule::now(),
notify: Schedule::later(expires + Duration::from_secs(10)),
expires: Instant::now() + expires,
expires: now() + expires.as_secs(),
status: Status::Scheduled,
disable_tls: false,
changed: false,
});
idx
};
@ -241,35 +290,95 @@ impl Message {
.await;
}
pub async fn save_changes(&mut self) {
let buf = self.serialize_changes();
if !buf.is_empty() {
let err = match OpenOptions::new().append(true).open(&self.path).await {
Ok(mut file) => match file.write_all(&buf).await {
Ok(_) => return,
Err(err) => err,
pub async fn save_changes(
mut self,
core: &SMTP,
prev_event: Option<u64>,
next_event: Option<u64>,
) -> bool {
debug_assert!(prev_event.is_some() == next_event.is_some());
let mut batch = BatchBuilder::new();
// Release quota for completed deliveries
self.release_quota(&mut batch);
// Update message queue
let mut batch = BatchBuilder::new();
if let Some(prev_event) = prev_event {
batch.clear(ValueClass::Queue(QueueClass::MessageEvent(QueueEvent {
due: prev_event,
queue_id: self.id,
})));
}
if let Some(next_event) = next_event {
batch.set(
ValueClass::Queue(QueueClass::MessageEvent(QueueEvent {
due: next_event,
queue_id: self.id,
})),
vec![],
);
}
batch
.with_account_id(u32::MAX)
.set(
BlobOp::Reserve {
hash: self.blob_hash.clone(),
until: self.next_delivery_event() + 3600,
},
Err(err) => err,
};
0u32.serialize(),
)
.set(
ValueClass::Queue(QueueClass::Message(self.id)),
Bincode::new(self).serialize(),
);
if let Err(err) = core.shared.default_data_store.write(batch.build()).await {
tracing::error!(
context = "queue",
event = "error",
"Failed to write to {}: {}",
self.path.display(),
"Failed to update queued message: {}",
err
);
false
} else {
true
}
}
pub async fn remove(&self) {
if let Err(err) = fs::remove_file(&self.path).await {
pub async fn remove(self, core: &SMTP, prev_event: u64) -> bool {
let mut batch = BatchBuilder::new();
// Release all quotas
for quota_key in self.quota_keys {
match quota_key {
QuotaKey::Count { key, .. } => {
batch.clear(ValueClass::Queue(QueueClass::QuotaCount(key)));
}
QuotaKey::Size { key, .. } => {
batch.clear(ValueClass::Queue(QueueClass::QuotaSize(key)));
}
}
}
batch
.clear(ValueClass::Queue(QueueClass::MessageEvent(QueueEvent {
due: prev_event,
queue_id: self.id,
})))
.clear(ValueClass::Queue(QueueClass::Message(self.id)));
if let Err(err) = core.shared.default_data_store.write(batch.build()).await {
tracing::error!(
context = "queue",
event = "error",
"Failed to delete queued message {}: {}",
self.path.display(),
"Failed to update queued message: {}",
err
);
false
} else {
true
}
}
}

View file

@ -21,14 +21,13 @@
* for more details.
*/
use std::time::{Duration, Instant};
use dashmap::mapref::entry::Entry;
use utils::listener::limiter::{ConcurrencyLimiter, InFlight, RateLimiter};
use store::write::now;
use utils::listener::limiter::{ConcurrencyLimiter, InFlight};
use crate::{
config::Throttle,
core::{throttle::Limiter, ResolveVariable, SMTP},
core::{ResolveVariable, SMTP},
};
use super::{Domain, Status};
@ -36,7 +35,7 @@ use super::{Domain, Status};
#[derive(Debug)]
pub enum Error {
Concurrency { limiter: ConcurrencyLimiter },
Rate { retry_at: Instant },
Rate { retry_at: u64 },
}
impl SMTP {
@ -53,10 +52,33 @@ impl SMTP {
.await
.unwrap_or(false)
{
match self.queue.throttle.entry(throttle.new_key(envelope)) {
let key = throttle.new_key(envelope);
if let Some(rate) = &throttle.rate {
if let Ok(Some(next_refill)) = self
.shared
.default_lookup_store
.is_rate_allowed(key.as_ref(), rate, false)
.await
{
tracing::info!(
parent: span,
context = "throttle",
event = "rate-limit-exceeded",
max_requests = rate.requests,
max_interval = rate.period.as_secs(),
"Queue rate limit exceeded."
);
return Err(Error::Rate {
retry_at: now() + next_refill,
});
}
}
if let Some(concurrency) = &throttle.concurrency {
match self.queue.throttle.entry(key) {
Entry::Occupied(mut e) => {
let limiter = e.get_mut();
if let Some(limiter) = &limiter.concurrency {
if let Some(inflight) = limiter.is_allowed() {
in_flight.push(inflight);
} else {
@ -72,38 +94,13 @@ impl SMTP {
});
}
}
if let (Some(limiter), Some(rate)) = (&mut limiter.rate, &throttle.rate) {
if !limiter.is_allowed(rate) {
tracing::info!(
parent: span,
context = "throttle",
event = "rate-limit-exceeded",
max_requests = rate.requests,
max_interval = rate.period.as_secs(),
"Queue rate limit exceeded."
);
return Err(Error::Rate {
retry_at: Instant::now()
+ Duration::from_secs(limiter.secs_to_refill()),
});
}
}
}
Entry::Vacant(e) => {
let concurrency = throttle.concurrency.map(|concurrency| {
let limiter = ConcurrencyLimiter::new(concurrency);
let limiter = ConcurrencyLimiter::new(*concurrency);
if let Some(inflight) = limiter.is_allowed() {
in_flight.push(inflight);
}
limiter
});
let rate = throttle.rate.as_ref().map(|rate| {
let r = RateLimiter::new(rate);
r.is_allowed(rate);
r
});
e.insert(Limiter { rate, concurrency });
e.insert(limiter);
}
}
}
}
@ -124,6 +121,5 @@ impl Domain {
self.status = Status::TemporaryFailure(super::Error::RateLimited);
}
}
self.changed = true;
}
}

View file

@ -46,7 +46,7 @@ impl<T: AsyncWrite + AsyncRead + Unpin> Session<T> {
};
// Throttle recipient
if !self.throttle_rcpt(rcpt, rate, "dkim") {
if !self.throttle_rcpt(rcpt, rate, "dkim").await {
tracing::debug!(
parent: &self.span,
context = "report",

View file

@ -21,7 +21,7 @@
* for more details.
*/
use std::{collections::hash_map::Entry, path::PathBuf, sync::Arc};
use std::collections::hash_map::Entry;
use ahash::AHashMap;
use mail_auth::{
@ -31,28 +31,22 @@ use mail_auth::{
ArcOutput, AuthenticatedMessage, AuthenticationResults, DkimOutput, DkimResult, DmarcOutput,
SpfResult,
};
use serde::{Deserialize, Serialize};
use tokio::{
io::{AsyncRead, AsyncWrite},
runtime::Handle,
use store::{
write::{now, BatchBuilder, Bincode, QueueClass, ReportEvent, ValueClass},
Deserialize, IterateParams, Serialize, ValueKey,
};
use tokio::io::{AsyncRead, AsyncWrite};
use utils::config::Rate;
use crate::{
config::AggregateFrequency,
core::{Session, SMTP},
queue::{DomainPart, InstantFromTimestamp, RecipientDomain, Schedule},
queue::{DomainPart, RecipientDomain},
};
use super::{
scheduler::{
json_append, json_read_blocking, json_write, ReportPath, ReportPolicy, ReportType,
Scheduler, ToHash,
},
DmarcEvent,
};
use super::{scheduler::ToHash, DmarcEvent, SerializedSize};
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct DmarcFormat {
pub rua: Vec<URI>,
pub policy: PolicyPublished,
@ -88,16 +82,15 @@ impl<T: AsyncWrite + AsyncRead + Unpin> Session<T> {
{
Some(rcpts) => {
if !rcpts.is_empty() {
rcpts
.into_iter()
.filter_map(|rcpt| {
if self.throttle_rcpt(rcpt.uri(), &failure_rate, "dmarc") {
rcpt.uri().into()
} else {
None
let mut new_rcpts = Vec::with_capacity(rcpts.len());
for rcpt in rcpts {
if self.throttle_rcpt(rcpt.uri(), &failure_rate, "dmarc").await {
new_rcpts.push(rcpt.uri());
}
})
.collect()
}
new_rcpts
} else {
if !dmarc_record.ruf().is_empty() {
tracing::debug!(
@ -306,38 +299,51 @@ impl<T: AsyncWrite + AsyncRead + Unpin> Session<T> {
}
}
pub trait GenerateDmarcReport {
fn generate_dmarc_report(&self, domain: ReportPolicy<String>, path: ReportPath<PathBuf>);
}
impl GenerateDmarcReport for Arc<SMTP> {
fn generate_dmarc_report(&self, domain: ReportPolicy<String>, path: ReportPath<PathBuf>) {
let core = self.clone();
let handle = Handle::current();
self.worker_pool.spawn(move || {
let deliver_at = path.created + path.deliver_at.as_secs();
impl SMTP {
pub async fn generate_dmarc_report(&self, event: ReportEvent) {
let span = tracing::info_span!(
"dmarc-report",
domain = domain.inner,
range_from = path.created,
range_to = deliver_at,
size = path.size,
domain = event.domain,
range_from = event.seq_id,
range_to = event.due,
);
// Deserialize report
let dmarc = if let Some(dmarc) = json_read_blocking::<DmarcFormat>(&path.path, &span) {
dmarc
} else {
let dmarc = match self
.shared
.default_data_store
.get_value::<Bincode<DmarcFormat>>(ValueKey::from(ValueClass::Queue(
QueueClass::DmarcReportHeader(event.clone()),
)))
.await
{
Ok(Some(dmarc)) => dmarc.inner,
Ok(None) => {
tracing::warn!(
parent: &span,
event = "missing",
"Failed to read DMARC report: Report not found"
);
return;
}
Err(err) => {
tracing::warn!(
parent: &span,
event = "error",
"Failed to read DMARC report: {}",
err
);
return;
}
};
// Verify external reporting addresses
let rua = match handle.block_on(
core.resolvers
let rua = match self
.resolvers
.dns
.verify_dmarc_report_address(&domain.inner, &dmarc.rua),
) {
.verify_dmarc_report_address(&event.domain, &dmarc.rua)
.await
{
Some(rcpts) => {
if !rcpts.is_empty() {
rcpts
@ -352,7 +358,7 @@ impl GenerateDmarcReport for Arc<SMTP> {
rua = ?dmarc.rua,
"Unauthorized external reporting addresses"
);
let _ = std::fs::remove_file(&path.path);
self.delete_dmarc_report(event).await;
return;
}
}
@ -364,76 +370,124 @@ impl GenerateDmarcReport for Arc<SMTP> {
rua = ?dmarc.rua,
"Failed to validate external report addresses",
);
let _ = std::fs::remove_file(&path.path);
self.delete_dmarc_report(event).await;
return;
}
};
let config = &core.report.config.dmarc_aggregate;
let mut serialized_size = serde_json::Serializer::new(SerializedSize::new(
self.eval_if(
&self.report.config.dmarc_aggregate.max_size,
&RecipientDomain::new(event.domain.as_str()),
)
.await
.unwrap_or(25 * 1024 * 1024),
));
let _ = serde::Serialize::serialize(&dmarc, &mut serialized_size);
let config = &self.report.config.dmarc_aggregate;
// Group duplicates
let from_key = ValueKey::from(ValueClass::Queue(QueueClass::DmarcReportEvent(
ReportEvent {
due: event.due,
policy_hash: event.policy_hash,
seq_id: 0,
domain: event.domain.clone(),
},
)));
let to_key = ValueKey::from(ValueClass::Queue(QueueClass::DmarcReportEvent(
ReportEvent {
due: event.due,
policy_hash: event.policy_hash,
seq_id: u64::MAX,
domain: event.domain.clone(),
},
)));
let mut record_map = AHashMap::with_capacity(dmarc.records.len());
for record in dmarc.records {
match record_map.entry(record) {
if let Err(err) = self
.shared
.default_data_store
.iterate(
IterateParams::new(from_key, to_key).ascending(),
|_, v| match record_map.entry(Bincode::<Record>::deserialize(v)?.inner) {
Entry::Occupied(mut e) => {
*e.get_mut() += 1;
Ok(true)
}
Entry::Vacant(e) => {
if serde::Serialize::serialize(e.key(), &mut serialized_size).is_ok() {
e.insert(1u32);
Ok(true)
} else {
Ok(false)
}
}
},
)
.await
{
tracing::warn!(
parent: &span,
event = "error",
"Failed to read DMARC report: {}",
err
);
}
// Create report
let mut report = Report::new()
.with_policy_published(dmarc.policy)
.with_date_range_begin(path.created)
.with_date_range_end(deliver_at)
.with_report_id(format!("{}_{}", domain.policy, path.created))
.with_date_range_begin(event.seq_id)
.with_date_range_end(event.due)
.with_report_id(format!("{}_{}", event.policy_hash, event.seq_id))
.with_email(
handle
.block_on(core.eval_if(
self.eval_if(
&config.address,
&RecipientDomain::new(domain.inner.as_str()),
))
&RecipientDomain::new(event.domain.as_str()),
)
.await
.unwrap_or_else(|| "MAILER-DAEMON@localhost".to_string()),
);
if let Some(org_name) = handle.block_on(core.eval_if::<String, _>(
if let Some(org_name) = self
.eval_if::<String, _>(
&config.org_name,
&RecipientDomain::new(domain.inner.as_str()),
)) {
&RecipientDomain::new(event.domain.as_str()),
)
.await
{
report = report.with_org_name(org_name);
}
if let Some(contact_info) = handle.block_on(core.eval_if::<String, _>(
if let Some(contact_info) = self
.eval_if::<String, _>(
&config.contact_info,
&RecipientDomain::new(domain.inner.as_str()),
)) {
&RecipientDomain::new(event.domain.as_str()),
)
.await
{
report = report.with_extra_contact_info(contact_info);
}
for (record, count) in record_map {
report.add_record(record.with_count(count));
}
let from_addr = handle
.block_on(core.eval_if(
let from_addr = self
.eval_if(
&config.address,
&RecipientDomain::new(domain.inner.as_str()),
))
&RecipientDomain::new(event.domain.as_str()),
)
.await
.unwrap_or_else(|| "MAILER-DAEMON@localhost".to_string());
let mut message = Vec::with_capacity(path.size);
let _ =
report.write_rfc5322(
&handle
.block_on(core.eval_if(
&core.report.config.submitter,
&RecipientDomain::new(domain.inner.as_str()),
))
let mut message = Vec::with_capacity(2048);
let _ = report.write_rfc5322(
&self
.eval_if(
&self.report.config.submitter,
&RecipientDomain::new(event.domain.as_str()),
)
.await
.unwrap_or_else(|| "localhost".to_string()),
(
handle
.block_on(core.eval_if(
&config.name,
&RecipientDomain::new(domain.inner.as_str()),
))
self.eval_if(&config.name, &RecipientDomain::new(event.domain.as_str()))
.await
.unwrap_or_else(|| "Mail Delivery Subsystem".to_string())
.as_str(),
from_addr.as_str(),
@ -443,87 +497,109 @@ impl GenerateDmarcReport for Arc<SMTP> {
);
// Send report
handle.block_on(core.send_report(
&from_addr,
rua.iter(),
message,
&config.sign,
&span,
false,
));
self.send_report(&from_addr, rua.iter(), message, &config.sign, &span, false)
.await;
if let Err(err) = std::fs::remove_file(&path.path) {
self.delete_dmarc_report(event).await;
}
pub async fn delete_dmarc_report(&self, event: ReportEvent) {
let from_key = ReportEvent {
due: event.due,
policy_hash: event.policy_hash,
seq_id: 0,
domain: event.domain.clone(),
};
let to_key = ReportEvent {
due: event.due,
policy_hash: event.policy_hash,
seq_id: u64::MAX,
domain: event.domain.clone(),
};
if let Err(err) = self
.shared
.default_data_store
.delete_range(
ValueKey::from(ValueClass::Queue(QueueClass::DmarcReportEvent(from_key))),
ValueKey::from(ValueClass::Queue(QueueClass::DmarcReportEvent(to_key))),
)
.await
{
tracing::warn!(
context = "report",
event = "error",
"Failed to remove report file {}: {}",
path.path.display(),
"Failed to remove repors: {}",
err
);
return;
}
let mut batch = BatchBuilder::new();
batch.clear(ValueClass::Queue(QueueClass::DmarcReportHeader(event)));
if let Err(err) = self.shared.default_data_store.write(batch.build()).await {
tracing::warn!(
context = "report",
event = "error",
"Failed to remove repors: {}",
err
);
}
});
}
}
impl Scheduler {
pub async fn schedule_dmarc(&mut self, event: Box<DmarcEvent>, core: &SMTP) {
let max_size = core
.eval_if(
&core.report.config.dmarc_aggregate.max_size,
&RecipientDomain::new(event.domain.as_str()),
)
.await
.unwrap_or(25 * 1024 * 1024);
let policy = event.dmarc_record.to_hash();
let (create, path) = match self.reports.entry(ReportType::Dmarc(ReportPolicy {
inner: event.domain,
policy,
})) {
Entry::Occupied(e) => (None, e.into_mut().dmarc_path()),
Entry::Vacant(e) => {
let domain = e.key().domain_name().to_string();
pub async fn schedule_dmarc(&self, event: Box<DmarcEvent>) {
let created = event.interval.to_timestamp();
let deliver_at = created + event.interval.as_secs();
self.main.push(Schedule {
due: deliver_at.to_instant(),
inner: e.key().clone(),
});
let path = core
.build_report_path(ReportType::Dmarc(&domain), policy, created, event.interval)
.await;
let v = e.insert(ReportType::Dmarc(ReportPath {
path,
deliver_at: event.interval,
created,
size: 0,
}));
(domain.into(), v.dmarc_path())
}
let mut report_event = ReportEvent {
due: deliver_at,
policy_hash: event.dmarc_record.to_hash(),
seq_id: created,
domain: event.domain,
};
if let Some(domain) = create {
// Write policy if missing
let mut builder = BatchBuilder::new();
if self
.shared
.default_data_store
.get_value::<()>(ValueKey::from(ValueClass::Queue(
QueueClass::DmarcReportHeader(report_event.clone()),
)))
.await
.unwrap_or_default()
.is_none()
{
// Serialize report
let entry = DmarcFormat {
rua: event.dmarc_record.rua().to_vec(),
policy: PolicyPublished::from_record(domain, &event.dmarc_record),
records: vec![event.report_record],
policy: PolicyPublished::from_record(
report_event.domain.to_string(),
&event.dmarc_record,
),
records: vec![],
};
let bytes_written = json_write(&path.path, &entry).await;
if bytes_written > 0 {
path.size += bytes_written;
} else {
// Something went wrong, remove record
self.reports.remove(&ReportType::Dmarc(ReportPolicy {
inner: entry.policy.domain,
policy,
}));
// Write report
builder.set(
ValueClass::Queue(QueueClass::DmarcReportHeader(report_event.clone())),
Bincode::new(entry).serialize(),
);
}
} else if path.size < max_size {
// Append to existing report
path.size += json_append(&path.path, &event.report_record, max_size - path.size).await;
// Write entry
report_event.seq_id = self.queue.snowflake_id.generate().unwrap_or_else(now);
builder.set(
ValueClass::Queue(QueueClass::DmarcReportEvent(report_event)),
Bincode::new(event.report_record).serialize(),
);
if let Err(err) = self.shared.default_data_store.write(builder.build()).await {
tracing::error!(
context = "report",
event = "error",
"Failed to write DMARC report event: {}",
err
);
}
}
}

View file

@ -21,7 +21,7 @@
* for more details.
*/
use std::{sync::Arc, time::SystemTime};
use std::{io, sync::Arc, time::SystemTime};
use mail_auth::{
common::headers::HeaderWriter,
@ -37,15 +37,13 @@ use tokio::io::{AsyncRead, AsyncWrite};
use utils::config::if_block::IfBlock;
use crate::{
config::{AddressMatch, AggregateFrequency, DkimSigner},
core::{management, Session, SMTP},
config::{AddressMatch, AggregateFrequency},
core::{Session, SMTP},
outbound::{dane::Tlsa, mta_sts::Policy},
queue::{DomainPart, Message},
USER_AGENT,
};
use self::scheduler::{ReportKey, ReportValue};
pub mod analysis;
pub mod dkim;
pub mod dmarc;
@ -57,7 +55,6 @@ pub mod tls;
pub enum Event {
Dmarc(Box<DmarcEvent>),
Tls(Box<TlsEvent>),
Manage(management::ReportRequest),
Stop,
}
@ -137,9 +134,11 @@ impl SMTP {
// Build message
let from_addr_lcase = from_addr.to_lowercase();
let from_addr_domain = from_addr_lcase.domain_part().to_string();
let mut message = Message::new_boxed(from_addr, from_addr_lcase, from_addr_domain);
let mut message = self
.queue
.new_message(from_addr, from_addr_lcase, from_addr_domain);
for rcpt_ in rcpts {
message.add_recipient(rcpt_.as_ref(), &self).await;
message.add_recipient(rcpt_.as_ref(), self).await;
}
// Sign message
@ -164,8 +163,8 @@ impl SMTP {
}
// Queue message
self.queue
.queue_message(message, signature.as_deref(), &report, span)
message
.queue(signature.as_deref(), &report, self, span)
.await;
}
@ -300,42 +299,28 @@ impl From<(&Option<Arc<Policy>>, &Option<Arc<Tlsa>>)> for PolicyType {
}
}
impl ReportKey {
pub fn domain(&self) -> &str {
match self {
scheduler::ReportType::Dmarc(p) => &p.inner,
scheduler::ReportType::Tls(d) => d,
}
pub(crate) struct SerializedSize {
bytes_left: usize,
}
impl SerializedSize {
pub fn new(bytes_left: usize) -> Self {
Self { bytes_left }
}
}
impl ReportValue {
pub async fn delete(&self) {
match self {
scheduler::ReportType::Dmarc(path) => {
if let Err(err) = tokio::fs::remove_file(&path.path).await {
tracing::warn!(
context = "report",
event = "error",
"Failed to remove report file {}: {}",
path.path.display(),
err
);
}
}
scheduler::ReportType::Tls(path) => {
for path in &path.path {
if let Err(err) = tokio::fs::remove_file(&path.inner).await {
tracing::warn!(
context = "report",
event = "error",
"Failed to remove report file {}: {}",
path.inner.display(),
err
);
}
}
impl io::Write for SerializedSize {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let buf_len = buf.len();
if buf_len <= self.bytes_left {
self.bytes_left -= buf_len;
Ok(buf_len)
} else {
Err(io::Error::new(io::ErrorKind::Other, "Size exceeded"))
}
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}

View file

@ -22,156 +22,82 @@
*/
use ahash::{AHashMap, RandomState};
use mail_auth::{
common::{
base32::{Base32Reader, Base32Writer},
headers::Writer,
},
dmarc::Dmarc,
};
use mail_auth::dmarc::Dmarc;
use serde::{de::DeserializeOwned, Serialize};
use std::{
collections::{hash_map::Entry, BinaryHeap},
hash::Hash,
path::PathBuf,
sync::Arc,
time::{Duration, Instant, SystemTime},
};
use tokio::{
fs::{self, OpenOptions},
io::AsyncWriteExt,
sync::mpsc,
use store::{
write::{now, QueueClass, ReportEvent, ValueClass},
Deserialize, IterateParams, ValueKey,
};
use tokio::sync::mpsc;
use crate::{
config::AggregateFrequency,
core::{management::ReportRequest, worker::SpawnCleanup, ReportCore, SMTP},
queue::{InstantFromTimestamp, Schedule},
core::{worker::SpawnCleanup, SMTP},
queue::manager::LONG_WAIT,
};
use super::{dmarc::GenerateDmarcReport, tls::GenerateTlsReport, Event};
pub type ReportKey = ReportType<ReportPolicy<String>, String>;
pub type ReportValue = ReportType<ReportPath<PathBuf>, ReportPath<Vec<ReportPolicy<PathBuf>>>>;
pub struct Scheduler {
short_wait: Duration,
long_wait: Duration,
pub main: BinaryHeap<Schedule<ReportKey>>,
pub reports: AHashMap<ReportKey, ReportValue>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize)]
pub enum ReportType<T, U> {
Dmarc(T),
Tls(U),
}
#[derive(Debug, PartialEq, Eq)]
pub struct ReportPath<T> {
pub path: T,
pub size: usize,
pub created: u64,
pub deliver_at: AggregateFrequency,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ReportPolicy<T> {
pub inner: T,
pub policy: u64,
}
use super::Event;
impl SpawnReport for mpsc::Receiver<Event> {
fn spawn(mut self, core: Arc<SMTP>, mut scheduler: Scheduler) {
fn spawn(mut self, core: Arc<SMTP>) {
tokio::spawn(async move {
let mut last_cleanup = Instant::now();
let mut next_wake_up;
loop {
match tokio::time::timeout(scheduler.wake_up_time(), self.recv()).await {
// Read events
let now = now();
let events = core.next_report_event().await;
next_wake_up = events
.last()
.and_then(|e| match e {
QueueClass::DmarcReportHeader(e) | QueueClass::TlsReportHeader(e)
if e.due > now =>
{
Duration::from_secs(e.due - now).into()
}
_ => None,
})
.unwrap_or(LONG_WAIT);
let core_ = core.clone();
tokio::spawn(async move {
let mut tls_reports = AHashMap::new();
for report_event in events {
match report_event {
QueueClass::DmarcReportHeader(event) if event.due <= now => {
core_.generate_dmarc_report(event).await;
}
QueueClass::TlsReportHeader(event) if event.due <= now => {
tls_reports
.entry(event.domain.clone())
.or_insert_with(Vec::new)
.push(event);
}
_ => (),
}
}
for (domain_name, tls_report) in tls_reports {
core_.generate_tls_report(domain_name, tls_report).await;
}
});
match tokio::time::timeout(next_wake_up, self.recv()).await {
Ok(Some(event)) => match event {
Event::Dmarc(event) => {
scheduler.schedule_dmarc(event, &core).await;
core.schedule_dmarc(event).await;
}
Event::Tls(event) => {
scheduler.schedule_tls(event, &core).await;
core.schedule_tls(event).await;
}
Event::Manage(request) => match request {
ReportRequest::List {
type_,
domain,
result_tx,
} => {
let mut result = Vec::new();
for key in scheduler.reports.keys() {
if domain
.as_ref()
.map_or(false, |domain| domain != key.domain())
{
continue;
}
if let Some(type_) = &type_ {
if !matches!(
(key, type_),
(ReportType::Dmarc(_), ReportType::Dmarc(_))
| (ReportType::Tls(_), ReportType::Tls(_))
) {
continue;
}
}
result.push(key.to_string());
}
let _ = result_tx.send(result);
}
ReportRequest::Status {
report_ids,
result_tx,
} => {
let mut result = Vec::with_capacity(report_ids.len());
for report_id in &report_ids {
result.push(
scheduler
.reports
.get(report_id)
.map(|report_value| (report_id, report_value).into()),
);
}
let _ = result_tx.send(result);
}
ReportRequest::Cancel {
report_ids,
result_tx,
} => {
let mut result = Vec::with_capacity(report_ids.len());
for report_id in &report_ids {
result.push(
if let Some(report) = scheduler.reports.remove(report_id) {
report.delete().await;
true
} else {
false
},
);
}
let _ = result_tx.send(result);
}
},
Event::Stop => break,
},
Ok(None) => break,
Err(_) => {
while let Some(report) = scheduler.next_due() {
match report {
(ReportType::Dmarc(domain), ReportType::Dmarc(path)) => {
core.generate_dmarc_report(domain, path);
}
(ReportType::Tls(domain), ReportType::Tls(path)) => {
core.generate_tls_report(domain, path);
}
_ => unreachable!(),
}
}
// Cleanup expired throttles
if last_cleanup.elapsed().as_secs() >= 86400 {
last_cleanup = Instant::now();
@ -185,429 +111,54 @@ impl SpawnReport for mpsc::Receiver<Event> {
}
impl SMTP {
pub async fn build_report_path(
&self,
domain: ReportType<&str, &str>,
policy: u64,
created: u64,
interval: AggregateFrequency,
) -> PathBuf {
let (ext, domain) = match domain {
ReportType::Dmarc(domain) => ("d", domain),
ReportType::Tls(domain) => ("t", domain),
};
// Build base path
let mut path = self.report.config.path.clone();
let todo = "fix";
let hash = 1;
if hash > 0 {
path.push((policy % hash).to_string());
}
let _ = fs::create_dir(&path).await;
// Build filename
let mut w = Base32Writer::with_capacity(domain.len() + 13);
w.write(&policy.to_le_bytes()[..]);
w.write(&(created.saturating_sub(946684800) as u32).to_le_bytes()[..]);
w.push_byte(
match interval {
AggregateFrequency::Hourly => 0,
AggregateFrequency::Daily => 1,
AggregateFrequency::Weekly => 2,
AggregateFrequency::Never => 3,
pub async fn next_report_event(&self) -> Vec<QueueClass> {
let from_key = ValueKey::from(ValueClass::Queue(QueueClass::DmarcReportHeader(
ReportEvent {
due: 0,
policy_hash: 0,
seq_id: 0,
domain: String::new(),
},
false,
);
w.write(domain.as_bytes());
let mut file = w.finalize();
file.push('.');
file.push_str(ext);
path.push(file);
path
}
}
)));
let to_key = ValueKey::from(ValueClass::Queue(QueueClass::TlsReportHeader(
ReportEvent {
due: u64::MAX,
policy_hash: 0,
seq_id: 0,
domain: String::new(),
},
)));
impl ReportCore {
pub async fn read_reports(&self) -> Scheduler {
let mut scheduler = Scheduler::default();
let mut dir = match tokio::fs::read_dir(&self.config.path).await {
Ok(dir) => dir,
Err(_) => {
return scheduler;
}
};
loop {
match dir.next_entry().await {
Ok(Some(file)) => {
let file = file.path();
if file.is_dir() {
match tokio::fs::read_dir(&file).await {
Ok(mut dir) => {
let file_ = file;
loop {
match dir.next_entry().await {
Ok(Some(file)) => {
let file = file.path();
if file
.extension()
.map_or(false, |e| e == "t" || e == "d")
{
if let Err(err) = scheduler.add_path(file).await {
tracing::warn!("{}", err);
}
}
}
Ok(None) => break,
Err(err) => {
tracing::warn!(
"Failed to read report directory {}: {}",
file_.display(),
err
);
break;
}
}
}
}
Err(err) => {
tracing::warn!(
"Failed to read report directory {}: {}",
file.display(),
err
)
}
};
} else if file.extension().map_or(false, |e| e == "t" || e == "d") {
if let Err(err) = scheduler.add_path(file).await {
tracing::warn!("{}", err);
}
}
}
Ok(None) => {
break;
}
Err(err) => {
tracing::warn!(
"Failed to read report directory {}: {}",
self.config.path.display(),
err
);
break;
}
}
}
scheduler
}
}
impl Scheduler {
pub fn next_due(&mut self) -> Option<(ReportKey, ReportValue)> {
let item = self.main.peek()?;
if item.due <= Instant::now() {
let item = self.main.pop().unwrap();
self.reports
.remove(&item.inner)
.map(|policy| (item.inner, policy))
let mut events = Vec::new();
let now = now();
let result = self
.shared
.default_data_store
.iterate(
IterateParams::new(from_key, to_key).ascending().no_values(),
|key, _| {
let event = ReportEvent::deserialize(key)?;
let do_continue = event.due <= now;
events.push(if *key.last().unwrap() == 0 {
QueueClass::DmarcReportHeader(event)
} else {
None
}
}
pub fn wake_up_time(&self) -> Duration {
self.main
.peek()
.map(|item| {
item.due
.checked_duration_since(Instant::now())
.unwrap_or(self.short_wait)
})
.unwrap_or(self.long_wait)
}
pub async fn add_path(&mut self, path: PathBuf) -> Result<(), String> {
let (file, ext) = path
.file_name()
.and_then(|f| f.to_str())
.and_then(|f| f.rsplit_once('.'))
.ok_or_else(|| format!("Invalid queue file name {}", path.display()))?;
let file_size = fs::metadata(&path)
.await
.map_err(|err| {
format!(
"Failed to obtain file metadata for {}: {}",
path.display(),
err
QueueClass::TlsReportHeader(event)
});
Ok(do_continue)
},
)
})?
.len();
if file_size == 0 {
let _ = fs::remove_file(&path).await;
return Err(format!(
"Removed zero length report file {}",
path.display()
));
}
.await;
// Decode domain name
let mut policy = [0u8; std::mem::size_of::<u64>()];
let mut created = [0u8; std::mem::size_of::<u32>()];
let mut deliver_at = AggregateFrequency::Never;
let mut domain = Vec::new();
for (pos, byte) in Base32Reader::new(file.as_bytes()).enumerate() {
match pos {
0..=7 => {
policy[pos] = byte;
}
8..=11 => {
created[pos - 8] = byte;
}
12 => {
deliver_at = match byte {
0 => AggregateFrequency::Hourly,
1 => AggregateFrequency::Daily,
2 => AggregateFrequency::Weekly,
_ => {
return Err(format!(
"Failed to base32 decode report file {}",
path.display()
));
}
};
}
_ => {
domain.push(byte);
}
}
}
if domain.is_empty() {
return Err(format!(
"Failed to base32 decode report file {}",
path.display()
));
}
let domain = String::from_utf8(domain).map_err(|err| {
format!(
"Failed to base32 decode report file {}: {}",
path.display(),
err
)
})?;
// Rebuild parts
let policy = u64::from_le_bytes(policy);
let created = u32::from_le_bytes(created) as u64 + 946684800;
match ext {
"d" => {
let key = ReportType::Dmarc(ReportPolicy {
inner: domain,
policy,
});
self.reports.insert(
key.clone(),
ReportType::Dmarc(ReportPath {
path,
size: file_size as usize,
created,
deliver_at,
}),
);
self.main.push(Schedule {
due: (created + deliver_at.as_secs()).to_instant(),
inner: key,
});
}
"t" => match self.reports.entry(ReportType::Tls(domain)) {
Entry::Occupied(mut e) => {
if let ReportType::Tls(tls) = e.get_mut() {
tls.size += file_size as usize;
tls.path.push(ReportPolicy {
inner: path,
policy,
});
}
}
Entry::Vacant(e) => {
self.main.push(Schedule {
due: (created + deliver_at.as_secs()).to_instant(),
inner: e.key().clone(),
});
e.insert(ReportType::Tls(ReportPath {
path: vec![ReportPolicy {
inner: path,
policy,
}],
size: file_size as usize,
created,
deliver_at,
}));
}
},
_ => unreachable!(),
}
Ok(())
}
}
pub async fn json_write(path: &PathBuf, entry: &impl Serialize) -> usize {
if let Ok(bytes) = serde_json::to_vec(entry) {
// Save serialized report
let bytes_written = bytes.len() - 2;
match fs::File::create(&path).await {
Ok(mut file) => match file.write_all(&bytes[..bytes_written]).await {
Ok(_) => bytes_written,
Err(err) => {
if let Err(err) = result {
tracing::error!(
context = "report",
context = "queue",
event = "error",
"Failed to write to report file {}: {}",
path.display(),
err
);
0
}
},
Err(err) => {
tracing::error!(
context = "report",
event = "error",
"Failed to create report file {}: {}",
path.display(),
err
);
0
}
}
} else {
0
}
}
pub async fn json_append(path: &PathBuf, entry: &impl Serialize, bytes_left: usize) -> usize {
let mut bytes = Vec::with_capacity(128);
bytes.push(b',');
if serde_json::to_writer(&mut bytes, entry).is_ok() && bytes.len() <= bytes_left {
let err = match OpenOptions::new().append(true).open(&path).await {
Ok(mut file) => match file.write_all(&bytes).await {
Ok(_) => return bytes.len(),
Err(err) => err,
},
Err(err) => err,
};
tracing::error!(
context = "report",
event = "error",
"Failed to append report to {}: {}",
path.display(),
"Failed to read from store: {}",
err
);
}
0
}
pub async fn json_read<T: DeserializeOwned>(path: &PathBuf, span: &tracing::Span) -> Option<T> {
match fs::read_to_string(&path).await {
Ok(mut json) => {
json.push_str("]}");
match serde_json::from_str(&json) {
Ok(report) => Some(report),
Err(err) => {
tracing::error!(
parent: span,
context = "deserialize",
event = "error",
"Failed to deserialize report file {}: {}",
path.display(),
err
);
None
}
}
}
Err(err) => {
tracing::error!(
parent: span,
context = "io",
event = "error",
"Failed to read report file {}: {}",
path.display(),
err
);
None
}
}
}
pub fn json_read_blocking<T: DeserializeOwned>(path: &PathBuf, span: &tracing::Span) -> Option<T> {
match std::fs::read_to_string(path) {
Ok(mut json) => {
json.push_str("]}");
match serde_json::from_str(&json) {
Ok(report) => Some(report),
Err(err) => {
tracing::error!(
parent: span,
context = "deserialize",
event = "error",
"Failed to deserialize report file {}: {}",
path.display(),
err
);
None
}
}
}
Err(err) => {
tracing::error!(
parent: span,
context = "io",
event = "error",
"Failed to read report file {}: {}",
path.display(),
err
);
None
}
}
}
impl Default for Scheduler {
fn default() -> Self {
Self {
short_wait: Duration::from_millis(1),
long_wait: Duration::from_secs(86400 * 365),
main: BinaryHeap::with_capacity(128),
reports: AHashMap::with_capacity(128),
}
}
}
impl ReportKey {
pub fn domain_name(&self) -> &str {
match self {
ReportType::Dmarc(domain) => domain.inner.as_str(),
ReportType::Tls(domain) => domain.as_str(),
}
}
}
impl ReportValue {
pub fn dmarc_path(&mut self) -> &mut ReportPath<PathBuf> {
match self {
ReportType::Dmarc(path) => path,
ReportType::Tls(_) => unreachable!(),
}
}
pub fn tls_path(&mut self) -> &mut ReportPath<Vec<ReportPolicy<PathBuf>>> {
match self {
ReportType::Tls(path) => path,
ReportType::Dmarc(_) => unreachable!(),
}
events
}
}
@ -641,5 +192,5 @@ impl ToTimestamp for Duration {
}
pub trait SpawnReport {
fn spawn(self, core: Arc<SMTP>, scheduler: Scheduler);
fn spawn(self, core: Arc<SMTP>);
}

View file

@ -36,7 +36,7 @@ impl<T: AsyncWrite + AsyncRead + Unpin> Session<T> {
output: &SpfOutput,
) {
// Throttle recipient
if !self.throttle_rcpt(rcpt, rate, "spf") {
if !self.throttle_rcpt(rcpt, rate, "spf").await {
tracing::debug!(
parent: &self.span,
context = "report",

View file

@ -21,7 +21,7 @@
* for more details.
*/
use std::{collections::hash_map::Entry, path::PathBuf, sync::Arc, time::Duration};
use std::{collections::hash_map::Entry, sync::Arc, time::Duration};
use ahash::AHashMap;
use mail_auth::{
@ -34,25 +34,21 @@ use mail_auth::{
use mail_parser::DateTime;
use reqwest::header::CONTENT_TYPE;
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use tokio::runtime::Handle;
use store::{
write::{now, BatchBuilder, Bincode, QueueClass, ReportEvent, ValueClass},
Deserialize, IterateParams, Serialize, ValueKey,
};
use crate::{
config::AggregateFrequency,
core::SMTP,
outbound::mta_sts::{Mode, MxPattern},
queue::{InstantFromTimestamp, RecipientDomain, Schedule},
queue::RecipientDomain,
USER_AGENT,
};
use super::{
scheduler::{
json_append, json_read_blocking, json_write, ReportPath, ReportPolicy, ReportType,
Scheduler, ToHash,
},
TlsEvent,
};
use super::{scheduler::ToHash, SerializedSize, TlsEvent};
#[derive(Debug, Clone)]
pub struct TlsRptOptions {
@ -60,81 +56,155 @@ pub struct TlsRptOptions {
pub interval: AggregateFrequency,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct TlsFormat {
rua: Vec<ReportUri>,
policy: PolicyDetails,
records: Vec<Option<FailureDetails>>,
}
pub trait GenerateTlsReport {
fn generate_tls_report(&self, domain: String, paths: ReportPath<Vec<ReportPolicy<PathBuf>>>);
}
#[cfg(feature = "test_mode")]
pub static TLS_HTTP_REPORT: parking_lot::Mutex<Vec<u8>> = parking_lot::Mutex::new(Vec::new());
impl GenerateTlsReport for Arc<SMTP> {
fn generate_tls_report(&self, domain: String, path: ReportPath<Vec<ReportPolicy<PathBuf>>>) {
let core = self.clone();
let handle = Handle::current();
impl SMTP {
pub async fn generate_tls_report(&self, domain_name: String, events: Vec<ReportEvent>) {
let (event_from, event_to, policy) = events
.first()
.map(|e| (e.seq_id, e.due, e.policy_hash))
.unwrap();
self.worker_pool.spawn(move || {
let deliver_at = path.created + path.deliver_at.as_secs();
let span = tracing::info_span!(
"tls-report",
domain = domain,
range_from = path.created,
range_to = deliver_at,
size = path.size,
domain = domain_name,
range_from = event_from,
range_to = event_to,
);
// Deserialize report
let config = &core.report.config.tls;
let config = &self.report.config.tls;
let mut report = TlsReport {
organization_name: handle
.block_on(
core.eval_if(&config.org_name, &RecipientDomain::new(domain.as_str())),
organization_name: self
.eval_if(
&config.org_name,
&RecipientDomain::new(domain_name.as_str()),
)
.await
.clone(),
date_range: DateRange {
start_datetime: DateTime::from_timestamp(path.created as i64),
end_datetime: DateTime::from_timestamp(deliver_at as i64),
start_datetime: DateTime::from_timestamp(event_from as i64),
end_datetime: DateTime::from_timestamp(event_to as i64),
},
contact_info: handle
.block_on(
core.eval_if(&config.contact_info, &RecipientDomain::new(domain.as_str())),
contact_info: self
.eval_if(
&config.contact_info,
&RecipientDomain::new(domain_name.as_str()),
)
.await
.clone(),
report_id: format!(
"{}_{}",
path.created,
path.path.first().map_or(0, |p| p.policy)
),
policies: Vec::with_capacity(path.path.len()),
report_id: format!("{}_{}", event_from, policy),
policies: Vec::with_capacity(events.len()),
};
let mut rua = Vec::new();
for path in &path.path {
if let Some(tls) = json_read_blocking::<TlsFormat>(&path.inner, &span) {
let mut serialized_size = serde_json::Serializer::new(SerializedSize::new(
self.eval_if(
&self.report.config.tls.max_size,
&RecipientDomain::new(domain_name.as_str()),
)
.await
.unwrap_or(25 * 1024 * 1024),
));
let _ = serde::Serialize::serialize(&report, &mut serialized_size);
for event in &events {
// Deserialize report
let tls = match self
.shared
.default_data_store
.get_value::<Bincode<TlsFormat>>(ValueKey::from(ValueClass::Queue(
QueueClass::TlsReportHeader(event.clone()),
)))
.await
{
Ok(Some(dmarc)) => dmarc.inner,
Ok(None) => {
tracing::warn!(
parent: &span,
event = "missing",
"Failed to read DMARC report: Report not found"
);
continue;
}
Err(err) => {
tracing::warn!(
parent: &span,
event = "error",
"Failed to read DMARC report: {}",
err
);
continue;
}
};
let _ = serde::Serialize::serialize(&tls, &mut serialized_size);
// Group duplicates
let mut total_success = 0;
let mut total_failure = 0;
let from_key =
ValueKey::from(ValueClass::Queue(QueueClass::TlsReportEvent(ReportEvent {
due: event.due,
policy_hash: event.policy_hash,
seq_id: 0,
domain: event.domain.clone(),
})));
let to_key =
ValueKey::from(ValueClass::Queue(QueueClass::TlsReportEvent(ReportEvent {
due: event.due,
policy_hash: event.policy_hash,
seq_id: u64::MAX,
domain: event.domain.clone(),
})));
let mut record_map = AHashMap::with_capacity(tls.records.len());
for record in tls.records {
if let Some(record) = record {
match record_map.entry(record) {
if let Err(err) = self
.shared
.default_data_store
.iterate(IterateParams::new(from_key, to_key).ascending(), |_, v| {
if let Some(failure_details) =
Bincode::<Option<FailureDetails>>::deserialize(v)?.inner
{
total_failure += 1;
match record_map.entry(failure_details) {
Entry::Occupied(mut e) => {
*e.get_mut() += 1;
Ok(true)
}
Entry::Vacant(e) => {
if serde::Serialize::serialize(e.key(), &mut serialized_size)
.is_ok()
{
e.insert(1u32);
Ok(true)
} else {
Ok(false)
}
}
}
total_failure += 1;
} else {
total_success += 1;
Ok(true)
}
})
.await
{
tracing::warn!(
parent: &span,
event = "error",
"Failed to read TLS report: {}",
err
);
}
report.policies.push(Policy {
policy: tls.policy,
summary: Summary {
@ -152,7 +222,6 @@ impl GenerateTlsReport for Arc<SMTP> {
rua = tls.rua;
}
}
if report.policies.is_empty() {
// This should not happen
@ -161,15 +230,15 @@ impl GenerateTlsReport for Arc<SMTP> {
event = "empty-report",
"No policies found in report"
);
path.cleanup_blocking();
self.delete_tls_report(events).await;
return;
}
// Compress and serialize report
let json = report.to_json();
let mut e = GzEncoder::new(Vec::with_capacity(json.len()), Compression::default());
let json =
match std::io::Write::write_all(&mut e, json.as_bytes()).and_then(|_| e.finish()) {
let json = match std::io::Write::write_all(&mut e, json.as_bytes()).and_then(|_| e.finish())
{
Ok(report) => report,
Err(err) => {
tracing::error!(
@ -178,6 +247,7 @@ impl GenerateTlsReport for Arc<SMTP> {
"Failed to compress report: {}",
err
);
self.delete_tls_report(events).await;
return;
}
};
@ -195,7 +265,7 @@ impl GenerateTlsReport for Arc<SMTP> {
#[cfg(feature = "test_mode")]
if uri == "https://127.0.0.1/tls" {
TLS_HTTP_REPORT.lock().extend_from_slice(&json);
path.cleanup_blocking();
self.delete_tls_report(events).await;
return;
}
@ -213,7 +283,7 @@ impl GenerateTlsReport for Arc<SMTP> {
event = "success",
url = uri,
);
path.cleanup_blocking();
self.delete_tls_report(events).await;
return;
} else {
tracing::debug!(
@ -245,23 +315,23 @@ impl GenerateTlsReport for Arc<SMTP> {
// Deliver report over SMTP
if !rcpts.is_empty() {
let from_addr = handle
.block_on(core.eval_if(&config.address, &RecipientDomain::new(domain.as_str())))
let from_addr = self
.eval_if(&config.address, &RecipientDomain::new(domain_name.as_str()))
.await
.unwrap_or_else(|| "MAILER-DAEMON@localhost".to_string());
let mut message = Vec::with_capacity(path.size);
let mut message = Vec::with_capacity(2048);
let _ = report.write_rfc5322_from_bytes(
&domain,
&handle
.block_on(core.eval_if(
&core.report.config.submitter,
&RecipientDomain::new(domain.as_str()),
))
&domain_name,
&self
.eval_if(
&self.report.config.submitter,
&RecipientDomain::new(domain_name.as_str()),
)
.await
.unwrap_or_else(|| "localhost".to_string()),
(
handle
.block_on(
core.eval_if(&config.name, &RecipientDomain::new(domain.as_str())),
)
self.eval_if(&config.name, &RecipientDomain::new(domain_name.as_str()))
.await
.unwrap_or_else(|| "Mail Delivery Subsystem".to_string())
.as_str(),
from_addr.as_str(),
@ -272,14 +342,15 @@ impl GenerateTlsReport for Arc<SMTP> {
);
// Send report
handle.block_on(core.send_report(
self.send_report(
&from_addr,
rcpts.iter(),
message,
&config.sign,
&span,
false,
));
)
.await;
} else {
tracing::info!(
parent: &span,
@ -287,83 +358,36 @@ impl GenerateTlsReport for Arc<SMTP> {
"No valid recipients found to deliver report to."
);
}
path.cleanup_blocking();
});
self.delete_tls_report(events).await;
}
}
impl Scheduler {
pub async fn schedule_tls(&mut self, event: Box<TlsEvent>, core: &SMTP) {
let max_size = core
.eval_if(
&core.report.config.tls.max_size,
&RecipientDomain::new(event.domain.as_str()),
)
.await
.unwrap_or(25 * 1024 * 1024);
let policy_hash = event.policy.to_hash();
let (path, pos, create) = match self.reports.entry(ReportType::Tls(event.domain)) {
Entry::Occupied(e) => {
if let ReportType::Tls(path) = e.get() {
if let Some(pos) = path.path.iter().position(|p| p.policy == policy_hash) {
(e.into_mut().tls_path(), pos, None)
} else {
let pos = path.path.len();
let domain = e.key().domain_name().to_string();
let path = e.into_mut().tls_path();
path.path.push(ReportPolicy {
inner: core
.build_report_path(
ReportType::Tls(&domain),
policy_hash,
path.created,
path.deliver_at,
)
.await,
policy: policy_hash,
});
(path, pos, domain.into())
}
} else {
unreachable!()
}
}
Entry::Vacant(e) => {
pub async fn schedule_tls(&self, event: Box<TlsEvent>) {
let created = event.interval.to_timestamp();
let deliver_at = created + event.interval.as_secs();
self.main.push(Schedule {
due: deliver_at.to_instant(),
inner: e.key().clone(),
});
let domain = e.key().domain_name().to_string();
let path = core
.build_report_path(
ReportType::Tls(&domain),
policy_hash,
created,
event.interval,
)
.await;
let v = e.insert(ReportType::Tls(ReportPath {
path: vec![ReportPolicy {
inner: path,
policy: policy_hash,
}],
size: 0,
created,
deliver_at: event.interval,
}));
(v.tls_path(), 0, domain.into())
}
let mut report_event = ReportEvent {
due: deliver_at,
policy_hash: event.policy.to_hash(),
seq_id: created,
domain: event.domain,
};
if let Some(domain) = create {
// Write policy if missing
let mut builder = BatchBuilder::new();
if self
.shared
.default_data_store
.get_value::<()>(ValueKey::from(ValueClass::Queue(
QueueClass::TlsReportHeader(report_event.clone()),
)))
.await
.unwrap_or_default()
.is_none()
{
// Serialize report
let mut policy = PolicyDetails {
policy_type: PolicyType::NoPolicyFound,
policy_string: vec![],
policy_domain: domain,
policy_domain: report_event.domain.clone(),
mx_host: vec![],
};
@ -420,47 +444,78 @@ impl Scheduler {
let entry = TlsFormat {
rua: event.tls_record.rua.clone(),
policy,
records: vec![event.failure],
records: vec![],
};
let bytes_written = json_write(&path.path[pos].inner, &entry).await;
if bytes_written > 0 {
path.size += bytes_written;
} else {
// Something went wrong, remove record
if let Entry::Occupied(mut e) = self
.reports
.entry(ReportType::Tls(entry.policy.policy_domain))
{
if let ReportType::Tls(path) = e.get_mut() {
path.path.retain(|p| p.policy != policy_hash);
if path.path.is_empty() {
e.remove_entry();
// Write report
builder.set(
ValueClass::Queue(QueueClass::TlsReportHeader(report_event.clone())),
Bincode::new(entry).serialize(),
);
}
}
}
}
} else if path.size < max_size {
// Append to existing report
path.size +=
json_append(&path.path[pos].inner, &event.failure, max_size - path.size).await;
}
}
}
impl ReportPath<Vec<ReportPolicy<PathBuf>>> {
fn cleanup_blocking(&self) {
for path in &self.path {
if let Err(err) = std::fs::remove_file(&path.inner) {
// Write entry
report_event.seq_id = self.queue.snowflake_id.generate().unwrap_or_else(now);
builder.set(
ValueClass::Queue(QueueClass::TlsReportEvent(report_event)),
Bincode::new(event.failure).serialize(),
);
if let Err(err) = self.shared.default_data_store.write(builder.build()).await {
tracing::error!(
context = "report",
report = "tls",
event = "error",
"Failed to delete file {}: {}",
path.inner.display(),
"Failed to write DMARC report event: {}",
err
);
}
}
pub async fn delete_tls_report(&self, events: Vec<ReportEvent>) {
let mut batch = BatchBuilder::new();
for event in events {
let from_key = ReportEvent {
due: event.due,
policy_hash: event.policy_hash,
seq_id: 0,
domain: event.domain.clone(),
};
let to_key = ReportEvent {
due: event.due,
policy_hash: event.policy_hash,
seq_id: u64::MAX,
domain: event.domain.clone(),
};
if let Err(err) = self
.shared
.default_data_store
.delete_range(
ValueKey::from(ValueClass::Queue(QueueClass::TlsReportEvent(from_key))),
ValueKey::from(ValueClass::Queue(QueueClass::TlsReportEvent(to_key))),
)
.await
{
tracing::warn!(
context = "report",
event = "error",
"Failed to remove repors: {}",
err
);
return;
}
batch.clear(ValueClass::Queue(QueueClass::TlsReportHeader(event)));
}
if let Err(err) = self.shared.default_data_store.write(batch.build()).await {
tracing::warn!(
context = "report",
event = "error",
"Failed to remove repors: {}",
err
);
}
}
}

View file

@ -21,7 +21,7 @@
* for more details.
*/
use std::{sync::Arc, time::Duration};
use std::sync::Arc;
use mail_auth::common::headers::HeaderWriter;
use sieve::{
@ -32,18 +32,12 @@ use smtp_proto::{
MAIL_BY_TRACE, MAIL_RET_FULL, MAIL_RET_HDRS, RCPT_NOTIFY_DELAY, RCPT_NOTIFY_FAILURE,
RCPT_NOTIFY_NEVER, RCPT_NOTIFY_SUCCESS,
};
use store::{backend::memory::MemoryStore, LookupKey, LookupStore, LookupValue};
use store::{backend::memory::MemoryStore, LookupStore};
use tokio::runtime::Handle;
use crate::{
core::SMTP,
queue::{DomainPart, InstantFromTimestamp, Message},
};
use crate::{core::SMTP, queue::DomainPart};
use super::{
plugins::{lookup::VariableExists, PluginContext},
ScriptModification, ScriptParameters, ScriptResult,
};
use super::{plugins::PluginContext, ScriptModification, ScriptParameters, ScriptResult};
impl SMTP {
pub fn run_script_blocking(
@ -97,15 +91,15 @@ impl SMTP {
'outer: for list in lists {
if let Some(store) = self.shared.lookup_stores.get(&list) {
for value in &values {
if let Ok(LookupValue::Value { .. }) = handle.block_on(
store.key_get::<VariableExists>(LookupKey::Key(
if let Ok(true) = handle.block_on(
store.key_exists(
if !matches!(match_as, MatchAs::Lowercase) {
value.clone()
} else {
value.to_lowercase()
}
.into_bytes(),
)),
),
) {
input = true.into();
break 'outer;
@ -156,7 +150,7 @@ impl SMTP {
// Build message
let return_path_lcase = self.sieve.return_path.to_lowercase();
let return_path_domain = return_path_lcase.domain_part().to_string();
let mut message = Message::new_boxed(
let mut message = self.queue.new_message(
self.sieve.return_path.clone(),
return_path_lcase,
return_path_domain,
@ -223,7 +217,6 @@ impl SMTP {
if trace {
message.flags |= MAIL_BY_TRACE;
}
let rlimit = Duration::from_secs(rlimit);
match mode {
ByMode::Notify => {
for domain in &mut message.domains {
@ -246,16 +239,15 @@ impl SMTP {
if trace {
message.flags |= MAIL_BY_TRACE;
}
let alimit = (alimit as u64).to_instant();
match mode {
ByMode::Notify => {
for domain in &mut message.domains {
domain.notify.due = alimit;
domain.notify.due = alimit as u64;
}
}
ByMode::Return => {
for domain in &mut message.domains {
domain.expires = alimit;
domain.expires = alimit as u64;
}
}
ByMode::Default => (),
@ -302,10 +294,10 @@ impl SMTP {
None
};
handle.block_on(self.queue.queue_message(
message,
handle.block_on(message.queue(
headers.as_deref(),
raw_message,
self,
&span,
));
}

View file

@ -29,12 +29,12 @@ use nlp::{
tokenizers::osb::{OsbToken, OsbTokenizer},
};
use sieve::{runtime::Variable, FunctionMap};
use store::{write::key::KeySerializer, LookupKey, LookupStore, LookupValue, U64_LEN};
use store::{write::key::KeySerializer, LookupStore, U64_LEN};
use tokio::runtime::Handle;
use crate::config::scripts::SieveContext;
use super::{lookup::VariableExists, PluginContext};
use super::PluginContext;
pub fn register_train(plugin_id: u32, fnc_map: &mut FunctionMap<SieveContext>) {
fnc_map.set_external_function("bayes_train", plugin_id, 3);
@ -110,14 +110,13 @@ fn train(ctx: PluginContext<'_>, is_train: bool) -> Variable {
for (hash, weights) in model.weights {
if handle
.block_on(
store.key_set(
store.counter_incr(
KeySerializer::new(U64_LEN)
.write(hash.h1)
.write(hash.h2)
.finalize(),
LookupValue::Counter {
num: weights.into(),
},
weights.into(),
None,
),
)
.is_err()
@ -135,14 +134,13 @@ fn train(ctx: PluginContext<'_>, is_train: bool) -> Variable {
};
if handle
.block_on(
store.key_set(
store.counter_incr(
KeySerializer::new(U64_LEN)
.write(0u64)
.write(0u64)
.finalize(),
LookupValue::Counter {
num: weights.into(),
},
weights.into(),
None,
),
)
.is_err()
@ -337,15 +335,15 @@ impl LookupOrInsert for BayesTokenCache {
) -> Option<Weights> {
if let Some(weights) = self.get(&hash) {
weights.unwrap_or_default().into()
} else if let Ok(result) = handle.block_on(
get_token.key_get::<VariableExists>(LookupKey::Counter(
} else if let Ok(num) = handle.block_on(
get_token.counter_get(
KeySerializer::new(U64_LEN)
.write(hash.h1)
.write(hash.h2)
.finalize(),
)),
),
) {
if let LookupValue::Counter { num } = result {
if num != 0 {
let weights = Weights::from(num);
self.insert_positive(hash, weights);
weights

View file

@ -29,7 +29,7 @@ use std::{
use mail_auth::flate2;
use sieve::{runtime::Variable, FunctionMap};
use store::{Deserialize, LookupKey, LookupValue, Value};
use store::{Deserialize, Value};
use crate::{
config::scripts::{RemoteList, SieveContext},
@ -72,10 +72,7 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable {
if !item.is_empty()
&& ctx
.handle
.block_on(store.key_get::<VariableExists>(LookupKey::Key(
item.to_string().into_owned().into_bytes(),
)))
.map(|v| v != LookupValue::None)
.block_on(store.key_exists(item.to_string().into_owned().into_bytes()))
.unwrap_or(false)
{
return true.into();
@ -85,10 +82,7 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable {
}
v if !v.is_empty() => ctx
.handle
.block_on(store.key_get::<VariableExists>(LookupKey::Key(
v.to_string().into_owned().into_bytes(),
)))
.map(|v| v != LookupValue::None)
.block_on(store.key_exists(v.to_string().into_owned().into_bytes()))
.unwrap_or(false),
_ => false,
}
@ -113,14 +107,13 @@ pub fn exec_get(ctx: PluginContext<'_>) -> Variable {
if let Some(store) = store {
ctx.handle
.block_on(store.key_get::<VariableWrapper>(LookupKey::Key(
.block_on(
store.key_get::<VariableWrapper>(
ctx.arguments[1].to_string().into_owned().into_bytes(),
)))
.map(|v| match v {
LookupValue::Value { value, .. } => value.into_inner(),
LookupValue::Counter { num } => num.into(),
LookupValue::None => Variable::default(),
})
),
)
.unwrap_or_default()
.map(|v| v.into_inner())
.unwrap_or_default()
} else {
tracing::warn!(
@ -142,22 +135,20 @@ pub fn exec_set(ctx: PluginContext<'_>) -> Variable {
if let Some(store) = store {
let expires = match &ctx.arguments[3] {
Variable::Integer(v) => *v as u64,
Variable::Float(v) => *v as u64,
_ => 0,
Variable::Integer(v) => Some(*v as u64),
Variable::Float(v) => Some(*v as u64),
_ => None,
};
ctx.handle
.block_on(store.key_set(
ctx.arguments[1].to_string().into_owned().into_bytes(),
LookupValue::Value {
value: if !ctx.arguments[2].is_empty() {
if !ctx.arguments[2].is_empty() {
bincode::serialize(&ctx.arguments[2]).unwrap_or_default()
} else {
vec![]
},
expires,
},
))
.is_ok()
.into()
@ -426,9 +417,6 @@ pub fn exec_local_domain(ctx: PluginContext<'_>) -> Variable {
#[derive(Debug, PartialEq, Eq)]
pub struct VariableWrapper(Variable);
#[derive(Debug, PartialEq, Eq)]
pub struct VariableExists;
impl Deserialize for VariableWrapper {
fn deserialize(bytes: &[u8]) -> store::Result<Self> {
Ok(VariableWrapper(
@ -439,9 +427,9 @@ impl Deserialize for VariableWrapper {
}
}
impl Deserialize for VariableExists {
fn deserialize(_: &[u8]) -> store::Result<Self> {
Ok(VariableExists)
impl From<i64> for VariableWrapper {
fn from(value: i64) -> Self {
VariableWrapper(value.into())
}
}
@ -451,12 +439,6 @@ impl VariableWrapper {
}
}
impl From<Value<'static>> for VariableExists {
fn from(_: Value<'static>) -> Self {
VariableExists
}
}
impl From<Value<'static>> for VariableWrapper {
fn from(value: Value<'static>) -> Self {
VariableWrapper(into_sieve_value(value))

View file

@ -44,6 +44,7 @@ flate2 = "1.0"
async-trait = "0.1.68"
redis = { version = "0.24.0", features = [ "tokio-comp", "tokio-rustls-comp", "tls-rustls-insecure", "tls-rustls-webpki-roots", "cluster-async"], optional = true }
deadpool = { version = "0.10.0", features = ["managed"], optional = true }
bincode = "1.3.3"
[dev-dependencies]
tokio = { version = "1.23", features = ["full"] }

View file

@ -23,68 +23,145 @@
use redis::AsyncCommands;
use crate::{Deserialize, LookupKey, LookupValue};
use crate::Deserialize;
use super::{RedisPool, RedisStore};
impl RedisStore {
pub async fn key_set(&self, key: Vec<u8>, value: LookupValue<Vec<u8>>) -> crate::Result<()> {
pub async fn key_set(
&self,
key: Vec<u8>,
value: Vec<u8>,
expires: Option<u64>,
) -> crate::Result<()> {
match &self.pool {
RedisPool::Single(pool) => self.key_set_(pool.get().await?.as_mut(), key, value).await,
RedisPool::Cluster(pool) => self.key_set_(pool.get().await?.as_mut(), key, value).await,
RedisPool::Single(pool) => {
self.key_set_(pool.get().await?.as_mut(), key, value, expires)
.await
}
RedisPool::Cluster(pool) => {
self.key_set_(pool.get().await?.as_mut(), key, value, expires)
.await
}
}
}
pub async fn key_incr(
&self,
key: Vec<u8>,
value: i64,
expires: Option<u64>,
) -> crate::Result<i64> {
match &self.pool {
RedisPool::Single(pool) => {
self.key_incr_(pool.get().await?.as_mut(), key, value, expires)
.await
}
RedisPool::Cluster(pool) => {
self.key_incr_(pool.get().await?.as_mut(), key, value, expires)
.await
}
}
}
pub async fn key_delete(&self, key: Vec<u8>) -> crate::Result<()> {
match &self.pool {
RedisPool::Single(pool) => self.key_delete_(pool.get().await?.as_mut(), key).await,
RedisPool::Cluster(pool) => self.key_delete_(pool.get().await?.as_mut(), key).await,
}
}
pub async fn key_get<T: Deserialize + std::fmt::Debug + 'static>(
&self,
key: LookupKey,
) -> crate::Result<LookupValue<T>> {
key: Vec<u8>,
) -> crate::Result<Option<T>> {
match &self.pool {
RedisPool::Single(pool) => self.key_get_(pool.get().await?.as_mut(), key).await,
RedisPool::Cluster(pool) => self.key_get_(pool.get().await?.as_mut(), key).await,
}
}
pub async fn counter_get(&self, key: Vec<u8>) -> crate::Result<i64> {
match &self.pool {
RedisPool::Single(pool) => self.counter_get_(pool.get().await?.as_mut(), key).await,
RedisPool::Cluster(pool) => self.counter_get_(pool.get().await?.as_mut(), key).await,
}
}
pub async fn key_exists(&self, key: Vec<u8>) -> crate::Result<bool> {
match &self.pool {
RedisPool::Single(pool) => self.key_exists_(pool.get().await?.as_mut(), key).await,
RedisPool::Cluster(pool) => self.key_exists_(pool.get().await?.as_mut(), key).await,
}
}
async fn key_get_<T: Deserialize + std::fmt::Debug + 'static>(
&self,
conn: &mut impl AsyncCommands,
key: LookupKey,
) -> crate::Result<LookupValue<T>> {
match key {
LookupKey::Key(key) => {
key: Vec<u8>,
) -> crate::Result<Option<T>> {
if let Some(value) = conn.get::<_, Option<Vec<u8>>>(key).await? {
T::deserialize(&value).map(|value| LookupValue::Value { value, expires: 0 })
T::deserialize(&value).map(Some)
} else {
Ok(LookupValue::None)
Ok(None)
}
}
LookupKey::Counter(key) => {
let value: Option<i64> = conn.get(key).await?;
Ok(LookupValue::Counter {
num: value.unwrap_or(0),
})
}
async fn counter_get_(
&self,
conn: &mut impl AsyncCommands,
key: Vec<u8>,
) -> crate::Result<i64> {
conn.get::<_, Option<i64>>(key)
.await
.map(|x| x.unwrap_or(0))
.map_err(Into::into)
}
async fn key_exists_(
&self,
conn: &mut impl AsyncCommands,
key: Vec<u8>,
) -> crate::Result<bool> {
conn.exists(key).await.map_err(Into::into)
}
async fn key_set_(
&self,
conn: &mut impl AsyncCommands,
key: Vec<u8>,
value: LookupValue<Vec<u8>>,
value: Vec<u8>,
expires: Option<u64>,
) -> crate::Result<()> {
match value {
LookupValue::Value { value, expires } => {
if expires > 0 {
conn.set_ex(key, value, expires).await?;
if let Some(expires) = expires {
conn.set_ex(key, value, expires).await.map_err(Into::into)
} else {
conn.set(key, value).await?;
conn.set(key, value).await.map_err(Into::into)
}
}
LookupValue::Counter { num } => conn.incr(key, num).await?,
LookupValue::None => (),
}
Ok(())
async fn key_incr_(
&self,
conn: &mut impl AsyncCommands,
key: Vec<u8>,
value: i64,
expires: Option<u64>,
) -> crate::Result<i64> {
if let Some(expires) = expires {
redis::pipe()
.atomic()
.incr(&key, value)
.expire(&key, expires as i64)
.ignore()
.query_async(conn)
.await
.map_err(Into::into)
} else {
conn.incr(&key, value).await.map_err(Into::into)
}
}
async fn key_delete_(&self, conn: &mut impl AsyncCommands, key: Vec<u8>) -> crate::Result<()> {
conn.del(key).await.map_err(Into::into)
}
}

View file

@ -21,17 +21,16 @@
* for more details.
*/
use utils::expr;
use utils::{config::Rate, expr};
use crate::{backend::memory::MemoryStore, Row};
use crate::{backend::memory::MemoryStore, write::LookupClass, Row};
#[allow(unused_imports)]
use crate::{
write::{
key::{DeserializeBigEndian, KeySerializer},
now, BatchBuilder, Operation, ValueClass, ValueOp,
},
Deserialize, IterateParams, LookupKey, LookupStore, LookupValue, QueryResult, Store, Value,
ValueKey, U64_LEN,
Deserialize, IterateParams, LookupStore, QueryResult, Store, Value, ValueKey, U64_LEN,
};
impl LookupStore {
@ -59,33 +58,28 @@ impl LookupStore {
result
}
pub async fn key_set(&self, key: Vec<u8>, value: LookupValue<Vec<u8>>) -> crate::Result<()> {
pub async fn key_set(
&self,
key: Vec<u8>,
value: Vec<u8>,
expires: Option<u64>,
) -> crate::Result<()> {
match self {
LookupStore::Store(store) => {
let (class, op) = match value {
LookupValue::Value { value, expires } => (
ValueClass::Key(key),
ValueOp::Set(
let mut batch = BatchBuilder::new();
batch.ops.push(Operation::Value {
class: ValueClass::Lookup(LookupClass::Key(key)),
op: ValueOp::Set(
KeySerializer::new(value.len() + U64_LEN)
.write(if expires > 0 {
now() + expires
} else {
u64::MAX
})
.write(expires.map_or(u64::MAX, |expires| now() + expires))
.write(value.as_slice())
.finalize(),
),
),
LookupValue::Counter { num } => (ValueClass::Key(key), ValueOp::Add(num)),
LookupValue::None => return Ok(()),
};
let mut batch = BatchBuilder::new();
batch.ops.push(Operation::Value { class, op });
});
store.write(batch.build()).await
}
#[cfg(feature = "redis")]
LookupStore::Redis(store) => store.key_set(key, value).await,
LookupStore::Redis(store) => store.key_set(key, value, expires).await,
LookupStore::Query(lookup) => lookup
.store
.query::<usize>(
@ -100,83 +94,209 @@ impl LookupStore {
}
}
pub async fn counter_incr(
&self,
key: Vec<u8>,
value: i64,
expires: Option<u64>,
) -> crate::Result<i64> {
match self {
LookupStore::Store(store) => {
let mut batch = BatchBuilder::new();
if let Some(expires) = expires {
batch.ops.push(Operation::Value {
class: ValueClass::Lookup(LookupClass::CounterExpiry(key.clone())),
op: ValueOp::Set(
KeySerializer::new(U64_LEN)
.write(now() + expires)
.finalize(),
),
});
}
batch.ops.push(Operation::Value {
class: ValueClass::Lookup(LookupClass::Counter(key)),
op: ValueOp::Add(value),
});
store.write(batch.build()).await?;
Ok(0)
}
#[cfg(feature = "redis")]
LookupStore::Redis(store) => store.key_incr(key, value, expires).await,
LookupStore::Query(_) | LookupStore::Memory(_) => Err(crate::Error::InternalError(
"This store does not support counter_incr".into(),
)),
}
}
pub async fn key_delete(&self, key: Vec<u8>) -> crate::Result<()> {
match self {
LookupStore::Store(store) => {
let mut batch = BatchBuilder::new();
batch.ops.push(Operation::Value {
class: ValueClass::Lookup(LookupClass::Key(key)),
op: ValueOp::Clear,
});
store.write(batch.build()).await
}
#[cfg(feature = "redis")]
LookupStore::Redis(store) => store.key_delete(key).await,
LookupStore::Query(_) | LookupStore::Memory(_) => Err(crate::Error::InternalError(
"This store does not support key_set".into(),
)),
}
}
pub async fn counter_delete(&self, key: Vec<u8>) -> crate::Result<()> {
match self {
LookupStore::Store(store) => {
let mut batch = BatchBuilder::new();
batch.ops.push(Operation::Value {
class: ValueClass::Lookup(LookupClass::Counter(key)),
op: ValueOp::Clear,
});
store.write(batch.build()).await
}
#[cfg(feature = "redis")]
LookupStore::Redis(store) => store.key_delete(key).await,
LookupStore::Query(_) | LookupStore::Memory(_) => Err(crate::Error::InternalError(
"This store does not support key_set".into(),
)),
}
}
pub async fn key_get<T: Deserialize + From<Value<'static>> + std::fmt::Debug + 'static>(
&self,
key: LookupKey,
) -> crate::Result<LookupValue<T>> {
key: Vec<u8>,
) -> crate::Result<Option<T>> {
match self {
LookupStore::Store(store) => match key {
LookupKey::Key(key) => store
.get_value::<LookupValue<T>>(ValueKey {
account_id: 0,
collection: 0,
document_id: 0,
class: ValueClass::Key(key),
})
LookupStore::Store(store) => store
.get_value::<LookupValue<T>>(ValueKey::from(ValueClass::Lookup(LookupClass::Key(
key,
))))
.await
.map(|value| value.unwrap_or(LookupValue::None)),
LookupKey::Counter(key) => store
.get_counter(ValueKey {
account_id: 0,
collection: 0,
document_id: 0,
class: ValueClass::Key(key),
})
.await
.map(|num| LookupValue::Counter { num }),
},
.map(|value| value.and_then(|v| v.into())),
#[cfg(feature = "redis")]
LookupStore::Redis(store) => store.key_get(key).await,
LookupStore::Memory(store) => {
let key = String::from(key);
let key = String::from_utf8(key).unwrap_or_default();
match store.as_ref() {
MemoryStore::List(list) => Ok(if list.contains(&key) {
LookupValue::Value {
value: T::from(Value::Bool(true)),
expires: 0,
}
Some(T::from(Value::Bool(true)))
} else {
LookupValue::None
None
}),
MemoryStore::Map(map) => Ok(map
.get(&key)
.map(|value| LookupValue::Value {
value: T::from(value.to_owned()),
expires: 0,
})
.unwrap_or(LookupValue::None)),
MemoryStore::Map(map) => {
Ok(map.get(&key).map(|value| T::from(value.to_owned())))
}
}
}
LookupStore::Query(lookup) => lookup
.store
.query::<Option<Row>>(&lookup.query, vec![String::from(key).into()])
.query::<Option<Row>>(
&lookup.query,
vec![String::from_utf8(key).unwrap_or_default().into()],
)
.await
.map(|row| {
row.and_then(|row| row.values.into_iter().next())
.map(|value| LookupValue::Value {
value: T::from(value),
expires: 0,
})
.unwrap_or(LookupValue::None)
.map(|value| T::from(value))
}),
}
}
pub async fn counter_get(&self, key: Vec<u8>) -> crate::Result<i64> {
match self {
LookupStore::Store(store) => {
store
.get_counter(ValueKey::from(ValueClass::Lookup(LookupClass::Counter(
key,
))))
.await
}
#[cfg(feature = "redis")]
LookupStore::Redis(store) => store.counter_get(key).await,
LookupStore::Query(_) | LookupStore::Memory(_) => Err(crate::Error::InternalError(
"This store does not support counter_get".into(),
)),
}
}
pub async fn key_exists(&self, key: Vec<u8>) -> crate::Result<bool> {
match self {
LookupStore::Store(store) => store
.get_value::<LookupValue<()>>(ValueKey::from(ValueClass::Lookup(LookupClass::Key(
key,
))))
.await
.map(|value| matches!(value, Some(LookupValue::Value(())))),
#[cfg(feature = "redis")]
LookupStore::Redis(store) => store.key_exists(key).await,
LookupStore::Memory(store) => {
let key = String::from_utf8(key).unwrap_or_default();
match store.as_ref() {
MemoryStore::List(list) => Ok(list.contains(&key)),
MemoryStore::Map(map) => Ok(map.contains_key(&key)),
}
}
LookupStore::Query(lookup) => lookup
.store
.query::<Option<Row>>(
&lookup.query,
vec![String::from_utf8(key).unwrap_or_default().into()],
)
.await
.map(|row| row.is_some()),
}
}
pub async fn is_rate_allowed(
&self,
key: &[u8],
rate: &Rate,
soft_check: bool,
) -> crate::Result<Option<u64>> {
let now = now();
let range_start = now / rate.period.as_secs();
let range_end = (range_start * rate.period.as_secs()) + rate.period.as_secs();
let expires_in = range_end - now;
let mut bucket = Vec::with_capacity(key.len() + U64_LEN);
bucket.extend_from_slice(key);
bucket.extend_from_slice(range_start.to_be_bytes().as_slice());
let requests = if !soft_check {
let requests = self.counter_incr(bucket, 1, expires_in.into()).await?;
if requests > 0 {
requests - 1
} else {
// Increment and get not supported by store, fetch counter
let mut bucket = Vec::with_capacity(key.len() + U64_LEN);
bucket.extend_from_slice(key);
bucket.extend_from_slice(range_start.to_be_bytes().as_slice());
self.counter_get(bucket).await?.saturating_sub(1)
}
} else {
self.counter_get(bucket).await?
};
if requests < rate.requests as i64 {
Ok(None)
} else {
Ok(Some(expires_in))
}
}
pub async fn purge_expired(&self) -> crate::Result<()> {
match self {
LookupStore::Store(store) => {
let from_key = ValueKey {
account_id: 0,
collection: 0,
document_id: 0,
class: ValueClass::Key(vec![0u8]),
};
let to_key = ValueKey {
account_id: 0,
collection: 0,
document_id: 0,
class: ValueClass::Key(vec![u8::MAX; 10]),
};
// Delete expired keys
let from_key = ValueKey::from(ValueClass::Lookup(LookupClass::Key(vec![0u8])));
let to_key =
ValueKey::from(ValueClass::Lookup(LookupClass::Key(vec![u8::MAX; 10])));
let current_time = now();
let mut expired_keys = Vec::new();
@ -192,7 +312,46 @@ impl LookupStore {
let mut batch = BatchBuilder::new();
for key in expired_keys {
batch.ops.push(Operation::Value {
class: ValueClass::Key(key),
class: ValueClass::Lookup(LookupClass::Key(key)),
op: ValueOp::Clear,
});
if batch.ops.len() >= 1000 {
store.write(batch.build()).await?;
batch = BatchBuilder::new();
}
}
if !batch.ops.is_empty() {
store.write(batch.build()).await?;
}
}
// Delete expired counters
let from_key =
ValueKey::from(ValueClass::Lookup(LookupClass::CounterExpiry(vec![0u8])));
let to_key = ValueKey::from(ValueClass::Lookup(LookupClass::CounterExpiry(vec![
u8::MAX;
10
])));
let current_time = now();
let mut expired_keys = Vec::new();
store
.iterate(IterateParams::new(from_key, to_key), |key, value| {
if value.deserialize_be_u64(0)? < current_time {
expired_keys.push(key.get(1..).unwrap_or_default().to_vec());
}
Ok(true)
})
.await?;
if !expired_keys.is_empty() {
let mut batch = BatchBuilder::new();
for key in expired_keys {
batch.ops.push(Operation::Value {
class: ValueClass::Lookup(LookupClass::Counter(key.clone())),
op: ValueOp::Clear,
});
batch.ops.push(Operation::Value {
class: ValueClass::Lookup(LookupClass::CounterExpiry(key)),
op: ValueOp::Clear,
});
if batch.ops.len() >= 1000 {
@ -214,14 +373,16 @@ impl LookupStore {
}
}
enum LookupValue<T> {
Value(T),
None,
}
impl<T: Deserialize> Deserialize for LookupValue<T> {
fn deserialize(bytes: &[u8]) -> crate::Result<Self> {
bytes.deserialize_be_u64(0).and_then(|expires| {
Ok(if expires > now() {
LookupValue::Value {
value: T::deserialize(bytes.get(U64_LEN..).unwrap_or_default())?,
expires,
}
LookupValue::Value(T::deserialize(bytes.get(U64_LEN..).unwrap_or_default())?)
} else {
LookupValue::None
})
@ -229,6 +390,15 @@ impl<T: Deserialize> Deserialize for LookupValue<T> {
}
}
impl<T> From<LookupValue<T>> for Option<T> {
fn from(value: LookupValue<T>) -> Self {
match value {
LookupValue::Value(value) => Some(value),
LookupValue::None => None,
}
}
}
impl From<Value<'static>> for String {
fn from(value: Value<'static>) -> Self {
match value {

View file

@ -255,7 +255,8 @@ impl Store {
Self::RocksDb(store) => store.purge_bitmaps().await,
}
}
pub(crate) async fn delete_range(&self, from: impl Key, to: impl Key) -> crate::Result<()> {
pub async fn delete_range(&self, from: impl Key, to: impl Key) -> crate::Result<()> {
match self {
#[cfg(feature = "sqlite")]
Self::SQLite(store) => store.delete_range(from, to).await,
@ -395,9 +396,11 @@ impl Store {
#[cfg(feature = "test_mode")]
pub async fn blob_expire_all(&self) {
use utils::{BlobHash, BLOB_HASH_LEN};
use crate::{
write::{key::DeserializeBigEndian, BatchBuilder, BlobOp, Operation, ValueOp},
BlobHash, BLOB_HASH_LEN, U64_LEN,
U64_LEN,
};
// Delete all temporary hashes

View file

@ -119,13 +119,9 @@ pub struct LogKey {
pub change_id: u64,
}
pub const BLOB_HASH_LEN: usize = 32;
pub const U64_LEN: usize = std::mem::size_of::<u64>();
pub const U32_LEN: usize = std::mem::size_of::<u32>();
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct BlobHash([u8; BLOB_HASH_LEN]);
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
@ -325,19 +321,6 @@ impl From<MemoryStore> for LookupStore {
}
}
#[derive(Clone, Debug)]
pub enum LookupKey {
Key(Vec<u8>),
Counter(Vec<u8>),
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum LookupValue<T> {
Value { value: T, expires: u64 },
Counter { num: i64 },
None,
}
#[derive(Clone, Debug, PartialEq)]
pub enum Value<'x> {
Integer(i64),
@ -363,16 +346,6 @@ impl<'x> Value<'x> {
}
}
impl From<LookupKey> for String {
fn from(value: LookupKey) -> Self {
let key = match value {
LookupKey::Key(key) | LookupKey::Counter(key) => key,
};
String::from_utf8(key)
.unwrap_or_else(|err| String::from_utf8_lossy(&err.into_bytes()).into_owned())
}
}
#[derive(Clone, Debug)]
pub struct Row {
pub values: Vec<Value<'static>>,
@ -558,6 +531,16 @@ impl<'x> From<i64> for Value<'x> {
}
}
impl From<Value<'static>> for i64 {
fn from(value: Value<'static>) -> Self {
if let Value::Integer(value) = value {
value
} else {
0
}
}
}
impl<'x> From<u64> for Value<'x> {
fn from(value: u64) -> Self {
Self::Integer(value as i64)

View file

@ -22,10 +22,11 @@
*/
use ahash::AHashSet;
use utils::{BlobHash, BLOB_HASH_LEN};
use crate::{
write::BatchBuilder, BlobClass, BlobHash, BlobStore, Deserialize, IterateParams, Store,
ValueKey, BLOB_HASH_LEN, U32_LEN, U64_LEN,
write::BatchBuilder, BlobClass, BlobStore, Deserialize, IterateParams, Store, ValueKey,
U32_LEN, U64_LEN,
};
use super::{key::DeserializeBigEndian, now, BlobOp, Operation, ValueClass, ValueOp};

View file

@ -22,15 +22,18 @@
*/
use std::convert::TryInto;
use utils::codec::leb128::Leb128_;
use utils::{codec::leb128::Leb128_, BLOB_HASH_LEN};
use crate::{
BitmapKey, IndexKey, IndexKeyPrefix, Key, LogKey, ValueKey, BLOB_HASH_LEN, SUBSPACE_BITMAPS,
SUBSPACE_INDEXES, SUBSPACE_LOGS, SUBSPACE_VALUES, U32_LEN, U64_LEN, WITHOUT_BLOCK_NUM,
WITH_SUBSPACE,
BitmapKey, Deserialize, IndexKey, IndexKeyPrefix, Key, LogKey, ValueKey, SUBSPACE_BITMAPS,
SUBSPACE_COUNTERS, SUBSPACE_INDEXES, SUBSPACE_LOGS, SUBSPACE_VALUES, U32_LEN, U64_LEN,
WITHOUT_BLOCK_NUM, WITH_SUBSPACE,
};
use super::{AnyKey, BitmapClass, BlobOp, DirectoryClass, TagValue, ValueClass};
use super::{
AnyKey, BitmapClass, BlobOp, DirectoryClass, LookupClass, QueueClass, ReportEvent, TagValue,
ValueClass,
};
pub struct KeySerializer {
pub buf: Vec<u8>,
@ -217,7 +220,16 @@ impl Key for LogKey {
impl<T: AsRef<ValueClass> + Sync + Send> Key for ValueKey<T> {
fn subspace(&self) -> u8 {
if !matches!(
self.class.as_ref(),
ValueClass::Directory(DirectoryClass::UsedQuota(_))
| ValueClass::Lookup(LookupClass::Counter(_))
| ValueClass::Queue(QueueClass::QuotaCount(_) | QueueClass::QuotaSize(_))
) {
SUBSPACE_VALUES
} else {
SUBSPACE_COUNTERS
}
}
fn serialize(&self, flags: u32) -> Vec<u8> {
@ -250,7 +262,6 @@ impl<T: AsRef<ValueClass> + Sync + Send> Key for ValueKey<T> {
.write(self.account_id)
.write(self.collection)
.write(self.document_id),
ValueClass::Key(key) => serializer.write(4u8).write(key.as_slice()),
ValueClass::IndexEmail(seq) => serializer
.write(5u8)
.write(*seq)
@ -276,6 +287,11 @@ impl<T: AsRef<ValueClass> + Sync + Send> Key for ValueKey<T> {
.write(self.document_id),
},
ValueClass::Config(key) => serializer.write(8u8).write(key.as_slice()),
ValueClass::Lookup(lookup) => match lookup {
LookupClass::Key(key) => serializer.write(4u8).write(key.as_slice()),
LookupClass::Counter(key) => serializer.write(9u8).write(key.as_slice()),
LookupClass::CounterExpiry(key) => serializer.write(10u8).write(key.as_slice()),
},
ValueClass::Directory(directory) => match directory {
DirectoryClass::NameToId(name) => serializer.write(20u8).write(name.as_slice()),
DirectoryClass::EmailToId(email) => serializer.write(21u8).write(email.as_slice()),
@ -297,6 +313,41 @@ impl<T: AsRef<ValueClass> + Sync + Send> Key for ValueKey<T> {
.write(*principal_id)
.write(*has_member),
},
ValueClass::Queue(queue) => match queue {
QueueClass::Message(queue_id) => serializer.write(50u8).write(*queue_id),
QueueClass::MessageEvent(event) => serializer
.write(51u8)
.write(event.due)
.write(event.queue_id),
QueueClass::DmarcReportHeader(event) => serializer
.write(52u8)
.write(event.due)
.write(event.domain.as_bytes())
.write(event.policy_hash)
.write(event.seq_id)
.write(0u8),
QueueClass::TlsReportHeader(event) => serializer
.write(52u8)
.write(event.due)
.write(event.domain.as_bytes())
.write(event.policy_hash)
.write(event.seq_id)
.write(1u8),
QueueClass::DmarcReportEvent(event) => serializer
.write(53u8)
.write(event.due)
.write(event.domain.as_bytes())
.write(event.policy_hash)
.write(event.seq_id),
QueueClass::TlsReportEvent(event) => serializer
.write(54u8)
.write(event.due)
.write(event.domain.as_bytes())
.write(event.policy_hash)
.write(event.seq_id),
QueueClass::QuotaCount(key) => serializer.write(55u8).write(key.as_slice()),
QueueClass::QuotaSize(key) => serializer.write(56u8).write(key.as_slice()),
},
}
.finalize()
}
@ -425,7 +476,10 @@ impl ValueClass {
U32_LEN * 2 + 3
}
ValueClass::Acl(_) => U32_LEN * 3 + 2,
ValueClass::Key(v) | ValueClass::Config(v) => v.len(),
ValueClass::Lookup(
LookupClass::Counter(v) | LookupClass::CounterExpiry(v) | LookupClass::Key(v),
)
| ValueClass::Config(v) => v.len(),
ValueClass::Directory(d) => match d {
DirectoryClass::NameToId(v)
| DirectoryClass::EmailToId(v)
@ -438,6 +492,17 @@ impl ValueClass {
BlobOp::Commit { .. } | BlobOp::Link { .. } => BLOB_HASH_LEN + U32_LEN * 2 + 2,
},
ValueClass::IndexEmail { .. } => U64_LEN * 2,
ValueClass::Queue(q) => match q {
QueueClass::Message(_) => U64_LEN,
QueueClass::MessageEvent(_) => U64_LEN * 2,
QueueClass::DmarcReportEvent(event) | QueueClass::TlsReportEvent(event) => {
event.domain.len() + U64_LEN * 3
}
QueueClass::DmarcReportHeader(event) | QueueClass::TlsReportHeader(event) => {
event.domain.len() + (U64_LEN * 3) + 1
}
QueueClass::QuotaCount(v) | QueueClass::QuotaSize(v) => v.len(),
},
}
}
}
@ -475,3 +540,20 @@ impl From<BlobOp> for ValueClass {
ValueClass::Blob(value)
}
}
impl Deserialize for ReportEvent {
fn deserialize(key: &[u8]) -> crate::Result<Self> {
Ok(ReportEvent {
due: key.deserialize_be_u64(1)?,
policy_hash: key.deserialize_be_u64(key.len() - (U64_LEN * 2 + 1))?,
seq_id: key.deserialize_be_u64(key.len() - (U64_LEN + 1))?,
domain: key
.get(U64_LEN + 1..key.len() - (U64_LEN * 2 + 1))
.and_then(|domain| std::str::from_utf8(domain).ok())
.map(|s| s.to_string())
.ok_or_else(|| {
crate::Error::InternalError("Failed to deserialize report domain".into())
})?,
})
}
}

View file

@ -29,12 +29,13 @@ use std::{
};
use nlp::tokenizers::word::WordTokenizer;
use utils::codec::leb128::{Leb128Iterator, Leb128Vec};
use crate::{
backend::MAX_TOKEN_LENGTH, BlobClass, BlobHash, Deserialize, Serialize, BLOB_HASH_LEN,
use utils::{
codec::leb128::{Leb128Iterator, Leb128Vec},
BlobHash,
};
use crate::{backend::MAX_TOKEN_LENGTH, BlobClass, Deserialize, Serialize};
use self::assert::AssertValue;
pub mod assert;
@ -131,13 +132,21 @@ pub enum TagValue {
pub enum ValueClass {
Property(u8),
Acl(u32),
Key(Vec<u8>),
Lookup(LookupClass),
TermIndex,
ReservedId,
Directory(DirectoryClass),
Blob(BlobOp),
IndexEmail(u64),
Config(Vec<u8>),
Queue(QueueClass),
}
#[derive(Debug, PartialEq, Clone, Eq, Hash)]
pub enum LookupClass {
Key(Vec<u8>),
Counter(Vec<u8>),
CounterExpiry(Vec<u8>),
}
#[derive(Debug, PartialEq, Clone, Eq, Hash)]
@ -151,6 +160,32 @@ pub enum DirectoryClass {
UsedQuota(u32),
}
#[derive(Debug, PartialEq, Clone, Eq, Hash)]
pub enum QueueClass {
Message(u64),
MessageEvent(QueueEvent),
DmarcReportHeader(ReportEvent),
DmarcReportEvent(ReportEvent),
TlsReportHeader(ReportEvent),
TlsReportEvent(ReportEvent),
QuotaCount(Vec<u8>),
QuotaSize(Vec<u8>),
}
#[derive(Debug, PartialEq, Clone, Eq, Hash)]
pub struct QueueEvent {
pub due: u64,
pub queue_id: u64,
}
#[derive(Debug, PartialEq, Clone, Eq, Hash)]
pub struct ReportEvent {
pub due: u64,
pub policy_hash: u64,
pub seq_id: u64,
pub domain: String,
}
#[derive(Debug, PartialEq, Eq, Hash, Default)]
pub enum ValueOp {
Set(Vec<u8>),
@ -264,6 +299,14 @@ impl Deserialize for u64 {
}
}
impl Deserialize for i64 {
fn deserialize(bytes: &[u8]) -> crate::Result<Self> {
Ok(i64::from_be_bytes(bytes.try_into().map_err(|_| {
crate::Error::InternalError("Failed to deserialize i64".to_string())
})?))
}
}
impl Deserialize for u32 {
fn deserialize(bytes: &[u8]) -> crate::Result<Self> {
Ok(u32::from_be_bytes(bytes.try_into().map_err(|_| {
@ -527,68 +570,12 @@ impl BitmapClass {
}
}
impl BlobHash {
pub fn new_max() -> Self {
BlobHash([u8::MAX; BLOB_HASH_LEN])
}
pub fn try_from_hash_slice(value: &[u8]) -> Result<BlobHash, std::array::TryFromSliceError> {
value.try_into().map(BlobHash)
}
pub fn as_slice(&self) -> &[u8] {
self.0.as_ref()
}
}
impl From<&[u8]> for BlobHash {
fn from(value: &[u8]) -> Self {
BlobHash(blake3::hash(value).into())
}
}
impl From<Vec<u8>> for BlobHash {
fn from(value: Vec<u8>) -> Self {
value.as_slice().into()
}
}
impl From<&Vec<u8>> for BlobHash {
fn from(value: &Vec<u8>) -> Self {
value.as_slice().into()
}
}
impl AsRef<BlobHash> for BlobHash {
fn as_ref(&self) -> &BlobHash {
self
}
}
impl AsRef<[u8]> for BlobHash {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
impl AsMut<[u8]> for BlobHash {
fn as_mut(&mut self) -> &mut [u8] {
self.0.as_mut()
}
}
impl AsRef<BlobClass> for BlobClass {
fn as_ref(&self) -> &BlobClass {
self
}
}
impl From<BlobHash> for Vec<u8> {
fn from(value: BlobHash) -> Self {
value.0.to_vec()
}
}
impl BlobClass {
pub fn account_id(&self) -> u32 {
match self {
@ -605,3 +592,57 @@ impl BlobClass {
}
}
}
pub struct Bincode<T: serde::Serialize + serde::de::DeserializeOwned> {
pub inner: T,
}
impl<T: serde::Serialize + serde::de::DeserializeOwned> Bincode<T> {
pub fn new(inner: T) -> Self {
Self { inner }
}
}
impl<T: serde::Serialize + serde::de::DeserializeOwned> Serialize for &Bincode<T> {
fn serialize(self) -> Vec<u8> {
lz4_flex::compress_prepend_size(&bincode::serialize(&self.inner).unwrap_or_default())
}
}
impl<T: serde::Serialize + serde::de::DeserializeOwned> Serialize for Bincode<T> {
fn serialize(self) -> Vec<u8> {
lz4_flex::compress_prepend_size(&bincode::serialize(&self.inner).unwrap_or_default())
}
}
impl<T: serde::Serialize + serde::de::DeserializeOwned + Sized + Sync + Send> Deserialize
for Bincode<T>
{
fn deserialize(bytes: &[u8]) -> crate::Result<Self> {
lz4_flex::decompress_size_prepended(bytes)
.map_err(|err| {
crate::Error::InternalError(format!("Bincode decompression failed: {err:?}"))
})
.and_then(|result| {
bincode::deserialize(&result).map_err(|err| {
crate::Error::InternalError(format!(
"Bincode deserialization failed (len {}): {err:?}",
result.len()
))
})
})
.map(|inner| Self { inner })
}
}
impl<T: serde::Serialize + serde::de::DeserializeOwned> ToBitmaps for Bincode<T> {
fn to_bitmaps(&self, _ops: &mut Vec<crate::write::Operation>, _field: u8, _set: bool) {
unreachable!()
}
}
impl<T: serde::Serialize + serde::de::DeserializeOwned> ToBitmaps for &Bincode<T> {
fn to_bitmaps(&self, _ops: &mut Vec<crate::write::Operation>, _field: u8, _set: bool) {
unreachable!()
}
}

View file

@ -39,6 +39,7 @@ arc-swap = "1.6.0"
futures = "0.3"
proxy-header = { version = "0.1.0", features = ["tokio"] }
regex = "1.7.0"
blake3 = "1.3.3"
[target.'cfg(unix)'.dependencies]
privdrop = "0.5.3"

View file

@ -21,9 +21,11 @@
* for more details.
*/
use std::{borrow::Cow, path::PathBuf};
use std::borrow::Cow;
use tokio::{fs, io::AsyncReadExt, sync::oneshot};
use tokio::sync::oneshot;
use crate::BlobHash;
#[derive(Debug)]
pub enum DeliveryEvent {
@ -38,7 +40,7 @@ pub enum DeliveryEvent {
pub struct IngestMessage {
pub sender_address: String,
pub recipients: Vec<String>,
pub message_path: PathBuf,
pub message_blob: BlobHash,
pub message_size: usize,
}
@ -53,29 +55,3 @@ pub enum DeliveryResult {
reason: Cow<'static, str>,
},
}
impl IngestMessage {
pub async fn read_message(&self) -> Result<Vec<u8>, ()> {
let mut raw_message = vec![0u8; self.message_size];
let mut file = fs::File::open(&self.message_path).await.map_err(|err| {
tracing::error!(
context = "read_message",
event = "error",
"Failed to open message file {}: {}",
self.message_path.display(),
err
);
})?;
file.read_exact(&mut raw_message).await.map_err(|err| {
tracing::error!(
context = "read_message",
event = "error",
"Failed to read {} bytes file {} from disk: {}",
self.message_size,
self.message_path.display(),
err
);
})?;
Ok(raw_message)
}
}

View file

@ -50,6 +50,66 @@ use rustls_pki_types::TrustAnchor;
use tracing_appender::non_blocking::WorkerGuard;
use tracing_subscriber::{prelude::__tracing_subscriber_SubscriberExt, EnvFilter};
pub const BLOB_HASH_LEN: usize = 32;
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct BlobHash([u8; BLOB_HASH_LEN]);
impl BlobHash {
pub fn new_max() -> Self {
BlobHash([u8::MAX; BLOB_HASH_LEN])
}
pub fn try_from_hash_slice(value: &[u8]) -> Result<BlobHash, std::array::TryFromSliceError> {
value.try_into().map(BlobHash)
}
pub fn as_slice(&self) -> &[u8] {
self.0.as_ref()
}
}
impl From<&[u8]> for BlobHash {
fn from(value: &[u8]) -> Self {
BlobHash(blake3::hash(value).into())
}
}
impl From<Vec<u8>> for BlobHash {
fn from(value: Vec<u8>) -> Self {
value.as_slice().into()
}
}
impl From<&Vec<u8>> for BlobHash {
fn from(value: &Vec<u8>) -> Self {
value.as_slice().into()
}
}
impl AsRef<BlobHash> for BlobHash {
fn as_ref(&self) -> &BlobHash {
self
}
}
impl From<BlobHash> for Vec<u8> {
fn from(value: BlobHash) -> Self {
value.0.to_vec()
}
}
impl AsRef<[u8]> for BlobHash {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
impl AsMut<[u8]> for BlobHash {
fn as_mut(&mut self) -> &mut [u8] {
self.0.as_mut()
}
}
pub trait UnwrapFailure<T> {
fn failed(self, action: &str) -> T;
}

View file

@ -6,7 +6,7 @@ resolver = "2"
[features]
#default = ["sqlite", "foundationdb", "postgres", "mysql", "rocks", "elastic", "s3", "redis"]
default = ["sqlite", "postgres", "mysql"]
default = ["sqlite", "postgres", "mysql", "redis"]
sqlite = ["store/sqlite"]
foundationdb = ["store/foundation"]
postgres = ["store/postgres"]

View file

@ -7,14 +7,14 @@ DOMAIN="example.org"
#STORE="foundationdb"
#FTS_STORE="foundationdb"
#BLOB_STORE="foundationdb"
STORE="rocksdb"
FTS_STORE="rocksdb"
BLOB_STORE="rocksdb"
#STORE="sqlite"
#FTS_STORE="sqlite"
#BLOB_STORE="sqlite"
FEATURES="foundationdb postgres mysql rocks elastic s3 redis"
#FEATURES="sqlite"
#STORE="rocksdb"
#FTS_STORE="rocksdb"
#BLOB_STORE="rocksdb"
STORE="sqlite"
FTS_STORE="sqlite"
BLOB_STORE="sqlite"
#FEATURES="foundationdb postgres mysql rocks elastic s3 redis"
FEATURES="sqlite"
# Directories
DIRECTORY="internal"