diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index ac7bbcff87..bfa9731739 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -7,7 +7,7 @@ edition = "2018" path = "src/client.rs" [features] -test-support = ["rpc/test-support"] +test-support = ["gpui/test-support", "rpc/test-support"] [dependencies] gpui = { path = "../gpui" } @@ -29,3 +29,7 @@ surf = "2.2" thiserror = "1.0.29" time = "0.3" tiny_http = "0.8" + +[dev-dependencies] +gpui = { path = "../gpui", features = ["test-support"] } +rpc = { path = "../rpc", features = ["test-support"] } diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index a8356bcea0..47f9aeb8e2 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -125,7 +125,7 @@ struct ClientState { entity_id_extractors: HashMap u64>>, model_handlers: HashMap< (TypeId, u64), - Box, &mut AsyncAppContext)>, + Option, &mut AsyncAppContext)>>, >, _maintain_connection: Option>, heartbeat_interval: Duration, @@ -158,14 +158,9 @@ pub struct Subscription { impl Drop for Subscription { fn drop(&mut self) { if let Some(client) = self.client.upgrade() { - drop( - client - .state - .write() - .model_handlers - .remove(&self.id) - .unwrap(), - ); + let mut state = client.state.write(); + let _ = state.entity_id_extractors.remove(&self.id.0).unwrap(); + let _ = state.model_handlers.remove(&self.id).unwrap(); } } } @@ -285,7 +280,7 @@ impl Client { state.model_handlers.insert( subscription_id, - Box::new(move |envelope, cx| { + Some(Box::new(move |envelope, cx| { if let Some(model) = model.upgrade(cx) { let envelope = envelope.into_any().downcast::>().unwrap(); model.update(cx, |model, cx| { @@ -294,7 +289,7 @@ impl Client { } }); } - }), + })), ); Subscription { @@ -335,7 +330,7 @@ impl Client { }); let prev_handler = state.model_handlers.insert( subscription_id, - Box::new(move |envelope, cx| { + Some(Box::new(move |envelope, cx| { if let Some(model) = model.upgrade(cx) { let envelope = envelope.into_any().downcast::>().unwrap(); model.update(cx, |model, cx| { @@ -344,7 +339,7 @@ impl Client { } }); } - }), + })), ); if prev_handler.is_some() { panic!("registered a handler for the same entity twice") @@ -450,7 +445,8 @@ impl Client { let payload_type_id = message.payload_type_id(); let entity_id = (extract_entity_id)(message.as_ref()); let handler_key = (payload_type_id, entity_id); - if let Some(mut handler) = state.model_handlers.remove(&handler_key) { + if let Some(handler) = state.model_handlers.get_mut(&handler_key) { + let mut handler = handler.take().unwrap(); drop(state); // Avoid deadlocks if the handler interacts with rpc::Client let start_time = Instant::now(); log::info!("RPC client message {}", message.payload_type_name()); @@ -459,10 +455,11 @@ impl Client { "RPC message handled. duration:{:?}", start_time.elapsed() ); - this.state - .write() - .model_handlers - .insert(handler_key, handler); + + let mut state = this.state.write(); + if state.model_handlers.contains_key(&handler_key) { + state.model_handlers.insert(handler_key, Some(handler)); + } } else { log::info!("unhandled message {}", message.payload_type_name()); } @@ -813,4 +810,64 @@ mod tests { ); assert_eq!(decode_worktree_url("not://the-right-format"), None); } + + #[gpui::test] + async fn test_subscribing_after_dropping_subscription(mut cx: TestAppContext) { + cx.foreground().forbid_parking(); + + let user_id = 5; + let mut client = Client::new(FakeHttpClient::with_404_response()); + let server = FakeServer::for_client(user_id, &mut client, &cx).await; + + let model = cx.add_model(|_| Model { subscription: None }); + let (mut done_tx1, _done_rx1) = postage::oneshot::channel(); + let (mut done_tx2, mut done_rx2) = postage::oneshot::channel(); + let subscription1 = model.update(&mut cx, |_, cx| { + client.subscribe(cx, move |_, _: TypedEnvelope, _, _| { + postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap(); + Ok(()) + }) + }); + drop(subscription1); + let _subscription2 = model.update(&mut cx, |_, cx| { + client.subscribe(cx, move |_, _: TypedEnvelope, _, _| { + postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap(); + Ok(()) + }) + }); + server.send(proto::Ping {}).await; + done_rx2.recv().await.unwrap(); + } + + #[gpui::test] + async fn test_dropping_subscription_in_handler(mut cx: TestAppContext) { + cx.foreground().forbid_parking(); + + let user_id = 5; + let mut client = Client::new(FakeHttpClient::with_404_response()); + let server = FakeServer::for_client(user_id, &mut client, &cx).await; + + let model = cx.add_model(|_| Model { subscription: None }); + let (mut done_tx, mut done_rx) = postage::oneshot::channel(); + model.update(&mut cx, |model, cx| { + model.subscription = Some(client.subscribe( + cx, + move |model, _: TypedEnvelope, _, _| { + model.subscription.take(); + postage::sink::Sink::try_send(&mut done_tx, ()).unwrap(); + Ok(()) + }, + )); + }); + server.send(proto::Ping {}).await; + done_rx.recv().await.unwrap(); + } + + struct Model { + subscription: Option, + } + + impl Entity for Model { + type Event = (); + } }