diff --git a/server/src/rpc.rs b/server/src/rpc.rs index e2cf625857..d6bb42d256 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -1289,7 +1289,7 @@ mod tests { github, AppState, Config, }; use async_std::{sync::RwLockReadGuard, task}; - use gpui::TestAppContext; + use gpui::{ModelHandle, TestAppContext}; use parking_lot::Mutex; use postage::{mpsc, watch}; use serde_json::json; @@ -1780,24 +1780,24 @@ mod tests { // Create an org that includes these 2 users. let db = &server.app_state.db; let org_id = db.create_org("Test Org", "test-org").await.unwrap(); - db.add_org_member(org_id, current_user_id(&user_store_a), false) + db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false) .await .unwrap(); - db.add_org_member(org_id, current_user_id(&user_store_b), false) + db.add_org_member(org_id, current_user_id(&user_store_b, &cx_b), false) .await .unwrap(); // Create a channel that includes all the users. let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap(); - db.add_channel_member(channel_id, current_user_id(&user_store_a), false) + db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false) .await .unwrap(); - db.add_channel_member(channel_id, current_user_id(&user_store_b), false) + db.add_channel_member(channel_id, current_user_id(&user_store_b, &cx_b), false) .await .unwrap(); db.create_channel_message( channel_id, - current_user_id(&user_store_b), + current_user_id(&user_store_b, &cx_b), "hello A, it's B.", OffsetDateTime::now_utc(), 1, @@ -1912,10 +1912,10 @@ mod tests { let db = &server.app_state.db; let org_id = db.create_org("Test Org", "test-org").await.unwrap(); let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap(); - db.add_org_member(org_id, current_user_id(&user_store_a), false) + db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false) .await .unwrap(); - db.add_channel_member(channel_id, current_user_id(&user_store_a), false) + db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false) .await .unwrap(); @@ -1964,7 +1964,6 @@ mod tests { #[gpui::test] async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { cx_a.foreground().forbid_parking(); - let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) }); // Connect to a server as 2 clients. let mut server = TestServer::start().await; @@ -1975,24 +1974,24 @@ mod tests { // Create an org that includes these 2 users. let db = &server.app_state.db; let org_id = db.create_org("Test Org", "test-org").await.unwrap(); - db.add_org_member(org_id, current_user_id(&user_store_a), false) + db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false) .await .unwrap(); - db.add_org_member(org_id, current_user_id(&user_store_b), false) + db.add_org_member(org_id, current_user_id(&user_store_b, &cx_b), false) .await .unwrap(); // Create a channel that includes all the users. let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap(); - db.add_channel_member(channel_id, current_user_id(&user_store_a), false) + db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false) .await .unwrap(); - db.add_channel_member(channel_id, current_user_id(&user_store_b), false) + db.add_channel_member(channel_id, current_user_id(&user_store_b, &cx_b), false) .await .unwrap(); db.create_channel_message( channel_id, - current_user_id(&user_store_b), + current_user_id(&user_store_b, &cx_b), "hello A, it's B.", OffsetDateTime::now_utc(), 2, @@ -2000,8 +1999,6 @@ mod tests { .await .unwrap(); - let user_store_a = - UserStore::new(client_a.clone(), http.clone(), cx_a.background().as_ref()); let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx)); channels_a .condition(&mut cx_a, |list, _| list.available_channels().is_some()) @@ -2054,7 +2051,7 @@ mod tests { // Disconnect client B, ensuring we can still access its cached channel data. server.forbid_connections(); - server.disconnect_client(current_user_id(&user_store_b)); + server.disconnect_client(current_user_id(&user_store_b, &cx_b)); while !matches!( status_b.recv().await, Some(rpc::Status::ReconnectionError { .. }) @@ -2206,7 +2203,7 @@ mod tests { &mut self, cx: &mut TestAppContext, name: &str, - ) -> (Arc, Arc) { + ) -> (Arc, ModelHandle) { let user_id = self.app_state.db.create_user(name, false).await.unwrap(); let client_name = name.to_string(); let mut client = Client::new(); @@ -2254,8 +2251,9 @@ mod tests { .await .unwrap(); - let user_store = UserStore::new(client.clone(), http, &cx.background()); - let mut authed_user = user_store.watch_current_user(); + let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx)); + let mut authed_user = + user_store.read_with(cx, |user_store, _| user_store.watch_current_user()); while authed_user.recv().await.unwrap().is_none() {} (client, user_store) @@ -2314,8 +2312,10 @@ mod tests { } } - fn current_user_id(user_store: &Arc) -> UserId { - UserId::from_proto(user_store.current_user().unwrap().id) + fn current_user_id(user_store: &ModelHandle, cx: &TestAppContext) -> UserId { + UserId::from_proto( + user_store.read_with(cx, |user_store, _| user_store.current_user().unwrap().id), + ) } fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> { diff --git a/zed/src/channel.rs b/zed/src/channel.rs index 2eb904915c..f042a2e508 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -6,7 +6,7 @@ use crate::{ use anyhow::{anyhow, Context, Result}; use gpui::{ sum_tree::{self, Bias, SumTree}, - Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle, + AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle, }; use postage::prelude::Stream; use rand::prelude::*; @@ -26,7 +26,7 @@ pub struct ChannelList { available_channels: Option>, channels: HashMap>, rpc: Arc, - user_store: Arc, + user_store: ModelHandle, _task: Task>, } @@ -41,7 +41,7 @@ pub struct Channel { messages: SumTree, loaded_all_messages: bool, next_pending_message_id: usize, - user_store: Arc, + user_store: ModelHandle, rpc: Arc, rng: StdRng, _subscription: rpc::Subscription, @@ -87,7 +87,7 @@ impl Entity for ChannelList { impl ChannelList { pub fn new( - user_store: Arc, + user_store: ModelHandle, rpc: Arc, cx: &mut ModelContext, ) -> Self { @@ -186,7 +186,7 @@ impl Entity for Channel { impl Channel { pub fn new( details: ChannelDetails, - user_store: Arc, + user_store: ModelHandle, rpc: Arc, cx: &mut ModelContext, ) -> Self { @@ -199,7 +199,8 @@ impl Channel { cx.spawn(|channel, mut cx| { async move { let response = rpc.request(proto::JoinChannel { channel_id }).await?; - let messages = messages_from_proto(response.messages, &user_store).await?; + let messages = + messages_from_proto(response.messages, &user_store, &mut cx).await?; let loaded_all_messages = response.done; channel.update(&mut cx, |channel, cx| { @@ -241,6 +242,7 @@ impl Channel { let current_user = self .user_store + .read(cx) .current_user() .ok_or_else(|| anyhow!("current_user is not present"))?; @@ -272,6 +274,7 @@ impl Channel { let message = ChannelMessage::from_proto( response.message.ok_or_else(|| anyhow!("invalid message"))?, &user_store, + &mut cx, ) .await?; this.update(&mut cx, |this, cx| { @@ -301,7 +304,8 @@ impl Channel { }) .await?; let loaded_all_messages = response.done; - let messages = messages_from_proto(response.messages, &user_store).await?; + let messages = + messages_from_proto(response.messages, &user_store, &mut cx).await?; this.update(&mut cx, |this, cx| { this.loaded_all_messages = loaded_all_messages; this.insert_messages(messages, cx); @@ -324,7 +328,7 @@ impl Channel { cx.spawn(|this, mut cx| { async move { let response = rpc.request(proto::JoinChannel { channel_id }).await?; - let messages = messages_from_proto(response.messages, &user_store).await?; + let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?; let loaded_all_messages = response.done; let pending_messages = this.update(&mut cx, |this, cx| { @@ -359,6 +363,7 @@ impl Channel { let message = ChannelMessage::from_proto( response.message.ok_or_else(|| anyhow!("invalid message"))?, &user_store, + &mut cx, ) .await?; this.update(&mut cx, |this, cx| { @@ -413,7 +418,7 @@ impl Channel { cx.spawn(|this, mut cx| { async move { - let message = ChannelMessage::from_proto(message, &user_store).await?; + let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?; this.update(&mut cx, |this, cx| { this.insert_messages(SumTree::from_item(message, &()), cx) }); @@ -486,7 +491,8 @@ impl Channel { async fn messages_from_proto( proto_messages: Vec, - user_store: &UserStore, + user_store: &ModelHandle, + cx: &mut AsyncAppContext, ) -> Result> { let unique_user_ids = proto_messages .iter() @@ -494,11 +500,15 @@ async fn messages_from_proto( .collect::>() .into_iter() .collect(); - user_store.load_users(unique_user_ids).await?; + user_store + .update(cx, |user_store, cx| { + user_store.load_users(unique_user_ids, cx) + }) + .await?; let mut messages = Vec::with_capacity(proto_messages.len()); for message in proto_messages { - messages.push(ChannelMessage::from_proto(message, &user_store).await?); + messages.push(ChannelMessage::from_proto(message, user_store, cx).await?); } let mut result = SumTree::new(); result.extend(messages, &()); @@ -517,9 +527,14 @@ impl From for ChannelDetails { impl ChannelMessage { pub async fn from_proto( message: proto::ChannelMessage, - user_store: &UserStore, + user_store: &ModelHandle, + cx: &mut AsyncAppContext, ) -> Result { - let sender = user_store.fetch_user(message.sender_id).await?; + let sender = user_store + .update(cx, |user_store, cx| { + user_store.fetch_user(message.sender_id, cx) + }) + .await?; Ok(ChannelMessage { id: ChannelMessageId::Saved(message.id), body: message.body, @@ -595,7 +610,7 @@ mod tests { let mut client = Client::new(); let http_client = FakeHttpClient::new(|_| async move { Ok(Response::new(404)) }); let server = FakeServer::for_client(user_id, &mut client, &cx).await; - let user_store = UserStore::new(client.clone(), http_client, cx.background().as_ref()); + let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client, cx)); let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx)); channel_list.read_with(&cx, |list, _| assert_eq!(list.available_channels(), None)); diff --git a/zed/src/lib.rs b/zed/src/lib.rs index e07d04a5a4..397eed486b 100644 --- a/zed/src/lib.rs +++ b/zed/src/lib.rs @@ -28,7 +28,6 @@ use channel::ChannelList; use gpui::{action, keymap::Binding, ModelHandle}; use parking_lot::Mutex; use postage::watch; -use presence::Presence; use std::sync::Arc; pub use settings::Settings; @@ -46,10 +45,9 @@ pub struct AppState { pub languages: Arc, pub themes: Arc, pub rpc: Arc, - pub user_store: Arc, + pub user_store: ModelHandle, pub fs: Arc, pub channel_list: ModelHandle, - pub presence: ModelHandle, } pub fn init(app_state: &Arc, cx: &mut gpui::MutableAppContext) { diff --git a/zed/src/people_panel.rs b/zed/src/people_panel.rs index 7e6d675489..246a5e0491 100644 --- a/zed/src/people_panel.rs +++ b/zed/src/people_panel.rs @@ -1,16 +1,17 @@ -use crate::presence::Presence; use gpui::{ elements::Empty, Element, ElementBox, Entity, ModelHandle, RenderContext, View, ViewContext, }; +use crate::user::UserStore; + pub struct PeoplePanel { - presence: ModelHandle, + user_store: ModelHandle, } impl PeoplePanel { - pub fn new(presence: ModelHandle, cx: &mut ViewContext) -> Self { - cx.observe(&presence, |_, _, cx| cx.notify()); - Self { presence } + pub fn new(user_store: ModelHandle, cx: &mut ViewContext) -> Self { + cx.observe(&user_store, |_, _, cx| cx.notify()); + Self { user_store } } } diff --git a/zed/src/presence.rs b/zed/src/presence.rs index 356baa22b4..2cdde21e2f 100644 --- a/zed/src/presence.rs +++ b/zed/src/presence.rs @@ -106,24 +106,25 @@ impl Entity for Presence { type Event = Event; } -impl Collaborator { - async fn from_proto( - collaborator: proto::Collaborator, - user_store: &Arc, - ) -> Result { - let user = user_store.fetch_user(collaborator.user_id).await?; - let mut worktrees = Vec::new(); - for worktree in collaborator.worktrees { - let mut participants = Vec::new(); - for participant_id in worktree.participants { - participants.push(user_store.fetch_user(participant_id).await?); - } - worktrees.push(WorktreeMetadata { - root_name: worktree.root_name, - is_shared: worktree.is_shared, - participants, - }); - } - Ok(Self { user, worktrees }) - } -} +// impl Collaborator { +// async fn from_proto( +// collaborator: proto::Collaborator, +// user_store: &Arc, +// cx: &mut AsyncAppContext, +// ) -> Result { +// let user = user_store.fetch_user(collaborator.user_id).await?; +// let mut worktrees = Vec::new(); +// for worktree in collaborator.worktrees { +// let mut participants = Vec::new(); +// for participant_id in worktree.participants { +// participants.push(user_store.fetch_user(participant_id).await?); +// } +// worktrees.push(WorktreeMetadata { +// root_name: worktree.root_name, +// is_shared: worktree.is_shared, +// participants, +// }); +// } +// Ok(Self { user, worktrees }) +// } +// } diff --git a/zed/src/test.rs b/zed/src/test.rs index 77b284b245..9e02bb74db 100644 --- a/zed/src/test.rs +++ b/zed/src/test.rs @@ -169,14 +169,13 @@ pub fn test_app_state(cx: &mut MutableAppContext) -> Arc { let themes = ThemeRegistry::new(Assets, cx.font_cache().clone()); let rpc = rpc::Client::new(); let http = FakeHttpClient::new(|_| async move { Ok(ServerResponse::new(404)) }); - let user_store = UserStore::new(rpc.clone(), http, cx.background()); + let user_store = cx.add_model(|cx| UserStore::new(rpc.clone(), http, cx)); Arc::new(AppState { settings_tx: Arc::new(Mutex::new(settings_tx)), settings, themes, languages: languages.clone(), channel_list: cx.add_model(|cx| ChannelList::new(user_store.clone(), rpc.clone(), cx)), - presence: cx.add_model(|cx| Presence::new(user_store.clone(), rpc.clone(), cx)), rpc, user_store, fs: Arc::new(RealFs), diff --git a/zed/src/user.rs b/zed/src/user.rs index 54e84d756f..8a050d5164 100644 --- a/zed/src/user.rs +++ b/zed/src/user.rs @@ -5,13 +5,9 @@ use crate::{ }; use anyhow::{anyhow, Context, Result}; use futures::future; -use gpui::{executor, ImageData, Task}; -use parking_lot::Mutex; -use postage::{oneshot, prelude::Stream, sink::Sink, watch}; -use std::{ - collections::HashMap, - sync::{Arc, Weak}, -}; +use gpui::{Entity, ImageData, ModelContext, Task}; +use postage::{prelude::Stream, sink::Sink, watch}; +use std::{collections::HashMap, sync::Arc}; use zrpc::proto; #[derive(Debug)] @@ -22,41 +18,38 @@ pub struct User { } pub struct UserStore { - users: Mutex>>, + users: HashMap>, current_user: watch::Receiver>>, rpc: Arc, http: Arc, _maintain_current_user: Task<()>, } +pub enum Event {} + +impl Entity for UserStore { + type Event = Event; +} + impl UserStore { - pub fn new( - rpc: Arc, - http: Arc, - executor: &executor::Background, - ) -> Arc { + pub fn new(rpc: Arc, http: Arc, cx: &mut ModelContext) -> Self { let (mut current_user_tx, current_user_rx) = watch::channel(); - let (mut this_tx, mut this_rx) = oneshot::channel::>(); - let this = Arc::new(Self { + Self { users: Default::default(), current_user: current_user_rx, rpc: rpc.clone(), http, - _maintain_current_user: executor.spawn(async move { - let this = if let Some(this) = this_rx.recv().await { - this - } else { - return; - }; + _maintain_current_user: cx.spawn_weak(|this, mut cx| async move { let mut status = rpc.status(); while let Some(status) = status.recv().await { match status { Status::Connected { .. } => { - if let Some((this, user_id)) = this.upgrade().zip(rpc.user_id()) { - current_user_tx - .send(this.fetch_user(user_id).log_err().await) - .await - .ok(); + if let Some((this, user_id)) = this.upgrade(&cx).zip(rpc.user_id()) { + let user = this + .update(&mut cx, |this, cx| this.fetch_user(user_id, cx)) + .log_err() + .await; + current_user_tx.send(user).await.ok(); } } Status::SignedOut => { @@ -66,49 +59,60 @@ impl UserStore { } } }), - }); - let weak = Arc::downgrade(&this); - executor - .spawn(async move { this_tx.send(weak).await }) - .detach(); - this + } } - pub async fn load_users(&self, mut user_ids: Vec) -> Result<()> { - { - let users = self.users.lock(); - user_ids.retain(|id| !users.contains_key(id)); - } + pub fn load_users( + &mut self, + mut user_ids: Vec, + cx: &mut ModelContext, + ) -> Task> { + let rpc = self.rpc.clone(); + let http = self.http.clone(); + user_ids.retain(|id| !self.users.contains_key(id)); + cx.spawn_weak(|this, mut cx| async move { + if !user_ids.is_empty() { + let response = rpc.request(proto::GetUsers { user_ids }).await?; + let new_users = future::join_all( + response + .users + .into_iter() + .map(|user| User::new(user, http.as_ref())), + ) + .await; - if !user_ids.is_empty() { - let response = self.rpc.request(proto::GetUsers { user_ids }).await?; - let new_users = future::join_all( - response - .users - .into_iter() - .map(|user| User::new(user, self.http.as_ref())), - ) - .await; - let mut users = self.users.lock(); - for user in new_users { - users.insert(user.id, Arc::new(user)); + if let Some(this) = this.upgrade(&cx) { + this.update(&mut cx, |this, _| { + for user in new_users { + this.users.insert(user.id, Arc::new(user)); + } + }); + } } - } - Ok(()) + Ok(()) + }) } - pub async fn fetch_user(&self, user_id: u64) -> Result> { - if let Some(user) = self.users.lock().get(&user_id).cloned() { - return Ok(user); + pub fn fetch_user( + &mut self, + user_id: u64, + cx: &mut ModelContext, + ) -> Task>> { + if let Some(user) = self.users.get(&user_id).cloned() { + return cx.spawn_weak(|_, _| async move { Ok(user) }); } - self.load_users(vec![user_id]).await?; - self.users - .lock() - .get(&user_id) - .cloned() - .ok_or_else(|| anyhow!("server responded with no users")) + let load_users = self.load_users(vec![user_id], cx); + cx.spawn(|this, mut cx| async move { + load_users.await?; + this.update(&mut cx, |this, _| { + this.users + .get(&user_id) + .cloned() + .ok_or_else(|| anyhow!("server responded with no users")) + }) + }) } pub fn current_user(&self) -> Option> { diff --git a/zed/src/workspace.rs b/zed/src/workspace.rs index bcdbe17a54..3182fc3aab 100644 --- a/zed/src/workspace.rs +++ b/zed/src/workspace.rs @@ -333,7 +333,7 @@ pub struct Workspace { pub settings: watch::Receiver, languages: Arc, rpc: Arc, - user_store: Arc, + user_store: ModelHandle, fs: Arc, modal: Option, center: PaneGroup, @@ -381,11 +381,11 @@ impl Workspace { ); right_sidebar.add_item( "icons/user-16.svg", - cx.add_view(|cx| PeoplePanel::new(app_state.presence.clone(), cx)) + cx.add_view(|cx| PeoplePanel::new(app_state.user_store.clone(), cx)) .into(), ); - let mut current_user = app_state.user_store.watch_current_user().clone(); + let mut current_user = app_state.user_store.read(cx).watch_current_user().clone(); let mut connection_status = app_state.rpc.status().clone(); let _observe_current_user = cx.spawn_weak(|this, mut cx| async move { current_user.recv().await; @@ -965,6 +965,7 @@ impl Workspace { let theme = &self.settings.borrow().theme; let avatar = if let Some(avatar) = self .user_store + .read(cx) .current_user() .and_then(|user| user.avatar.clone()) {