Return project collaborators and connection IDs in a RoomGuard

This commit is contained in:
Antonio Scandurra 2022-12-05 18:37:01 +01:00
parent be3fb1e985
commit 5443d9cffe
2 changed files with 57 additions and 44 deletions

View file

@ -1981,8 +1981,12 @@ impl Database {
&self,
project_id: ProjectId,
connection_id: ConnectionId,
) -> Result<Vec<project_collaborator::Model>> {
self.transaction(|tx| async move {
) -> Result<RoomGuard<Vec<project_collaborator::Model>>> {
self.room_transaction(|tx| async move {
let project = project::Entity::find_by_id(project_id)
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("no such project"))?;
let collaborators = project_collaborator::Entity::find()
.filter(project_collaborator::Column::ProjectId.eq(project_id))
.all(&*tx)
@ -1992,7 +1996,7 @@ impl Database {
.iter()
.any(|collaborator| collaborator.connection_id == connection_id.0 as i32)
{
Ok(collaborators)
Ok((project.room_id, collaborators))
} else {
Err(anyhow!("no such project"))?
}
@ -2004,13 +2008,17 @@ impl Database {
&self,
project_id: ProjectId,
connection_id: ConnectionId,
) -> Result<HashSet<ConnectionId>> {
self.transaction(|tx| async move {
) -> Result<RoomGuard<HashSet<ConnectionId>>> {
self.room_transaction(|tx| async move {
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryAs {
ConnectionId,
}
let project = project::Entity::find_by_id(project_id)
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("no such project"))?;
let mut db_connection_ids = project_collaborator::Entity::find()
.select_only()
.column_as(
@ -2028,7 +2036,7 @@ impl Database {
}
if connection_ids.contains(&connection_id) {
Ok(connection_ids)
Ok((project.room_id, connection_ids))
} else {
Err(anyhow!("no such project"))?
}

View file

@ -1245,7 +1245,7 @@ async fn update_language_server(
.await?;
broadcast(
session.connection_id,
project_connection_ids,
project_connection_ids.iter().copied(),
|connection_id| {
session
.peer
@ -1264,23 +1264,24 @@ where
T: EntityMessage + RequestMessage,
{
let project_id = ProjectId::from_proto(request.remote_entity_id());
let collaborators = session
.db()
.await
.project_collaborators(project_id, session.connection_id)
.await?;
let host = collaborators
.iter()
.find(|collaborator| collaborator.is_host)
.ok_or_else(|| anyhow!("host not found"))?;
let host_connection_id = {
let collaborators = session
.db()
.await
.project_collaborators(project_id, session.connection_id)
.await?;
ConnectionId(
collaborators
.iter()
.find(|collaborator| collaborator.is_host)
.ok_or_else(|| anyhow!("host not found"))?
.connection_id as u32,
)
};
let payload = session
.peer
.forward_request(
session.connection_id,
ConnectionId(host.connection_id as u32),
request,
)
.forward_request(session.connection_id, host_connection_id, request)
.await?;
response.send(payload)?;
@ -1293,16 +1294,18 @@ async fn save_buffer(
session: Session,
) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id);
let collaborators = session
.db()
.await
.project_collaborators(project_id, session.connection_id)
.await?;
let host = collaborators
.into_iter()
.find(|collaborator| collaborator.is_host)
.ok_or_else(|| anyhow!("host not found"))?;
let host_connection_id = ConnectionId(host.connection_id as u32);
let host_connection_id = {
let collaborators = session
.db()
.await
.project_collaborators(project_id, session.connection_id)
.await?;
let host = collaborators
.iter()
.find(|collaborator| collaborator.is_host)
.ok_or_else(|| anyhow!("host not found"))?;
ConnectionId(host.connection_id as u32)
};
let response_payload = session
.peer
.forward_request(session.connection_id, host_connection_id, request.clone())
@ -1316,7 +1319,7 @@ async fn save_buffer(
collaborators
.retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
let project_connection_ids = collaborators
.into_iter()
.iter()
.map(|collaborator| ConnectionId(collaborator.connection_id as u32));
broadcast(host_connection_id, project_connection_ids, |conn_id| {
session
@ -1353,7 +1356,7 @@ async fn update_buffer(
broadcast(
session.connection_id,
project_connection_ids,
project_connection_ids.iter().copied(),
|connection_id| {
session
.peer
@ -1374,7 +1377,7 @@ async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session)
broadcast(
session.connection_id,
project_connection_ids,
project_connection_ids.iter().copied(),
|connection_id| {
session
.peer
@ -1393,7 +1396,7 @@ async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Re
.await?;
broadcast(
session.connection_id,
project_connection_ids,
project_connection_ids.iter().copied(),
|connection_id| {
session
.peer
@ -1412,7 +1415,7 @@ async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<(
.await?;
broadcast(
session.connection_id,
project_connection_ids,
project_connection_ids.iter().copied(),
|connection_id| {
session
.peer
@ -1430,14 +1433,16 @@ async fn follow(
let project_id = ProjectId::from_proto(request.project_id);
let leader_id = ConnectionId(request.leader_id);
let follower_id = session.connection_id;
let project_connection_ids = session
.db()
.await
.project_connection_ids(project_id, session.connection_id)
.await?;
{
let project_connection_ids = session
.db()
.await
.project_connection_ids(project_id, session.connection_id)
.await?;
if !project_connection_ids.contains(&leader_id) {
Err(anyhow!("no such peer"))?;
if !project_connection_ids.contains(&leader_id) {
Err(anyhow!("no such peer"))?;
}
}
let mut response_payload = session
@ -1691,7 +1696,7 @@ async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> R
.await?;
broadcast(
session.connection_id,
project_connection_ids,
project_connection_ids.iter().copied(),
|connection_id| {
session
.peer