diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 4375056c9a..19d45e221d 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -68,21 +68,20 @@ lazy_static! { } type MessageHandler = Box< - dyn Send + Sync + Fn(Arc, UserId, Box) -> BoxFuture<'static, ()>, + dyn Send + Sync + Fn(Arc, Box, Session) -> BoxFuture<'static, ()>, >; -struct Message { - sender_user_id: UserId, - sender_connection_id: ConnectionId, - payload: T, -} - struct Response { server: Arc, receipt: Receipt, responded: Arc, } +struct Session { + user_id: UserId, + connection_id: ConnectionId, +} + impl Response { fn send(self, payload: R::Response) -> Result<()> { self.responded.store(true, SeqCst); @@ -201,13 +200,13 @@ impl Server { fn add_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Arc, UserId, TypedEnvelope) -> Fut, + F: 'static + Send + Sync + Fn(Arc, TypedEnvelope, Session) -> Fut, Fut: 'static + Send + Future>, M: EnvelopedMessage, { let prev_handler = self.handlers.insert( TypeId::of::(), - Box::new(move |server, sender_user_id, envelope| { + Box::new(move |server, envelope, session| { let envelope = envelope.into_any().downcast::>().unwrap(); let span = info_span!( "handle message", @@ -219,7 +218,7 @@ impl Server { "message received" ); }); - let future = (handler)(server, sender_user_id, *envelope); + let future = (handler)(server, *envelope, session); async move { if let Err(error) = future.await { tracing::error!(%error, "error handling message"); @@ -237,19 +236,12 @@ impl Server { fn add_message_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Arc, Message) -> Fut, + F: 'static + Send + Sync + Fn(Arc, M, Session) -> Fut, Fut: 'static + Send + Future>, M: EnvelopedMessage, { - self.add_handler(move |server, sender_user_id, envelope| { - handler( - server, - Message { - sender_user_id, - sender_connection_id: envelope.sender_id, - payload: envelope.payload, - }, - ) + self.add_handler(move |server, envelope, session| { + handler(server, envelope.payload, session) }); self } @@ -258,27 +250,22 @@ impl Server { /// a connection but we want to respond on the connection before anybody else can send on it. fn add_request_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Arc, Message, Response) -> Fut, + F: 'static + Send + Sync + Fn(Arc, M, Response, Session) -> Fut, Fut: Send + Future>, M: RequestMessage, { let handler = Arc::new(handler); - self.add_handler(move |server, sender_user_id, envelope| { + self.add_handler(move |server, envelope, session| { let receipt = envelope.receipt(); let handler = handler.clone(); async move { - let request = Message { - sender_user_id, - sender_connection_id: envelope.sender_id, - payload: envelope.payload, - }; let responded = Arc::new(AtomicBool::default()); let response = Response { server: server.clone(), responded: responded.clone(), receipt, }; - match (handler)(server.clone(), request, response).await { + match (handler)(server.clone(), envelope.payload, response, session).await { Ok(()) => { if responded.load(std::sync::atomic::Ordering::SeqCst) { Ok(()) @@ -392,7 +379,11 @@ impl Server { let span_enter = span.enter(); if let Some(handler) = this.handlers.get(&message.payload_type_id()) { let is_background = message.is_background(); - let handle_message = (handler)(this.clone(), user_id, message); + let session = Session { + user_id, + connection_id, + }; + let handle_message = (handler)(this.clone(), message, session); drop(span_enter); let handle_message = handle_message.instrument(span); @@ -509,8 +500,9 @@ impl Server { async fn ping( self: Arc, - _: Message, + _: proto::Ping, response: Response, + _session: Session, ) -> Result<()> { response.send(proto::Ack {})?; Ok(()) @@ -518,13 +510,14 @@ impl Server { async fn create_room( self: Arc, - request: Message, + _request: proto::CreateRoom, response: Response, + session: Session, ) -> Result<()> { let room = self .app_state .db - .create_room(request.sender_user_id, request.sender_connection_id) + .create_room(session.user_id, session.connection_id) .await?; let live_kit_connection_info = @@ -535,10 +528,7 @@ impl Server { .trace_err() { if let Some(token) = live_kit - .room_token( - &room.live_kit_room, - &request.sender_connection_id.to_string(), - ) + .room_token(&room.live_kit_room, &session.connection_id.to_string()) .trace_err() { Some(proto::LiveKitConnectionInfo { @@ -559,29 +549,26 @@ impl Server { room: Some(room), live_kit_connection_info, })?; - self.update_user_contacts(request.sender_user_id).await?; + self.update_user_contacts(session.user_id).await?; Ok(()) } async fn join_room( self: Arc, - request: Message, + request: proto::JoinRoom, response: Response, + session: Session, ) -> Result<()> { let room = self .app_state .db .join_room( - RoomId::from_proto(request.payload.id), - request.sender_user_id, - request.sender_connection_id, + RoomId::from_proto(request.id), + session.user_id, + session.connection_id, ) .await?; - for connection_id in self - .store() - .await - .connection_ids_for_user(request.sender_user_id) - { + for connection_id in self.store().await.connection_ids_for_user(session.user_id) { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); @@ -590,10 +577,7 @@ impl Server { let live_kit_connection_info = if let Some(live_kit) = self.app_state.live_kit_client.as_ref() { if let Some(token) = live_kit - .room_token( - &room.live_kit_room, - &request.sender_connection_id.to_string(), - ) + .room_token(&room.live_kit_room, &session.connection_id.to_string()) .trace_err() { Some(proto::LiveKitConnectionInfo { @@ -613,12 +597,16 @@ impl Server { live_kit_connection_info, })?; - self.update_user_contacts(request.sender_user_id).await?; + self.update_user_contacts(session.user_id).await?; Ok(()) } - async fn leave_room(self: Arc, message: Message) -> Result<()> { - self.leave_room_for_connection(message.sender_connection_id, message.sender_user_id) + async fn leave_room( + self: Arc, + _message: proto::LeaveRoom, + session: Session, + ) -> Result<()> { + self.leave_room_for_connection(session.connection_id, session.user_id) .await } @@ -707,17 +695,15 @@ impl Server { async fn call( self: Arc, - request: Message, + request: proto::Call, response: Response, + session: Session, ) -> Result<()> { - let room_id = RoomId::from_proto(request.payload.room_id); - let calling_user_id = request.sender_user_id; - let calling_connection_id = request.sender_connection_id; - let called_user_id = UserId::from_proto(request.payload.called_user_id); - let initial_project_id = request - .payload - .initial_project_id - .map(ProjectId::from_proto); + let room_id = RoomId::from_proto(request.room_id); + let calling_user_id = session.user_id; + let calling_connection_id = session.connection_id; + let called_user_id = UserId::from_proto(request.called_user_id); + let initial_project_id = request.initial_project_id.map(ProjectId::from_proto); if !self .app_state .db @@ -773,15 +759,16 @@ impl Server { async fn cancel_call( self: Arc, - request: Message, + request: proto::CancelCall, response: Response, + session: Session, ) -> Result<()> { - let called_user_id = UserId::from_proto(request.payload.called_user_id); - let room_id = RoomId::from_proto(request.payload.room_id); + let called_user_id = UserId::from_proto(request.called_user_id); + let room_id = RoomId::from_proto(request.room_id); let room = self .app_state .db - .cancel_call(Some(room_id), request.sender_connection_id, called_user_id) + .cancel_call(Some(room_id), session.connection_id, called_user_id) .await?; for connection_id in self.store().await.connection_ids_for_user(called_user_id) { self.peer @@ -795,41 +782,41 @@ impl Server { Ok(()) } - async fn decline_call(self: Arc, message: Message) -> Result<()> { - let room_id = RoomId::from_proto(message.payload.room_id); + async fn decline_call( + self: Arc, + message: proto::DeclineCall, + session: Session, + ) -> Result<()> { + let room_id = RoomId::from_proto(message.room_id); let room = self .app_state .db - .decline_call(Some(room_id), message.sender_user_id) + .decline_call(Some(room_id), session.user_id) .await?; - for connection_id in self - .store() - .await - .connection_ids_for_user(message.sender_user_id) - { + for connection_id in self.store().await.connection_ids_for_user(session.user_id) { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); } self.room_updated(&room); - self.update_user_contacts(message.sender_user_id).await?; + self.update_user_contacts(session.user_id).await?; Ok(()) } async fn update_participant_location( self: Arc, - request: Message, + request: proto::UpdateParticipantLocation, response: Response, + session: Session, ) -> Result<()> { - let room_id = RoomId::from_proto(request.payload.room_id); + let room_id = RoomId::from_proto(request.room_id); let location = request - .payload .location .ok_or_else(|| anyhow!("invalid location"))?; let room = self .app_state .db - .update_room_participant_location(room_id, request.sender_connection_id, location) + .update_room_participant_location(room_id, session.connection_id, location) .await?; self.room_updated(&room); response.send(proto::Ack {})?; @@ -851,16 +838,17 @@ impl Server { async fn share_project( self: Arc, - request: Message, + request: proto::ShareProject, response: Response, + session: Session, ) -> Result<()> { let (project_id, room) = self .app_state .db .share_project( - RoomId::from_proto(request.payload.room_id), - request.sender_connection_id, - &request.payload.worktrees, + RoomId::from_proto(request.room_id), + session.connection_id, + &request.worktrees, ) .await?; response.send(proto::ShareProjectResponse { @@ -873,21 +861,20 @@ impl Server { async fn unshare_project( self: Arc, - message: Message, + message: proto::UnshareProject, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(message.payload.project_id); + let project_id = ProjectId::from_proto(message.project_id); let (room, guest_connection_ids) = self .app_state .db - .unshare_project(project_id, message.sender_connection_id) + .unshare_project(project_id, session.connection_id) .await?; - broadcast( - message.sender_connection_id, - guest_connection_ids, - |conn_id| self.peer.send(conn_id, message.payload.clone()), - ); + broadcast(session.connection_id, guest_connection_ids, |conn_id| { + self.peer.send(conn_id, message.clone()) + }); self.room_updated(&room); Ok(()) @@ -926,26 +913,25 @@ impl Server { async fn join_project( self: Arc, - request: Message, + request: proto::JoinProject, response: Response, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); - let guest_user_id = request.sender_user_id; + let project_id = ProjectId::from_proto(request.project_id); + let guest_user_id = session.user_id; tracing::info!(%project_id, "join project"); let (project, replica_id) = self .app_state .db - .join_project(project_id, request.sender_connection_id) + .join_project(project_id, session.connection_id) .await?; let collaborators = project .collaborators .iter() - .filter(|collaborator| { - collaborator.connection_id != request.sender_connection_id.0 as i32 - }) + .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32) .map(|collaborator| proto::Collaborator { peer_id: collaborator.connection_id as u32, replica_id: collaborator.replica_id.0 as u32, @@ -970,7 +956,7 @@ impl Server { proto::AddProjectCollaborator { project_id: project_id.to_proto(), collaborator: Some(proto::Collaborator { - peer_id: request.sender_connection_id.0, + peer_id: session.connection_id.0, replica_id: replica_id.0 as u32, user_id: guest_user_id.to_proto(), }), @@ -1005,14 +991,13 @@ impl Server { is_last_update: worktree.is_complete, }; for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) { - self.peer - .send(request.sender_connection_id, update.clone())?; + self.peer.send(session.connection_id, update.clone())?; } // Stream this worktree's diagnostics. for summary in worktree.diagnostic_summaries { self.peer.send( - request.sender_connection_id, + session.connection_id, proto::UpdateDiagnosticSummary { project_id: project_id.to_proto(), worktree_id: worktree.id.to_proto(), @@ -1024,7 +1009,7 @@ impl Server { for language_server in &project.language_servers { self.peer.send( - request.sender_connection_id, + session.connection_id, proto::UpdateLanguageServer { project_id: project_id.to_proto(), language_server_id: language_server.id, @@ -1040,9 +1025,13 @@ impl Server { Ok(()) } - async fn leave_project(self: Arc, request: Message) -> Result<()> { - let sender_id = request.sender_connection_id; - let project_id = ProjectId::from_proto(request.payload.project_id); + async fn leave_project( + self: Arc, + request: proto::LeaveProject, + session: Session, + ) -> Result<()> { + let sender_id = session.connection_id; + let project_id = ProjectId::from_proto(request.project_id); let project; { project = self @@ -1073,28 +1062,22 @@ impl Server { async fn update_project( self: Arc, - request: Message, + request: proto::UpdateProject, response: Response, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); + let project_id = ProjectId::from_proto(request.project_id); let (room, guest_connection_ids) = self .app_state .db - .update_project( - project_id, - request.sender_connection_id, - &request.payload.worktrees, - ) + .update_project(project_id, session.connection_id, &request.worktrees) .await?; broadcast( - request.sender_connection_id, + session.connection_id, guest_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); self.room_updated(&room); @@ -1105,24 +1088,22 @@ impl Server { async fn update_worktree( self: Arc, - request: Message, + request: proto::UpdateWorktree, response: Response, + session: Session, ) -> Result<()> { let guest_connection_ids = self .app_state .db - .update_worktree(&request.payload, request.sender_connection_id) + .update_worktree(&request, session.connection_id) .await?; broadcast( - request.sender_connection_id, + session.connection_id, guest_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); response.send(proto::Ack {})?; @@ -1131,24 +1112,22 @@ impl Server { async fn update_diagnostic_summary( self: Arc, - request: Message, + request: proto::UpdateDiagnosticSummary, response: Response, + session: Session, ) -> Result<()> { let guest_connection_ids = self .app_state .db - .update_diagnostic_summary(&request.payload, request.sender_connection_id) + .update_diagnostic_summary(&request, session.connection_id) .await?; broadcast( - request.sender_connection_id, + session.connection_id, guest_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); @@ -1158,23 +1137,21 @@ impl Server { async fn start_language_server( self: Arc, - request: Message, + request: proto::StartLanguageServer, + session: Session, ) -> Result<()> { let guest_connection_ids = self .app_state .db - .start_language_server(&request.payload, request.sender_connection_id) + .start_language_server(&request, session.connection_id) .await?; broadcast( - request.sender_connection_id, + session.connection_id, guest_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); Ok(()) @@ -1182,23 +1159,21 @@ impl Server { async fn update_language_server( self: Arc, - request: Message, + request: proto::UpdateLanguageServer, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); + let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = self .app_state .db - .project_connection_ids(project_id, request.sender_connection_id) + .project_connection_ids(project_id, session.connection_id) .await?; broadcast( - request.sender_connection_id, + session.connection_id, project_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); Ok(()) @@ -1206,17 +1181,18 @@ impl Server { async fn forward_project_request( self: Arc, - request: Message, + request: T, response: Response, + session: Session, ) -> Result<()> where T: EntityMessage + RequestMessage, { - let project_id = ProjectId::from_proto(request.payload.remote_entity_id()); + let project_id = ProjectId::from_proto(request.remote_entity_id()); let collaborators = self .app_state .db - .project_collaborators(project_id, request.sender_connection_id) + .project_collaborators(project_id, session.connection_id) .await?; let host = collaborators .iter() @@ -1226,9 +1202,9 @@ impl Server { let payload = self .peer .forward_request( - request.sender_connection_id, + session.connection_id, ConnectionId(host.connection_id as u32), - request.payload, + request, ) .await?; @@ -1238,14 +1214,15 @@ impl Server { async fn save_buffer( self: Arc, - request: Message, + request: proto::SaveBuffer, response: Response, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); + let project_id = ProjectId::from_proto(request.project_id); let collaborators = self .app_state .db - .project_collaborators(project_id, request.sender_connection_id) + .project_collaborators(project_id, session.connection_id) .await?; let host = collaborators .into_iter() @@ -1254,21 +1231,16 @@ impl Server { let host_connection_id = ConnectionId(host.connection_id as u32); let response_payload = self .peer - .forward_request( - request.sender_connection_id, - host_connection_id, - request.payload.clone(), - ) + .forward_request(session.connection_id, host_connection_id, request.clone()) .await?; let mut collaborators = self .app_state .db - .project_collaborators(project_id, request.sender_connection_id) + .project_collaborators(project_id, session.connection_id) .await?; - collaborators.retain(|collaborator| { - collaborator.connection_id != request.sender_connection_id.0 as i32 - }); + collaborators + .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32); let project_connection_ids = collaborators .into_iter() .map(|collaborator| ConnectionId(collaborator.connection_id as u32)); @@ -1282,37 +1254,36 @@ impl Server { async fn create_buffer_for_peer( self: Arc, - request: Message, + request: proto::CreateBufferForPeer, + session: Session, ) -> Result<()> { self.peer.forward_send( - request.sender_connection_id, - ConnectionId(request.payload.peer_id), - request.payload, + session.connection_id, + ConnectionId(request.peer_id), + request, )?; Ok(()) } async fn update_buffer( self: Arc, - request: Message, + request: proto::UpdateBuffer, response: Response, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); + let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = self .app_state .db - .project_connection_ids(project_id, request.sender_connection_id) + .project_connection_ids(project_id, session.connection_id) .await?; broadcast( - request.sender_connection_id, + session.connection_id, project_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); response.send(proto::Ack {})?; @@ -1321,24 +1292,22 @@ impl Server { async fn update_buffer_file( self: Arc, - request: Message, + request: proto::UpdateBufferFile, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); + let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = self .app_state .db - .project_connection_ids(project_id, request.sender_connection_id) + .project_connection_ids(project_id, session.connection_id) .await?; broadcast( - request.sender_connection_id, + session.connection_id, project_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); Ok(()) @@ -1346,44 +1315,43 @@ impl Server { async fn buffer_reloaded( self: Arc, - request: Message, + request: proto::BufferReloaded, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); + let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = self .app_state .db - .project_connection_ids(project_id, request.sender_connection_id) + .project_connection_ids(project_id, session.connection_id) .await?; broadcast( - request.sender_connection_id, + session.connection_id, project_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); Ok(()) } - async fn buffer_saved(self: Arc, request: Message) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); + async fn buffer_saved( + self: Arc, + request: proto::BufferSaved, + session: Session, + ) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = self .app_state .db - .project_connection_ids(project_id, request.sender_connection_id) + .project_connection_ids(project_id, session.connection_id) .await?; broadcast( - request.sender_connection_id, + session.connection_id, project_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); Ok(()) @@ -1391,16 +1359,17 @@ impl Server { async fn follow( self: Arc, - request: Message, + request: proto::Follow, response: Response, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); - let leader_id = ConnectionId(request.payload.leader_id); - let follower_id = request.sender_connection_id; + let project_id = ProjectId::from_proto(request.project_id); + let leader_id = ConnectionId(request.leader_id); + let follower_id = session.connection_id; let project_connection_ids = self .app_state .db - .project_connection_ids(project_id, request.sender_connection_id) + .project_connection_ids(project_id, session.connection_id) .await?; if !project_connection_ids.contains(&leader_id) { @@ -1409,7 +1378,7 @@ impl Server { let mut response_payload = self .peer - .forward_request(request.sender_connection_id, leader_id, request.payload) + .forward_request(session.connection_id, leader_id, request) .await?; response_payload .views @@ -1418,50 +1387,44 @@ impl Server { Ok(()) } - async fn unfollow(self: Arc, request: Message) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); - let leader_id = ConnectionId(request.payload.leader_id); + async fn unfollow(self: Arc, request: proto::Unfollow, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let leader_id = ConnectionId(request.leader_id); let project_connection_ids = self .app_state .db - .project_connection_ids(project_id, request.sender_connection_id) + .project_connection_ids(project_id, session.connection_id) .await?; if !project_connection_ids.contains(&leader_id) { Err(anyhow!("no such peer"))?; } self.peer - .forward_send(request.sender_connection_id, leader_id, request.payload)?; + .forward_send(session.connection_id, leader_id, request)?; Ok(()) } async fn update_followers( self: Arc, - request: Message, + request: proto::UpdateFollowers, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); + let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = self .app_state .db - .project_connection_ids(project_id, request.sender_connection_id) + .project_connection_ids(project_id, session.connection_id) .await?; - let leader_id = request - .payload - .variant - .as_ref() - .and_then(|variant| match variant { - proto::update_followers::Variant::CreateView(payload) => payload.leader_id, - proto::update_followers::Variant::UpdateView(payload) => payload.leader_id, - proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id, - }); - for follower_id in &request.payload.follower_ids { + let leader_id = request.variant.as_ref().and_then(|variant| match variant { + proto::update_followers::Variant::CreateView(payload) => payload.leader_id, + proto::update_followers::Variant::UpdateView(payload) => payload.leader_id, + proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id, + }); + for follower_id in &request.follower_ids { let follower_id = ConnectionId(*follower_id); if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id { - self.peer.forward_send( - request.sender_connection_id, - follower_id, - request.payload.clone(), - )?; + self.peer + .forward_send(session.connection_id, follower_id, request.clone())?; } } Ok(()) @@ -1469,11 +1432,11 @@ impl Server { async fn get_users( self: Arc, - request: Message, + request: proto::GetUsers, response: Response, + _session: Session, ) -> Result<()> { let user_ids = request - .payload .user_ids .into_iter() .map(UserId::from_proto) @@ -1496,10 +1459,11 @@ impl Server { async fn fuzzy_search_users( self: Arc, - request: Message, + request: proto::FuzzySearchUsers, response: Response, + session: Session, ) -> Result<()> { - let query = request.payload.query; + let query = request.query; let db = &self.app_state.db; let users = match query.len() { 0 => vec![], @@ -1512,7 +1476,7 @@ impl Server { }; let users = users .into_iter() - .filter(|user| user.id != request.sender_user_id) + .filter(|user| user.id != session.user_id) .map(|user| proto::User { id: user.id.to_proto(), avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), @@ -1525,11 +1489,12 @@ impl Server { async fn request_contact( self: Arc, - request: Message, + request: proto::RequestContact, response: Response, + session: Session, ) -> Result<()> { - let requester_id = request.sender_user_id; - let responder_id = UserId::from_proto(request.payload.responder_id); + let requester_id = session.user_id; + let responder_id = UserId::from_proto(request.responder_id); if requester_id == responder_id { return Err(anyhow!("cannot add yourself as a contact"))?; } @@ -1564,18 +1529,19 @@ impl Server { async fn respond_to_contact_request( self: Arc, - request: Message, + request: proto::RespondToContactRequest, response: Response, + session: Session, ) -> Result<()> { - let responder_id = request.sender_user_id; - let requester_id = UserId::from_proto(request.payload.requester_id); - if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 { + let responder_id = session.user_id; + let requester_id = UserId::from_proto(request.requester_id); + if request.response == proto::ContactRequestResponse::Dismiss as i32 { self.app_state .db .dismiss_contact_notification(responder_id, requester_id) .await?; } else { - let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32; + let accept = request.response == proto::ContactRequestResponse::Accept as i32; self.app_state .db .respond_to_contact_request(responder_id, requester_id, accept) @@ -1618,11 +1584,12 @@ impl Server { async fn remove_contact( self: Arc, - request: Message, + request: proto::RemoveContact, response: Response, + session: Session, ) -> Result<()> { - let requester_id = request.sender_user_id; - let responder_id = UserId::from_proto(request.payload.user_id); + let requester_id = session.user_id; + let responder_id = UserId::from_proto(request.user_id); self.app_state .db .remove_contact(requester_id, responder_id) @@ -1652,23 +1619,21 @@ impl Server { async fn update_diff_base( self: Arc, - request: Message, + request: proto::UpdateDiffBase, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); + let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = self .app_state .db - .project_connection_ids(project_id, request.sender_connection_id) + .project_connection_ids(project_id, session.connection_id) .await?; broadcast( - request.sender_connection_id, + session.connection_id, project_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); Ok(()) @@ -1676,18 +1641,19 @@ impl Server { async fn get_private_user_info( self: Arc, - request: Message, + _request: proto::GetPrivateUserInfo, response: Response, + session: Session, ) -> Result<()> { let metrics_id = self .app_state .db - .get_user_metrics_id(request.sender_user_id) + .get_user_metrics_id(session.user_id) .await?; let user = self .app_state .db - .get_user_by_id(request.sender_user_id) + .get_user_by_id(session.user_id) .await? .ok_or_else(|| anyhow!("user not found"))?; response.send(proto::GetPrivateUserInfoResponse {