If block parser implementation.

This commit is contained in:
Mauro D 2022-12-11 18:05:34 +00:00
parent 2a7189deb9
commit 37ea712cc2
11 changed files with 842 additions and 408 deletions

View file

@ -12,7 +12,7 @@ tls.implicit = false
id = "smtps"
bind = ["0.0.0.0:465"]
tls.implicit = true
tls.sni = [{subject = "domain.org", certificate = "abc"}]
tls.sni = [{subject = "domain.org", pki = "abc"}]
socket.backlog = 1024
[[server.listener]]
@ -24,8 +24,8 @@ tls = {implicit = true}
enable = true,
implicit = true,
timeout = 300
certificate = default
sni = [{subject = "domain.org", certificate = "abc"}]
pki = default
sni = [{subject = "domain.org", pki = "abc"}]
protocols = ["TLSv1.2", TLSv1.3"]
ciphers = ["cipher1", "cipher2"]
ignore_client_order = true
@ -45,18 +45,15 @@ backlog = 1024,
[stage.connect]
script = "connect.sieve"
concurrency = 10000
throttle = [
{key = ["remoteip", "localip"], concurrency = 10000, rate = "3/5m"},
]
[[stage.connect.rule]]
match-if = {remote-ip = "127.0.0.1"}
throttle = {key = ["remoteip", "localip"], concurrency = 1000, rate = [3, 10]}
[[stage.connect.throttle]]
key = ["remoteip", "localip"]
concurrency = 10000
rate = "3/5m"
[stage.ehlo]
require = true
script = ehlo.sieve
authenticate = {spf = "strict"}
timeout = "5m"
max-ehlos = 1
@ -65,39 +62,37 @@ pipelining = true
chunking = true
requiretls = true
no-soliciting = ""
auth = []
dsn = true
future-release = false
future-release = [
{if = "is-submission", then = "5d"},
{else = false}
]
deliver-by = false
mt-priority = false
size = 100000
expn = false
[[stage.ehlo.rule]]
match-if = {listener = "submission"}
authenticate = {spf = "none"}
capabilities = {expn = true, auth = ["mechanism"], dsn = true, future-release = "5d", deliver-by = "30m", mt-priority = "mixer"}
[stage.auth]
mechanisms = [
{if = "is-submission", then = ["plain", "login"]},
{else = []}
]
require = [
{if = "is-submission", then = true},
{else = false}
]
auth-host = "auth-server"
require = false
timeout = 10
errors = {total = 3, wait = "5s"}
[[stage.auth.rule]]
match-if = {listener = "submission"}
require = true
[stage.mail]
auth = {spf = "strict"}
script = mail-from.sieve
timeout = 10
throttle = {key = "mail-from", concurrency = 10, rate = [3, 60]}
[[stage.mail.rule]]
match-if = {id = "submission"}
authenticate = {spf = "disable"}
#capabilities = {deliver-by = "1000s", future-release = "1000s"}
[[stage.mail.throttle]]
key = "mail-from"
concurrency = 10
rate = [3, 60]
[stage.rcpt]
script = rcpt-to.sieve
@ -109,15 +104,9 @@ errors = {total = 3, wait = "5s"}
timeout = "5m"
max-recipients = 100
[[stage.rcpt.rule]]
match-if = {id = "submission"}
max-recipients = 1000
[stage.data]
authenticate = {dkim = "true", arc = "true", dmarc = "true"}
script = data.sieve
timeout = "10s"
add-headers = ["arc-seal", "received-spf", "authentication-results", "return-path", "received"]
[stage.data.limits]
messages = 10
@ -126,13 +115,55 @@ received-headers = 50
mime-parts = 50
nested-messages = 3
[[stage.data]]
match-if = {id = "submission"}
authenticate = {dkim = "false", arc = "false", dmarc = "false"}
add-headers = ["dkim-signature", "arc-seal", "return-path", "received"]
[state.data.add-headers]
received = true
received-spf = true
return-path = true
auth-results = true
message-id = true
date = true
[stage.disconnect]
script = disconnect.sieve
[stage.queue]
script = queue.sieve
queue-id = "local"
[[rule]]
id = ""
[dkim]
verify = "strict"
[dkim.sign]
pki = "cert"
domain = ""
selector = ""
headers = ["From", "To", "Date", "Message-ID"]
algorithm = "rsa-sha256"
canonicalization = "simple/relaxed"
expire = "10d"
third-party = ""
third-party-algo = ""
auid = ""
set-body-length = false
reporting = true
[spf]
verify-ehlo = ""
verify-mail-from = ""
[arc]
verify = "strict"
[arc.seal]
pki = "cert"
domain = ""
selector = ""
headers = ["From", "To", "Date", "Message-ID", "DKIM-Signature"]
algorithm = "rsa-sha256"
canonicalization = "simple/relaxed"
expire = "10d"
set-body-length = false
[external]
@ -144,7 +175,7 @@ auth.username = "hello"
auth.password = "world"
tls = "optional, require, dane, dane-fallback-require, dane-require
[queue]
[queue."out"]
retry = [0, 1, 15, 60, 90]
notify = [9, 10]
prefer = ipv6
@ -153,15 +184,13 @@ tls = optional, require, dane, dane-fallback-require, dane-require
limits = { attempts = 100, time = 3600, queued-messages = 10000, queue-size = 1000000 }
throttle = { rate = 1/60, concurrency = 1000, key = all }
[[queue.virtual]]
match-if = {rcpt-domain = "*.example.org"}
id = "local"
[queue."local"]
relay-host = "lmtp"
[rules]
[rule]
[rules."is-local"]
[rule."is-local"]
rcpt-domain = ["*.example.org"]
rcpt = [""]
server-id = "relay"
@ -221,11 +250,12 @@ send-reports = true
report-frequency = requested, 86400
incoming-address = "dmarc@*"
[certificate.default]
[key]
[key."default"]
type = "rsa"
cert = '''
certificate = '''
-----BEGIN CERTIFICATE-----
'''
pki = '''
private-key = '''
-----BEGIN PRIVATE KEY-----
'''

View file

@ -38,9 +38,9 @@ impl KeyLog for KeyLogger {
impl Config {
pub fn rustls_certificate(&self, cert_id: &str) -> super::Result<Certificate> {
certs(&mut Cursor::new(self.file_contents((
"certificate",
"key",
cert_id,
"cert",
"certificate",
))?))
.map_err(|err| {
format!(
@ -61,9 +61,9 @@ impl Config {
pub fn rustls_private_key(&self, cert_id: &str) -> super::Result<PrivateKey> {
match read_one(&mut Cursor::new(self.file_contents((
"certificate",
"key",
cert_id,
"pki",
"private-key",
))?))
.map_err(|err| {
format!(

View file

@ -1,12 +1,27 @@
use std::{net::IpAddr, str::FromStr};
use std::{collections::hash_map::Entry, net::IpAddr};
use super::{
utils::{AsKey, ParseKey},
Condition, Config, IpAddrMask, StringMatch,
utils::{AsKey, ParseKey, ParseValue},
Condition, Conditions, Config, ConfigContext, IpAddrMask, StringMatch,
};
impl Config {
pub fn parse_conditions(&self, key: impl AsKey) -> super::Result<Vec<Condition>> {
pub fn parse_rules(&self, ctx: &mut ConfigContext) -> super::Result<()> {
for rule_name in self.sub_keys("rule") {
match ctx.rules.entry(rule_name.to_string()) {
Entry::Vacant(e) => {
e.insert(self.parse_conditions(("rule", rule_name))?.into());
}
Entry::Occupied(_) => {
return Err(format!("Duplicate rule {:?} found.", rule_name));
}
}
}
Ok(())
}
fn parse_conditions(&self, key: impl AsKey) -> super::Result<Conditions> {
let mut conditions = Vec::new();
let prefix = key.as_prefix();
@ -19,76 +34,77 @@ impl Config {
};
match property {
"rcpt" => {
let value = value.parse_key(key_.as_str())?;
if let Some(Condition::Recipient(values)) = conditions.last_mut() {
values.push(value.parse_key(key_.as_str())?);
values.push(value);
} else {
conditions
.push(Condition::Recipient(vec![value.parse_key(key_.as_str())?]));
conditions.push(Condition::Recipient(vec![value]));
}
}
"rcpt-domain" => {
let value = value.parse_key(key_.as_str())?;
if let Some(Condition::RecipientDomain(values)) = conditions.last_mut() {
values.push(value.parse_key(key_.as_str())?);
values.push(value);
} else {
conditions.push(Condition::RecipientDomain(vec![
value.parse_key(key_.as_str())?
]));
conditions.push(Condition::RecipientDomain(vec![value]));
}
}
"sender" => {
let value = value.parse_key(key_.as_str())?;
if let Some(Condition::Sender(values)) = conditions.last_mut() {
values.push(value.parse_key(key_.as_str())?);
values.push(value);
} else {
conditions
.push(Condition::Sender(vec![value.parse_key(key_.as_str())?]));
conditions.push(Condition::Sender(vec![value]));
}
}
"sender-domain" => {
let value = value.parse_key(key_.as_str())?;
if let Some(Condition::SenderDomain(values)) = conditions.last_mut() {
values.push(value.parse_key(key_.as_str())?);
values.push(value);
} else {
conditions.push(Condition::SenderDomain(vec![
value.parse_key(key_.as_str())?
]));
conditions.push(Condition::SenderDomain(vec![value]));
}
}
"mx" => {
let value = value.parse_key(key_.as_str())?;
if let Some(Condition::Mx(values)) = conditions.last_mut() {
values.push(value.parse_key(key_.as_str())?);
values.push(value);
} else {
conditions.push(Condition::Mx(vec![value.parse_key(key_.as_str())?]));
conditions.push(Condition::Mx(vec![value]));
}
}
"priority" => {
let value = value.parse_key(key_.as_str())?;
if let Some(Condition::Priority(values)) = conditions.last_mut() {
values.push(value.parse_key(key_.as_str())?);
values.push(value);
} else {
conditions
.push(Condition::Priority(vec![value.parse_key(key_.as_str())?]));
conditions.push(Condition::Priority(vec![value]));
}
}
"listener" => {
let value = value.to_string();
if let Some(Condition::Listener(values)) = conditions.last_mut() {
values.push(value.parse_key(key_.as_str())?);
values.push(value);
} else {
conditions
.push(Condition::Listener(vec![value.parse_key(key_.as_str())?]));
conditions.push(Condition::Listener(vec![value]));
}
}
"local-ip" => {
let value = value.parse_key(key_.as_str())?;
if let Some(Condition::LocalIp(values)) = conditions.last_mut() {
values.push(value.parse_key(key_.as_str())?);
values.push(value);
} else {
conditions
.push(Condition::LocalIp(vec![value.parse_key(key_.as_str())?]));
conditions.push(Condition::LocalIp(vec![value]));
}
}
"remote-ip" => {
let value = value.parse_key(key_.as_str())?;
if let Some(Condition::RemoteIp(values)) = conditions.last_mut() {
values.push(value.parse_key(key_.as_str())?);
values.push(value);
} else {
conditions
.push(Condition::RemoteIp(vec![value.parse_key(key_.as_str())?]));
conditions.push(Condition::RemoteIp(vec![value]));
}
}
_ => {
@ -98,56 +114,70 @@ impl Config {
}
}
Ok(conditions)
Ok(Conditions { conditions })
}
}
impl FromStr for StringMatch {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(if let Some(value) = s.strip_prefix("list:") {
impl ParseValue for StringMatch {
fn parse_value(_key: impl AsKey, value: &str) -> super::Result<Self> {
Ok(if let Some(value) = value.strip_prefix("list:") {
StringMatch::InList(value.into())
} else if let Some(value) = s.strip_prefix("regex:") {
} else if let Some(value) = value.strip_prefix("regex:") {
StringMatch::RegexMatch(value.into())
} else if let Some(value) = s.strip_prefix('*') {
} else if let Some(value) = value.strip_prefix('*') {
StringMatch::StartsWith(value.into())
} else if let Some(value) = s.strip_suffix('*') {
} else if let Some(value) = value.strip_suffix('*') {
StringMatch::EndsWith(value.into())
} else {
StringMatch::EqualTo(s.into())
StringMatch::EqualTo(value.into())
})
}
}
impl FromStr for IpAddrMask {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Some((addr, mask)) = s.rsplit_once('/') {
let mask = mask.trim().parse::<u32>().map_err(|_| ())?;
match addr.trim().parse::<IpAddr>().map_err(|_| ())? {
IpAddr::V4(addr) if (8..=32).contains(&mask) => Ok(IpAddrMask::V4 {
addr,
mask: u32::MAX << (32 - mask),
}),
IpAddr::V6(addr) if (8..=128).contains(&mask) => Ok(IpAddrMask::V6 {
addr,
mask: u128::MAX << (128 - mask),
}),
_ => Err(()),
impl ParseValue for IpAddrMask {
fn parse_value(key: impl AsKey, value: &str) -> super::Result<Self> {
if let Some((addr, mask)) = value.rsplit_once('/') {
if let (Ok(addr), Ok(mask)) =
(addr.trim().parse::<IpAddr>(), mask.trim().parse::<u32>())
{
match addr {
IpAddr::V4(addr) if (8..=32).contains(&mask) => {
return Ok(IpAddrMask::V4 {
addr,
mask: u32::MAX << (32 - mask),
})
}
IpAddr::V6(addr) if (8..=128).contains(&mask) => {
return Ok(IpAddrMask::V6 {
addr,
mask: u128::MAX << (128 - mask),
})
}
_ => (),
}
}
} else {
Ok(match s.trim().parse::<IpAddr>().map_err(|_| ())? {
IpAddr::V4(addr) => IpAddrMask::V4 {
addr,
mask: u32::MAX,
},
IpAddr::V6(addr) => IpAddrMask::V6 {
addr,
mask: u128::MAX,
},
})
match value.trim().parse::<IpAddr>() {
Ok(IpAddr::V4(addr)) => {
return Ok(IpAddrMask::V4 {
addr,
mask: u32::MAX,
})
}
Ok(IpAddr::V6(addr)) => {
return Ok(IpAddrMask::V6 {
addr,
mask: u128::MAX,
})
}
_ => (),
}
}
Err(format!(
"Invalid IP address {:?} for property {:?}.",
value,
key.as_key()
))
}
}

144
src/config/if_block.rs Normal file
View file

@ -0,0 +1,144 @@
use super::{
utils::{AsKey, ParseValues},
Config, ConfigContext, IfBlock, IfThen,
};
impl Config {
pub fn parse_if_block<T: Default + ParseValues>(
&self,
prefix: impl AsKey,
ctx: &ConfigContext,
) -> super::Result<Option<IfBlock<T>>> {
let prefix = prefix.as_prefix();
let key = prefix.as_key();
let mut found_if = false;
let mut found_else = false;
let mut found_then = false;
// Parse conditions
let mut if_block = IfBlock::new(T::default());
let mut last_array_pos = usize::MAX;
for (item, value) in &self.keys {
if let Some(suffix_) = item.strip_prefix(&prefix) {
if let Some((array_pos, suffix)) =
suffix_.split_once('.').and_then(|(array_pos, suffix)| {
(array_pos.parse::<usize>().ok()?, suffix).into()
})
{
if suffix == ".if" || suffix.starts_with(".if.") {
if array_pos != last_array_pos {
if last_array_pos != usize::MAX && !found_then {
return Err(format!(
"Missing 'then' in 'if' condition {} for property {:?}.",
last_array_pos + 1,
key
));
}
if_block.if_then.push(IfThen {
conditions: Vec::new(),
then: T::default(),
});
found_then = false;
last_array_pos = array_pos;
}
if let Some(conditions) = ctx.rules.get(value) {
if_block
.if_then
.last_mut()
.unwrap()
.conditions
.push(conditions.clone());
} else {
return Err(format!(
"Rule {:?} does not exist for property {:?}.",
value, key
));
}
found_if = true;
} else if suffix == ".else" || suffix.starts_with(".else.") {
if !found_else {
if found_if {
if_block.default = self.parse_values((
prefix.as_str(),
suffix_.split_once(".else").unwrap().0,
"else",
))?;
found_else = true;
} else {
return Err(format!(
"Found 'else' before 'if' for property {:?}.",
key
));
}
}
} else if suffix == ".then" || suffix.starts_with(".then.") {
if !found_else {
if array_pos == last_array_pos {
if !found_then {
if_block.if_then.last_mut().unwrap().then =
self.parse_values((
prefix.as_str(),
suffix_.split_once(".then").unwrap().0,
"then",
))?;
found_then = true;
}
} else {
return Err(format!(
"Found 'then' without 'if' for property {:?}.",
key
));
}
} else {
return Err(format!(
"Found 'then' in 'else' block for property {:?}.",
key
));
}
} else {
return Err(format!("Invalid key {:?} found in 'if' block.", item));
}
} else if !found_if {
// Found probably a multi-value, parse and return
if_block.default = self.parse_values(key.as_str())?;
return Ok(Some(if_block));
} else {
return Err(format!("Invalid key {:?} found in 'if' block.", item));
}
} else if item == &key {
// There is a single value, parse and return
if_block.default = self.parse_values(key.as_str())?;
return Ok(Some(if_block));
}
}
if !found_if {
Ok(None)
} else if !found_then {
Err(format!(
"Missing 'then' in 'if' condition {} for property {:?}.",
last_array_pos + 1,
key
))
} else if !found_else {
Err(format!("Missing 'else' for property {:?}.", key))
} else {
Ok(Some(if_block))
}
}
}
impl<T: Default> IfBlock<T> {
pub fn new(value: T) -> Self {
Self {
if_then: Vec::with_capacity(0),
default: value,
}
}
}

View file

@ -1,5 +1,6 @@
pub mod certificate;
pub mod condition;
pub mod if_block;
pub mod parser;
pub mod server;
pub mod stage;
@ -9,8 +10,11 @@ pub mod utils;
use std::{
collections::BTreeMap,
net::{Ipv4Addr, Ipv6Addr},
sync::Arc,
time::Duration,
};
use ahash::AHashMap;
use rustls::ServerConfig;
use smtp_proto::MtPriority;
use tokio::net::TcpListener;
@ -42,6 +46,10 @@ pub enum Condition {
Priority(Vec<i64>),
}
pub struct Conditions {
pub conditions: Vec<Condition>,
}
pub enum StringMatch {
EqualTo(String),
StartsWith(String),
@ -59,11 +67,27 @@ pub enum ThrottleKey {
LocalIp,
}
pub struct IfThen<T: Default> {
pub conditions: Vec<Arc<Conditions>>,
pub then: T,
}
#[derive(Default)]
pub struct IfBlock<T: Default> {
pub if_then: Vec<IfThen<T>>,
pub default: T,
}
pub struct Throttle {
pub key: Vec<ThrottleKey>,
pub concurrency: u64,
pub rate_requests: u64,
pub rate_period: u64,
pub concurrency: IfBlock<u64>,
pub rate: IfBlock<ThrottleRate>,
}
#[derive(Default)]
pub struct ThrottleRate {
pub requests: u64,
pub period: Duration,
}
pub enum IpAddrMask {
@ -71,17 +95,15 @@ pub enum IpAddrMask {
V6 { addr: Ipv6Addr, mask: u128 },
}
pub struct ConnectStage {
pub script: String,
pub concurrency: u64,
pub pipelining: bool,
pub struct Connect {
pub script: Option<IfBlock<String>>,
pub concurrency: IfBlock<u64>,
pub throttle: Vec<Throttle>,
}
pub struct EhloStage {
pub struct Ehlo {
pub script: String,
pub require: bool,
pub spf: AuthLevel,
pub timeout: u64,
pub max_commands: u64,
@ -90,7 +112,6 @@ pub struct EhloStage {
pub chunking: bool,
pub requiretls: bool,
pub no_soliciting: Option<String>,
pub auth: u64,
pub future_release: Option<u64>,
pub deliver_by: Option<u64>,
pub mt_priority: MtPriority,
@ -98,22 +119,24 @@ pub struct EhloStage {
pub expn: bool,
}
pub struct AuthStage {
pub struct Auth {
pub script: String,
pub require: bool,
pub auth_host: usize,
pub mechanisms: u64,
pub timeout: u64,
pub errors_max: usize,
pub errors_wait: u64,
}
pub struct MailStage {
pub struct Mail {
pub script: String,
pub spf: AuthLevel,
pub timeout: u64,
pub throttle: Vec<Throttle>,
}
pub struct RcptStage {
pub struct Rcpt {
pub script: String,
pub timeout: u64,
@ -132,17 +155,15 @@ pub struct RcptStage {
// Limits
pub max_recipients: usize,
// Throttle
pub throttle: Vec<Throttle>,
}
pub struct DataStage {
pub struct Data {
pub script: String,
pub timeout: u64,
// Message Authentication
pub dkim: AuthLevel,
pub arc: AuthLevel,
pub dmarc: AuthLevel,
// Limits
pub max_messages: usize,
pub max_message_size: usize,
@ -153,13 +174,27 @@ pub struct DataStage {
// Headers
pub add_received: bool,
pub add_received_spf: bool,
pub add_return_path: bool,
pub add_auth_results: bool,
pub add_dkim_signature: bool,
pub add_arc_seal: bool,
pub add_message_id: bool,
pub add_date: bool,
}
pub struct Queue {
pub script: String,
pub queue_id: usize,
}
pub struct Stage {
pub connect: Connect,
pub ehlo: Ehlo,
pub auth: Auth,
pub mail: Mail,
pub rcpt: Rcpt,
pub data: Data,
pub queue: Queue,
}
pub enum AuthLevel {
Enable,
Disable,
@ -171,4 +206,8 @@ pub struct Config {
keys: BTreeMap<String, String>,
}
pub struct ConfigContext {
pub rules: AHashMap<String, Arc<Conditions>>,
}
pub type Result<T> = std::result::Result<T, String>;

View file

@ -245,7 +245,7 @@ impl<'x> TomlParser<'x> {
']' => break,
ch => {
return Err(format!(
"Unexpected character {:?} found in array for key {:?} at line {}.",
"Unexpected character {:?} found in array for property {:?} at line {}.",
ch, key, self.line
));
}
@ -263,7 +263,7 @@ impl<'x> TomlParser<'x> {
'}' => break,
ch => {
return Err(format!(
"Unexpected character {:?} found in inline table for key {:?} at line {}.",
"Unexpected character {:?} found in inline table for property {:?} at line {}.",
ch, key, self.line
));
}

View file

@ -1,4 +1,4 @@
use std::{net::SocketAddr, str::FromStr, sync::Arc, time::Duration};
use std::{net::SocketAddr, sync::Arc, time::Duration};
use rustls::{
cipher_suite::{
@ -9,13 +9,13 @@ use rustls::{
},
server::{NoClientAuth, ResolvesServerCertUsingSni},
sign::{any_supported_type, CertifiedKey},
ServerConfig, ALL_CIPHER_SUITES, ALL_KX_GROUPS, ALL_VERSIONS,
ServerConfig, SupportedCipherSuite, ALL_CIPHER_SUITES, ALL_KX_GROUPS, ALL_VERSIONS,
};
use tokio::net::TcpSocket;
use super::{
certificate::{CertificateResolver, TLS12_VERSION, TLS13_VERSION},
utils::ParseKey,
utils::{AsKey, ParseKey, ParseValue},
Config, Server, ServerProtocol,
};
@ -41,7 +41,7 @@ impl Config {
fn build_server(&self, array_pos: &str) -> super::Result<Server> {
// Obtain server id
let id = self.property_require::<String>(("server.listeners", array_pos, "id"))?;
let id = self.value_require(("server.listeners", array_pos, "id"))?;
// Build TLS config
let (tls, tls_implicit) = if self
@ -54,11 +54,11 @@ impl Config {
// Parse protocol versions
let mut tls_v2 = false;
let mut tls_v3 = false;
for (key, protocol) in self.properties_or_default::<String>(
for (key, protocol) in self.values_or_default(
("server.listener", array_pos, "tls.protocols"),
"server.tls.protocols",
) {
match protocol?.as_str() {
match protocol {
"TLSv1.2" | "0x0303" => tls_v2 = true,
"TLSv1.3" | "0x0304" => tls_v3 = true,
protocol => {
@ -75,75 +75,42 @@ impl Config {
// Parse cipher suites
let mut ciphers = Vec::new();
for (key, protocol) in self.properties_or_default::<String>(
for (key, protocol) in self.values_or_default(
("server.listener", array_pos, "tls.cipher"),
"server.tls.cipher",
) {
ciphers.push(match protocol?.as_str() {
// TLS1.3 suites
"TLS13_AES_256_GCM_SHA384" => TLS13_AES_256_GCM_SHA384,
"TLS13_AES_128_GCM_SHA256" => TLS13_AES_128_GCM_SHA256,
"TLS13_CHACHA20_POLY1305_SHA256" => TLS13_CHACHA20_POLY1305_SHA256,
// TLS1.2 suites
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => {
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
}
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => {
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
}
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => {
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
}
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" => {
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
}
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" => {
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
}
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => {
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
}
cipher => {
return Err(format!(
"Unsupported TLS cipher suite {:?} found in key {:?}",
cipher, key
))
}
});
ciphers.push(protocol.parse_key(key)?);
}
// Obtain default certificate
let cert_id = self
.property_or_default::<String>(
.value_or_default(
("server.listener", array_pos, "tls.certificate"),
"server.tls.certificate",
)?
)
.ok_or_else(|| format!("Undefined certificate id for listener {:?}.", id))?;
let cert = self.rustls_certificate(&cert_id)?;
let pki = self.rustls_private_key(&cert_id)?;
let cert = self.rustls_certificate(cert_id)?;
let pki = self.rustls_private_key(cert_id)?;
// Add SNI certificates
let mut resolver = ResolvesServerCertUsingSni::new();
for (key, value) in self.properties_or_default::<String>(
("server.listener", array_pos, "tls.sni"),
"server.tls.sni",
) {
for (key, value) in
self.values_or_default(("server.listener", array_pos, "tls.sni"), "server.tls.sni")
{
if let Some(prefix) = key.strip_suffix(".subject") {
resolver
.add(
value?.as_str(),
match self.property::<String>((prefix, "cert"))? {
value,
match self.value((prefix, "cert")) {
Some(sni_cert_id) if sni_cert_id != cert_id => CertifiedKey {
cert: vec![self.rustls_certificate(&sni_cert_id)?],
key: any_supported_type(
&self.rustls_private_key(&sni_cert_id)?,
)
.map_err(|err| {
format!(
"Failed to sign SNI certificate for {:?}: {}",
key, err
)
})?,
cert: vec![self.rustls_certificate(sni_cert_id)?],
key: any_supported_type(&self.rustls_private_key(sni_cert_id)?)
.map_err(|err| {
format!(
"Failed to sign SNI certificate for {:?}: {}",
key, err
)
})?,
ocsp: None,
sct_list: None,
},
@ -219,9 +186,9 @@ impl Config {
// Build listeners
let mut listeners = Vec::new();
for (_, addr) in self.properties::<SocketAddr>(("server.listener", array_pos, "bind")) {
for result in self.properties::<SocketAddr>(("server.listener", array_pos, "bind")) {
// Parse bind address and build socket
let addr = addr?;
let (_, addr) = result?;
let socket = if addr.is_ipv4() {
TcpSocket::new_v4()
} else {
@ -230,35 +197,35 @@ impl Config {
.map_err(|err| format!("Failed to create socket: {}", err))?;
// Set socket options
for option in [
"reuse-addr",
"reuse-port",
"send-buffer-size",
"recv-buffer-size",
"linger",
"tos",
] {
if let Some(value) = self.property_or_default::<String>(
("server.listener", array_pos, "socket", option),
("server.socket", option),
)? {
match option {
"reuse-addr" => socket.set_reuseaddr(value.parse_key(option)?),
"reuse-port" => socket.set_reuseport(value.parse_key(option)?),
"send-buffer-size" => socket.set_send_buffer_size(value.parse_key(option)?),
"recv-buffer-size" => socket.set_recv_buffer_size(value.parse_key(option)?),
"linger" => socket
.set_linger(Duration::from_millis(value.parse_key(option)?).into()),
"tos" => socket.set_tos(value.parse_key(option)?),
_ => unreachable!(),
for (key, value) in
self.values_or_default(("server.listener", array_pos, "socket"), "server.socket")
{
let option = key
.rsplit_once('.')
.map(|(_, option)| option)
.unwrap_or_default();
match option {
"reuse-addr" => socket.set_reuseaddr(value.parse_key(key)?),
"reuse-port" => socket.set_reuseport(value.parse_key(key)?),
"send-buffer-size" => socket.set_send_buffer_size(value.parse_key(key)?),
"recv-buffer-size" => socket.set_recv_buffer_size(value.parse_key(key)?),
"linger" => {
socket.set_linger(Duration::from_millis(value.parse_key(key)?).into())
}
"tos" => socket.set_tos(value.parse_key(key)?),
_ => {
return Err(format!(
"Invalid socket option {} for listener '{}'.",
option, id
))
}
.map_err(|err| {
format!(
"Failed to set socket option '{}' for listener '{}': {}",
option, id, err
)
})?;
}
.map_err(|err| {
format!(
"Failed to set socket option '{}' for listener '{}': {}",
option, id, err
)
})?;
}
// Bind socket
@ -308,19 +275,21 @@ impl Config {
}
Ok(Server {
id,
id: id.to_string(),
hostname: self
.property_or_default(
.value_or_default(
("server.listener", array_pos, "hostname"),
"server.hostname",
)?
.ok_or("Hostname directive not found.")?,
)
.ok_or("Hostname directive not found.")?
.to_string(),
greeting: self
.property_or_default(
.value_or_default(
("server.listener", array_pos, "greeting"),
"server.greeting",
)?
.unwrap_or_else(|| "Stalwart SMTP at your service".to_string()),
)
.unwrap_or("Stalwart SMTP at your service")
.to_string(),
protocol: self
.property_or_default(
("server.listener", array_pos, "protocol"),
@ -334,16 +303,59 @@ impl Config {
}
}
impl FromStr for ServerProtocol {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.eq_ignore_ascii_case("smtp") {
impl ParseValue for ServerProtocol {
fn parse_value(key: impl AsKey, value: &str) -> super::Result<Self> {
if value.eq_ignore_ascii_case("smtp") {
Ok(Self::Smtp)
} else if s.eq_ignore_ascii_case("lmtp") {
} else if value.eq_ignore_ascii_case("lmtp") {
Ok(Self::Lmtp)
} else {
Err(())
Err(format!(
"Invalid server protocol type {:?} for property {:?}.",
value,
key.as_key()
))
}
}
}
impl ParseValue for SocketAddr {
fn parse_value(key: impl AsKey, value: &str) -> super::Result<Self> {
value.parse().map_err(|_| {
format!(
"Invalid socket address {:?} for property {:?}.",
value,
key.as_key()
)
})
}
}
impl ParseValue for SupportedCipherSuite {
fn parse_value(key: impl AsKey, value: &str) -> super::Result<Self> {
Ok(match value {
// TLS1.3 suites
"TLS13_AES_256_GCM_SHA384" => TLS13_AES_256_GCM_SHA384,
"TLS13_AES_128_GCM_SHA256" => TLS13_AES_128_GCM_SHA256,
"TLS13_CHACHA20_POLY1305_SHA256" => TLS13_CHACHA20_POLY1305_SHA256,
// TLS1.2 suites
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => {
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
}
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" => TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" => TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => {
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
}
cipher => {
return Err(format!(
"Unsupported TLS cipher suite {:?} found in key {:?}",
cipher,
key.as_key()
))
}
})
}
}

View file

@ -1 +1,13 @@
use super::{Config, ConfigContext, Connect, IfBlock};
impl Config {
fn parse_stage_connect(&self, ctx: &ConfigContext) -> super::Result<Connect> {
Ok(Connect {
script: self.parse_if_block::<String>("stage.connect.script", ctx)?,
concurrency: self
.parse_if_block::<u64>("stage.connect.concurrency", ctx)?
.unwrap_or_else(|| IfBlock::new(10000)),
throttle: self.parse_throttle_list("stage.connect.throttle", ctx)?,
})
}
}

View file

@ -1,98 +1,109 @@
use std::str::FromStr;
use super::{
utils::{AsKey, ParseKey},
Config, Throttle, ThrottleKey,
utils::{AsKey, ParseKey, ParseValue},
Config, ConfigContext, IfBlock, Throttle, ThrottleKey, ThrottleRate,
};
impl Config {
pub fn parse_throttle(&self, prefix: impl AsKey) -> super::Result<Vec<Throttle>> {
let mut throttles = Vec::new();
pub fn parse_throttle_list(
&self,
prefix: impl AsKey,
ctx: &ConfigContext,
) -> super::Result<Vec<Throttle>> {
let mut result = Vec::new();
let key = prefix.as_key();
for array_pos in self.sub_keys(prefix) {
let mut throttle = Throttle {
key: Vec::new(),
concurrency: 0,
rate_requests: 0,
rate_period: 0,
};
// Parse keys
for (_, throttle_key) in
self.properties::<ThrottleKey>((key.as_str(), array_pos, "key"))
{
throttle.key.push(throttle_key?);
}
// Parse concurrency
if let Some(concurrency) =
self.property::<u64>((key.as_str(), array_pos, "concurrency"))?
{
throttle.concurrency = concurrency;
}
// Parse rate
if let Some(rate) = self.value((key.as_str(), array_pos, "rate")) {
if let Some((requests, period)) = rate.split_once('/') {
throttle.rate_requests = requests
.trim()
.parse::<u64>()
.ok()
.and_then(|r| if r > 0 { Some(r) } else { None })
.ok_or_else(|| {
format!(
"Invalid rate value {:?} for key {:?}.",
rate,
(key.as_str(), array_pos, "rate").as_key()
)
})?;
throttle.rate_period =
period.parse_duration((key.as_str(), array_pos, "rate"))?;
} else {
return Err(format!(
"Invalid rate value {:?} for key {:?}.",
rate,
(key.as_str(), array_pos, "rate").as_key()
));
}
}
// Validate
if throttle.key.is_empty() {
return Err(format!(
"No throttle keys found in {:?}",
(key.as_str(), array_pos, "key").as_key()
));
} else if throttle.rate_requests == 0 && throttle.concurrency == 0 {
return Err(format!(
concat!(
"Throttle {:?} needs to define a ",
"valid 'rate' or 'concurrency' property."
),
(key.as_str(), array_pos).as_key()
));
}
throttles.push(throttle);
result.push(self.parse_throttle((key.as_str(), array_pos), ctx)?);
}
Ok(throttles)
Ok(result)
}
pub fn parse_throttle(
&self,
prefix: impl AsKey,
ctx: &ConfigContext,
) -> super::Result<Throttle> {
let prefix = prefix.as_key();
let throttle = Throttle {
key: self.parse_values((prefix.as_str(), "key"))?,
concurrency: if let Some(concurrency) =
self.parse_if_block::<u64>((prefix.as_str(), "concurrency"), ctx)?
{
concurrency
} else {
IfBlock::default()
},
rate: if let Some(rate) =
self.parse_if_block::<ThrottleRate>((prefix.as_str(), "rate"), ctx)?
{
rate
} else {
IfBlock::default()
},
};
// Validate
if throttle.key.is_empty() {
Err(format!("No throttle keys found in {:?}", prefix))
} else if throttle.rate.default.requests == 0 && throttle.concurrency.default == 0 {
Err(format!(
concat!(
"Throttle {:?} needs to define a ",
"valid 'rate' or 'concurrency' property."
),
prefix
))
} else {
Ok(throttle)
}
}
}
impl FromStr for ThrottleKey {
type Err = ();
impl ParseValue for ThrottleRate {
fn parse_value(key: impl AsKey, value: &str) -> super::Result<Self> {
if let Some((requests, period)) = value.split_once('/') {
Ok(ThrottleRate {
requests: requests
.trim()
.parse::<u64>()
.ok()
.and_then(|r| if r > 0 { Some(r) } else { None })
.ok_or_else(|| {
format!(
"Invalid rate value {:?} for property {:?}.",
value,
key.as_key()
)
})?,
period: period.parse_key(key)?,
})
} else {
Err(format!(
"Invalid rate value {:?} for property {:?}.",
value,
key.as_key()
))
}
}
}
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s {
impl ParseValue for ThrottleKey {
fn parse_value(key: impl AsKey, value: &str) -> super::Result<Self> {
Ok(match value {
"rcpt-domain" => ThrottleKey::RecipientDomain,
"sender-domain" => ThrottleKey::SenderDomain,
"listener" => ThrottleKey::Listener,
"mx" => ThrottleKey::Mx,
"remote-ip" => ThrottleKey::RemoteIp,
"local-ip" => ThrottleKey::LocalIp,
_ => return Err(()),
_ => {
return Err(format!(
"Invalid throttle key {:?} for property {:?}.",
value,
key.as_key()
))
}
})
}
}

View file

@ -1,21 +1,18 @@
use std::str::FromStr;
use std::{net::IpAddr, time::Duration};
use super::Config;
impl Config {
pub fn property<T: FromStr>(&self, key: impl AsKey) -> super::Result<Option<T>> {
pub fn property<T: ParseValue>(&self, key: impl AsKey) -> super::Result<Option<T>> {
let key = key.as_key();
if let Some(value) = self.keys.get(&key) {
match T::from_str(value) {
Ok(result) => Ok(Some(result)),
Err(_) => Err(format!("Invalid value {:?} for key {:?}.", value, key)),
}
T::parse_value(key, value).map(Some)
} else {
Ok(None)
}
}
pub fn property_or_default<T: FromStr>(
pub fn property_or_default<T: ParseValue>(
&self,
key: impl AsKey,
default: impl AsKey,
@ -26,7 +23,7 @@ impl Config {
}
}
pub fn property_require<T: FromStr>(&self, key: impl AsKey) -> super::Result<T> {
pub fn property_require<T: ParseValue>(&self, key: impl AsKey) -> super::Result<T> {
match self.property(key.clone()) {
Ok(Some(result)) => Ok(result),
Ok(None) => Err(format!("Missing property {:?}.", key.as_key())),
@ -54,18 +51,17 @@ impl Config {
})
}
pub fn properties<T: FromStr>(
pub fn properties<T: ParseValue>(
&self,
prefix: impl AsKey,
) -> impl Iterator<Item = (&str, super::Result<T>)> {
) -> impl Iterator<Item = super::Result<(&str, T)>> {
let full_prefix = prefix.as_key();
let prefix = prefix.as_prefix();
self.keys.iter().filter_map(move |(key, value)| {
if key.starts_with(&prefix) {
(
key.as_str(),
T::from_str(value)
.map_err(|_| format!("Invalid value {:?} for key {:?}.", value, key)),
)
if key.starts_with(&prefix) || key == &full_prefix {
T::parse_value(key.as_str(), value)
.map(|value| (key.as_str(), value))
.into()
} else {
None
@ -73,14 +69,45 @@ impl Config {
})
}
pub fn properties_or_default<T: FromStr>(
pub fn value(&self, key: impl AsKey) -> Option<&str> {
self.keys.get(&key.as_key()).map(|s| s.as_str())
}
pub fn value_require(&self, key: impl AsKey) -> super::Result<&str> {
self.keys
.get(&key.as_key())
.map(|s| s.as_str())
.ok_or_else(|| format!("Missing property {:?}.", key.as_key()))
}
pub fn value_or_default(&self, key: impl AsKey, default: impl AsKey) -> Option<&str> {
self.keys
.get(&key.as_key())
.or_else(|| self.keys.get(&default.as_key()))
.map(|s| s.as_str())
}
pub fn values(&self, prefix: impl AsKey) -> impl Iterator<Item = (&str, &str)> {
let full_prefix = prefix.as_key();
let prefix = prefix.as_prefix();
self.keys.iter().filter_map(move |(key, value)| {
if key.starts_with(&prefix) || key == &full_prefix {
(key.as_str(), value.as_str()).into()
} else {
None
}
})
}
pub fn values_or_default(
&self,
prefix: impl AsKey,
default: impl AsKey,
) -> impl Iterator<Item = (&str, super::Result<T>)> {
) -> impl Iterator<Item = (&str, &str)> {
let mut prefix = prefix.as_prefix();
self.properties(if self.keys.keys().any(|k| k.starts_with(&prefix)) {
self.values(if self.keys.keys().any(|k| k.starts_with(&prefix)) {
prefix.truncate(prefix.len() - 1);
prefix
} else {
@ -88,57 +115,185 @@ impl Config {
})
}
pub fn value(&self, key: impl AsKey) -> Option<&str> {
self.keys.get(&key.as_key()).map(|s| s.as_str())
}
pub fn take_value(&mut self, key: &str) -> Option<String> {
self.keys.remove(key)
}
pub fn file_contents(&self, key: impl AsKey) -> super::Result<Vec<u8>> {
let key_ = key.clone();
if let Some(value) = self.property::<String>(key_)? {
let key = key.as_key();
if let Some(value) = self.keys.get(&key) {
if value.starts_with("file://") {
std::fs::read(&value).map_err(|err| {
std::fs::read(value).map_err(|err| {
format!(
"Failed to read file {:?} for key {:?}: {}",
value,
key.as_key(),
err
"Failed to read file {:?} for property {:?}: {}",
value, key, err
)
})
} else {
Ok(value.into_bytes())
Ok(value.to_string().into_bytes())
}
} else {
Err(format!(
"Property {:?} not found in configuration file.",
key.as_key()
key
))
}
}
}
pub trait ParseKey {
fn parse_key<T: FromStr>(&self, key: impl AsKey) -> super::Result<T>;
fn parse_duration(&self, key: impl AsKey) -> super::Result<u64>;
}
impl ParseKey for &str {
fn parse_key<T: FromStr>(&self, key: impl AsKey) -> super::Result<T> {
match T::from_str(self) {
Ok(result) => Ok(result),
Err(_) => Err(format!(
"Invalid value {:?} for key {:?}.",
self,
key.as_key()
)),
pub fn parse_values<T: ParseValues>(&self, prefix: impl AsKey) -> super::Result<T> {
let mut result = T::default();
for (pos, (key, value)) in self.values(prefix.clone()).enumerate() {
if pos == 0 || T::is_multivalue() {
result.add_value(T::Item::parse_value(key, value)?);
} else {
return Err(format!(
"Property {:?} cannot have multiple values.",
prefix.as_key()
));
}
}
Ok(result)
}
}
pub trait ParseValues: Sized + Default {
type Item: ParseValue;
fn add_value(&mut self, value: Self::Item);
fn is_multivalue() -> bool;
}
pub trait ParseValue: Sized {
fn parse_value(key: impl AsKey, value: &str) -> super::Result<Self>;
}
pub trait ParseKey<T: ParseValue> {
fn parse_key(&self, key: impl AsKey) -> super::Result<T>;
}
impl<T: ParseValue> ParseKey<T> for &str {
fn parse_key(&self, key: impl AsKey) -> super::Result<T> {
T::parse_value(key, self)
}
}
impl<T: ParseValue> ParseKey<T> for String {
fn parse_key(&self, key: impl AsKey) -> super::Result<T> {
T::parse_value(key, self.as_str())
}
}
impl<T: ParseValue> ParseKey<T> for &String {
fn parse_key(&self, key: impl AsKey) -> super::Result<T> {
T::parse_value(key, self.as_str())
}
}
impl<T: ParseValue> ParseValues for Vec<T> {
type Item = T;
fn add_value(&mut self, value: Self::Item) {
self.push(value);
}
fn parse_duration(&self, key: impl AsKey) -> super::Result<u64> {
let duration = self.trim().to_ascii_uppercase();
fn is_multivalue() -> bool {
true
}
}
impl<T: ParseValue + Default> ParseValues for T {
type Item = T;
fn add_value(&mut self, value: Self::Item) {
*self = value;
}
fn is_multivalue() -> bool {
false
}
}
impl ParseValue for String {
fn parse_value(_key: impl AsKey, value: &str) -> super::Result<Self> {
Ok(value.to_string())
}
}
impl ParseValue for u64 {
fn parse_value(key: impl AsKey, value: &str) -> super::Result<Self> {
value.parse().map_err(|_| {
format!(
"Invalid integer value {:?} for property {:?}.",
value,
key.as_key()
)
})
}
}
impl ParseValue for i64 {
fn parse_value(key: impl AsKey, value: &str) -> super::Result<Self> {
value.parse().map_err(|_| {
format!(
"Invalid integer value {:?} for property {:?}.",
value,
key.as_key()
)
})
}
}
impl ParseValue for u32 {
fn parse_value(key: impl AsKey, value: &str) -> super::Result<Self> {
value.parse().map_err(|_| {
format!(
"Invalid integer value {:?} for property {:?}.",
value,
key.as_key()
)
})
}
}
impl ParseValue for IpAddr {
fn parse_value(key: impl AsKey, value: &str) -> super::Result<Self> {
value.parse().map_err(|_| {
format!(
"Invalid IP address value {:?} for property {:?}.",
value,
key.as_key()
)
})
}
}
impl ParseValue for usize {
fn parse_value(key: impl AsKey, value: &str) -> super::Result<Self> {
value.parse().map_err(|_| {
format!(
"Invalid integer value {:?} for property {:?}.",
value,
key.as_key()
)
})
}
}
impl ParseValue for bool {
fn parse_value(key: impl AsKey, value: &str) -> super::Result<Self> {
value.parse().map_err(|_| {
format!(
"Invalid boolean value {:?} for property {:?}.",
value,
key.as_key()
)
})
}
}
impl ParseValue for Duration {
fn parse_value(key: impl AsKey, value: &str) -> super::Result<Self> {
let duration = value.trim().to_ascii_uppercase();
let (num, multiplier) = if let Some(num) = duration.strip_prefix('d') {
(num, 24 * 60 * 60 * 1000)
} else if let Some(num) = duration.strip_prefix('h') {
@ -156,31 +311,21 @@ impl ParseKey for &str {
.ok()
.and_then(|num| {
if num > 0 {
Some(num * multiplier)
Some(Duration::from_millis(num * multiplier))
} else {
None
}
})
.ok_or_else(|| {
format!(
"Invalid duration value {:?} for key {:?}.",
self,
"Invalid duration value {:?} for property {:?}.",
value,
key.as_key()
)
})
}
}
impl ParseKey for String {
fn parse_key<T: FromStr>(&self, key: impl AsKey) -> super::Result<T> {
self.as_str().parse_key(key)
}
fn parse_duration(&self, key: impl AsKey) -> super::Result<u64> {
self.as_str().parse_duration(key)
}
}
pub trait AsKey: Clone {
fn as_key(&self) -> String;
fn as_prefix(&self) -> String;

View file

@ -12,8 +12,12 @@ use parking_lot::Mutex;
pub struct RateLimiter {
max_requests: f64,
max_interval: f64,
max_concurrent: u64,
limiter: Arc<Mutex<(Instant, f64)>>,
}
#[derive(Debug)]
pub struct ConcurrencyLimiter {
max_concurrent: u64,
concurrent: Arc<AtomicU64>,
}
@ -28,35 +32,46 @@ impl Drop for InFlightRequest {
}
impl RateLimiter {
pub fn new(max_concurrent: u64, max_requests: u64, max_interval: u64) -> Self {
pub fn new(max_requests: u64, max_interval: u64) -> Self {
RateLimiter {
max_concurrent,
max_requests: max_requests as f64,
max_interval: max_interval as f64,
limiter: Arc::new(Mutex::new((Instant::now(), max_requests as f64))),
}
}
pub fn is_allowed(&self) -> bool {
// Check rate limit
let mut limiter = self.limiter.lock();
let elapsed = limiter.0.elapsed().as_secs_f64();
limiter.0 = Instant::now();
limiter.1 += elapsed * (self.max_requests / self.max_interval);
if limiter.1 > self.max_requests {
limiter.1 = self.max_requests;
}
if limiter.1 >= 1.0 {
limiter.1 -= 1.0;
true
} else {
false
}
}
pub fn reset(&self) {
*self.limiter.lock() = (Instant::now(), self.max_requests);
}
}
impl ConcurrencyLimiter {
pub fn new(max_concurrent: u64) -> Self {
ConcurrencyLimiter {
max_concurrent,
concurrent: Arc::new(0.into()),
}
}
pub fn is_allowed(&self) -> Option<InFlightRequest> {
if self.max_concurrent == 0 || self.concurrent.load(Ordering::Relaxed) < self.max_concurrent
{
// Check rate limit
if self.max_requests > 0.0 {
let mut limiter = self.limiter.lock();
let elapsed = limiter.0.elapsed().as_secs_f64();
limiter.0 = Instant::now();
limiter.1 += elapsed * (self.max_requests / self.max_interval);
if limiter.1 > self.max_requests {
limiter.1 = self.max_requests;
}
if limiter.1 >= 1.0 {
limiter.1 -= 1.0;
} else {
return None;
}
}
if self.concurrent.load(Ordering::Relaxed) < self.max_concurrent {
// Return in-flight request
self.concurrent.fetch_add(1, Ordering::Relaxed);
Some(InFlightRequest {
@ -66,8 +81,4 @@ impl RateLimiter {
None
}
}
pub fn reset(&self) {
*self.limiter.lock() = (Instant::now(), self.max_requests);
}
}