mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-11 21:13:02 +00:00
Return project collaborators and connection IDs in a RoomGuard
This commit is contained in:
parent
be3fb1e985
commit
5443d9cffe
2 changed files with 57 additions and 44 deletions
|
@ -1981,8 +1981,12 @@ impl Database {
|
|||
&self,
|
||||
project_id: ProjectId,
|
||||
connection_id: ConnectionId,
|
||||
) -> Result<Vec<project_collaborator::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
) -> Result<RoomGuard<Vec<project_collaborator::Model>>> {
|
||||
self.room_transaction(|tx| async move {
|
||||
let project = project::Entity::find_by_id(project_id)
|
||||
.one(&*tx)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("no such project"))?;
|
||||
let collaborators = project_collaborator::Entity::find()
|
||||
.filter(project_collaborator::Column::ProjectId.eq(project_id))
|
||||
.all(&*tx)
|
||||
|
@ -1992,7 +1996,7 @@ impl Database {
|
|||
.iter()
|
||||
.any(|collaborator| collaborator.connection_id == connection_id.0 as i32)
|
||||
{
|
||||
Ok(collaborators)
|
||||
Ok((project.room_id, collaborators))
|
||||
} else {
|
||||
Err(anyhow!("no such project"))?
|
||||
}
|
||||
|
@ -2004,13 +2008,17 @@ impl Database {
|
|||
&self,
|
||||
project_id: ProjectId,
|
||||
connection_id: ConnectionId,
|
||||
) -> Result<HashSet<ConnectionId>> {
|
||||
self.transaction(|tx| async move {
|
||||
) -> Result<RoomGuard<HashSet<ConnectionId>>> {
|
||||
self.room_transaction(|tx| async move {
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
|
||||
enum QueryAs {
|
||||
ConnectionId,
|
||||
}
|
||||
|
||||
let project = project::Entity::find_by_id(project_id)
|
||||
.one(&*tx)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("no such project"))?;
|
||||
let mut db_connection_ids = project_collaborator::Entity::find()
|
||||
.select_only()
|
||||
.column_as(
|
||||
|
@ -2028,7 +2036,7 @@ impl Database {
|
|||
}
|
||||
|
||||
if connection_ids.contains(&connection_id) {
|
||||
Ok(connection_ids)
|
||||
Ok((project.room_id, connection_ids))
|
||||
} else {
|
||||
Err(anyhow!("no such project"))?
|
||||
}
|
||||
|
|
|
@ -1245,7 +1245,7 @@ async fn update_language_server(
|
|||
.await?;
|
||||
broadcast(
|
||||
session.connection_id,
|
||||
project_connection_ids,
|
||||
project_connection_ids.iter().copied(),
|
||||
|connection_id| {
|
||||
session
|
||||
.peer
|
||||
|
@ -1264,23 +1264,24 @@ where
|
|||
T: EntityMessage + RequestMessage,
|
||||
{
|
||||
let project_id = ProjectId::from_proto(request.remote_entity_id());
|
||||
let collaborators = session
|
||||
.db()
|
||||
.await
|
||||
.project_collaborators(project_id, session.connection_id)
|
||||
.await?;
|
||||
let host = collaborators
|
||||
.iter()
|
||||
.find(|collaborator| collaborator.is_host)
|
||||
.ok_or_else(|| anyhow!("host not found"))?;
|
||||
let host_connection_id = {
|
||||
let collaborators = session
|
||||
.db()
|
||||
.await
|
||||
.project_collaborators(project_id, session.connection_id)
|
||||
.await?;
|
||||
ConnectionId(
|
||||
collaborators
|
||||
.iter()
|
||||
.find(|collaborator| collaborator.is_host)
|
||||
.ok_or_else(|| anyhow!("host not found"))?
|
||||
.connection_id as u32,
|
||||
)
|
||||
};
|
||||
|
||||
let payload = session
|
||||
.peer
|
||||
.forward_request(
|
||||
session.connection_id,
|
||||
ConnectionId(host.connection_id as u32),
|
||||
request,
|
||||
)
|
||||
.forward_request(session.connection_id, host_connection_id, request)
|
||||
.await?;
|
||||
|
||||
response.send(payload)?;
|
||||
|
@ -1293,16 +1294,18 @@ async fn save_buffer(
|
|||
session: Session,
|
||||
) -> Result<()> {
|
||||
let project_id = ProjectId::from_proto(request.project_id);
|
||||
let collaborators = session
|
||||
.db()
|
||||
.await
|
||||
.project_collaborators(project_id, session.connection_id)
|
||||
.await?;
|
||||
let host = collaborators
|
||||
.into_iter()
|
||||
.find(|collaborator| collaborator.is_host)
|
||||
.ok_or_else(|| anyhow!("host not found"))?;
|
||||
let host_connection_id = ConnectionId(host.connection_id as u32);
|
||||
let host_connection_id = {
|
||||
let collaborators = session
|
||||
.db()
|
||||
.await
|
||||
.project_collaborators(project_id, session.connection_id)
|
||||
.await?;
|
||||
let host = collaborators
|
||||
.iter()
|
||||
.find(|collaborator| collaborator.is_host)
|
||||
.ok_or_else(|| anyhow!("host not found"))?;
|
||||
ConnectionId(host.connection_id as u32)
|
||||
};
|
||||
let response_payload = session
|
||||
.peer
|
||||
.forward_request(session.connection_id, host_connection_id, request.clone())
|
||||
|
@ -1316,7 +1319,7 @@ async fn save_buffer(
|
|||
collaborators
|
||||
.retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
|
||||
let project_connection_ids = collaborators
|
||||
.into_iter()
|
||||
.iter()
|
||||
.map(|collaborator| ConnectionId(collaborator.connection_id as u32));
|
||||
broadcast(host_connection_id, project_connection_ids, |conn_id| {
|
||||
session
|
||||
|
@ -1353,7 +1356,7 @@ async fn update_buffer(
|
|||
|
||||
broadcast(
|
||||
session.connection_id,
|
||||
project_connection_ids,
|
||||
project_connection_ids.iter().copied(),
|
||||
|connection_id| {
|
||||
session
|
||||
.peer
|
||||
|
@ -1374,7 +1377,7 @@ async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session)
|
|||
|
||||
broadcast(
|
||||
session.connection_id,
|
||||
project_connection_ids,
|
||||
project_connection_ids.iter().copied(),
|
||||
|connection_id| {
|
||||
session
|
||||
.peer
|
||||
|
@ -1393,7 +1396,7 @@ async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Re
|
|||
.await?;
|
||||
broadcast(
|
||||
session.connection_id,
|
||||
project_connection_ids,
|
||||
project_connection_ids.iter().copied(),
|
||||
|connection_id| {
|
||||
session
|
||||
.peer
|
||||
|
@ -1412,7 +1415,7 @@ async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<(
|
|||
.await?;
|
||||
broadcast(
|
||||
session.connection_id,
|
||||
project_connection_ids,
|
||||
project_connection_ids.iter().copied(),
|
||||
|connection_id| {
|
||||
session
|
||||
.peer
|
||||
|
@ -1430,14 +1433,16 @@ async fn follow(
|
|||
let project_id = ProjectId::from_proto(request.project_id);
|
||||
let leader_id = ConnectionId(request.leader_id);
|
||||
let follower_id = session.connection_id;
|
||||
let project_connection_ids = session
|
||||
.db()
|
||||
.await
|
||||
.project_connection_ids(project_id, session.connection_id)
|
||||
.await?;
|
||||
{
|
||||
let project_connection_ids = session
|
||||
.db()
|
||||
.await
|
||||
.project_connection_ids(project_id, session.connection_id)
|
||||
.await?;
|
||||
|
||||
if !project_connection_ids.contains(&leader_id) {
|
||||
Err(anyhow!("no such peer"))?;
|
||||
if !project_connection_ids.contains(&leader_id) {
|
||||
Err(anyhow!("no such peer"))?;
|
||||
}
|
||||
}
|
||||
|
||||
let mut response_payload = session
|
||||
|
@ -1691,7 +1696,7 @@ async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> R
|
|||
.await?;
|
||||
broadcast(
|
||||
session.connection_id,
|
||||
project_connection_ids,
|
||||
project_connection_ids.iter().copied(),
|
||||
|connection_id| {
|
||||
session
|
||||
.peer
|
||||
|
|
Loading…
Reference in a new issue