Add tests for bind in the handler

This commit is contained in:
Valentin Tolmer 2021-04-11 22:01:24 +02:00
parent 49404b24d7
commit 71045b08fe
3 changed files with 40 additions and 31 deletions

View file

@ -27,6 +27,7 @@ pub struct ListUsersRequest {
pub filters: Option<RequestFilter>,
}
#[derive(sqlx::FromRow)]
#[cfg_attr(test, derive(PartialEq, Eq, Debug))]
pub struct User {
pub user_id: String,
@ -35,29 +36,25 @@ pub struct User {
pub first_name: String,
pub last_name: String,
// pub avatar: ?,
pub creation_date: chrono::NaiveDateTime,
// TODO: wait until supported for Any
// pub creation_date: chrono::NaiveDateTime,
}
#[async_trait]
pub trait BackendHandler: Clone + Send {
async fn bind(&mut self, request: BindRequest) -> Result<()>;
async fn list_users(&mut self, request: ListUsersRequest) -> Result<Vec<User>>;
async fn bind(&self, request: BindRequest) -> Result<()>;
async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>>;
}
#[derive(Debug, Clone)]
pub struct SqlBackendHandler {
config: Configuration,
sql_pool: AnyPool,
authenticated: bool,
}
impl SqlBackendHandler {
pub fn new(config: Configuration, sql_pool: AnyPool) -> Self {
SqlBackendHandler {
config,
sql_pool,
authenticated: false,
}
SqlBackendHandler { config, sql_pool }
}
}
@ -88,10 +85,9 @@ fn get_filter_expr(filter: RequestFilter) -> SimpleExpr {
#[async_trait]
impl BackendHandler for SqlBackendHandler {
async fn bind(&mut self, request: BindRequest) -> Result<()> {
async fn bind(&self, request: BindRequest) -> Result<()> {
if request.name == self.config.ldap_user_dn {
if request.password == self.config.ldap_user_pass {
self.authenticated = true;
return Ok(());
} else {
bail!(r#"Authentication error for "{}""#, request.name)
@ -110,7 +106,7 @@ impl BackendHandler for SqlBackendHandler {
bail!(r#"Authentication error for "{}""#, request.name)
}
async fn list_users(&mut self, request: ListUsersRequest) -> Result<Vec<User>> {
async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>> {
let query = {
let mut query_builder = Query::select()
.column(Users::UserId)
@ -133,15 +129,7 @@ impl BackendHandler for SqlBackendHandler {
query_builder.to_string(MysqlQueryBuilder)
};
let results = sqlx::query(&query)
.map(|row: sqlx::any::AnyRow| User {
user_id: row.get::<String, _>("user_id"),
email: row.get::<String, _>("email"),
display_name: row.get::<String, _>("display_name"),
first_name: row.get::<String, _>("first_name"),
last_name: row.get::<String, _>("last_name"),
creation_date: chrono::NaiveDateTime::from_timestamp(0, 0), // TODO: wait until datetime is supported for Any.
})
let results = sqlx::query_as::<_, User>(&query)
.fetch(&self.sql_pool)
.collect::<Vec<sqlx::Result<User>>>()
.await;
@ -158,7 +146,32 @@ mockall::mock! {
}
#[async_trait]
impl BackendHandler for TestBackendHandler {
async fn bind(&mut self, request: BindRequest) -> Result<()>;
async fn list_users(&mut self, request: ListUsersRequest) -> Result<Vec<User>>;
async fn bind(&self, request: BindRequest) -> Result<()>;
async fn list_users(&self, request: ListUsersRequest) -> Result<Vec<User>>;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_bind_admin() {
let sql_pool = sqlx::any::AnyPoolOptions::new()
.connect("sqlite::memory:")
.await
.unwrap();
let mut config = Configuration::default();
config.ldap_user_dn = "admin".to_string();
config.ldap_user_pass = "test".to_string();
let handler = SqlBackendHandler::new(config, sql_pool);
assert!(true);
assert!(handler
.bind(BindRequest {
name: "admin".to_string(),
password: "test".to_string()
})
.await
.is_ok());
}
}

View file

@ -57,7 +57,7 @@ pub async fn init_table(pool: &AnyPool) -> sqlx::Result<()> {
.col(ColumnDef::new(Users::Password).string_len(255).not_null())
.col(ColumnDef::new(Users::TotpSecret).string_len(64))
.col(ColumnDef::new(Users::MfaType).string_len(64))
.to_string(MysqlQueryBuilder),
.to_string(SqliteQueryBuilder),
)
.execute(pool)
.await?;
@ -69,7 +69,6 @@ pub async fn init_table(pool: &AnyPool) -> sqlx::Result<()> {
ColumnDef::new(Groups::GroupId)
.integer()
.not_null()
.auto_increment()
.primary_key(),
)
.col(
@ -77,7 +76,7 @@ pub async fn init_table(pool: &AnyPool) -> sqlx::Result<()> {
.string_len(255)
.not_null(),
)
.to_string(MysqlQueryBuilder),
.to_string(SqliteQueryBuilder),
)
.execute(pool)
.await?;
@ -94,8 +93,7 @@ pub async fn init_table(pool: &AnyPool) -> sqlx::Result<()> {
.col(
ColumnDef::new(Memberships::GroupId)
.integer()
.not_null()
.auto_increment(),
.not_null(),
)
.foreign_key(
ForeignKey::create()
@ -113,7 +111,7 @@ pub async fn init_table(pool: &AnyPool) -> sqlx::Result<()> {
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.to_string(MysqlQueryBuilder),
.to_string(SqliteQueryBuilder),
)
.execute(pool)
.await?;

View file

@ -358,7 +358,6 @@ mod tests {
display_name: "Bôb Böbberson".to_string(),
first_name: "Bôb".to_string(),
last_name: "Böbberson".to_string(),
creation_date: NaiveDateTime::from_timestamp(1_000_000_000, 0),
},
User {
user_id: "jim".to_string(),
@ -366,7 +365,6 @@ mod tests {
display_name: "Jimminy Cricket".to_string(),
first_name: "Jim".to_string(),
last_name: "Cricket".to_string(),
creation_date: NaiveDateTime::from_timestamp(1_003_000_000, 0),
},
])
});