From 9e59056e7fdf7886ba31461543b5942089cca3fa Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 30 Nov 2022 14:18:46 +0100 Subject: [PATCH] Implement `db2::Database::get_user_by_github_account` --- crates/collab/src/db2.rs | 97 +++++++++++++++++++++++------ crates/collab/src/db2/tests.rs | 108 ++++++++++++++++----------------- 2 files changed, 133 insertions(+), 72 deletions(-) diff --git a/crates/collab/src/db2.rs b/crates/collab/src/db2.rs index 47ddf8cd22..1d50437a9c 100644 --- a/crates/collab/src/db2.rs +++ b/crates/collab/src/db2.rs @@ -13,11 +13,11 @@ use collections::HashMap; use dashmap::DashMap; use futures::StreamExt; use rpc::{proto, ConnectionId}; -use sea_orm::ActiveValue; use sea_orm::{ entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr, TransactionTrait, }; +use sea_orm::{ActiveValue, IntoActiveModel}; use sea_query::OnConflict; use serde::{Deserialize, Serialize}; use sqlx::migrate::{Migrate, Migration, MigrationSource}; @@ -31,7 +31,7 @@ use tokio::sync::{Mutex, OwnedMutexGuard}; pub use user::Model as User; pub struct Database { - url: String, + options: ConnectOptions, pool: DatabaseConnection, rooms: DashMap>>, #[cfg(test)] @@ -41,11 +41,9 @@ pub struct Database { } impl Database { - pub async fn new(url: &str, max_connections: u32) -> Result { - let mut options = ConnectOptions::new(url.into()); - options.max_connections(max_connections); + pub async fn new(options: ConnectOptions) -> Result { Ok(Self { - url: url.into(), + options: options.clone(), pool: sea_orm::Database::connect(options).await?, rooms: DashMap::with_capacity(16384), #[cfg(test)] @@ -59,12 +57,12 @@ impl Database { &self, migrations_path: &Path, ignore_checksum_mismatch: bool, - ) -> anyhow::Result<(sqlx::AnyConnection, Vec<(Migration, Duration)>)> { + ) -> anyhow::Result> { let migrations = MigrationSource::resolve(migrations_path) .await .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?; - let mut connection = sqlx::AnyConnection::connect(&self.url).await?; + let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?; connection.ensure_migrations_table().await?; let applied_migrations: HashMap<_, _> = connection @@ -93,7 +91,7 @@ impl Database { } } - Ok((connection, new_migrations)) + Ok(new_migrations) } pub async fn create_user( @@ -142,6 +140,43 @@ impl Database { .await } + pub async fn get_user_by_github_account( + &self, + github_login: &str, + github_user_id: Option, + ) -> Result> { + self.transact(|tx| async { + let tx = tx; + if let Some(github_user_id) = github_user_id { + if let Some(user_by_github_user_id) = user::Entity::find() + .filter(user::Column::GithubUserId.eq(github_user_id)) + .one(&tx) + .await? + { + let mut user_by_github_user_id = user_by_github_user_id.into_active_model(); + user_by_github_user_id.github_login = ActiveValue::set(github_login.into()); + Ok(Some(user_by_github_user_id.update(&tx).await?)) + } else if let Some(user_by_github_login) = user::Entity::find() + .filter(user::Column::GithubLogin.eq(github_login)) + .one(&tx) + .await? + { + let mut user_by_github_login = user_by_github_login.into_active_model(); + user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id)); + Ok(Some(user_by_github_login.update(&tx).await?)) + } else { + Ok(None) + } + } else { + Ok(user::Entity::find() + .filter(user::Column::GithubLogin.eq(github_login)) + .one(&tx) + .await?) + } + }) + .await + } + pub async fn share_project( &self, room_id: RoomId, @@ -545,7 +580,9 @@ mod test { .unwrap(); let mut db = runtime.block_on(async { - let db = Database::new(&url, 5).await.unwrap(); + let mut options = ConnectOptions::new(url); + options.max_connections(5); + let db = Database::new(options).await.unwrap(); let sql = include_str!(concat!( env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite/20221109000000_test_schema.sql" @@ -590,7 +627,11 @@ mod test { sqlx::Postgres::create_database(&url) .await .expect("failed to create test db"); - let db = Database::new(&url, 5).await.unwrap(); + let mut options = ConnectOptions::new(url); + options + .max_connections(5) + .idle_timeout(Duration::from_secs(0)); + let db = Database::new(options).await.unwrap(); let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); db.migrate(Path::new(migrations_path), false).await.unwrap(); db @@ -610,11 +651,31 @@ mod test { } } - // TODO: Implement drop - // impl Drop for PostgresTestDb { - // fn drop(&mut self) { - // let db = self.db.take().unwrap(); - // db.teardown(&self.url); - // } - // } + impl Drop for TestDb { + fn drop(&mut self) { + let db = self.db.take().unwrap(); + if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() { + db.runtime.as_ref().unwrap().block_on(async { + use util::ResultExt; + let query = " + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE + pg_stat_activity.datname = current_database() AND + pid <> pg_backend_pid(); + "; + db.pool + .execute(sea_orm::Statement::from_string( + db.pool.get_database_backend(), + query.into(), + )) + .await + .log_err(); + sqlx::Postgres::drop_database(db.options.get_url()) + .await + .log_err(); + }) + } + } + } } diff --git a/crates/collab/src/db2/tests.rs b/crates/collab/src/db2/tests.rs index a5bac24140..60d3fa64b0 100644 --- a/crates/collab/src/db2/tests.rs +++ b/crates/collab/src/db2/tests.rs @@ -88,63 +88,63 @@ test_both_dbs!( } ); -// test_both_dbs!( -// test_get_user_by_github_account_postgres, -// test_get_user_by_github_account_sqlite, -// db, -// { -// let user_id1 = db -// .create_user( -// "user1@example.com", -// false, -// NewUserParams { -// github_login: "login1".into(), -// github_user_id: 101, -// invite_count: 0, -// }, -// ) -// .await -// .unwrap() -// .user_id; -// let user_id2 = db -// .create_user( -// "user2@example.com", -// false, -// NewUserParams { -// github_login: "login2".into(), -// github_user_id: 102, -// invite_count: 0, -// }, -// ) -// .await -// .unwrap() -// .user_id; +test_both_dbs!( + test_get_user_by_github_account_postgres, + test_get_user_by_github_account_sqlite, + db, + { + let user_id1 = db + .create_user( + "user1@example.com", + false, + NewUserParams { + github_login: "login1".into(), + github_user_id: 101, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id; + let user_id2 = db + .create_user( + "user2@example.com", + false, + NewUserParams { + github_login: "login2".into(), + github_user_id: 102, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id; -// let user = db -// .get_user_by_github_account("login1", None) -// .await -// .unwrap() -// .unwrap(); -// assert_eq!(user.id, user_id1); -// assert_eq!(&user.github_login, "login1"); -// assert_eq!(user.github_user_id, Some(101)); + let user = db + .get_user_by_github_account("login1", None) + .await + .unwrap() + .unwrap(); + assert_eq!(user.id, user_id1); + assert_eq!(&user.github_login, "login1"); + assert_eq!(user.github_user_id, Some(101)); -// assert!(db -// .get_user_by_github_account("non-existent-login", None) -// .await -// .unwrap() -// .is_none()); + assert!(db + .get_user_by_github_account("non-existent-login", None) + .await + .unwrap() + .is_none()); -// let user = db -// .get_user_by_github_account("the-new-login2", Some(102)) -// .await -// .unwrap() -// .unwrap(); -// assert_eq!(user.id, user_id2); -// assert_eq!(&user.github_login, "the-new-login2"); -// assert_eq!(user.github_user_id, Some(102)); -// } -// ); + let user = db + .get_user_by_github_account("the-new-login2", Some(102)) + .await + .unwrap() + .unwrap(); + assert_eq!(user.id, user_id2); + assert_eq!(&user.github_login, "the-new-login2"); + assert_eq!(user.github_user_id, Some(102)); + } +); // test_both_dbs!( // test_create_access_tokens_postgres,