diff --git a/crates/collab/src/db2.rs b/crates/collab/src/db2.rs index 1d50437a9c..e2a03931d8 100644 --- a/crates/collab/src/db2.rs +++ b/crates/collab/src/db2.rs @@ -1,3 +1,4 @@ +mod access_token; mod project; mod project_collaborator; mod room; @@ -17,8 +18,8 @@ use sea_orm::{ entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr, TransactionTrait, }; -use sea_orm::{ActiveValue, IntoActiveModel}; -use sea_query::OnConflict; +use sea_orm::{ActiveValue, ConnectionTrait, IntoActiveModel, QueryOrder, QuerySelect}; +use sea_query::{OnConflict, Query}; use serde::{Deserialize, Serialize}; use sqlx::migrate::{Migrate, Migration, MigrationSource}; use sqlx::Connection; @@ -336,6 +337,63 @@ impl Database { }) } + pub async fn create_access_token_hash( + &self, + user_id: UserId, + access_token_hash: &str, + max_access_token_count: usize, + ) -> Result<()> { + self.transact(|tx| async { + let tx = tx; + + access_token::ActiveModel { + user_id: ActiveValue::set(user_id), + hash: ActiveValue::set(access_token_hash.into()), + ..Default::default() + } + .insert(&tx) + .await?; + + access_token::Entity::delete_many() + .filter( + access_token::Column::Id.in_subquery( + Query::select() + .column(access_token::Column::Id) + .from(access_token::Entity) + .and_where(access_token::Column::UserId.eq(user_id)) + .order_by(access_token::Column::Id, sea_orm::Order::Desc) + .limit(10000) + .offset(max_access_token_count as u64) + .to_owned(), + ), + ) + .exec(&tx) + .await?; + tx.commit().await?; + Ok(()) + }) + .await + } + + pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + Hash, + } + + self.transact(|tx| async move { + Ok(access_token::Entity::find() + .select_only() + .column(access_token::Column::Hash) + .filter(access_token::Column::UserId.eq(user_id)) + .order_by_desc(access_token::Column::Id) + .into_values::<_, QueryAs>() + .all(&tx) + .await?) + }) + .await + } + async fn transact(&self, f: F) -> Result where F: Send + Fn(DatabaseTransaction) -> Fut, @@ -344,6 +402,16 @@ impl Database { let body = async { loop { let tx = self.pool.begin().await?; + + // In Postgres, serializable transactions are opt-in + if let sea_orm::DatabaseBackend::Postgres = self.pool.get_database_backend() { + tx.execute(sea_orm::Statement::from_string( + sea_orm::DatabaseBackend::Postgres, + "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(), + )) + .await?; + } + match f(tx).await { Ok(result) => return Ok(result), Err(error) => match error { @@ -544,6 +612,7 @@ macro_rules! id_type { }; } +id_type!(AccessTokenId); id_type!(UserId); id_type!(RoomId); id_type!(RoomParticipantId); diff --git a/crates/collab/src/db2/access_token.rs b/crates/collab/src/db2/access_token.rs new file mode 100644 index 0000000000..f5caa4843d --- /dev/null +++ b/crates/collab/src/db2/access_token.rs @@ -0,0 +1,29 @@ +use super::{AccessTokenId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "access_tokens")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: AccessTokenId, + pub user_id: UserId, + pub hash: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + User, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::User.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db2/tests.rs b/crates/collab/src/db2/tests.rs index 60d3fa64b0..e26ffee7a8 100644 --- a/crates/collab/src/db2/tests.rs +++ b/crates/collab/src/db2/tests.rs @@ -146,51 +146,51 @@ test_both_dbs!( } ); -// test_both_dbs!( -// test_create_access_tokens_postgres, -// test_create_access_tokens_sqlite, -// db, -// { -// let user = db -// .create_user( -// "u1@example.com", -// false, -// NewUserParams { -// github_login: "u1".into(), -// github_user_id: 1, -// invite_count: 0, -// }, -// ) -// .await -// .unwrap() -// .user_id; +test_both_dbs!( + test_create_access_tokens_postgres, + test_create_access_tokens_sqlite, + db, + { + let user = db + .create_user( + "u1@example.com", + false, + NewUserParams { + github_login: "u1".into(), + github_user_id: 1, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id; -// db.create_access_token_hash(user, "h1", 3).await.unwrap(); -// db.create_access_token_hash(user, "h2", 3).await.unwrap(); -// assert_eq!( -// db.get_access_token_hashes(user).await.unwrap(), -// &["h2".to_string(), "h1".to_string()] -// ); + db.create_access_token_hash(user, "h1", 3).await.unwrap(); + db.create_access_token_hash(user, "h2", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h2".to_string(), "h1".to_string()] + ); -// db.create_access_token_hash(user, "h3", 3).await.unwrap(); -// assert_eq!( -// db.get_access_token_hashes(user).await.unwrap(), -// &["h3".to_string(), "h2".to_string(), "h1".to_string(),] -// ); + db.create_access_token_hash(user, "h3", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h3".to_string(), "h2".to_string(), "h1".to_string(),] + ); -// db.create_access_token_hash(user, "h4", 3).await.unwrap(); -// assert_eq!( -// db.get_access_token_hashes(user).await.unwrap(), -// &["h4".to_string(), "h3".to_string(), "h2".to_string(),] -// ); + db.create_access_token_hash(user, "h4", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h4".to_string(), "h3".to_string(), "h2".to_string(),] + ); -// db.create_access_token_hash(user, "h5", 3).await.unwrap(); -// assert_eq!( -// db.get_access_token_hashes(user).await.unwrap(), -// &["h5".to_string(), "h4".to_string(), "h3".to_string()] -// ); -// } -// ); + db.create_access_token_hash(user, "h5", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h5".to_string(), "h4".to_string(), "h3".to_string()] + ); + } +); // test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { // let mut user_ids = Vec::new(); diff --git a/crates/collab/src/db2/user.rs b/crates/collab/src/db2/user.rs index a0e21f9811..5e8a484571 100644 --- a/crates/collab/src/db2/user.rs +++ b/crates/collab/src/db2/user.rs @@ -17,6 +17,15 @@ pub struct Model { } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} +pub enum Relation { + #[sea_orm(has_many = "super::access_token::Entity")] + AccessToken, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::AccessToken.def() + } +} impl ActiveModelBehavior for ActiveModel {}