Use a synchronous mutex for ConnectionPool

This commit is contained in:
Antonio Scandurra 2022-12-13 13:50:51 +01:00
parent a594ba8f8a
commit d4c8fa3090
3 changed files with 20 additions and 20 deletions

View file

@ -6062,7 +6062,6 @@ async fn test_random_collaboration(
let user_connection_ids = server
.connection_pool
.lock()
.await
.user_connection_ids(removed_guest_id)
.collect::<Vec<_>>();
assert_eq!(user_connection_ids.len(), 1);
@ -6083,7 +6082,7 @@ async fn test_random_collaboration(
}
for user_id in &user_ids {
let contacts = server.app_state.db.get_contacts(*user_id).await.unwrap();
let pool = server.connection_pool.lock().await;
let pool = server.connection_pool.lock();
for contact in contacts {
if let db::Contact::Accepted { user_id, .. } = contact {
if pool.is_user_online(user_id) {
@ -6112,7 +6111,6 @@ async fn test_random_collaboration(
let user_connection_ids = server
.connection_pool
.lock()
.await
.user_connection_ids(user_id)
.collect::<Vec<_>>();
assert_eq!(user_connection_ids.len(), 1);

View file

@ -53,7 +53,7 @@ use std::{
},
time::Duration,
};
use tokio::sync::{watch, Mutex, MutexGuard};
use tokio::sync::watch;
use tower::ServiceBuilder;
use tracing::{info_span, instrument, Instrument};
@ -90,14 +90,14 @@ impl<R: RequestMessage> Response<R> {
struct Session {
user_id: UserId,
connection_id: ConnectionId,
db: Arc<Mutex<DbHandle>>,
db: Arc<tokio::sync::Mutex<DbHandle>>,
peer: Arc<Peer>,
connection_pool: Arc<Mutex<ConnectionPool>>,
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
}
impl Session {
async fn db(&self) -> MutexGuard<DbHandle> {
async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
#[cfg(test)]
tokio::task::yield_now().await;
let guard = self.db.lock().await;
@ -109,9 +109,7 @@ impl Session {
async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
#[cfg(test)]
tokio::task::yield_now().await;
let guard = self.connection_pool.lock().await;
#[cfg(test)]
tokio::task::yield_now().await;
let guard = self.connection_pool.lock();
ConnectionPoolGuard {
guard,
_not_send: PhantomData,
@ -140,7 +138,7 @@ impl Deref for DbHandle {
pub struct Server {
peer: Arc<Peer>,
pub(crate) connection_pool: Arc<Mutex<ConnectionPool>>,
pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
app_state: Arc<AppState>,
executor: Executor,
handlers: HashMap<TypeId, MessageHandler>,
@ -148,7 +146,7 @@ pub struct Server {
}
pub(crate) struct ConnectionPoolGuard<'a> {
guard: MutexGuard<'a, ConnectionPool>,
guard: parking_lot::MutexGuard<'a, ConnectionPool>,
_not_send: PhantomData<Rc<()>>,
}
@ -268,7 +266,7 @@ impl Server {
}
{
let pool = pool.lock().await;
let pool = pool.lock();
for canceled_user_id in canceled_calls_to_user_ids {
for connection_id in pool.user_connection_ids(canceled_user_id) {
peer.send(
@ -286,7 +284,7 @@ impl Server {
let busy = db.is_user_busy(user_id).await.trace_err();
let contacts = db.get_contacts(user_id).await.trace_err();
if let Some((busy, contacts)) = busy.zip(contacts) {
let pool = pool.lock().await;
let pool = pool.lock();
let updated_contact = contact_for_user(user_id, false, busy, &pool);
for contact in contacts {
if let db::Contact::Accepted {
@ -456,7 +454,7 @@ impl Server {
).await?;
{
let mut pool = this.connection_pool.lock().await;
let mut pool = this.connection_pool.lock();
pool.add_connection(connection_id, user_id, user.admin);
this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
@ -475,7 +473,7 @@ impl Server {
let session = Session {
user_id,
connection_id,
db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))),
db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
peer: this.peer.clone(),
connection_pool: this.connection_pool.clone(),
live_kit_client: this.app_state.live_kit_client.clone()
@ -550,7 +548,7 @@ impl Server {
) -> Result<()> {
if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
if let Some(code) = &user.invite_code {
let pool = self.connection_pool.lock().await;
let pool = self.connection_pool.lock();
let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
for connection_id in pool.user_connection_ids(inviter_id) {
self.peer.send(
@ -576,7 +574,7 @@ impl Server {
pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
if let Some(invite_code) = &user.invite_code {
let pool = self.connection_pool.lock().await;
let pool = self.connection_pool.lock();
for connection_id in pool.user_connection_ids(user_id) {
self.peer.send(
connection_id,
@ -597,7 +595,7 @@ impl Server {
pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
ServerSnapshot {
connection_pool: ConnectionPoolGuard {
guard: self.connection_pool.lock().await,
guard: self.connection_pool.lock(),
_not_send: PhantomData,
},
peer: &self.peer,
@ -718,7 +716,6 @@ pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result
let connections = server
.connection_pool
.lock()
.await
.connections()
.filter(|connection| !connection.admin)
.count();

View file

@ -23,6 +23,11 @@ pub struct Connection {
}
impl ConnectionPool {
pub fn reset(&mut self) {
self.connections.clear();
self.connected_users.clear();
}
#[instrument(skip(self))]
pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) {
self.connections