Re-register message handlers in RPC server

This commit is contained in:
Antonio Scandurra 2021-08-19 15:35:03 +02:00
parent d6412fdbde
commit d398b96f56

View file

@ -35,23 +35,26 @@ use zrpc::{
type ReplicaId = u16; type ReplicaId = u16;
type Handler = Box< type MessageHandler = Box<
dyn Send dyn Send
+ Sync + Sync
+ Fn(&mut Option<Box<dyn Any + Send + Sync>>, Arc<Server>) -> Option<BoxFuture<'static, ()>>, + Fn(
&mut Option<Box<dyn Any + Send + Sync>>,
Arc<Server>,
) -> Option<BoxFuture<'static, tide::Result<()>>>,
>; >;
#[derive(Default)] #[derive(Default)]
struct ServerBuilder { struct ServerBuilder {
handlers: Vec<Handler>, handlers: Vec<MessageHandler>,
handler_types: HashSet<TypeId>, handler_types: HashSet<TypeId>,
} }
impl ServerBuilder { impl ServerBuilder {
pub fn on_message<F, Fut, M>(&mut self, handler: F) -> &mut Self pub fn on_message<F, Fut, M>(mut self, handler: F) -> Self
where where
F: 'static + Send + Sync + Fn(Box<TypedEnvelope<M>>, Arc<Server>) -> Fut, F: 'static + Send + Sync + Fn(Box<TypedEnvelope<M>>, Arc<Server>) -> Fut,
Fut: 'static + Send + Future<Output = ()>, Fut: 'static + Send + Future<Output = tide::Result<()>>,
M: EnvelopedMessage, M: EnvelopedMessage,
{ {
if self.handler_types.insert(TypeId::of::<M>()) { if self.handler_types.insert(TypeId::of::<M>()) {
@ -87,7 +90,7 @@ impl ServerBuilder {
pub struct Server { pub struct Server {
rpc: Arc<Peer>, rpc: Arc<Peer>,
state: Arc<AppState>, state: Arc<AppState>,
handlers: Vec<Handler>, handlers: Vec<MessageHandler>,
} }
impl Server { impl Server {
@ -119,10 +122,16 @@ impl Server {
futures::select_biased! { futures::select_biased! {
message = next_message => { message = next_message => {
if let Some(message) = message { if let Some(message) = message {
let start_time = Instant::now();
log::info!("RPC message received");
let mut message = Some(message); let mut message = Some(message);
for handler in &this.handlers { for handler in &this.handlers {
if let Some(future) = (handler)(&mut message, this.clone()) { if let Some(future) = (handler)(&mut message, this.clone()) {
future.await; if let Err(err) = future.await {
log::error!("error handling message: {:?}", err);
} else {
log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
}
break; break;
} }
} }
@ -336,26 +345,24 @@ impl State {
pub fn build_server(state: &Arc<AppState>, rpc: &Arc<Peer>) -> Arc<Server> { pub fn build_server(state: &Arc<AppState>, rpc: &Arc<Peer>) -> Arc<Server> {
ServerBuilder::default() ServerBuilder::default()
// .on_message(share_worktree) .on_message(share_worktree)
// .on_message(join_worktree) .on_message(join_worktree)
// .on_message(update_worktree) .on_message(update_worktree)
// .on_message(close_worktree) .on_message(close_worktree)
// .on_message(open_buffer) .on_message(open_buffer)
// .on_message(close_buffer) .on_message(close_buffer)
// .on_message(update_buffer) .on_message(update_buffer)
// .on_message(buffer_saved) .on_message(buffer_saved)
// .on_message(save_buffer) .on_message(save_buffer)
// .on_message(get_channels) .on_message(get_channels)
// .on_message(get_users) .on_message(get_users)
// .on_message(join_channel) .on_message(join_channel)
// .on_message(send_channel_message) .on_message(send_channel_message)
.build(rpc, state) .build(rpc, state)
} }
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) { pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let server = build_server(app.state(), rpc); let server = build_server(app.state(), rpc);
let rpc = rpc.clone();
app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| { app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
let user_id = request.ext::<UserId>().copied(); let user_id = request.ext::<UserId>().copied();
let server = server.clone(); let server = server.clone();
@ -399,11 +406,10 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
} }
async fn share_worktree( async fn share_worktree(
mut request: TypedEnvelope<proto::ShareWorktree>, mut request: Box<TypedEnvelope<proto::ShareWorktree>>,
rpc: &Arc<Peer>, server: Arc<Server>,
state: &Arc<AppState>,
) -> tide::Result<()> { ) -> tide::Result<()> {
let mut state = state.rpc.write().await; let mut state = server.state.rpc.write().await;
let worktree_id = state.next_worktree_id; let worktree_id = state.next_worktree_id;
state.next_worktree_id += 1; state.next_worktree_id += 1;
let access_token = random_token(); let access_token = random_token();
@ -428,7 +434,9 @@ async fn share_worktree(
}, },
); );
rpc.respond( server
.rpc
.respond(
request.receipt(), request.receipt(),
proto::ShareWorktreeResponse { proto::ShareWorktreeResponse {
worktree_id, worktree_id,
@ -440,14 +448,13 @@ async fn share_worktree(
} }
async fn join_worktree( async fn join_worktree(
request: TypedEnvelope<proto::OpenWorktree>, request: Box<TypedEnvelope<proto::OpenWorktree>>,
rpc: &Arc<Peer>, server: Arc<Server>,
state: &Arc<AppState>,
) -> tide::Result<()> { ) -> tide::Result<()> {
let worktree_id = request.payload.worktree_id; let worktree_id = request.payload.worktree_id;
let access_token = &request.payload.access_token; let access_token = &request.payload.access_token;
let mut state = state.rpc.write().await; let mut state = server.state.rpc.write().await;
if let Some((peer_replica_id, worktree)) = if let Some((peer_replica_id, worktree)) =
state.join_worktree(request.sender_id, worktree_id, access_token) state.join_worktree(request.sender_id, worktree_id, access_token)
{ {
@ -468,7 +475,7 @@ async fn join_worktree(
} }
broadcast(request.sender_id, worktree.connection_ids(), |conn_id| { broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
rpc.send( server.rpc.send(
conn_id, conn_id,
proto::AddPeer { proto::AddPeer {
worktree_id, worktree_id,
@ -480,7 +487,9 @@ async fn join_worktree(
) )
}) })
.await?; .await?;
rpc.respond( server
.rpc
.respond(
request.receipt(), request.receipt(),
proto::OpenWorktreeResponse { proto::OpenWorktreeResponse {
worktree_id, worktree_id,
@ -494,7 +503,9 @@ async fn join_worktree(
) )
.await?; .await?;
} else { } else {
rpc.respond( server
.rpc
.respond(
request.receipt(), request.receipt(),
proto::OpenWorktreeResponse { proto::OpenWorktreeResponse {
worktree_id, worktree_id,
@ -510,12 +521,11 @@ async fn join_worktree(
} }
async fn update_worktree( async fn update_worktree(
request: TypedEnvelope<proto::UpdateWorktree>, request: Box<TypedEnvelope<proto::UpdateWorktree>>,
rpc: &Arc<Peer>, server: Arc<Server>,
state: &Arc<AppState>,
) -> tide::Result<()> { ) -> tide::Result<()> {
{ {
let mut state = state.rpc.write().await; let mut state = server.state.rpc.write().await;
let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?; let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
for entry_id in &request.payload.removed_entries { for entry_id in &request.payload.removed_entries {
worktree.entries.remove(&entry_id); worktree.entries.remove(&entry_id);
@ -526,18 +536,17 @@ async fn update_worktree(
} }
} }
broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await?; broadcast_in_worktree(request.payload.worktree_id, &request, &server).await?;
Ok(()) Ok(())
} }
async fn close_worktree( async fn close_worktree(
request: TypedEnvelope<proto::CloseWorktree>, request: Box<TypedEnvelope<proto::CloseWorktree>>,
rpc: &Arc<Peer>, server: Arc<Server>,
state: &Arc<AppState>,
) -> tide::Result<()> { ) -> tide::Result<()> {
let connection_ids; let connection_ids;
{ {
let mut state = state.rpc.write().await; let mut state = server.state.rpc.write().await;
let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?; let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
connection_ids = worktree.connection_ids(); connection_ids = worktree.connection_ids();
if worktree.host_connection_id == Some(request.sender_id) { if worktree.host_connection_id == Some(request.sender_id) {
@ -548,7 +557,7 @@ async fn close_worktree(
} }
broadcast(request.sender_id, connection_ids, |conn_id| { broadcast(request.sender_id, connection_ids, |conn_id| {
rpc.send( server.rpc.send(
conn_id, conn_id,
proto::RemovePeer { proto::RemovePeer {
worktree_id: request.payload.worktree_id, worktree_id: request.payload.worktree_id,
@ -562,53 +571,55 @@ async fn close_worktree(
} }
async fn open_buffer( async fn open_buffer(
request: TypedEnvelope<proto::OpenBuffer>, request: Box<TypedEnvelope<proto::OpenBuffer>>,
rpc: &Arc<Peer>, server: Arc<Server>,
state: &Arc<AppState>,
) -> tide::Result<()> { ) -> tide::Result<()> {
let receipt = request.receipt(); let receipt = request.receipt();
let worktree_id = request.payload.worktree_id; let worktree_id = request.payload.worktree_id;
let host_connection_id = state let host_connection_id = server
.state
.rpc .rpc
.read() .read()
.await .await
.read_worktree(worktree_id, request.sender_id)? .read_worktree(worktree_id, request.sender_id)?
.host_connection_id()?; .host_connection_id()?;
let response = rpc let response = server
.rpc
.forward_request(request.sender_id, host_connection_id, request.payload) .forward_request(request.sender_id, host_connection_id, request.payload)
.await?; .await?;
rpc.respond(receipt, response).await?; server.rpc.respond(receipt, response).await?;
Ok(()) Ok(())
} }
async fn close_buffer( async fn close_buffer(
request: TypedEnvelope<proto::CloseBuffer>, request: Box<TypedEnvelope<proto::CloseBuffer>>,
rpc: &Arc<Peer>, server: Arc<Server>,
state: &Arc<AppState>,
) -> tide::Result<()> { ) -> tide::Result<()> {
let host_connection_id = state let host_connection_id = server
.state
.rpc .rpc
.read() .read()
.await .await
.read_worktree(request.payload.worktree_id, request.sender_id)? .read_worktree(request.payload.worktree_id, request.sender_id)?
.host_connection_id()?; .host_connection_id()?;
rpc.forward_send(request.sender_id, host_connection_id, request.payload) server
.rpc
.forward_send(request.sender_id, host_connection_id, request.payload)
.await?; .await?;
Ok(()) Ok(())
} }
async fn save_buffer( async fn save_buffer(
request: TypedEnvelope<proto::SaveBuffer>, request: Box<TypedEnvelope<proto::SaveBuffer>>,
rpc: &Arc<Peer>, server: Arc<Server>,
state: &Arc<AppState>,
) -> tide::Result<()> { ) -> tide::Result<()> {
let host; let host;
let guests; let guests;
{ {
let state = state.rpc.read().await; let state = server.state.rpc.read().await;
let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?; let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
host = worktree.host_connection_id()?; host = worktree.host_connection_id()?;
guests = worktree guests = worktree
@ -620,17 +631,19 @@ async fn save_buffer(
let sender = request.sender_id; let sender = request.sender_id;
let receipt = request.receipt(); let receipt = request.receipt();
let response = rpc let response = server
.rpc
.forward_request(sender, host, request.payload.clone()) .forward_request(sender, host, request.payload.clone())
.await?; .await?;
broadcast(host, guests, |conn_id| { broadcast(host, guests, |conn_id| {
let response = response.clone(); let response = response.clone();
let server = &server;
async move { async move {
if conn_id == sender { if conn_id == sender {
rpc.respond(receipt, response).await server.rpc.respond(receipt, response).await
} else { } else {
rpc.forward_send(host, conn_id, response).await server.rpc.forward_send(host, conn_id, response).await
} }
} }
}) })
@ -640,33 +653,33 @@ async fn save_buffer(
} }
async fn update_buffer( async fn update_buffer(
request: TypedEnvelope<proto::UpdateBuffer>, request: Box<TypedEnvelope<proto::UpdateBuffer>>,
rpc: &Arc<Peer>, server: Arc<Server>,
state: &Arc<AppState>,
) -> tide::Result<()> { ) -> tide::Result<()> {
broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await broadcast_in_worktree(request.payload.worktree_id, &request, &server).await
} }
async fn buffer_saved( async fn buffer_saved(
request: TypedEnvelope<proto::BufferSaved>, request: Box<TypedEnvelope<proto::BufferSaved>>,
rpc: &Arc<Peer>, server: Arc<Server>,
state: &Arc<AppState>,
) -> tide::Result<()> { ) -> tide::Result<()> {
broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await broadcast_in_worktree(request.payload.worktree_id, &request, &server).await
} }
async fn get_channels( async fn get_channels(
request: TypedEnvelope<proto::GetChannels>, request: Box<TypedEnvelope<proto::GetChannels>>,
rpc: &Arc<Peer>, server: Arc<Server>,
state: &Arc<AppState>,
) -> tide::Result<()> { ) -> tide::Result<()> {
let user_id = state let user_id = server
.state
.rpc .rpc
.read() .read()
.await .await
.user_id_for_connection(request.sender_id)?; .user_id_for_connection(request.sender_id)?;
let channels = state.db.get_channels_for_user(user_id).await?; let channels = server.state.db.get_channels_for_user(user_id).await?;
rpc.respond( server
.rpc
.respond(
request.receipt(), request.receipt(),
proto::GetChannelsResponse { proto::GetChannelsResponse {
channels: channels channels: channels
@ -683,18 +696,19 @@ async fn get_channels(
} }
async fn get_users( async fn get_users(
request: TypedEnvelope<proto::GetUsers>, request: Box<TypedEnvelope<proto::GetUsers>>,
rpc: &Arc<Peer>, server: Arc<Server>,
state: &Arc<AppState>,
) -> tide::Result<()> { ) -> tide::Result<()> {
let user_id = state let user_id = server
.state
.rpc .rpc
.read() .read()
.await .await
.user_id_for_connection(request.sender_id)?; .user_id_for_connection(request.sender_id)?;
let receipt = request.receipt(); let receipt = request.receipt();
let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto); let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
let users = state let users = server
.state
.db .db
.get_users_by_ids(user_id, user_ids) .get_users_by_ids(user_id, user_ids)
.await? .await?
@ -705,23 +719,26 @@ async fn get_users(
avatar_url: String::new(), avatar_url: String::new(),
}) })
.collect(); .collect();
rpc.respond(receipt, proto::GetUsersResponse { users }) server
.rpc
.respond(receipt, proto::GetUsersResponse { users })
.await?; .await?;
Ok(()) Ok(())
} }
async fn join_channel( async fn join_channel(
request: TypedEnvelope<proto::JoinChannel>, request: Box<TypedEnvelope<proto::JoinChannel>>,
rpc: &Arc<Peer>, server: Arc<Server>,
state: &Arc<AppState>,
) -> tide::Result<()> { ) -> tide::Result<()> {
let user_id = state let user_id = server
.state
.rpc .rpc
.read() .read()
.await .await
.user_id_for_connection(request.sender_id)?; .user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id); let channel_id = ChannelId::from_proto(request.payload.channel_id);
if !state if !server
.state
.db .db
.can_user_access_channel(user_id, channel_id) .can_user_access_channel(user_id, channel_id)
.await? .await?
@ -729,12 +746,14 @@ async fn join_channel(
Err(anyhow!("access denied"))?; Err(anyhow!("access denied"))?;
} }
state server
.state
.rpc .rpc
.write() .write()
.await .await
.join_channel(request.sender_id, channel_id); .join_channel(request.sender_id, channel_id);
let messages = state let messages = server
.state
.db .db
.get_recent_channel_messages(channel_id, 50) .get_recent_channel_messages(channel_id, 50)
.await? .await?
@ -746,21 +765,22 @@ async fn join_channel(
sender_id: msg.sender_id.to_proto(), sender_id: msg.sender_id.to_proto(),
}) })
.collect(); .collect();
rpc.respond(request.receipt(), proto::JoinChannelResponse { messages }) server
.rpc
.respond(request.receipt(), proto::JoinChannelResponse { messages })
.await?; .await?;
Ok(()) Ok(())
} }
async fn send_channel_message( async fn send_channel_message(
request: TypedEnvelope<proto::SendChannelMessage>, request: Box<TypedEnvelope<proto::SendChannelMessage>>,
peer: &Arc<Peer>, server: Arc<Server>,
app: &Arc<AppState>,
) -> tide::Result<()> { ) -> tide::Result<()> {
let channel_id = ChannelId::from_proto(request.payload.channel_id); let channel_id = ChannelId::from_proto(request.payload.channel_id);
let user_id; let user_id;
let connection_ids; let connection_ids;
{ {
let state = app.rpc.read().await; let state = server.state.rpc.read().await;
user_id = state.user_id_for_connection(request.sender_id)?; user_id = state.user_id_for_connection(request.sender_id)?;
if let Some(channel) = state.channels.get(&channel_id) { if let Some(channel) = state.channels.get(&channel_id) {
connection_ids = channel.connection_ids(); connection_ids = channel.connection_ids();
@ -770,7 +790,8 @@ async fn send_channel_message(
} }
let timestamp = OffsetDateTime::now_utc(); let timestamp = OffsetDateTime::now_utc();
let message_id = app let message_id = server
.state
.db .db
.create_channel_message(channel_id, user_id, &request.payload.body, timestamp) .create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
.await?; .await?;
@ -784,7 +805,7 @@ async fn send_channel_message(
}), }),
}; };
broadcast(request.sender_id, connection_ids, |conn_id| { broadcast(request.sender_id, connection_ids, |conn_id| {
peer.send(conn_id, message.clone()) server.rpc.send(conn_id, message.clone())
}) })
.await?; .await?;
@ -793,11 +814,11 @@ async fn send_channel_message(
async fn broadcast_in_worktree<T: proto::EnvelopedMessage>( async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
worktree_id: u64, worktree_id: u64,
request: TypedEnvelope<T>, request: &TypedEnvelope<T>,
rpc: &Arc<Peer>, server: &Arc<Server>,
state: &Arc<AppState>,
) -> tide::Result<()> { ) -> tide::Result<()> {
let connection_ids = state let connection_ids = server
.state
.rpc .rpc
.read() .read()
.await .await
@ -805,7 +826,9 @@ async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
.connection_ids(); .connection_ids();
broadcast(request.sender_id, connection_ids, |conn_id| { broadcast(request.sender_id, connection_ids, |conn_id| {
rpc.forward_send(request.sender_id, conn_id, request.payload.clone()) server
.rpc
.forward_send(request.sender_id, conn_id, request.payload.clone())
}) })
.await?; .await?;