diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index fd1ed7d50f..e667930cad 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -106,10 +106,10 @@ impl Database { } pub async fn clear_stale_data(&self) -> Result<()> { - self.transact(|tx| async { + self.transaction(|tx| async move { project_collaborator::Entity::delete_many() .filter(project_collaborator::Column::ConnectionEpoch.ne(self.epoch)) - .exec(&tx) + .exec(&*tx) .await?; room_participant::Entity::delete_many() .filter( @@ -117,11 +117,11 @@ impl Database { .ne(self.epoch) .or(room_participant::Column::CallingConnectionEpoch.ne(self.epoch)), ) - .exec(&tx) + .exec(&*tx) .await?; project::Entity::delete_many() .filter(project::Column::HostConnectionEpoch.ne(self.epoch)) - .exec(&tx) + .exec(&*tx) .await?; room::Entity::delete_many() .filter( @@ -133,9 +133,8 @@ impl Database { .to_owned(), ), ) - .exec(&tx) + .exec(&*tx) .await?; - tx.commit().await?; Ok(()) }) .await @@ -149,7 +148,8 @@ impl Database { admin: bool, params: NewUserParams, ) -> Result { - self.transact(|tx| async { + self.transaction(|tx| async { + let tx = tx; let user = user::Entity::insert(user::ActiveModel { email_address: ActiveValue::set(Some(email_address.into())), github_login: ActiveValue::set(params.github_login.clone()), @@ -163,11 +163,9 @@ impl Database { .update_column(user::Column::GithubLogin) .to_owned(), ) - .exec_with_returning(&tx) + .exec_with_returning(&*tx) .await?; - tx.commit().await?; - Ok(NewUserResult { user_id: user.id, metrics_id: user.metrics_id.to_string(), @@ -179,16 +177,16 @@ impl Database { } pub async fn get_user_by_id(&self, id: UserId) -> Result> { - self.transact(|tx| async move { Ok(user::Entity::find_by_id(id).one(&tx).await?) }) + self.transaction(|tx| async move { Ok(user::Entity::find_by_id(id).one(&*tx).await?) }) .await } pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { - self.transact(|tx| async { + self.transaction(|tx| async { let tx = tx; Ok(user::Entity::find() .filter(user::Column::Id.is_in(ids.iter().copied())) - .all(&tx) + .all(&*tx) .await?) }) .await @@ -199,32 +197,32 @@ impl Database { github_login: &str, github_user_id: Option, ) -> Result> { - self.transact(|tx| async { - let tx = tx; + self.transaction(|tx| async move { + let tx = &*tx; if let Some(github_user_id) = github_user_id { if let Some(user_by_github_user_id) = user::Entity::find() .filter(user::Column::GithubUserId.eq(github_user_id)) - .one(&tx) + .one(tx) .await? { let mut user_by_github_user_id = user_by_github_user_id.into_active_model(); user_by_github_user_id.github_login = ActiveValue::set(github_login.into()); - Ok(Some(user_by_github_user_id.update(&tx).await?)) + Ok(Some(user_by_github_user_id.update(tx).await?)) } else if let Some(user_by_github_login) = user::Entity::find() .filter(user::Column::GithubLogin.eq(github_login)) - .one(&tx) + .one(tx) .await? { let mut user_by_github_login = user_by_github_login.into_active_model(); user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id)); - Ok(Some(user_by_github_login.update(&tx).await?)) + Ok(Some(user_by_github_login.update(tx).await?)) } else { Ok(None) } } else { Ok(user::Entity::find() .filter(user::Column::GithubLogin.eq(github_login)) - .one(&tx) + .one(tx) .await?) } }) @@ -232,12 +230,12 @@ impl Database { } pub async fn get_all_users(&self, page: u32, limit: u32) -> Result> { - self.transact(|tx| async move { + self.transaction(|tx| async move { Ok(user::Entity::find() .order_by_asc(user::Column::GithubLogin) .limit(limit as u64) .offset(page as u64 * limit as u64) - .all(&tx) + .all(&*tx) .await?) }) .await @@ -247,7 +245,7 @@ impl Database { &self, invited_by_another_user: bool, ) -> Result> { - self.transact(|tx| async move { + self.transaction(|tx| async move { Ok(user::Entity::find() .filter( user::Column::InviteCount @@ -258,7 +256,7 @@ impl Database { user::Column::InviterId.is_null() }), ) - .all(&tx) + .all(&*tx) .await?) }) .await @@ -270,12 +268,12 @@ impl Database { MetricsId, } - self.transact(|tx| async move { + self.transaction(|tx| async move { let metrics_id: Uuid = user::Entity::find_by_id(id) .select_only() .column(user::Column::MetricsId) .into_values::<_, QueryAs>() - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("could not find user"))?; Ok(metrics_id.to_string()) @@ -284,45 +282,42 @@ impl Database { } pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { - self.transact(|tx| async move { + self.transaction(|tx| async move { user::Entity::update_many() .filter(user::Column::Id.eq(id)) .set(user::ActiveModel { admin: ActiveValue::set(is_admin), ..Default::default() }) - .exec(&tx) + .exec(&*tx) .await?; - tx.commit().await?; Ok(()) }) .await } pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> { - self.transact(|tx| async move { + self.transaction(|tx| async move { user::Entity::update_many() .filter(user::Column::Id.eq(id)) .set(user::ActiveModel { connected_once: ActiveValue::set(connected_once), ..Default::default() }) - .exec(&tx) + .exec(&*tx) .await?; - tx.commit().await?; Ok(()) }) .await } pub async fn destroy_user(&self, id: UserId) -> Result<()> { - self.transact(|tx| async move { + self.transaction(|tx| async move { access_token::Entity::delete_many() .filter(access_token::Column::UserId.eq(id)) - .exec(&tx) + .exec(&*tx) .await?; - user::Entity::delete_by_id(id).exec(&tx).await?; - tx.commit().await?; + user::Entity::delete_by_id(id).exec(&*tx).await?; Ok(()) }) .await @@ -342,7 +337,7 @@ impl Database { user_b_busy: bool, } - self.transact(|tx| async move { + self.transaction(|tx| async move { let user_a_participant = Alias::new("user_a_participant"); let user_b_participant = Alias::new("user_b_participant"); let mut db_contacts = contact::Entity::find() @@ -372,7 +367,7 @@ impl Database { user_b_participant, ) .into_model::() - .stream(&tx) + .stream(&*tx) .await?; let mut contacts = Vec::new(); @@ -421,10 +416,10 @@ impl Database { } pub async fn is_user_busy(&self, user_id: UserId) -> Result { - self.transact(|tx| async move { + self.transaction(|tx| async move { let participant = room_participant::Entity::find() .filter(room_participant::Column::UserId.eq(user_id)) - .one(&tx) + .one(&*tx) .await?; Ok(participant.is_some()) }) @@ -432,7 +427,7 @@ impl Database { } pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { - self.transact(|tx| async move { + self.transaction(|tx| async move { let (id_a, id_b) = if user_id_1 < user_id_2 { (user_id_1, user_id_2) } else { @@ -446,7 +441,7 @@ impl Database { .and(contact::Column::UserIdB.eq(id_b)) .and(contact::Column::Accepted.eq(true)), ) - .one(&tx) + .one(&*tx) .await? .is_some()) }) @@ -454,7 +449,7 @@ impl Database { } pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { - self.transact(|tx| async move { + self.transaction(|tx| async move { let (id_a, id_b, a_to_b) = if sender_id < receiver_id { (sender_id, receiver_id, true) } else { @@ -487,11 +482,10 @@ impl Database { ) .to_owned(), ) - .exec_without_returning(&tx) + .exec_without_returning(&*tx) .await?; if rows_affected == 1 { - tx.commit().await?; Ok(()) } else { Err(anyhow!("contact already requested"))? @@ -501,7 +495,7 @@ impl Database { } pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { - self.transact(|tx| async move { + self.transaction(|tx| async move { let (id_a, id_b) = if responder_id < requester_id { (responder_id, requester_id) } else { @@ -514,11 +508,10 @@ impl Database { .eq(id_a) .and(contact::Column::UserIdB.eq(id_b)), ) - .exec(&tx) + .exec(&*tx) .await?; if result.rows_affected == 1 { - tx.commit().await?; Ok(()) } else { Err(anyhow!("no such contact"))? @@ -532,7 +525,7 @@ impl Database { user_id: UserId, contact_user_id: UserId, ) -> Result<()> { - self.transact(|tx| async move { + self.transaction(|tx| async move { let (id_a, id_b, a_to_b) = if user_id < contact_user_id { (user_id, contact_user_id, true) } else { @@ -557,12 +550,11 @@ impl Database { .and(contact::Column::Accepted.eq(false))), ), ) - .exec(&tx) + .exec(&*tx) .await?; if result.rows_affected == 0 { Err(anyhow!("no such contact request"))? } else { - tx.commit().await?; Ok(()) } }) @@ -575,7 +567,7 @@ impl Database { requester_id: UserId, accept: bool, ) -> Result<()> { - self.transact(|tx| async move { + self.transaction(|tx| async move { let (id_a, id_b, a_to_b) = if responder_id < requester_id { (responder_id, requester_id, false) } else { @@ -594,7 +586,7 @@ impl Database { .and(contact::Column::UserIdB.eq(id_b)) .and(contact::Column::AToB.eq(a_to_b)), ) - .exec(&tx) + .exec(&*tx) .await?; result.rows_affected } else { @@ -606,14 +598,13 @@ impl Database { .and(contact::Column::AToB.eq(a_to_b)) .and(contact::Column::Accepted.eq(false)), ) - .exec(&tx) + .exec(&*tx) .await?; result.rows_affected }; if rows_affected == 1 { - tx.commit().await?; Ok(()) } else { Err(anyhow!("no such contact request"))? @@ -635,7 +626,7 @@ impl Database { } pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { - self.transact(|tx| async { + self.transaction(|tx| async { let tx = tx; let like_string = Self::fuzzy_like_string(name_query); let query = " @@ -652,7 +643,7 @@ impl Database { query.into(), vec![like_string.into(), name_query.into(), limit.into()], )) - .all(&tx) + .all(&*tx) .await?) }) .await @@ -661,7 +652,7 @@ impl Database { // signups pub async fn create_signup(&self, signup: &NewSignup) -> Result<()> { - self.transact(|tx| async { + self.transaction(|tx| async move { signup::Entity::insert(signup::ActiveModel { email_address: ActiveValue::set(signup.email_address.clone()), email_confirmation_code: ActiveValue::set(random_email_confirmation_code()), @@ -681,16 +672,15 @@ impl Database { .update_column(signup::Column::EmailAddress) .to_owned(), ) - .exec(&tx) + .exec(&*tx) .await?; - tx.commit().await?; Ok(()) }) .await } pub async fn get_waitlist_summary(&self) -> Result { - self.transact(|tx| async move { + self.transaction(|tx| async move { let query = " SELECT COUNT(*) as count, @@ -711,7 +701,7 @@ impl Database { query.into(), vec![], )) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("invalid result"))?, ) @@ -724,23 +714,23 @@ impl Database { .iter() .map(|s| s.email_address.as_str()) .collect::>(); - self.transact(|tx| async { + self.transaction(|tx| async { + let tx = tx; signup::Entity::update_many() .filter(signup::Column::EmailAddress.is_in(emails.iter().copied())) .set(signup::ActiveModel { email_confirmation_sent: ActiveValue::set(true), ..Default::default() }) - .exec(&tx) + .exec(&*tx) .await?; - tx.commit().await?; Ok(()) }) .await } pub async fn get_unsent_invites(&self, count: usize) -> Result> { - self.transact(|tx| async move { + self.transaction(|tx| async move { Ok(signup::Entity::find() .select_only() .column(signup::Column::EmailAddress) @@ -755,7 +745,7 @@ impl Database { .order_by_asc(signup::Column::CreatedAt) .limit(count as u64) .into_model() - .all(&tx) + .all(&*tx) .await?) }) .await @@ -769,10 +759,10 @@ impl Database { email_address: &str, device_id: Option<&str>, ) -> Result { - self.transact(|tx| async move { + self.transaction(|tx| async move { let existing_user = user::Entity::find() .filter(user::Column::EmailAddress.eq(email_address)) - .one(&tx) + .one(&*tx) .await?; if existing_user.is_some() { @@ -785,7 +775,7 @@ impl Database { .eq(code) .and(user::Column::InviteCount.gt(0)), ) - .one(&tx) + .one(&*tx) .await? { Some(inviting_user) => inviting_user, @@ -806,7 +796,7 @@ impl Database { user::Column::InviteCount, Expr::col(user::Column::InviteCount).sub(1), ) - .exec(&tx) + .exec(&*tx) .await?; let signup = signup::Entity::insert(signup::ActiveModel { @@ -826,9 +816,8 @@ impl Database { .update_column(signup::Column::InvitingUserId) .to_owned(), ) - .exec_with_returning(&tx) + .exec_with_returning(&*tx) .await?; - tx.commit().await?; Ok(Invite { email_address: signup.email_address, @@ -843,7 +832,7 @@ impl Database { invite: &Invite, user: NewUserParams, ) -> Result> { - self.transact(|tx| async { + self.transaction(|tx| async { let tx = tx; let signup = signup::Entity::find() .filter( @@ -854,7 +843,7 @@ impl Database { .eq(invite.email_confirmation_code.as_str()), ), ) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?; @@ -881,12 +870,12 @@ impl Database { ]) .to_owned(), ) - .exec_with_returning(&tx) + .exec_with_returning(&*tx) .await?; let mut signup = signup.into_active_model(); signup.user_id = ActiveValue::set(Some(user.id)); - let signup = signup.update(&tx).await?; + let signup = signup.update(&*tx).await?; if let Some(inviting_user_id) = signup.inviting_user_id { contact::Entity::insert(contact::ActiveModel { @@ -898,11 +887,10 @@ impl Database { ..Default::default() }) .on_conflict(OnConflict::new().do_nothing().to_owned()) - .exec_without_returning(&tx) + .exec_without_returning(&*tx) .await?; } - tx.commit().await?; Ok(Some(NewUserResult { user_id: user.id, metrics_id: user.metrics_id.to_string(), @@ -914,7 +902,7 @@ impl Database { } pub async fn set_invite_count_for_user(&self, id: UserId, count: i32) -> Result<()> { - self.transact(|tx| async move { + self.transaction(|tx| async move { if count > 0 { user::Entity::update_many() .filter( @@ -926,7 +914,7 @@ impl Database { invite_code: ActiveValue::set(Some(random_invite_code())), ..Default::default() }) - .exec(&tx) + .exec(&*tx) .await?; } @@ -936,17 +924,16 @@ impl Database { invite_count: ActiveValue::set(count), ..Default::default() }) - .exec(&tx) + .exec(&*tx) .await?; - tx.commit().await?; Ok(()) }) .await } pub async fn get_invite_code_for_user(&self, id: UserId) -> Result> { - self.transact(|tx| async move { - match user::Entity::find_by_id(id).one(&tx).await? { + self.transaction(|tx| async move { + match user::Entity::find_by_id(id).one(&*tx).await? { Some(user) if user.invite_code.is_some() => { Ok(Some((user.invite_code.unwrap(), user.invite_count))) } @@ -957,10 +944,10 @@ impl Database { } pub async fn get_user_for_invite_code(&self, code: &str) -> Result { - self.transact(|tx| async move { + self.transaction(|tx| async move { user::Entity::find() .filter(user::Column::InviteCode.eq(code)) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| { Error::Http( @@ -978,14 +965,14 @@ impl Database { &self, user_id: UserId, ) -> Result> { - self.transact(|tx| async move { + self.transaction(|tx| async move { let pending_participant = room_participant::Entity::find() .filter( room_participant::Column::UserId .eq(user_id) .and(room_participant::Column::AnsweringConnectionId.is_null()), ) - .one(&tx) + .one(&*tx) .await?; if let Some(pending_participant) = pending_participant { @@ -1004,12 +991,12 @@ impl Database { connection_id: ConnectionId, live_kit_room: &str, ) -> Result> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let room = room::ActiveModel { live_kit_room: ActiveValue::set(live_kit_room.into()), ..Default::default() } - .insert(&tx) + .insert(&*tx) .await?; let room_id = room.id; @@ -1023,11 +1010,11 @@ impl Database { calling_connection_epoch: ActiveValue::set(self.epoch), ..Default::default() } - .insert(&tx) + .insert(&*tx) .await?; let room = self.get_room(room_id, &tx).await?; - self.commit_room_transaction(room_id, tx, room).await + Ok((room_id, room)) }) .await } @@ -1040,7 +1027,7 @@ impl Database { called_user_id: UserId, initial_project_id: Option, ) -> Result> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { room_participant::ActiveModel { room_id: ActiveValue::set(room_id), user_id: ActiveValue::set(called_user_id), @@ -1050,14 +1037,13 @@ impl Database { initial_project_id: ActiveValue::set(initial_project_id), ..Default::default() } - .insert(&tx) + .insert(&*tx) .await?; let room = self.get_room(room_id, &tx).await?; let incoming_call = Self::build_incoming_call(&room, called_user_id) .ok_or_else(|| anyhow!("failed to build incoming call"))?; - self.commit_room_transaction(room_id, tx, (room, incoming_call)) - .await + Ok((room_id, (room, incoming_call))) }) .await } @@ -1067,17 +1053,17 @@ impl Database { room_id: RoomId, called_user_id: UserId, ) -> Result> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { room_participant::Entity::delete_many() .filter( room_participant::Column::RoomId .eq(room_id) .and(room_participant::Column::UserId.eq(called_user_id)), ) - .exec(&tx) + .exec(&*tx) .await?; let room = self.get_room(room_id, &tx).await?; - self.commit_room_transaction(room_id, tx, room).await + Ok((room_id, room)) }) .await } @@ -1087,14 +1073,14 @@ impl Database { expected_room_id: Option, user_id: UserId, ) -> Result> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let participant = room_participant::Entity::find() .filter( room_participant::Column::UserId .eq(user_id) .and(room_participant::Column::AnsweringConnectionId.is_null()), ) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("could not decline call"))?; let room_id = participant.room_id; @@ -1104,11 +1090,11 @@ impl Database { } room_participant::Entity::delete(participant.into_active_model()) - .exec(&tx) + .exec(&*tx) .await?; let room = self.get_room(room_id, &tx).await?; - self.commit_room_transaction(room_id, tx, room).await + Ok((room_id, room)) }) .await } @@ -1119,7 +1105,7 @@ impl Database { calling_connection_id: ConnectionId, called_user_id: UserId, ) -> Result> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let participant = room_participant::Entity::find() .filter( room_participant::Column::UserId @@ -1130,7 +1116,7 @@ impl Database { ) .and(room_participant::Column::AnsweringConnectionId.is_null()), ) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("could not cancel call"))?; let room_id = participant.room_id; @@ -1139,11 +1125,11 @@ impl Database { } room_participant::Entity::delete(participant.into_active_model()) - .exec(&tx) + .exec(&*tx) .await?; let room = self.get_room(room_id, &tx).await?; - self.commit_room_transaction(room_id, tx, room).await + Ok((room_id, room)) }) .await } @@ -1154,7 +1140,7 @@ impl Database { user_id: UserId, connection_id: ConnectionId, ) -> Result> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let result = room_participant::Entity::update_many() .filter( room_participant::Column::RoomId @@ -1167,33 +1153,30 @@ impl Database { answering_connection_epoch: ActiveValue::set(Some(self.epoch)), ..Default::default() }) - .exec(&tx) + .exec(&*tx) .await?; if result.rows_affected == 0 { Err(anyhow!("room does not exist or was already joined"))? } else { let room = self.get_room(room_id, &tx).await?; - self.commit_room_transaction(room_id, tx, room).await + Ok((room_id, room)) } }) .await } - pub async fn leave_room( - &self, - connection_id: ConnectionId, - ) -> Result>> { - self.transact(|tx| async move { + pub async fn leave_room(&self, connection_id: ConnectionId) -> Result> { + self.room_transaction(|tx| async move { let leaving_participant = room_participant::Entity::find() .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)) - .one(&tx) + .one(&*tx) .await?; if let Some(leaving_participant) = leaving_participant { // Leave room. let room_id = leaving_participant.room_id; room_participant::Entity::delete_by_id(leaving_participant.id) - .exec(&tx) + .exec(&*tx) .await?; // Cancel pending calls initiated by the leaving user. @@ -1203,14 +1186,14 @@ impl Database { .eq(connection_id.0) .and(room_participant::Column::AnsweringConnectionId.is_null()), ) - .all(&tx) + .all(&*tx) .await?; room_participant::Entity::delete_many() .filter( room_participant::Column::Id .is_in(called_participants.iter().map(|participant| participant.id)), ) - .exec(&tx) + .exec(&*tx) .await?; let canceled_calls_to_user_ids = called_participants .into_iter() @@ -1230,12 +1213,12 @@ impl Database { ) .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0)) .into_values::<_, QueryProjectIds>() - .all(&tx) + .all(&*tx) .await?; let mut left_projects = HashMap::default(); let mut collaborators = project_collaborator::Entity::find() .filter(project_collaborator::Column::ProjectId.is_in(project_ids)) - .stream(&tx) + .stream(&*tx) .await?; while let Some(collaborator) = collaborators.next().await { let collaborator = collaborator?; @@ -1266,7 +1249,7 @@ impl Database { // Leave projects. project_collaborator::Entity::delete_many() .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0)) - .exec(&tx) + .exec(&*tx) .await?; // Unshare projects. @@ -1276,33 +1259,27 @@ impl Database { .eq(room_id) .and(project::Column::HostConnectionId.eq(connection_id.0)), ) - .exec(&tx) + .exec(&*tx) .await?; let room = self.get_room(room_id, &tx).await?; if room.participants.is_empty() { - room::Entity::delete_by_id(room_id).exec(&tx).await?; + room::Entity::delete_by_id(room_id).exec(&*tx).await?; } - let left_room = self - .commit_room_transaction( - room_id, - tx, - LeftRoom { - room, - left_projects, - canceled_calls_to_user_ids, - }, - ) - .await?; + let left_room = LeftRoom { + room, + left_projects, + canceled_calls_to_user_ids, + }; if left_room.room.participants.is_empty() { self.rooms.remove(&room_id); } - Ok(Some(left_room)) + Ok((room_id, left_room)) } else { - Ok(None) + Err(anyhow!("could not leave room"))? } }) .await @@ -1314,8 +1291,8 @@ impl Database { connection_id: ConnectionId, location: proto::ParticipantLocation, ) -> Result> { - self.transact(|tx| async { - let mut tx = tx; + self.room_transaction(|tx| async { + let tx = tx; let location_kind; let location_project_id; match location @@ -1348,12 +1325,12 @@ impl Database { location_project_id: ActiveValue::set(location_project_id), ..Default::default() }) - .exec(&tx) + .exec(&*tx) .await?; if result.rows_affected == 1 { - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, room).await + let room = self.get_room(room_id, &tx).await?; + Ok((room_id, room)) } else { Err(anyhow!("could not update room participant location"))? } @@ -1478,22 +1455,6 @@ impl Database { }) } - async fn commit_room_transaction( - &self, - room_id: RoomId, - tx: DatabaseTransaction, - data: T, - ) -> Result> { - let lock = self.rooms.entry(room_id).or_default().clone(); - let _guard = lock.lock_owned().await; - tx.commit().await?; - Ok(RoomGuard { - data, - _guard, - _not_send: PhantomData, - }) - } - // projects pub async fn project_count_excluding_admins(&self) -> Result { @@ -1502,14 +1463,14 @@ impl Database { Count, } - self.transact(|tx| async move { + self.transaction(|tx| async move { Ok(project::Entity::find() .select_only() .column_as(project::Column::Id.count(), QueryAs::Count) .inner_join(user::Entity) .filter(user::Column::Admin.eq(false)) .into_values::<_, QueryAs>() - .one(&tx) + .one(&*tx) .await? .unwrap_or(0) as usize) }) @@ -1522,10 +1483,10 @@ impl Database { connection_id: ConnectionId, worktrees: &[proto::WorktreeMetadata], ) -> Result> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let participant = room_participant::Entity::find() .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("could not find participant"))?; if participant.room_id != room_id { @@ -1539,7 +1500,7 @@ impl Database { host_connection_epoch: ActiveValue::set(self.epoch), ..Default::default() } - .insert(&tx) + .insert(&*tx) .await?; if !worktrees.is_empty() { @@ -1554,7 +1515,7 @@ impl Database { is_complete: ActiveValue::set(false), } })) - .exec(&tx) + .exec(&*tx) .await?; } @@ -1567,12 +1528,11 @@ impl Database { is_host: ActiveValue::set(true), ..Default::default() } - .insert(&tx) + .insert(&*tx) .await?; let room = self.get_room(room_id, &tx).await?; - self.commit_room_transaction(room_id, tx, (project.id, room)) - .await + Ok((room_id, (project.id, room))) }) .await } @@ -1582,21 +1542,20 @@ impl Database { project_id: ProjectId, connection_id: ConnectionId, ) -> Result)>> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let guest_connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; let project = project::Entity::find_by_id(project_id) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("project not found"))?; if project.host_connection_id == connection_id.0 as i32 { let room_id = project.room_id; project::Entity::delete(project.into_active_model()) - .exec(&tx) + .exec(&*tx) .await?; let room = self.get_room(room_id, &tx).await?; - self.commit_room_transaction(room_id, tx, (room, guest_connection_ids)) - .await + Ok((room_id, (room, guest_connection_ids))) } else { Err(anyhow!("cannot unshare a project hosted by another user"))? } @@ -1610,10 +1569,10 @@ impl Database { connection_id: ConnectionId, worktrees: &[proto::WorktreeMetadata], ) -> Result)>> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let project = project::Entity::find_by_id(project_id) .filter(project::Column::HostConnectionId.eq(connection_id.0)) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("no such project"))?; @@ -1634,7 +1593,7 @@ impl Database { .update_column(worktree::Column::RootName) .to_owned(), ) - .exec(&tx) + .exec(&*tx) .await?; } @@ -1645,13 +1604,12 @@ impl Database { .is_not_in(worktrees.iter().map(|worktree| worktree.id as i64)), ), ) - .exec(&tx) + .exec(&*tx) .await?; let guest_connection_ids = self.project_guest_connection_ids(project.id, &tx).await?; let room = self.get_room(project.room_id, &tx).await?; - self.commit_room_transaction(project.room_id, tx, (room, guest_connection_ids)) - .await + Ok((project.room_id, (room, guest_connection_ids))) }) .await } @@ -1661,14 +1619,14 @@ impl Database { update: &proto::UpdateWorktree, connection_id: ConnectionId, ) -> Result>> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let project_id = ProjectId::from_proto(update.project_id); let worktree_id = update.worktree_id as i64; // Ensure the update comes from the host. let project = project::Entity::find_by_id(project_id) .filter(project::Column::HostConnectionId.eq(connection_id.0)) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("no such project"))?; let room_id = project.room_id; @@ -1683,7 +1641,7 @@ impl Database { abs_path: ActiveValue::set(update.abs_path.clone()), ..Default::default() }) - .exec(&tx) + .exec(&*tx) .await?; if !update.updated_entries.is_empty() { @@ -1719,7 +1677,7 @@ impl Database { ]) .to_owned(), ) - .exec(&tx) + .exec(&*tx) .await?; } @@ -1734,13 +1692,12 @@ impl Database { .is_in(update.removed_entries.iter().map(|id| *id as i64)), ), ) - .exec(&tx) + .exec(&*tx) .await?; } let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; - self.commit_room_transaction(room_id, tx, connection_ids) - .await + Ok((room_id, connection_ids)) }) .await } @@ -1750,7 +1707,7 @@ impl Database { update: &proto::UpdateDiagnosticSummary, connection_id: ConnectionId, ) -> Result>> { - self.transact(|tx| async { + self.room_transaction(|tx| async move { let project_id = ProjectId::from_proto(update.project_id); let worktree_id = update.worktree_id as i64; let summary = update @@ -1760,7 +1717,7 @@ impl Database { // Ensure the update comes from the host. let project = project::Entity::find_by_id(project_id) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("no such project"))?; if project.host_connection_id != connection_id.0 as i32 { @@ -1790,12 +1747,11 @@ impl Database { ]) .to_owned(), ) - .exec(&tx) + .exec(&*tx) .await?; let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; - self.commit_room_transaction(project.room_id, tx, connection_ids) - .await + Ok((project.room_id, connection_ids)) }) .await } @@ -1805,7 +1761,7 @@ impl Database { update: &proto::StartLanguageServer, connection_id: ConnectionId, ) -> Result>> { - self.transact(|tx| async { + self.room_transaction(|tx| async move { let project_id = ProjectId::from_proto(update.project_id); let server = update .server @@ -1814,7 +1770,7 @@ impl Database { // Ensure the update comes from the host. let project = project::Entity::find_by_id(project_id) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("no such project"))?; if project.host_connection_id != connection_id.0 as i32 { @@ -1836,12 +1792,11 @@ impl Database { .update_column(language_server::Column::Name) .to_owned(), ) - .exec(&tx) + .exec(&*tx) .await?; let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; - self.commit_room_transaction(project.room_id, tx, connection_ids) - .await + Ok((project.room_id, connection_ids)) }) .await } @@ -1851,15 +1806,15 @@ impl Database { project_id: ProjectId, connection_id: ConnectionId, ) -> Result> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let participant = room_participant::Entity::find() .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("must join a room first"))?; let project = project::Entity::find_by_id(project_id) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("no such project"))?; if project.room_id != participant.room_id { @@ -1868,7 +1823,7 @@ impl Database { let mut collaborators = project .find_related(project_collaborator::Entity) - .all(&tx) + .all(&*tx) .await?; let replica_ids = collaborators .iter() @@ -1887,11 +1842,11 @@ impl Database { is_host: ActiveValue::set(false), ..Default::default() } - .insert(&tx) + .insert(&*tx) .await?; collaborators.push(new_collaborator); - let db_worktrees = project.find_related(worktree::Entity).all(&tx).await?; + let db_worktrees = project.find_related(worktree::Entity).all(&*tx).await?; let mut worktrees = db_worktrees .into_iter() .map(|db_worktree| { @@ -1915,7 +1870,7 @@ impl Database { { let mut db_entries = worktree_entry::Entity::find() .filter(worktree_entry::Column::ProjectId.eq(project_id)) - .stream(&tx) + .stream(&*tx) .await?; while let Some(db_entry) = db_entries.next().await { let db_entry = db_entry?; @@ -1940,7 +1895,7 @@ impl Database { { let mut db_summaries = worktree_diagnostic_summary::Entity::find() .filter(worktree_diagnostic_summary::Column::ProjectId.eq(project_id)) - .stream(&tx) + .stream(&*tx) .await?; while let Some(db_summary) = db_summaries.next().await { let db_summary = db_summary?; @@ -1960,28 +1915,22 @@ impl Database { // Populate language servers. let language_servers = project .find_related(language_server::Entity) - .all(&tx) + .all(&*tx) .await?; - self.commit_room_transaction( - project.room_id, - tx, - ( - Project { - collaborators, - worktrees, - language_servers: language_servers - .into_iter() - .map(|language_server| proto::LanguageServer { - id: language_server.id as u64, - name: language_server.name, - }) - .collect(), - }, - replica_id as ReplicaId, - ), - ) - .await + let room_id = project.room_id; + let project = Project { + collaborators, + worktrees, + language_servers: language_servers + .into_iter() + .map(|language_server| proto::LanguageServer { + id: language_server.id as u64, + name: language_server.name, + }) + .collect(), + }; + Ok((room_id, (project, replica_id as ReplicaId))) }) .await } @@ -1991,43 +1940,39 @@ impl Database { project_id: ProjectId, connection_id: ConnectionId, ) -> Result> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let result = project_collaborator::Entity::delete_many() .filter( project_collaborator::Column::ProjectId .eq(project_id) .and(project_collaborator::Column::ConnectionId.eq(connection_id.0)), ) - .exec(&tx) + .exec(&*tx) .await?; if result.rows_affected == 0 { Err(anyhow!("not a collaborator on this project"))?; } let project = project::Entity::find_by_id(project_id) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("no such project"))?; let collaborators = project .find_related(project_collaborator::Entity) - .all(&tx) + .all(&*tx) .await?; let connection_ids = collaborators .into_iter() .map(|collaborator| ConnectionId(collaborator.connection_id as u32)) .collect(); - self.commit_room_transaction( - project.room_id, - tx, - LeftProject { - id: project_id, - host_user_id: project.host_user_id, - host_connection_id: ConnectionId(project.host_connection_id as u32), - connection_ids, - }, - ) - .await + let left_project = LeftProject { + id: project_id, + host_user_id: project.host_user_id, + host_connection_id: ConnectionId(project.host_connection_id as u32), + connection_ids, + }; + Ok((project.room_id, left_project)) }) .await } @@ -2037,10 +1982,10 @@ impl Database { project_id: ProjectId, connection_id: ConnectionId, ) -> Result> { - self.transact(|tx| async move { + self.transaction(|tx| async move { let collaborators = project_collaborator::Entity::find() .filter(project_collaborator::Column::ProjectId.eq(project_id)) - .all(&tx) + .all(&*tx) .await?; if collaborators @@ -2060,7 +2005,7 @@ impl Database { project_id: ProjectId, connection_id: ConnectionId, ) -> Result> { - self.transact(|tx| async move { + self.transaction(|tx| async move { #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] enum QueryAs { ConnectionId, @@ -2074,7 +2019,7 @@ impl Database { ) .filter(project_collaborator::Column::ProjectId.eq(project_id)) .into_values::() - .stream(&tx) + .stream(&*tx) .await?; let mut connection_ids = HashSet::default(); @@ -2131,7 +2076,7 @@ impl Database { access_token_hash: &str, max_access_token_count: usize, ) -> Result<()> { - self.transact(|tx| async { + self.transaction(|tx| async { let tx = tx; access_token::ActiveModel { @@ -2139,7 +2084,7 @@ impl Database { hash: ActiveValue::set(access_token_hash.into()), ..Default::default() } - .insert(&tx) + .insert(&*tx) .await?; access_token::Entity::delete_many() @@ -2155,9 +2100,8 @@ impl Database { .to_owned(), ), ) - .exec(&tx) + .exec(&*tx) .await?; - tx.commit().await?; Ok(()) }) .await @@ -2169,22 +2113,22 @@ impl Database { Hash, } - self.transact(|tx| async move { + self.transaction(|tx| async move { Ok(access_token::Entity::find() .select_only() .column(access_token::Column::Hash) .filter(access_token::Column::UserId.eq(user_id)) .order_by_desc(access_token::Column::Id) .into_values::<_, QueryAs>() - .all(&tx) + .all(&*tx) .await?) }) .await } - async fn transact(&self, f: F) -> Result + async fn transaction(&self, f: F) -> Result where - F: Send + Fn(DatabaseTransaction) -> Fut, + F: Send + Fn(TransactionHandle) -> Fut, Fut: Send + Future>, { let body = async { @@ -2200,22 +2144,32 @@ impl Database { .await?; } - match f(tx).await { - Ok(result) => return Ok(result), - Err(error) => match error { - Error::Database( - DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) - | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), - ) if error - .as_database_error() - .and_then(|error| error.code()) - .as_deref() - == Some("40001") => - { - // Retry (don't break the loop) + let mut tx = Arc::new(Some(tx)); + let result = f(TransactionHandle(tx.clone())).await; + let tx = Arc::get_mut(&mut tx).unwrap().take().unwrap(); + + match result { + Ok(result) => { + tx.commit().await?; + return Ok(result); + } + Err(error) => { + tx.rollback().await?; + match error { + Error::Database( + DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) + | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), + ) if error + .as_database_error() + .and_then(|error| error.code()) + .as_deref() + == Some("40001") => + { + // Retry (don't break the loop) + } + error @ _ => return Err(error), } - error @ _ => return Err(error), - }, + } } } }; @@ -2234,6 +2188,85 @@ impl Database { body.await } } + + async fn room_transaction(&self, f: F) -> Result> + where + F: Send + Fn(TransactionHandle) -> Fut, + Fut: Send + Future>, + { + let body = async { + loop { + let tx = self.pool.begin().await?; + + // In Postgres, serializable transactions are opt-in + if let DatabaseBackend::Postgres = self.pool.get_database_backend() { + tx.execute(Statement::from_string( + DatabaseBackend::Postgres, + "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(), + )) + .await?; + } + + let mut tx = Arc::new(Some(tx)); + let result = f(TransactionHandle(tx.clone())).await; + let tx = Arc::get_mut(&mut tx).unwrap().take().unwrap(); + + match result { + Ok((room_id, data)) => { + let lock = self.rooms.entry(room_id).or_default().clone(); + let _guard = lock.lock_owned().await; + tx.commit().await?; + return Ok(RoomGuard { + data, + _guard, + _not_send: PhantomData, + }); + } + Err(error) => { + tx.rollback().await?; + match error { + Error::Database( + DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) + | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), + ) if error + .as_database_error() + .and_then(|error| error.code()) + .as_deref() + == Some("40001") => + { + // Retry (don't break the loop) + } + error @ _ => return Err(error), + } + } + } + } + }; + + #[cfg(test)] + { + if let Some(background) = self.background.as_ref() { + background.simulate_random_delay().await; + } + + self.runtime.as_ref().unwrap().block_on(body) + } + + #[cfg(not(test))] + { + body.await + } + } +} + +struct TransactionHandle(Arc>); + +impl Deref for TransactionHandle { + type Target = DatabaseTransaction; + + fn deref(&self) -> &Self::Target { + self.0.as_ref().as_ref().unwrap() + } } pub struct RoomGuard { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 9d3917a417..7f404feffe 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1854,9 +1854,7 @@ async fn leave_room_for_session(session: &Session) -> Result<()> { let live_kit_room; let delete_live_kit_room; { - let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? else { - return Err(anyhow!("no room to leave"))?; - }; + let mut left_room = session.db().await.leave_room(session.connection_id).await?; contacts_to_update.insert(session.user_id); for project in left_room.left_projects.values() {