diff --git a/server/src/bin/seed.rs b/server/src/bin/seed.rs index d2427d495c..4d3fb978db 100644 --- a/server/src/bin/seed.rs +++ b/server/src/bin/seed.rs @@ -27,8 +27,12 @@ async fn main() { let zed_users = ["nathansobo", "maxbrunsfeld", "as-cii", "iamnbutler"]; let mut zed_user_ids = Vec::::new(); for zed_user in zed_users { - if let Some(user_id) = db.get_user(zed_user).await.expect("failed to fetch user") { - zed_user_ids.push(user_id); + if let Some(user) = db + .get_user_by_github_login(zed_user) + .await + .expect("failed to fetch user") + { + zed_user_ids.push(user.id); } else { zed_user_ids.push( db.create_user(zed_user, true) diff --git a/server/src/db.rs b/server/src/db.rs index 14ad85b68a..a826220b11 100644 --- a/server/src/db.rs +++ b/server/src/db.rs @@ -84,27 +84,12 @@ impl Db { // users - #[allow(unused)] // Help rust-analyzer - #[cfg(any(test, feature = "seed-support"))] - pub async fn get_user(&self, github_login: &str) -> Result> { - test_support!(self, { - let query = " - SELECT id - FROM users - WHERE github_login = $1 - "; - sqlx::query_scalar(query) - .bind(github_login) - .fetch_optional(&self.pool) - .await - }) - } - pub async fn create_user(&self, github_login: &str, admin: bool) -> Result { test_support!(self, { let query = " INSERT INTO users (github_login, admin) VALUES ($1, $2) + ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login RETURNING id "; sqlx::query_scalar(query) diff --git a/server/src/rpc.rs b/server/src/rpc.rs index bbc0b090ca..950b6749dc 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -48,8 +48,9 @@ pub struct Server { #[derive(Default)] struct ServerState { connections: HashMap, + connections_by_user_id: HashMap>, pub worktrees: HashMap, - visible_worktrees_by_github_login: HashMap>, + visible_worktrees_by_user_id: HashMap>, channels: HashMap, next_worktree_id: u64, } @@ -62,7 +63,7 @@ struct ConnectionState { struct Worktree { host_connection_id: ConnectionId, - collaborator_github_logins: Vec, + collaborator_user_ids: Vec, root_name: String, share: Option, } @@ -113,7 +114,8 @@ impl Server { .add_handler(Server::join_channel) .add_handler(Server::leave_channel) .add_handler(Server::send_channel_message) - .add_handler(Server::get_channel_messages); + .add_handler(Server::get_channel_messages) + .add_handler(Server::get_collaborators); Arc::new(server) } @@ -215,7 +217,8 @@ impl Server { // Add a new connection associated with a given user. async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) { - self.state.write().await.connections.insert( + let mut state = self.state.write().await; + state.connections.insert( connection_id, ConnectionState { user_id, @@ -223,6 +226,11 @@ impl Server { channels: Default::default(), }, ); + state + .connections_by_user_id + .entry(user_id) + .or_default() + .insert(connection_id); } // Remove the given connection and its association with any worktrees. @@ -249,6 +257,15 @@ impl Server { } } } + + let user_connections = state + .connections_by_user_id + .get_mut(&connection.user_id) + .unwrap(); + user_connections.remove(&connection_id); + if user_connections.is_empty() { + state.connections_by_user_id.remove(&connection.user_id); + } } worktree_ids } @@ -264,10 +281,24 @@ impl Server { ) -> tide::Result<()> { let receipt = request.receipt(); + let mut collaborator_user_ids = Vec::new(); + for github_login in request.payload.collaborator_logins { + match self.app_state.db.create_user(&github_login, false).await { + Ok(user_id) => collaborator_user_ids.push(user_id), + Err(err) => { + let message = err.to_string(); + self.peer + .respond_with_error(receipt, proto::Error { message }) + .await?; + return Ok(()); + } + } + } + let mut state = self.state.write().await; let worktree_id = state.add_worktree(Worktree { host_connection_id: request.sender_id, - collaborator_github_logins: request.payload.collaborator_logins, + collaborator_user_ids, root_name: request.payload.root_name, share: None, }); @@ -351,12 +382,16 @@ impl Server { request: TypedEnvelope, ) -> tide::Result<()> { let worktree_id = request.payload.worktree_id; - let user = self.user_for_connection(request.sender_id).await?; + let user_id = self + .state + .read() + .await + .user_id_for_connection(request.sender_id)?; let response; let connection_ids; let mut state = self.state.write().await; - match state.join_worktree(request.sender_id, &user, worktree_id) { + match state.join_worktree(request.sender_id, user_id, worktree_id) { Ok((peer_replica_id, worktree)) => { let share = worktree.share()?; let peer_count = share.guest_connection_ids.len(); @@ -639,6 +674,66 @@ impl Server { Ok(()) } + async fn get_collaborators( + self: Arc, + request: TypedEnvelope, + ) -> tide::Result<()> { + let mut collaborators = HashMap::new(); + { + let state = self.state.read().await; + let user_id = state.user_id_for_connection(request.sender_id)?; + for worktree_id in state + .visible_worktrees_by_user_id + .get(&user_id) + .unwrap_or(&HashSet::new()) + { + let worktree = &state.worktrees[worktree_id]; + + let mut participants = Vec::new(); + for collaborator_user_id in &worktree.collaborator_user_ids { + collaborators + .entry(*collaborator_user_id) + .or_insert_with(|| proto::Collaborator { + user_id: collaborator_user_id.to_proto(), + worktrees: Vec::new(), + is_online: state.is_online(*collaborator_user_id), + }); + + if let Ok(share) = worktree.share() { + let mut conn_ids = state.user_connection_ids(*collaborator_user_id); + if conn_ids.any(|c| share.guest_connection_ids.contains_key(&c)) { + participants.push(collaborator_user_id.to_proto()); + } + } + } + + let host_user_id = state.user_id_for_connection(worktree.host_connection_id)?; + let host = + collaborators + .entry(host_user_id) + .or_insert_with(|| proto::Collaborator { + user_id: host_user_id.to_proto(), + worktrees: Vec::new(), + is_online: true, + }); + host.worktrees.push(proto::CollaboratorWorktree { + is_shared: worktree.share().is_ok(), + participants, + }); + } + } + + self.peer + .respond( + request.receipt(), + proto::GetCollaboratorsResponse { + collaborators: collaborators.into_values().collect(), + }, + ) + .await?; + Ok(()) + } + async fn join_channel( self: Arc, request: TypedEnvelope, @@ -856,24 +951,6 @@ impl Server { Ok(()) } - async fn user_for_connection(&self, connection_id: ConnectionId) -> tide::Result { - let user_id = self - .state - .read() - .await - .connections - .get(&connection_id) - .ok_or_else(|| anyhow!("no such connection"))? - .user_id; - Ok(self - .app_state - .db - .get_users_by_ids(user_id, Some(user_id).into_iter()) - .await? - .pop() - .ok_or_else(|| anyhow!("no such user"))?) - } - async fn broadcast_in_worktree( &self, worktree_id: u64, @@ -945,11 +1022,26 @@ impl ServerState { .user_id) } + fn user_connection_ids<'a>( + &'a self, + user_id: UserId, + ) -> impl 'a + Iterator { + self.connections_by_user_id + .get(&user_id) + .into_iter() + .flatten() + .copied() + } + + fn is_online(&self, user_id: UserId) -> bool { + self.connections_by_user_id.contains_key(&user_id) + } + // Add the given connection as a guest of the given worktree fn join_worktree( &mut self, connection_id: ConnectionId, - user: &User, + user_id: UserId, worktree_id: u64, ) -> tide::Result<(ReplicaId, &Worktree)> { let connection = self @@ -960,10 +1052,7 @@ impl ServerState { .worktrees .get_mut(&worktree_id) .ok_or_else(|| anyhow!("no such worktree"))?; - if !worktree - .collaborator_github_logins - .contains(&user.github_login) - { + if !worktree.collaborator_user_ids.contains(&user_id) { Err(anyhow!("no such worktree"))?; } @@ -1032,9 +1121,9 @@ impl ServerState { fn add_worktree(&mut self, worktree: Worktree) -> u64 { let worktree_id = self.next_worktree_id; - for collaborator_login in &worktree.collaborator_github_logins { - self.visible_worktrees_by_github_login - .entry(collaborator_login.clone()) + for collaborator_user_id in &worktree.collaborator_user_ids { + self.visible_worktrees_by_user_id + .entry(*collaborator_user_id) .or_default() .insert(worktree_id); } @@ -1055,10 +1144,10 @@ impl ServerState { } } } - for collaborator_login in worktree.collaborator_github_logins { + for collaborator_user_id in worktree.collaborator_user_ids { if let Some(visible_worktrees) = self - .visible_worktrees_by_github_login - .get_mut(&collaborator_login) + .visible_worktrees_by_user_id + .get_mut(&collaborator_user_id) { visible_worktrees.remove(&worktree_id); } diff --git a/zrpc/proto/zed.proto b/zrpc/proto/zed.proto index b5fe1604ae..30a282f6b5 100644 --- a/zrpc/proto/zed.proto +++ b/zrpc/proto/zed.proto @@ -38,6 +38,8 @@ message Envelope { OpenWorktree open_worktree = 33; OpenWorktreeResponse open_worktree_response = 34; UnshareWorktree unshare_worktree = 35; + GetCollaborators get_collaborators = 36; + GetCollaboratorsResponse get_collaborators_response = 37; } } @@ -184,6 +186,12 @@ message GetChannelMessagesResponse { bool done = 2; } +message GetCollaborators {} + +message GetCollaboratorsResponse { + repeated Collaborator collaborators = 1; +} + // Entities message Peer { @@ -326,3 +334,14 @@ message ChannelMessage { uint64 sender_id = 4; Nonce nonce = 5; } + +message Collaborator { + uint64 user_id = 1; + repeated CollaboratorWorktree worktrees = 2; + bool is_online = 3; +} + +message CollaboratorWorktree { + bool is_shared = 1; + repeated uint64 participants = 2; +} diff --git a/zrpc/src/proto.rs b/zrpc/src/proto.rs index 09348a0fc5..282710c64c 100644 --- a/zrpc/src/proto.rs +++ b/zrpc/src/proto.rs @@ -131,6 +131,8 @@ messages!( GetChannelMessagesResponse, GetChannels, GetChannelsResponse, + GetCollaborators, + GetCollaboratorsResponse, GetUsers, GetUsersResponse, JoinChannel, @@ -168,6 +170,7 @@ request_messages!( (UnshareWorktree, Ack), (SendChannelMessage, SendChannelMessageResponse), (GetChannelMessages, GetChannelMessagesResponse), + (GetCollaborators, GetCollaboratorsResponse), ); entity_messages!(