From 37ea712cc2399b433fe31e2b55786eaec42cd052 Mon Sep 17 00:00:00 2001 From: Mauro D Date: Sun, 11 Dec 2022 18:05:34 +0000 Subject: [PATCH] If block parser implementation. --- resources/config/config.toml | 128 ++++++++++------- src/config/certificate.rs | 8 +- src/config/condition.rs | 166 ++++++++++++--------- src/config/if_block.rs | 144 +++++++++++++++++++ src/config/mod.rs | 81 ++++++++--- src/config/parser.rs | 4 +- src/config/server.rs | 214 +++++++++++++++------------- src/config/stage.rs | 12 ++ src/config/throttle.rs | 163 +++++++++++---------- src/config/utils.rs | 269 +++++++++++++++++++++++++++-------- src/core/rate_limit.rs | 61 ++++---- 11 files changed, 842 insertions(+), 408 deletions(-) create mode 100644 src/config/if_block.rs diff --git a/resources/config/config.toml b/resources/config/config.toml index 3b3bc85..de1b613 100644 --- a/resources/config/config.toml +++ b/resources/config/config.toml @@ -12,7 +12,7 @@ tls.implicit = false id = "smtps" bind = ["0.0.0.0:465"] tls.implicit = true -tls.sni = [{subject = "domain.org", certificate = "abc"}] +tls.sni = [{subject = "domain.org", pki = "abc"}] socket.backlog = 1024 [[server.listener]] @@ -24,8 +24,8 @@ tls = {implicit = true} enable = true, implicit = true, timeout = 300 -certificate = default -sni = [{subject = "domain.org", certificate = "abc"}] +pki = default +sni = [{subject = "domain.org", pki = "abc"}] protocols = ["TLSv1.2", TLSv1.3"] ciphers = ["cipher1", "cipher2"] ignore_client_order = true @@ -45,18 +45,15 @@ backlog = 1024, [stage.connect] script = "connect.sieve" concurrency = 10000 -throttle = [ - {key = ["remoteip", "localip"], concurrency = 10000, rate = "3/5m"}, -] -[[stage.connect.rule]] -match-if = {remote-ip = "127.0.0.1"} -throttle = {key = ["remoteip", "localip"], concurrency = 1000, rate = [3, 10]} +[[stage.connect.throttle]] +key = ["remoteip", "localip"] +concurrency = 10000 +rate = "3/5m" [stage.ehlo] require = true script = ehlo.sieve -authenticate = {spf = "strict"} timeout = "5m" max-ehlos = 1 @@ -65,39 +62,37 @@ pipelining = true chunking = true requiretls = true no-soliciting = "" -auth = [] dsn = true -future-release = false +future-release = [ + {if = "is-submission", then = "5d"}, + {else = false} +] deliver-by = false mt-priority = false size = 100000 expn = false -[[stage.ehlo.rule]] -match-if = {listener = "submission"} -authenticate = {spf = "none"} -capabilities = {expn = true, auth = ["mechanism"], dsn = true, future-release = "5d", deliver-by = "30m", mt-priority = "mixer"} - [stage.auth] +mechanisms = [ + {if = "is-submission", then = ["plain", "login"]}, + {else = []} +] +require = [ + {if = "is-submission", then = true}, + {else = false} +] auth-host = "auth-server" -require = false timeout = 10 errors = {total = 3, wait = "5s"} -[[stage.auth.rule]] -match-if = {listener = "submission"} -require = true - [stage.mail] -auth = {spf = "strict"} script = mail-from.sieve timeout = 10 -throttle = {key = "mail-from", concurrency = 10, rate = [3, 60]} -[[stage.mail.rule]] -match-if = {id = "submission"} -authenticate = {spf = "disable"} -#capabilities = {deliver-by = "1000s", future-release = "1000s"} +[[stage.mail.throttle]] +key = "mail-from" +concurrency = 10 +rate = [3, 60] [stage.rcpt] script = rcpt-to.sieve @@ -109,15 +104,9 @@ errors = {total = 3, wait = "5s"} timeout = "5m" max-recipients = 100 -[[stage.rcpt.rule]] -match-if = {id = "submission"} -max-recipients = 1000 - [stage.data] -authenticate = {dkim = "true", arc = "true", dmarc = "true"} script = data.sieve timeout = "10s" -add-headers = ["arc-seal", "received-spf", "authentication-results", "return-path", "received"] [stage.data.limits] messages = 10 @@ -126,13 +115,55 @@ received-headers = 50 mime-parts = 50 nested-messages = 3 -[[stage.data]] -match-if = {id = "submission"} -authenticate = {dkim = "false", arc = "false", dmarc = "false"} -add-headers = ["dkim-signature", "arc-seal", "return-path", "received"] +[state.data.add-headers] +received = true +received-spf = true +return-path = true +auth-results = true +message-id = true +date = true -[stage.disconnect] -script = disconnect.sieve +[stage.queue] +script = queue.sieve +queue-id = "local" + +[[rule]] +id = "" + + +[dkim] +verify = "strict" + +[dkim.sign] +pki = "cert" +domain = "" +selector = "" +headers = ["From", "To", "Date", "Message-ID"] +algorithm = "rsa-sha256" +canonicalization = "simple/relaxed" +expire = "10d" +third-party = "" +third-party-algo = "" +auid = "" +set-body-length = false +reporting = true + +[spf] +verify-ehlo = "" +verify-mail-from = "" + +[arc] +verify = "strict" + +[arc.seal] +pki = "cert" +domain = "" +selector = "" +headers = ["From", "To", "Date", "Message-ID", "DKIM-Signature"] +algorithm = "rsa-sha256" +canonicalization = "simple/relaxed" +expire = "10d" +set-body-length = false [external] @@ -144,7 +175,7 @@ auth.username = "hello" auth.password = "world" tls = "optional, require, dane, dane-fallback-require, dane-require -[queue] +[queue."out"] retry = [0, 1, 15, 60, 90] notify = [9, 10] prefer = ipv6 @@ -153,15 +184,13 @@ tls = optional, require, dane, dane-fallback-require, dane-require limits = { attempts = 100, time = 3600, queued-messages = 10000, queue-size = 1000000 } throttle = { rate = 1/60, concurrency = 1000, key = all } -[[queue.virtual]] -match-if = {rcpt-domain = "*.example.org"} -id = "local" +[queue."local"] relay-host = "lmtp" -[rules] +[rule] -[rules."is-local"] +[rule."is-local"] rcpt-domain = ["*.example.org"] rcpt = [""] server-id = "relay" @@ -221,11 +250,12 @@ send-reports = true report-frequency = requested, 86400 incoming-address = "dmarc@*" -[certificate.default] +[key] +[key."default"] type = "rsa" -cert = ''' +certificate = ''' -----BEGIN CERTIFICATE----- ''' -pki = ''' +private-key = ''' -----BEGIN PRIVATE KEY----- ''' diff --git a/src/config/certificate.rs b/src/config/certificate.rs index 681a7f7..82c1b7b 100644 --- a/src/config/certificate.rs +++ b/src/config/certificate.rs @@ -38,9 +38,9 @@ impl KeyLog for KeyLogger { impl Config { pub fn rustls_certificate(&self, cert_id: &str) -> super::Result { certs(&mut Cursor::new(self.file_contents(( - "certificate", + "key", cert_id, - "cert", + "certificate", ))?)) .map_err(|err| { format!( @@ -61,9 +61,9 @@ impl Config { pub fn rustls_private_key(&self, cert_id: &str) -> super::Result { match read_one(&mut Cursor::new(self.file_contents(( - "certificate", + "key", cert_id, - "pki", + "private-key", ))?)) .map_err(|err| { format!( diff --git a/src/config/condition.rs b/src/config/condition.rs index 493d911..9394956 100644 --- a/src/config/condition.rs +++ b/src/config/condition.rs @@ -1,12 +1,27 @@ -use std::{net::IpAddr, str::FromStr}; +use std::{collections::hash_map::Entry, net::IpAddr}; use super::{ - utils::{AsKey, ParseKey}, - Condition, Config, IpAddrMask, StringMatch, + utils::{AsKey, ParseKey, ParseValue}, + Condition, Conditions, Config, ConfigContext, IpAddrMask, StringMatch, }; impl Config { - pub fn parse_conditions(&self, key: impl AsKey) -> super::Result> { + pub fn parse_rules(&self, ctx: &mut ConfigContext) -> super::Result<()> { + for rule_name in self.sub_keys("rule") { + match ctx.rules.entry(rule_name.to_string()) { + Entry::Vacant(e) => { + e.insert(self.parse_conditions(("rule", rule_name))?.into()); + } + Entry::Occupied(_) => { + return Err(format!("Duplicate rule {:?} found.", rule_name)); + } + } + } + + Ok(()) + } + + fn parse_conditions(&self, key: impl AsKey) -> super::Result { let mut conditions = Vec::new(); let prefix = key.as_prefix(); @@ -19,76 +34,77 @@ impl Config { }; match property { "rcpt" => { + let value = value.parse_key(key_.as_str())?; if let Some(Condition::Recipient(values)) = conditions.last_mut() { - values.push(value.parse_key(key_.as_str())?); + values.push(value); } else { - conditions - .push(Condition::Recipient(vec![value.parse_key(key_.as_str())?])); + conditions.push(Condition::Recipient(vec![value])); } } "rcpt-domain" => { + let value = value.parse_key(key_.as_str())?; if let Some(Condition::RecipientDomain(values)) = conditions.last_mut() { - values.push(value.parse_key(key_.as_str())?); + values.push(value); } else { - conditions.push(Condition::RecipientDomain(vec![ - value.parse_key(key_.as_str())? - ])); + conditions.push(Condition::RecipientDomain(vec![value])); } } "sender" => { + let value = value.parse_key(key_.as_str())?; if let Some(Condition::Sender(values)) = conditions.last_mut() { - values.push(value.parse_key(key_.as_str())?); + values.push(value); } else { - conditions - .push(Condition::Sender(vec![value.parse_key(key_.as_str())?])); + conditions.push(Condition::Sender(vec![value])); } } "sender-domain" => { + let value = value.parse_key(key_.as_str())?; if let Some(Condition::SenderDomain(values)) = conditions.last_mut() { - values.push(value.parse_key(key_.as_str())?); + values.push(value); } else { - conditions.push(Condition::SenderDomain(vec![ - value.parse_key(key_.as_str())? - ])); + conditions.push(Condition::SenderDomain(vec![value])); } } "mx" => { + let value = value.parse_key(key_.as_str())?; if let Some(Condition::Mx(values)) = conditions.last_mut() { - values.push(value.parse_key(key_.as_str())?); + values.push(value); } else { - conditions.push(Condition::Mx(vec![value.parse_key(key_.as_str())?])); + conditions.push(Condition::Mx(vec![value])); } } "priority" => { + let value = value.parse_key(key_.as_str())?; if let Some(Condition::Priority(values)) = conditions.last_mut() { - values.push(value.parse_key(key_.as_str())?); + values.push(value); } else { - conditions - .push(Condition::Priority(vec![value.parse_key(key_.as_str())?])); + conditions.push(Condition::Priority(vec![value])); } } "listener" => { + let value = value.to_string(); if let Some(Condition::Listener(values)) = conditions.last_mut() { - values.push(value.parse_key(key_.as_str())?); + values.push(value); } else { - conditions - .push(Condition::Listener(vec![value.parse_key(key_.as_str())?])); + conditions.push(Condition::Listener(vec![value])); } } "local-ip" => { + let value = value.parse_key(key_.as_str())?; + if let Some(Condition::LocalIp(values)) = conditions.last_mut() { - values.push(value.parse_key(key_.as_str())?); + values.push(value); } else { - conditions - .push(Condition::LocalIp(vec![value.parse_key(key_.as_str())?])); + conditions.push(Condition::LocalIp(vec![value])); } } "remote-ip" => { + let value = value.parse_key(key_.as_str())?; + if let Some(Condition::RemoteIp(values)) = conditions.last_mut() { - values.push(value.parse_key(key_.as_str())?); + values.push(value); } else { - conditions - .push(Condition::RemoteIp(vec![value.parse_key(key_.as_str())?])); + conditions.push(Condition::RemoteIp(vec![value])); } } _ => { @@ -98,56 +114,70 @@ impl Config { } } - Ok(conditions) + Ok(Conditions { conditions }) } } -impl FromStr for StringMatch { - type Err = (); - - fn from_str(s: &str) -> Result { - Ok(if let Some(value) = s.strip_prefix("list:") { +impl ParseValue for StringMatch { + fn parse_value(_key: impl AsKey, value: &str) -> super::Result { + Ok(if let Some(value) = value.strip_prefix("list:") { StringMatch::InList(value.into()) - } else if let Some(value) = s.strip_prefix("regex:") { + } else if let Some(value) = value.strip_prefix("regex:") { StringMatch::RegexMatch(value.into()) - } else if let Some(value) = s.strip_prefix('*') { + } else if let Some(value) = value.strip_prefix('*') { StringMatch::StartsWith(value.into()) - } else if let Some(value) = s.strip_suffix('*') { + } else if let Some(value) = value.strip_suffix('*') { StringMatch::EndsWith(value.into()) } else { - StringMatch::EqualTo(s.into()) + StringMatch::EqualTo(value.into()) }) } } -impl FromStr for IpAddrMask { - type Err = (); - - fn from_str(s: &str) -> Result { - if let Some((addr, mask)) = s.rsplit_once('/') { - let mask = mask.trim().parse::().map_err(|_| ())?; - match addr.trim().parse::().map_err(|_| ())? { - IpAddr::V4(addr) if (8..=32).contains(&mask) => Ok(IpAddrMask::V4 { - addr, - mask: u32::MAX << (32 - mask), - }), - IpAddr::V6(addr) if (8..=128).contains(&mask) => Ok(IpAddrMask::V6 { - addr, - mask: u128::MAX << (128 - mask), - }), - _ => Err(()), +impl ParseValue for IpAddrMask { + fn parse_value(key: impl AsKey, value: &str) -> super::Result { + if let Some((addr, mask)) = value.rsplit_once('/') { + if let (Ok(addr), Ok(mask)) = + (addr.trim().parse::(), mask.trim().parse::()) + { + match addr { + IpAddr::V4(addr) if (8..=32).contains(&mask) => { + return Ok(IpAddrMask::V4 { + addr, + mask: u32::MAX << (32 - mask), + }) + } + IpAddr::V6(addr) if (8..=128).contains(&mask) => { + return Ok(IpAddrMask::V6 { + addr, + mask: u128::MAX << (128 - mask), + }) + } + _ => (), + } } } else { - Ok(match s.trim().parse::().map_err(|_| ())? { - IpAddr::V4(addr) => IpAddrMask::V4 { - addr, - mask: u32::MAX, - }, - IpAddr::V6(addr) => IpAddrMask::V6 { - addr, - mask: u128::MAX, - }, - }) + match value.trim().parse::() { + Ok(IpAddr::V4(addr)) => { + return Ok(IpAddrMask::V4 { + addr, + mask: u32::MAX, + }) + } + Ok(IpAddr::V6(addr)) => { + return Ok(IpAddrMask::V6 { + addr, + mask: u128::MAX, + }) + } + _ => (), + } } + + Err(format!( + "Invalid IP address {:?} for property {:?}.", + value, + key.as_key() + )) } } diff --git a/src/config/if_block.rs b/src/config/if_block.rs new file mode 100644 index 0000000..94b94e1 --- /dev/null +++ b/src/config/if_block.rs @@ -0,0 +1,144 @@ +use super::{ + utils::{AsKey, ParseValues}, + Config, ConfigContext, IfBlock, IfThen, +}; + +impl Config { + pub fn parse_if_block( + &self, + prefix: impl AsKey, + ctx: &ConfigContext, + ) -> super::Result>> { + let prefix = prefix.as_prefix(); + let key = prefix.as_key(); + + let mut found_if = false; + let mut found_else = false; + let mut found_then = false; + + // Parse conditions + let mut if_block = IfBlock::new(T::default()); + let mut last_array_pos = usize::MAX; + + for (item, value) in &self.keys { + if let Some(suffix_) = item.strip_prefix(&prefix) { + if let Some((array_pos, suffix)) = + suffix_.split_once('.').and_then(|(array_pos, suffix)| { + (array_pos.parse::().ok()?, suffix).into() + }) + { + if suffix == ".if" || suffix.starts_with(".if.") { + if array_pos != last_array_pos { + if last_array_pos != usize::MAX && !found_then { + return Err(format!( + "Missing 'then' in 'if' condition {} for property {:?}.", + last_array_pos + 1, + key + )); + } + + if_block.if_then.push(IfThen { + conditions: Vec::new(), + then: T::default(), + }); + + found_then = false; + last_array_pos = array_pos; + } + + if let Some(conditions) = ctx.rules.get(value) { + if_block + .if_then + .last_mut() + .unwrap() + .conditions + .push(conditions.clone()); + } else { + return Err(format!( + "Rule {:?} does not exist for property {:?}.", + value, key + )); + } + + found_if = true; + } else if suffix == ".else" || suffix.starts_with(".else.") { + if !found_else { + if found_if { + if_block.default = self.parse_values(( + prefix.as_str(), + suffix_.split_once(".else").unwrap().0, + "else", + ))?; + found_else = true; + } else { + return Err(format!( + "Found 'else' before 'if' for property {:?}.", + key + )); + } + } + } else if suffix == ".then" || suffix.starts_with(".then.") { + if !found_else { + if array_pos == last_array_pos { + if !found_then { + if_block.if_then.last_mut().unwrap().then = + self.parse_values(( + prefix.as_str(), + suffix_.split_once(".then").unwrap().0, + "then", + ))?; + found_then = true; + } + } else { + return Err(format!( + "Found 'then' without 'if' for property {:?}.", + key + )); + } + } else { + return Err(format!( + "Found 'then' in 'else' block for property {:?}.", + key + )); + } + } else { + return Err(format!("Invalid key {:?} found in 'if' block.", item)); + } + } else if !found_if { + // Found probably a multi-value, parse and return + if_block.default = self.parse_values(key.as_str())?; + return Ok(Some(if_block)); + } else { + return Err(format!("Invalid key {:?} found in 'if' block.", item)); + } + } else if item == &key { + // There is a single value, parse and return + if_block.default = self.parse_values(key.as_str())?; + return Ok(Some(if_block)); + } + } + + if !found_if { + Ok(None) + } else if !found_then { + Err(format!( + "Missing 'then' in 'if' condition {} for property {:?}.", + last_array_pos + 1, + key + )) + } else if !found_else { + Err(format!("Missing 'else' for property {:?}.", key)) + } else { + Ok(Some(if_block)) + } + } +} + +impl IfBlock { + pub fn new(value: T) -> Self { + Self { + if_then: Vec::with_capacity(0), + default: value, + } + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs index 81abd05..982e6d3 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,5 +1,6 @@ pub mod certificate; pub mod condition; +pub mod if_block; pub mod parser; pub mod server; pub mod stage; @@ -9,8 +10,11 @@ pub mod utils; use std::{ collections::BTreeMap, net::{Ipv4Addr, Ipv6Addr}, + sync::Arc, + time::Duration, }; +use ahash::AHashMap; use rustls::ServerConfig; use smtp_proto::MtPriority; use tokio::net::TcpListener; @@ -42,6 +46,10 @@ pub enum Condition { Priority(Vec), } +pub struct Conditions { + pub conditions: Vec, +} + pub enum StringMatch { EqualTo(String), StartsWith(String), @@ -59,11 +67,27 @@ pub enum ThrottleKey { LocalIp, } +pub struct IfThen { + pub conditions: Vec>, + pub then: T, +} + +#[derive(Default)] +pub struct IfBlock { + pub if_then: Vec>, + pub default: T, +} + pub struct Throttle { pub key: Vec, - pub concurrency: u64, - pub rate_requests: u64, - pub rate_period: u64, + pub concurrency: IfBlock, + pub rate: IfBlock, +} + +#[derive(Default)] +pub struct ThrottleRate { + pub requests: u64, + pub period: Duration, } pub enum IpAddrMask { @@ -71,17 +95,15 @@ pub enum IpAddrMask { V6 { addr: Ipv6Addr, mask: u128 }, } -pub struct ConnectStage { - pub script: String, - pub concurrency: u64, - pub pipelining: bool, +pub struct Connect { + pub script: Option>, + pub concurrency: IfBlock, pub throttle: Vec, } -pub struct EhloStage { +pub struct Ehlo { pub script: String, pub require: bool, - pub spf: AuthLevel, pub timeout: u64, pub max_commands: u64, @@ -90,7 +112,6 @@ pub struct EhloStage { pub chunking: bool, pub requiretls: bool, pub no_soliciting: Option, - pub auth: u64, pub future_release: Option, pub deliver_by: Option, pub mt_priority: MtPriority, @@ -98,22 +119,24 @@ pub struct EhloStage { pub expn: bool, } -pub struct AuthStage { +pub struct Auth { pub script: String, pub require: bool, pub auth_host: usize, + pub mechanisms: u64, pub timeout: u64, pub errors_max: usize, pub errors_wait: u64, } -pub struct MailStage { +pub struct Mail { pub script: String, pub spf: AuthLevel, pub timeout: u64, + pub throttle: Vec, } -pub struct RcptStage { +pub struct Rcpt { pub script: String, pub timeout: u64, @@ -132,17 +155,15 @@ pub struct RcptStage { // Limits pub max_recipients: usize, + + // Throttle + pub throttle: Vec, } -pub struct DataStage { +pub struct Data { pub script: String, pub timeout: u64, - // Message Authentication - pub dkim: AuthLevel, - pub arc: AuthLevel, - pub dmarc: AuthLevel, - // Limits pub max_messages: usize, pub max_message_size: usize, @@ -153,13 +174,27 @@ pub struct DataStage { // Headers pub add_received: bool, pub add_received_spf: bool, + pub add_return_path: bool, pub add_auth_results: bool, - pub add_dkim_signature: bool, - pub add_arc_seal: bool, pub add_message_id: bool, pub add_date: bool, } +pub struct Queue { + pub script: String, + pub queue_id: usize, +} + +pub struct Stage { + pub connect: Connect, + pub ehlo: Ehlo, + pub auth: Auth, + pub mail: Mail, + pub rcpt: Rcpt, + pub data: Data, + pub queue: Queue, +} + pub enum AuthLevel { Enable, Disable, @@ -171,4 +206,8 @@ pub struct Config { keys: BTreeMap, } +pub struct ConfigContext { + pub rules: AHashMap>, +} + pub type Result = std::result::Result; diff --git a/src/config/parser.rs b/src/config/parser.rs index 901ca62..4f7f70c 100644 --- a/src/config/parser.rs +++ b/src/config/parser.rs @@ -245,7 +245,7 @@ impl<'x> TomlParser<'x> { ']' => break, ch => { return Err(format!( - "Unexpected character {:?} found in array for key {:?} at line {}.", + "Unexpected character {:?} found in array for property {:?} at line {}.", ch, key, self.line )); } @@ -263,7 +263,7 @@ impl<'x> TomlParser<'x> { '}' => break, ch => { return Err(format!( - "Unexpected character {:?} found in inline table for key {:?} at line {}.", + "Unexpected character {:?} found in inline table for property {:?} at line {}.", ch, key, self.line )); } diff --git a/src/config/server.rs b/src/config/server.rs index bfc471a..74f49d0 100644 --- a/src/config/server.rs +++ b/src/config/server.rs @@ -1,4 +1,4 @@ -use std::{net::SocketAddr, str::FromStr, sync::Arc, time::Duration}; +use std::{net::SocketAddr, sync::Arc, time::Duration}; use rustls::{ cipher_suite::{ @@ -9,13 +9,13 @@ use rustls::{ }, server::{NoClientAuth, ResolvesServerCertUsingSni}, sign::{any_supported_type, CertifiedKey}, - ServerConfig, ALL_CIPHER_SUITES, ALL_KX_GROUPS, ALL_VERSIONS, + ServerConfig, SupportedCipherSuite, ALL_CIPHER_SUITES, ALL_KX_GROUPS, ALL_VERSIONS, }; use tokio::net::TcpSocket; use super::{ certificate::{CertificateResolver, TLS12_VERSION, TLS13_VERSION}, - utils::ParseKey, + utils::{AsKey, ParseKey, ParseValue}, Config, Server, ServerProtocol, }; @@ -41,7 +41,7 @@ impl Config { fn build_server(&self, array_pos: &str) -> super::Result { // Obtain server id - let id = self.property_require::(("server.listeners", array_pos, "id"))?; + let id = self.value_require(("server.listeners", array_pos, "id"))?; // Build TLS config let (tls, tls_implicit) = if self @@ -54,11 +54,11 @@ impl Config { // Parse protocol versions let mut tls_v2 = false; let mut tls_v3 = false; - for (key, protocol) in self.properties_or_default::( + for (key, protocol) in self.values_or_default( ("server.listener", array_pos, "tls.protocols"), "server.tls.protocols", ) { - match protocol?.as_str() { + match protocol { "TLSv1.2" | "0x0303" => tls_v2 = true, "TLSv1.3" | "0x0304" => tls_v3 = true, protocol => { @@ -75,75 +75,42 @@ impl Config { // Parse cipher suites let mut ciphers = Vec::new(); - for (key, protocol) in self.properties_or_default::( + for (key, protocol) in self.values_or_default( ("server.listener", array_pos, "tls.cipher"), "server.tls.cipher", ) { - ciphers.push(match protocol?.as_str() { - // TLS1.3 suites - "TLS13_AES_256_GCM_SHA384" => TLS13_AES_256_GCM_SHA384, - "TLS13_AES_128_GCM_SHA256" => TLS13_AES_128_GCM_SHA256, - "TLS13_CHACHA20_POLY1305_SHA256" => TLS13_CHACHA20_POLY1305_SHA256, - // TLS1.2 suites - "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => { - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 - } - "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => { - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 - } - "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => { - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 - } - "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" => { - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 - } - "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" => { - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 - } - "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => { - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 - } - cipher => { - return Err(format!( - "Unsupported TLS cipher suite {:?} found in key {:?}", - cipher, key - )) - } - }); + ciphers.push(protocol.parse_key(key)?); } // Obtain default certificate let cert_id = self - .property_or_default::( + .value_or_default( ("server.listener", array_pos, "tls.certificate"), "server.tls.certificate", - )? + ) .ok_or_else(|| format!("Undefined certificate id for listener {:?}.", id))?; - let cert = self.rustls_certificate(&cert_id)?; - let pki = self.rustls_private_key(&cert_id)?; + let cert = self.rustls_certificate(cert_id)?; + let pki = self.rustls_private_key(cert_id)?; // Add SNI certificates let mut resolver = ResolvesServerCertUsingSni::new(); - for (key, value) in self.properties_or_default::( - ("server.listener", array_pos, "tls.sni"), - "server.tls.sni", - ) { + for (key, value) in + self.values_or_default(("server.listener", array_pos, "tls.sni"), "server.tls.sni") + { if let Some(prefix) = key.strip_suffix(".subject") { resolver .add( - value?.as_str(), - match self.property::((prefix, "cert"))? { + value, + match self.value((prefix, "cert")) { Some(sni_cert_id) if sni_cert_id != cert_id => CertifiedKey { - cert: vec![self.rustls_certificate(&sni_cert_id)?], - key: any_supported_type( - &self.rustls_private_key(&sni_cert_id)?, - ) - .map_err(|err| { - format!( - "Failed to sign SNI certificate for {:?}: {}", - key, err - ) - })?, + cert: vec![self.rustls_certificate(sni_cert_id)?], + key: any_supported_type(&self.rustls_private_key(sni_cert_id)?) + .map_err(|err| { + format!( + "Failed to sign SNI certificate for {:?}: {}", + key, err + ) + })?, ocsp: None, sct_list: None, }, @@ -219,9 +186,9 @@ impl Config { // Build listeners let mut listeners = Vec::new(); - for (_, addr) in self.properties::(("server.listener", array_pos, "bind")) { + for result in self.properties::(("server.listener", array_pos, "bind")) { // Parse bind address and build socket - let addr = addr?; + let (_, addr) = result?; let socket = if addr.is_ipv4() { TcpSocket::new_v4() } else { @@ -230,35 +197,35 @@ impl Config { .map_err(|err| format!("Failed to create socket: {}", err))?; // Set socket options - for option in [ - "reuse-addr", - "reuse-port", - "send-buffer-size", - "recv-buffer-size", - "linger", - "tos", - ] { - if let Some(value) = self.property_or_default::( - ("server.listener", array_pos, "socket", option), - ("server.socket", option), - )? { - match option { - "reuse-addr" => socket.set_reuseaddr(value.parse_key(option)?), - "reuse-port" => socket.set_reuseport(value.parse_key(option)?), - "send-buffer-size" => socket.set_send_buffer_size(value.parse_key(option)?), - "recv-buffer-size" => socket.set_recv_buffer_size(value.parse_key(option)?), - "linger" => socket - .set_linger(Duration::from_millis(value.parse_key(option)?).into()), - "tos" => socket.set_tos(value.parse_key(option)?), - _ => unreachable!(), + for (key, value) in + self.values_or_default(("server.listener", array_pos, "socket"), "server.socket") + { + let option = key + .rsplit_once('.') + .map(|(_, option)| option) + .unwrap_or_default(); + match option { + "reuse-addr" => socket.set_reuseaddr(value.parse_key(key)?), + "reuse-port" => socket.set_reuseport(value.parse_key(key)?), + "send-buffer-size" => socket.set_send_buffer_size(value.parse_key(key)?), + "recv-buffer-size" => socket.set_recv_buffer_size(value.parse_key(key)?), + "linger" => { + socket.set_linger(Duration::from_millis(value.parse_key(key)?).into()) + } + "tos" => socket.set_tos(value.parse_key(key)?), + _ => { + return Err(format!( + "Invalid socket option {} for listener '{}'.", + option, id + )) } - .map_err(|err| { - format!( - "Failed to set socket option '{}' for listener '{}': {}", - option, id, err - ) - })?; } + .map_err(|err| { + format!( + "Failed to set socket option '{}' for listener '{}': {}", + option, id, err + ) + })?; } // Bind socket @@ -308,19 +275,21 @@ impl Config { } Ok(Server { - id, + id: id.to_string(), hostname: self - .property_or_default( + .value_or_default( ("server.listener", array_pos, "hostname"), "server.hostname", - )? - .ok_or("Hostname directive not found.")?, + ) + .ok_or("Hostname directive not found.")? + .to_string(), greeting: self - .property_or_default( + .value_or_default( ("server.listener", array_pos, "greeting"), "server.greeting", - )? - .unwrap_or_else(|| "Stalwart SMTP at your service".to_string()), + ) + .unwrap_or("Stalwart SMTP at your service") + .to_string(), protocol: self .property_or_default( ("server.listener", array_pos, "protocol"), @@ -334,16 +303,59 @@ impl Config { } } -impl FromStr for ServerProtocol { - type Err = (); - - fn from_str(s: &str) -> Result { - if s.eq_ignore_ascii_case("smtp") { +impl ParseValue for ServerProtocol { + fn parse_value(key: impl AsKey, value: &str) -> super::Result { + if value.eq_ignore_ascii_case("smtp") { Ok(Self::Smtp) - } else if s.eq_ignore_ascii_case("lmtp") { + } else if value.eq_ignore_ascii_case("lmtp") { Ok(Self::Lmtp) } else { - Err(()) + Err(format!( + "Invalid server protocol type {:?} for property {:?}.", + value, + key.as_key() + )) } } } + +impl ParseValue for SocketAddr { + fn parse_value(key: impl AsKey, value: &str) -> super::Result { + value.parse().map_err(|_| { + format!( + "Invalid socket address {:?} for property {:?}.", + value, + key.as_key() + ) + }) + } +} + +impl ParseValue for SupportedCipherSuite { + fn parse_value(key: impl AsKey, value: &str) -> super::Result { + Ok(match value { + // TLS1.3 suites + "TLS13_AES_256_GCM_SHA384" => TLS13_AES_256_GCM_SHA384, + "TLS13_AES_128_GCM_SHA256" => TLS13_AES_128_GCM_SHA256, + "TLS13_CHACHA20_POLY1305_SHA256" => TLS13_CHACHA20_POLY1305_SHA256, + // TLS1.2 suites + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => { + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 + } + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" => TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" => TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => { + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + } + cipher => { + return Err(format!( + "Unsupported TLS cipher suite {:?} found in key {:?}", + cipher, + key.as_key() + )) + } + }) + } +} diff --git a/src/config/stage.rs b/src/config/stage.rs index 8b13789..f5433fd 100644 --- a/src/config/stage.rs +++ b/src/config/stage.rs @@ -1 +1,13 @@ +use super::{Config, ConfigContext, Connect, IfBlock}; +impl Config { + fn parse_stage_connect(&self, ctx: &ConfigContext) -> super::Result { + Ok(Connect { + script: self.parse_if_block::("stage.connect.script", ctx)?, + concurrency: self + .parse_if_block::("stage.connect.concurrency", ctx)? + .unwrap_or_else(|| IfBlock::new(10000)), + throttle: self.parse_throttle_list("stage.connect.throttle", ctx)?, + }) + } +} diff --git a/src/config/throttle.rs b/src/config/throttle.rs index 1e733cc..f617b1e 100644 --- a/src/config/throttle.rs +++ b/src/config/throttle.rs @@ -1,98 +1,109 @@ -use std::str::FromStr; - use super::{ - utils::{AsKey, ParseKey}, - Config, Throttle, ThrottleKey, + utils::{AsKey, ParseKey, ParseValue}, + Config, ConfigContext, IfBlock, Throttle, ThrottleKey, ThrottleRate, }; impl Config { - pub fn parse_throttle(&self, prefix: impl AsKey) -> super::Result> { - let mut throttles = Vec::new(); + pub fn parse_throttle_list( + &self, + prefix: impl AsKey, + ctx: &ConfigContext, + ) -> super::Result> { + let mut result = Vec::new(); let key = prefix.as_key(); for array_pos in self.sub_keys(prefix) { - let mut throttle = Throttle { - key: Vec::new(), - concurrency: 0, - rate_requests: 0, - rate_period: 0, - }; - - // Parse keys - for (_, throttle_key) in - self.properties::((key.as_str(), array_pos, "key")) - { - throttle.key.push(throttle_key?); - } - - // Parse concurrency - if let Some(concurrency) = - self.property::((key.as_str(), array_pos, "concurrency"))? - { - throttle.concurrency = concurrency; - } - - // Parse rate - if let Some(rate) = self.value((key.as_str(), array_pos, "rate")) { - if let Some((requests, period)) = rate.split_once('/') { - throttle.rate_requests = requests - .trim() - .parse::() - .ok() - .and_then(|r| if r > 0 { Some(r) } else { None }) - .ok_or_else(|| { - format!( - "Invalid rate value {:?} for key {:?}.", - rate, - (key.as_str(), array_pos, "rate").as_key() - ) - })?; - throttle.rate_period = - period.parse_duration((key.as_str(), array_pos, "rate"))?; - } else { - return Err(format!( - "Invalid rate value {:?} for key {:?}.", - rate, - (key.as_str(), array_pos, "rate").as_key() - )); - } - } - - // Validate - if throttle.key.is_empty() { - return Err(format!( - "No throttle keys found in {:?}", - (key.as_str(), array_pos, "key").as_key() - )); - } else if throttle.rate_requests == 0 && throttle.concurrency == 0 { - return Err(format!( - concat!( - "Throttle {:?} needs to define a ", - "valid 'rate' or 'concurrency' property." - ), - (key.as_str(), array_pos).as_key() - )); - } - - throttles.push(throttle); + result.push(self.parse_throttle((key.as_str(), array_pos), ctx)?); } - Ok(throttles) + Ok(result) + } + + pub fn parse_throttle( + &self, + prefix: impl AsKey, + ctx: &ConfigContext, + ) -> super::Result { + let prefix = prefix.as_key(); + let throttle = Throttle { + key: self.parse_values((prefix.as_str(), "key"))?, + concurrency: if let Some(concurrency) = + self.parse_if_block::((prefix.as_str(), "concurrency"), ctx)? + { + concurrency + } else { + IfBlock::default() + }, + rate: if let Some(rate) = + self.parse_if_block::((prefix.as_str(), "rate"), ctx)? + { + rate + } else { + IfBlock::default() + }, + }; + + // Validate + if throttle.key.is_empty() { + Err(format!("No throttle keys found in {:?}", prefix)) + } else if throttle.rate.default.requests == 0 && throttle.concurrency.default == 0 { + Err(format!( + concat!( + "Throttle {:?} needs to define a ", + "valid 'rate' or 'concurrency' property." + ), + prefix + )) + } else { + Ok(throttle) + } } } -impl FromStr for ThrottleKey { - type Err = (); +impl ParseValue for ThrottleRate { + fn parse_value(key: impl AsKey, value: &str) -> super::Result { + if let Some((requests, period)) = value.split_once('/') { + Ok(ThrottleRate { + requests: requests + .trim() + .parse::() + .ok() + .and_then(|r| if r > 0 { Some(r) } else { None }) + .ok_or_else(|| { + format!( + "Invalid rate value {:?} for property {:?}.", + value, + key.as_key() + ) + })?, + period: period.parse_key(key)?, + }) + } else { + Err(format!( + "Invalid rate value {:?} for property {:?}.", + value, + key.as_key() + )) + } + } +} - fn from_str(s: &str) -> Result { - Ok(match s { +impl ParseValue for ThrottleKey { + fn parse_value(key: impl AsKey, value: &str) -> super::Result { + Ok(match value { "rcpt-domain" => ThrottleKey::RecipientDomain, "sender-domain" => ThrottleKey::SenderDomain, "listener" => ThrottleKey::Listener, "mx" => ThrottleKey::Mx, "remote-ip" => ThrottleKey::RemoteIp, "local-ip" => ThrottleKey::LocalIp, - _ => return Err(()), + _ => { + return Err(format!( + "Invalid throttle key {:?} for property {:?}.", + value, + key.as_key() + )) + } }) } } diff --git a/src/config/utils.rs b/src/config/utils.rs index 2738fbc..2574a02 100644 --- a/src/config/utils.rs +++ b/src/config/utils.rs @@ -1,21 +1,18 @@ -use std::str::FromStr; +use std::{net::IpAddr, time::Duration}; use super::Config; impl Config { - pub fn property(&self, key: impl AsKey) -> super::Result> { + pub fn property(&self, key: impl AsKey) -> super::Result> { let key = key.as_key(); if let Some(value) = self.keys.get(&key) { - match T::from_str(value) { - Ok(result) => Ok(Some(result)), - Err(_) => Err(format!("Invalid value {:?} for key {:?}.", value, key)), - } + T::parse_value(key, value).map(Some) } else { Ok(None) } } - pub fn property_or_default( + pub fn property_or_default( &self, key: impl AsKey, default: impl AsKey, @@ -26,7 +23,7 @@ impl Config { } } - pub fn property_require(&self, key: impl AsKey) -> super::Result { + pub fn property_require(&self, key: impl AsKey) -> super::Result { match self.property(key.clone()) { Ok(Some(result)) => Ok(result), Ok(None) => Err(format!("Missing property {:?}.", key.as_key())), @@ -54,18 +51,17 @@ impl Config { }) } - pub fn properties( + pub fn properties( &self, prefix: impl AsKey, - ) -> impl Iterator)> { + ) -> impl Iterator> { + let full_prefix = prefix.as_key(); let prefix = prefix.as_prefix(); + self.keys.iter().filter_map(move |(key, value)| { - if key.starts_with(&prefix) { - ( - key.as_str(), - T::from_str(value) - .map_err(|_| format!("Invalid value {:?} for key {:?}.", value, key)), - ) + if key.starts_with(&prefix) || key == &full_prefix { + T::parse_value(key.as_str(), value) + .map(|value| (key.as_str(), value)) .into() } else { None @@ -73,14 +69,45 @@ impl Config { }) } - pub fn properties_or_default( + pub fn value(&self, key: impl AsKey) -> Option<&str> { + self.keys.get(&key.as_key()).map(|s| s.as_str()) + } + + pub fn value_require(&self, key: impl AsKey) -> super::Result<&str> { + self.keys + .get(&key.as_key()) + .map(|s| s.as_str()) + .ok_or_else(|| format!("Missing property {:?}.", key.as_key())) + } + + pub fn value_or_default(&self, key: impl AsKey, default: impl AsKey) -> Option<&str> { + self.keys + .get(&key.as_key()) + .or_else(|| self.keys.get(&default.as_key())) + .map(|s| s.as_str()) + } + + pub fn values(&self, prefix: impl AsKey) -> impl Iterator { + let full_prefix = prefix.as_key(); + let prefix = prefix.as_prefix(); + + self.keys.iter().filter_map(move |(key, value)| { + if key.starts_with(&prefix) || key == &full_prefix { + (key.as_str(), value.as_str()).into() + } else { + None + } + }) + } + + pub fn values_or_default( &self, prefix: impl AsKey, default: impl AsKey, - ) -> impl Iterator)> { + ) -> impl Iterator { let mut prefix = prefix.as_prefix(); - self.properties(if self.keys.keys().any(|k| k.starts_with(&prefix)) { + self.values(if self.keys.keys().any(|k| k.starts_with(&prefix)) { prefix.truncate(prefix.len() - 1); prefix } else { @@ -88,57 +115,185 @@ impl Config { }) } - pub fn value(&self, key: impl AsKey) -> Option<&str> { - self.keys.get(&key.as_key()).map(|s| s.as_str()) - } - pub fn take_value(&mut self, key: &str) -> Option { self.keys.remove(key) } pub fn file_contents(&self, key: impl AsKey) -> super::Result> { - let key_ = key.clone(); - if let Some(value) = self.property::(key_)? { + let key = key.as_key(); + if let Some(value) = self.keys.get(&key) { if value.starts_with("file://") { - std::fs::read(&value).map_err(|err| { + std::fs::read(value).map_err(|err| { format!( - "Failed to read file {:?} for key {:?}: {}", - value, - key.as_key(), - err + "Failed to read file {:?} for property {:?}: {}", + value, key, err ) }) } else { - Ok(value.into_bytes()) + Ok(value.to_string().into_bytes()) } } else { Err(format!( "Property {:?} not found in configuration file.", - key.as_key() + key )) } } -} -pub trait ParseKey { - fn parse_key(&self, key: impl AsKey) -> super::Result; - fn parse_duration(&self, key: impl AsKey) -> super::Result; -} - -impl ParseKey for &str { - fn parse_key(&self, key: impl AsKey) -> super::Result { - match T::from_str(self) { - Ok(result) => Ok(result), - Err(_) => Err(format!( - "Invalid value {:?} for key {:?}.", - self, - key.as_key() - )), + pub fn parse_values(&self, prefix: impl AsKey) -> super::Result { + let mut result = T::default(); + for (pos, (key, value)) in self.values(prefix.clone()).enumerate() { + if pos == 0 || T::is_multivalue() { + result.add_value(T::Item::parse_value(key, value)?); + } else { + return Err(format!( + "Property {:?} cannot have multiple values.", + prefix.as_key() + )); + } } + Ok(result) + } +} + +pub trait ParseValues: Sized + Default { + type Item: ParseValue; + + fn add_value(&mut self, value: Self::Item); + fn is_multivalue() -> bool; +} + +pub trait ParseValue: Sized { + fn parse_value(key: impl AsKey, value: &str) -> super::Result; +} + +pub trait ParseKey { + fn parse_key(&self, key: impl AsKey) -> super::Result; +} + +impl ParseKey for &str { + fn parse_key(&self, key: impl AsKey) -> super::Result { + T::parse_value(key, self) + } +} + +impl ParseKey for String { + fn parse_key(&self, key: impl AsKey) -> super::Result { + T::parse_value(key, self.as_str()) + } +} + +impl ParseKey for &String { + fn parse_key(&self, key: impl AsKey) -> super::Result { + T::parse_value(key, self.as_str()) + } +} + +impl ParseValues for Vec { + type Item = T; + + fn add_value(&mut self, value: Self::Item) { + self.push(value); } - fn parse_duration(&self, key: impl AsKey) -> super::Result { - let duration = self.trim().to_ascii_uppercase(); + fn is_multivalue() -> bool { + true + } +} + +impl ParseValues for T { + type Item = T; + + fn add_value(&mut self, value: Self::Item) { + *self = value; + } + + fn is_multivalue() -> bool { + false + } +} + +impl ParseValue for String { + fn parse_value(_key: impl AsKey, value: &str) -> super::Result { + Ok(value.to_string()) + } +} + +impl ParseValue for u64 { + fn parse_value(key: impl AsKey, value: &str) -> super::Result { + value.parse().map_err(|_| { + format!( + "Invalid integer value {:?} for property {:?}.", + value, + key.as_key() + ) + }) + } +} + +impl ParseValue for i64 { + fn parse_value(key: impl AsKey, value: &str) -> super::Result { + value.parse().map_err(|_| { + format!( + "Invalid integer value {:?} for property {:?}.", + value, + key.as_key() + ) + }) + } +} + +impl ParseValue for u32 { + fn parse_value(key: impl AsKey, value: &str) -> super::Result { + value.parse().map_err(|_| { + format!( + "Invalid integer value {:?} for property {:?}.", + value, + key.as_key() + ) + }) + } +} + +impl ParseValue for IpAddr { + fn parse_value(key: impl AsKey, value: &str) -> super::Result { + value.parse().map_err(|_| { + format!( + "Invalid IP address value {:?} for property {:?}.", + value, + key.as_key() + ) + }) + } +} + +impl ParseValue for usize { + fn parse_value(key: impl AsKey, value: &str) -> super::Result { + value.parse().map_err(|_| { + format!( + "Invalid integer value {:?} for property {:?}.", + value, + key.as_key() + ) + }) + } +} + +impl ParseValue for bool { + fn parse_value(key: impl AsKey, value: &str) -> super::Result { + value.parse().map_err(|_| { + format!( + "Invalid boolean value {:?} for property {:?}.", + value, + key.as_key() + ) + }) + } +} + +impl ParseValue for Duration { + fn parse_value(key: impl AsKey, value: &str) -> super::Result { + let duration = value.trim().to_ascii_uppercase(); let (num, multiplier) = if let Some(num) = duration.strip_prefix('d') { (num, 24 * 60 * 60 * 1000) } else if let Some(num) = duration.strip_prefix('h') { @@ -156,31 +311,21 @@ impl ParseKey for &str { .ok() .and_then(|num| { if num > 0 { - Some(num * multiplier) + Some(Duration::from_millis(num * multiplier)) } else { None } }) .ok_or_else(|| { format!( - "Invalid duration value {:?} for key {:?}.", - self, + "Invalid duration value {:?} for property {:?}.", + value, key.as_key() ) }) } } -impl ParseKey for String { - fn parse_key(&self, key: impl AsKey) -> super::Result { - self.as_str().parse_key(key) - } - - fn parse_duration(&self, key: impl AsKey) -> super::Result { - self.as_str().parse_duration(key) - } -} - pub trait AsKey: Clone { fn as_key(&self) -> String; fn as_prefix(&self) -> String; diff --git a/src/core/rate_limit.rs b/src/core/rate_limit.rs index ddcab7a..02cf599 100644 --- a/src/core/rate_limit.rs +++ b/src/core/rate_limit.rs @@ -12,8 +12,12 @@ use parking_lot::Mutex; pub struct RateLimiter { max_requests: f64, max_interval: f64, - max_concurrent: u64, limiter: Arc>, +} + +#[derive(Debug)] +pub struct ConcurrencyLimiter { + max_concurrent: u64, concurrent: Arc, } @@ -28,35 +32,46 @@ impl Drop for InFlightRequest { } impl RateLimiter { - pub fn new(max_concurrent: u64, max_requests: u64, max_interval: u64) -> Self { + pub fn new(max_requests: u64, max_interval: u64) -> Self { RateLimiter { - max_concurrent, max_requests: max_requests as f64, max_interval: max_interval as f64, limiter: Arc::new(Mutex::new((Instant::now(), max_requests as f64))), + } + } + + pub fn is_allowed(&self) -> bool { + // Check rate limit + let mut limiter = self.limiter.lock(); + let elapsed = limiter.0.elapsed().as_secs_f64(); + limiter.0 = Instant::now(); + limiter.1 += elapsed * (self.max_requests / self.max_interval); + if limiter.1 > self.max_requests { + limiter.1 = self.max_requests; + } + if limiter.1 >= 1.0 { + limiter.1 -= 1.0; + true + } else { + false + } + } + + pub fn reset(&self) { + *self.limiter.lock() = (Instant::now(), self.max_requests); + } +} + +impl ConcurrencyLimiter { + pub fn new(max_concurrent: u64) -> Self { + ConcurrencyLimiter { + max_concurrent, concurrent: Arc::new(0.into()), } } pub fn is_allowed(&self) -> Option { - if self.max_concurrent == 0 || self.concurrent.load(Ordering::Relaxed) < self.max_concurrent - { - // Check rate limit - if self.max_requests > 0.0 { - let mut limiter = self.limiter.lock(); - let elapsed = limiter.0.elapsed().as_secs_f64(); - limiter.0 = Instant::now(); - limiter.1 += elapsed * (self.max_requests / self.max_interval); - if limiter.1 > self.max_requests { - limiter.1 = self.max_requests; - } - if limiter.1 >= 1.0 { - limiter.1 -= 1.0; - } else { - return None; - } - } - + if self.concurrent.load(Ordering::Relaxed) < self.max_concurrent { // Return in-flight request self.concurrent.fetch_add(1, Ordering::Relaxed); Some(InFlightRequest { @@ -66,8 +81,4 @@ impl RateLimiter { None } } - - pub fn reset(&self) { - *self.limiter.lock() = (Instant::now(), self.max_requests); - } }