diff --git a/server/src/rpc.rs b/server/src/rpc.rs index 009fc9bd1e..f8eae2c831 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -5,12 +5,11 @@ use super::{ db::{ChannelId, MessageId, UserId}, AppState, }; -use crate::errors::TideResultExt; use anyhow::anyhow; use async_std::{sync::RwLock, task}; use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream}; use futures::{future::BoxFuture, FutureExt}; -use postage::{broadcast, mpsc, prelude::Sink as _, prelude::Stream as _}; +use postage::{mpsc, prelude::Sink as _, prelude::Stream as _}; use sha1::{Digest as _, Sha1}; use std::{ any::TypeId, @@ -20,7 +19,7 @@ use std::{ sync::Arc, time::Instant, }; -use store::{ReplicaId, Store, Worktree}; +use store::{Store, Worktree}; use surf::StatusCode; use tide::log; use tide::{ @@ -71,6 +70,7 @@ impl Server { .add_handler(Server::share_worktree) .add_handler(Server::unshare_worktree) .add_handler(Server::join_worktree) + .add_handler(Server::leave_worktree) .add_handler(Server::update_worktree) .add_handler(Server::open_buffer) .add_handler(Server::close_buffer) @@ -199,7 +199,7 @@ impl Server { } self.update_collaborators_for_users(removed_connection.collaborator_ids.iter()) - .await; + .await?; Ok(()) } @@ -420,22 +420,24 @@ impl Server { } async fn leave_worktree( - self: &Arc, - worktree_id: u64, - sender_conn_id: ConnectionId, + self: Arc, + request: TypedEnvelope, ) -> tide::Result<()> { + let sender_id = request.sender_id; + let worktree_id = request.payload.worktree_id; + if let Some((connection_ids, collaborator_ids)) = self .store .write() .await - .leave_worktree(sender_conn_id, worktree_id) + .leave_worktree(sender_id, worktree_id) { - broadcast(sender_conn_id, connection_ids, |conn_id| { + broadcast(sender_id, connection_ids, |conn_id| { self.peer.send( conn_id, proto::RemovePeer { worktree_id, - peer_id: sender_conn_id.0, + peer_id: sender_id.0, }, ) }) @@ -1550,19 +1552,23 @@ mod tests { .await; assert_eq!( - server.state().await.channels[&channel_id] + server + .state() + .await + .channel(channel_id) + .unwrap() .connection_ids .len(), 2 ); cx_b.update(|_| drop(channel_b)); server - .condition(|state| state.channels[&channel_id].connection_ids.len() == 1) + .condition(|state| state.channel(channel_id).unwrap().connection_ids.len() == 1) .await; cx_a.update(|_| drop(channel_a)); server - .condition(|state| !state.channels.contains_key(&channel_id)) + .condition(|state| state.channel(channel_id).is_none()) .await; } diff --git a/server/src/rpc/store.rs b/server/src/rpc/store.rs index c7a6c2b166..cbc691a8d7 100644 --- a/server/src/rpc/store.rs +++ b/server/src/rpc/store.rs @@ -1,4 +1,4 @@ -use crate::db::{ChannelId, MessageId, UserId}; +use crate::db::{ChannelId, UserId}; use crate::errors::TideResultExt; use anyhow::anyhow; use std::collections::{hash_map, HashMap, HashSet}; @@ -27,15 +27,15 @@ pub struct Worktree { pub share: Option, } -struct WorktreeShare { +pub struct WorktreeShare { pub guest_connection_ids: HashMap, pub active_replica_ids: HashSet, pub entries: HashMap, } #[derive(Default)] -struct Channel { - connection_ids: HashSet, +pub struct Channel { + pub connection_ids: HashSet, } pub type ReplicaId = u16; @@ -73,7 +73,7 @@ impl Store { return Err(anyhow!("no such connection"))?; }; - for channel_id in connection.channels { + for channel_id in &connection.channels { if let Some(channel) = self.channels.get_mut(&channel_id) { channel.connection_ids.remove(&connection_id); } @@ -89,12 +89,12 @@ impl Store { } let mut result = RemovedConnectionState::default(); - for worktree_id in connection.worktrees { + for worktree_id in connection.worktrees.clone() { if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) { - result.hosted_worktrees.insert(worktree_id, worktree); result .collaborator_ids .extend(worktree.collaborator_user_ids.iter().copied()); + result.hosted_worktrees.insert(worktree_id, worktree); } else { if let Some(worktree) = self.worktrees.get(&worktree_id) { result @@ -110,6 +110,10 @@ impl Store { Ok(result) } + pub fn channel(&self, id: ChannelId) -> Option<&Channel> { + self.channels.get(&id) + } + pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) { if let Some(connection) = self.connections.get_mut(&connection_id) { connection.channels.insert(channel_id); @@ -230,7 +234,7 @@ impl Store { connection.worktrees.remove(&worktree_id); } - if let Some(share) = worktree.share { + if let Some(share) = &worktree.share { for connection_id in share.guest_connection_ids.keys() { if let Some(connection) = self.connections.get_mut(connection_id) { connection.worktrees.remove(&worktree_id); @@ -238,7 +242,7 @@ impl Store { } } - for collaborator_user_id in worktree.collaborator_user_ids { + for collaborator_user_id in &worktree.collaborator_user_ids { if let Some(visible_worktrees) = self .visible_worktrees_by_user_id .get_mut(&collaborator_user_id) @@ -289,7 +293,7 @@ impl Store { let connection_ids = worktree.connection_ids(); - if let Some(share) = worktree.share.take() { + if let Some(_) = worktree.share.take() { for connection_id in &connection_ids { if let Some(connection) = self.connections.get_mut(connection_id) { connection.worktrees.remove(&worktree_id);