From d398b96f564cf80b87c667e7330630ac804ee823 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 19 Aug 2021 15:35:03 +0200 Subject: [PATCH] Re-register message handlers in RPC server --- server/src/rpc.rs | 305 +++++++++++++++++++++++++--------------------- 1 file changed, 164 insertions(+), 141 deletions(-) diff --git a/server/src/rpc.rs b/server/src/rpc.rs index e628ef2c81..465cd96a98 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -35,23 +35,26 @@ use zrpc::{ type ReplicaId = u16; -type Handler = Box< +type MessageHandler = Box< dyn Send + Sync - + Fn(&mut Option>, Arc) -> Option>, + + Fn( + &mut Option>, + Arc, + ) -> Option>>, >; #[derive(Default)] struct ServerBuilder { - handlers: Vec, + handlers: Vec, handler_types: HashSet, } impl ServerBuilder { - pub fn on_message(&mut self, handler: F) -> &mut Self + pub fn on_message(mut self, handler: F) -> Self where F: 'static + Send + Sync + Fn(Box>, Arc) -> Fut, - Fut: 'static + Send + Future, + Fut: 'static + Send + Future>, M: EnvelopedMessage, { if self.handler_types.insert(TypeId::of::()) { @@ -87,7 +90,7 @@ impl ServerBuilder { pub struct Server { rpc: Arc, state: Arc, - handlers: Vec, + handlers: Vec, } impl Server { @@ -119,10 +122,16 @@ impl Server { futures::select_biased! { message = next_message => { if let Some(message) = message { + let start_time = Instant::now(); + log::info!("RPC message received"); let mut message = Some(message); for handler in &this.handlers { if let Some(future) = (handler)(&mut message, this.clone()) { - future.await; + if let Err(err) = future.await { + log::error!("error handling message: {:?}", err); + } else { + log::info!("RPC message handled. duration:{:?}", start_time.elapsed()); + } break; } } @@ -336,26 +345,24 @@ impl State { pub fn build_server(state: &Arc, rpc: &Arc) -> Arc { ServerBuilder::default() - // .on_message(share_worktree) - // .on_message(join_worktree) - // .on_message(update_worktree) - // .on_message(close_worktree) - // .on_message(open_buffer) - // .on_message(close_buffer) - // .on_message(update_buffer) - // .on_message(buffer_saved) - // .on_message(save_buffer) - // .on_message(get_channels) - // .on_message(get_users) - // .on_message(join_channel) - // .on_message(send_channel_message) + .on_message(share_worktree) + .on_message(join_worktree) + .on_message(update_worktree) + .on_message(close_worktree) + .on_message(open_buffer) + .on_message(close_buffer) + .on_message(update_buffer) + .on_message(buffer_saved) + .on_message(save_buffer) + .on_message(get_channels) + .on_message(get_users) + .on_message(join_channel) + .on_message(send_channel_message) .build(rpc, state) } pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { let server = build_server(app.state(), rpc); - - let rpc = rpc.clone(); app.at("/rpc").with(auth::VerifyToken).get(move |request: Request>| { let user_id = request.ext::().copied(); let server = server.clone(); @@ -399,11 +406,10 @@ pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { } async fn share_worktree( - mut request: TypedEnvelope, - rpc: &Arc, - state: &Arc, + mut request: Box>, + server: Arc, ) -> tide::Result<()> { - let mut state = state.rpc.write().await; + let mut state = server.state.rpc.write().await; let worktree_id = state.next_worktree_id; state.next_worktree_id += 1; let access_token = random_token(); @@ -428,26 +434,27 @@ async fn share_worktree( }, ); - rpc.respond( - request.receipt(), - proto::ShareWorktreeResponse { - worktree_id, - access_token, - }, - ) - .await?; + server + .rpc + .respond( + request.receipt(), + proto::ShareWorktreeResponse { + worktree_id, + access_token, + }, + ) + .await?; Ok(()) } async fn join_worktree( - request: TypedEnvelope, - rpc: &Arc, - state: &Arc, + request: Box>, + server: Arc, ) -> tide::Result<()> { let worktree_id = request.payload.worktree_id; let access_token = &request.payload.access_token; - let mut state = state.rpc.write().await; + let mut state = server.state.rpc.write().await; if let Some((peer_replica_id, worktree)) = state.join_worktree(request.sender_id, worktree_id, access_token) { @@ -468,7 +475,7 @@ async fn join_worktree( } broadcast(request.sender_id, worktree.connection_ids(), |conn_id| { - rpc.send( + server.rpc.send( conn_id, proto::AddPeer { worktree_id, @@ -480,42 +487,45 @@ async fn join_worktree( ) }) .await?; - rpc.respond( - request.receipt(), - proto::OpenWorktreeResponse { - worktree_id, - worktree: Some(proto::Worktree { - root_name: worktree.root_name.clone(), - entries: worktree.entries.values().cloned().collect(), - }), - replica_id: peer_replica_id as u32, - peers, - }, - ) - .await?; + server + .rpc + .respond( + request.receipt(), + proto::OpenWorktreeResponse { + worktree_id, + worktree: Some(proto::Worktree { + root_name: worktree.root_name.clone(), + entries: worktree.entries.values().cloned().collect(), + }), + replica_id: peer_replica_id as u32, + peers, + }, + ) + .await?; } else { - rpc.respond( - request.receipt(), - proto::OpenWorktreeResponse { - worktree_id, - worktree: None, - replica_id: 0, - peers: Vec::new(), - }, - ) - .await?; + server + .rpc + .respond( + request.receipt(), + proto::OpenWorktreeResponse { + worktree_id, + worktree: None, + replica_id: 0, + peers: Vec::new(), + }, + ) + .await?; } Ok(()) } async fn update_worktree( - request: TypedEnvelope, - rpc: &Arc, - state: &Arc, + request: Box>, + server: Arc, ) -> tide::Result<()> { { - let mut state = state.rpc.write().await; + let mut state = server.state.rpc.write().await; let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?; for entry_id in &request.payload.removed_entries { worktree.entries.remove(&entry_id); @@ -526,18 +536,17 @@ async fn update_worktree( } } - broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await?; + broadcast_in_worktree(request.payload.worktree_id, &request, &server).await?; Ok(()) } async fn close_worktree( - request: TypedEnvelope, - rpc: &Arc, - state: &Arc, + request: Box>, + server: Arc, ) -> tide::Result<()> { let connection_ids; { - let mut state = state.rpc.write().await; + let mut state = server.state.rpc.write().await; let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?; connection_ids = worktree.connection_ids(); if worktree.host_connection_id == Some(request.sender_id) { @@ -548,7 +557,7 @@ async fn close_worktree( } broadcast(request.sender_id, connection_ids, |conn_id| { - rpc.send( + server.rpc.send( conn_id, proto::RemovePeer { worktree_id: request.payload.worktree_id, @@ -562,53 +571,55 @@ async fn close_worktree( } async fn open_buffer( - request: TypedEnvelope, - rpc: &Arc, - state: &Arc, + request: Box>, + server: Arc, ) -> tide::Result<()> { let receipt = request.receipt(); let worktree_id = request.payload.worktree_id; - let host_connection_id = state + let host_connection_id = server + .state .rpc .read() .await .read_worktree(worktree_id, request.sender_id)? .host_connection_id()?; - let response = rpc + let response = server + .rpc .forward_request(request.sender_id, host_connection_id, request.payload) .await?; - rpc.respond(receipt, response).await?; + server.rpc.respond(receipt, response).await?; Ok(()) } async fn close_buffer( - request: TypedEnvelope, - rpc: &Arc, - state: &Arc, + request: Box>, + server: Arc, ) -> tide::Result<()> { - let host_connection_id = state + let host_connection_id = server + .state .rpc .read() .await .read_worktree(request.payload.worktree_id, request.sender_id)? .host_connection_id()?; - rpc.forward_send(request.sender_id, host_connection_id, request.payload) + server + .rpc + .forward_send(request.sender_id, host_connection_id, request.payload) .await?; Ok(()) } async fn save_buffer( - request: TypedEnvelope, - rpc: &Arc, - state: &Arc, + request: Box>, + server: Arc, ) -> tide::Result<()> { let host; let guests; { - let state = state.rpc.read().await; + let state = server.state.rpc.read().await; let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?; host = worktree.host_connection_id()?; guests = worktree @@ -620,17 +631,19 @@ async fn save_buffer( let sender = request.sender_id; let receipt = request.receipt(); - let response = rpc + let response = server + .rpc .forward_request(sender, host, request.payload.clone()) .await?; broadcast(host, guests, |conn_id| { let response = response.clone(); + let server = &server; async move { if conn_id == sender { - rpc.respond(receipt, response).await + server.rpc.respond(receipt, response).await } else { - rpc.forward_send(host, conn_id, response).await + server.rpc.forward_send(host, conn_id, response).await } } }) @@ -640,61 +653,62 @@ async fn save_buffer( } async fn update_buffer( - request: TypedEnvelope, - rpc: &Arc, - state: &Arc, + request: Box>, + server: Arc, ) -> tide::Result<()> { - broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await + broadcast_in_worktree(request.payload.worktree_id, &request, &server).await } async fn buffer_saved( - request: TypedEnvelope, - rpc: &Arc, - state: &Arc, + request: Box>, + server: Arc, ) -> tide::Result<()> { - broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await + broadcast_in_worktree(request.payload.worktree_id, &request, &server).await } async fn get_channels( - request: TypedEnvelope, - rpc: &Arc, - state: &Arc, + request: Box>, + server: Arc, ) -> tide::Result<()> { - let user_id = state + let user_id = server + .state .rpc .read() .await .user_id_for_connection(request.sender_id)?; - let channels = state.db.get_channels_for_user(user_id).await?; - rpc.respond( - request.receipt(), - proto::GetChannelsResponse { - channels: channels - .into_iter() - .map(|chan| proto::Channel { - id: chan.id.to_proto(), - name: chan.name, - }) - .collect(), - }, - ) - .await?; + let channels = server.state.db.get_channels_for_user(user_id).await?; + server + .rpc + .respond( + request.receipt(), + proto::GetChannelsResponse { + channels: channels + .into_iter() + .map(|chan| proto::Channel { + id: chan.id.to_proto(), + name: chan.name, + }) + .collect(), + }, + ) + .await?; Ok(()) } async fn get_users( - request: TypedEnvelope, - rpc: &Arc, - state: &Arc, + request: Box>, + server: Arc, ) -> tide::Result<()> { - let user_id = state + let user_id = server + .state .rpc .read() .await .user_id_for_connection(request.sender_id)?; let receipt = request.receipt(); let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto); - let users = state + let users = server + .state .db .get_users_by_ids(user_id, user_ids) .await? @@ -705,23 +719,26 @@ async fn get_users( avatar_url: String::new(), }) .collect(); - rpc.respond(receipt, proto::GetUsersResponse { users }) + server + .rpc + .respond(receipt, proto::GetUsersResponse { users }) .await?; Ok(()) } async fn join_channel( - request: TypedEnvelope, - rpc: &Arc, - state: &Arc, + request: Box>, + server: Arc, ) -> tide::Result<()> { - let user_id = state + let user_id = server + .state .rpc .read() .await .user_id_for_connection(request.sender_id)?; let channel_id = ChannelId::from_proto(request.payload.channel_id); - if !state + if !server + .state .db .can_user_access_channel(user_id, channel_id) .await? @@ -729,12 +746,14 @@ async fn join_channel( Err(anyhow!("access denied"))?; } - state + server + .state .rpc .write() .await .join_channel(request.sender_id, channel_id); - let messages = state + let messages = server + .state .db .get_recent_channel_messages(channel_id, 50) .await? @@ -746,21 +765,22 @@ async fn join_channel( sender_id: msg.sender_id.to_proto(), }) .collect(); - rpc.respond(request.receipt(), proto::JoinChannelResponse { messages }) + server + .rpc + .respond(request.receipt(), proto::JoinChannelResponse { messages }) .await?; Ok(()) } async fn send_channel_message( - request: TypedEnvelope, - peer: &Arc, - app: &Arc, + request: Box>, + server: Arc, ) -> tide::Result<()> { let channel_id = ChannelId::from_proto(request.payload.channel_id); let user_id; let connection_ids; { - let state = app.rpc.read().await; + let state = server.state.rpc.read().await; user_id = state.user_id_for_connection(request.sender_id)?; if let Some(channel) = state.channels.get(&channel_id) { connection_ids = channel.connection_ids(); @@ -770,7 +790,8 @@ async fn send_channel_message( } let timestamp = OffsetDateTime::now_utc(); - let message_id = app + let message_id = server + .state .db .create_channel_message(channel_id, user_id, &request.payload.body, timestamp) .await?; @@ -784,7 +805,7 @@ async fn send_channel_message( }), }; broadcast(request.sender_id, connection_ids, |conn_id| { - peer.send(conn_id, message.clone()) + server.rpc.send(conn_id, message.clone()) }) .await?; @@ -793,11 +814,11 @@ async fn send_channel_message( async fn broadcast_in_worktree( worktree_id: u64, - request: TypedEnvelope, - rpc: &Arc, - state: &Arc, + request: &TypedEnvelope, + server: &Arc, ) -> tide::Result<()> { - let connection_ids = state + let connection_ids = server + .state .rpc .read() .await @@ -805,7 +826,9 @@ async fn broadcast_in_worktree( .connection_ids(); broadcast(request.sender_id, connection_ids, |conn_id| { - rpc.forward_send(request.sender_id, conn_id, request.payload.clone()) + server + .rpc + .forward_send(request.sender_id, conn_id, request.payload.clone()) }) .await?;