diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 3066260bc4..bc074e30df 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -2131,47 +2131,30 @@ impl Database { F: Send + Fn(TransactionHandle) -> Fut, Fut: Send + Future>, { - let body = async { - loop { - let (tx, result) = self.with_transaction(&f).await?; - match result { - Ok(result) => { - tx.commit().await?; - return Ok(result); - } - Err(error) => { - tx.rollback().await?; - match error { - Error::Database( - DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) - | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), - ) if error - .as_database_error() - .and_then(|error| error.code()) - .as_deref() - == Some("40001") => - { + loop { + let (tx, result) = self.run(self.with_transaction(&f)).await?; + match result { + Ok(result) => { + match self.run(async move { Ok(tx.commit().await?) }).await { + Ok(()) => return Ok(result), + Err(error) => { + if is_serialization_error(&error) { // Retry (don't break the loop) + } else { + return Err(error); } - error @ _ => return Err(error), } } } + Err(error) => { + self.run(tx.rollback()).await?; + if is_serialization_error(&error) { + // Retry (don't break the loop) + } else { + return Err(error); + } + } } - }; - - #[cfg(test)] - { - if let Some(background) = self.background.as_ref() { - background.simulate_random_delay().await; - } - - self.runtime.as_ref().unwrap().block_on(body) - } - - #[cfg(not(test))] - { - body.await } } @@ -2180,53 +2163,38 @@ impl Database { F: Send + Fn(TransactionHandle) -> Fut, Fut: Send + Future>, { - let body = async { - loop { - let (tx, result) = self.with_transaction(&f).await?; - match result { - Ok((room_id, data)) => { - let lock = self.rooms.entry(room_id).or_default().clone(); - let _guard = lock.lock_owned().await; - tx.commit().await?; - return Ok(RoomGuard { - data, - _guard, - _not_send: PhantomData, - }); - } - Err(error) => { - tx.rollback().await?; - match error { - Error::Database( - DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) - | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), - ) if error - .as_database_error() - .and_then(|error| error.code()) - .as_deref() - == Some("40001") => - { + loop { + let (tx, result) = self.run(self.with_transaction(&f)).await?; + match result { + Ok((room_id, data)) => { + let lock = self.rooms.entry(room_id).or_default().clone(); + let _guard = lock.lock_owned().await; + match self.run(async move { Ok(tx.commit().await?) }).await { + Ok(()) => { + return Ok(RoomGuard { + data, + _guard, + _not_send: PhantomData, + }); + } + Err(error) => { + if is_serialization_error(&error) { // Retry (don't break the loop) + } else { + return Err(error); } - error @ _ => return Err(error), } } } + Err(error) => { + self.run(tx.rollback()).await?; + if is_serialization_error(&error) { + // Retry (don't break the loop) + } else { + return Err(error); + } + } } - }; - - #[cfg(test)] - { - if let Some(background) = self.background.as_ref() { - background.simulate_random_delay().await; - } - - self.runtime.as_ref().unwrap().block_on(body) - } - - #[cfg(not(test))] - { - body.await } } @@ -2254,6 +2222,49 @@ impl Database { Ok((tx, result)) } + + async fn run(&self, future: F) -> T + where + F: Future, + { + #[cfg(test)] + { + if let Some(background) = self.background.as_ref() { + background.simulate_random_delay().await; + } + + let result = self.runtime.as_ref().unwrap().block_on(future); + + if let Some(background) = self.background.as_ref() { + background.simulate_random_delay().await; + } + + result + } + + #[cfg(not(test))] + { + future.await + } + } +} + +fn is_serialization_error(error: &Error) -> bool { + const SERIALIZATION_FAILURE_CODE: &'static str = "40001"; + match error { + Error::Database( + DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) + | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), + ) if error + .as_database_error() + .and_then(|error| error.code()) + .as_deref() + == Some(SERIALIZATION_FAILURE_CODE) => + { + true + } + _ => false, + } } struct TransactionHandle(Arc>);