#[cfg(any(test, feature = "test-support"))] pub mod test; pub mod channel; pub mod http; pub mod user; use anyhow::{anyhow, Context, Result}; use async_recursion::async_recursion; use async_tungstenite::tungstenite::{ error::Error as WebsocketError, http::{Request, StatusCode}, }; use futures::{future::LocalBoxFuture, FutureExt, SinkExt, StreamExt, TryStreamExt}; use gpui::{ actions, AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, View, ViewContext, ViewHandle, }; use http::HttpClient; use lazy_static::lazy_static; use parking_lot::RwLock; use postage::watch; use rand::prelude::*; use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage}; use std::{ any::TypeId, collections::HashMap, convert::TryFrom, fmt::Write as _, future::Future, sync::{ atomic::{AtomicUsize, Ordering}, Arc, Weak, }, time::{Duration, Instant}, }; use thiserror::Error; use url::Url; use util::{ResultExt, TryFutureExt}; pub use channel::*; pub use rpc::*; pub use user::*; lazy_static! { pub static ref ZED_SERVER_URL: String = std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string()); pub static ref IMPERSONATE_LOGIN: Option = std::env::var("ZED_IMPERSONATE") .ok() .and_then(|s| if s.is_empty() { None } else { Some(s) }); } pub const ZED_SECRET_CLIENT_TOKEN: &'static str = "618033988749894"; actions!(client, [Authenticate]); pub fn init(rpc: Arc, cx: &mut MutableAppContext) { cx.add_global_action(move |_: &Authenticate, cx| { let rpc = rpc.clone(); cx.spawn(|cx| async move { rpc.authenticate_and_connect(true, &cx).log_err().await }) .detach(); }); } pub struct Client { id: usize, peer: Arc, http: Arc, state: RwLock, authenticate: Option Task>>>, establish_connection: Option< Box< dyn 'static + Send + Sync + Fn( &Credentials, &AsyncAppContext, ) -> Task>, >, >, } #[derive(Error, Debug)] pub enum EstablishConnectionError { #[error("upgrade required")] UpgradeRequired, #[error("unauthorized")] Unauthorized, #[error("{0}")] Other(#[from] anyhow::Error), #[error("{0}")] Http(#[from] http::Error), #[error("{0}")] Io(#[from] std::io::Error), #[error("{0}")] Websocket(#[from] async_tungstenite::tungstenite::http::Error), } impl From for EstablishConnectionError { fn from(error: WebsocketError) -> Self { if let WebsocketError::Http(response) = &error { match response.status() { StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized, StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired, _ => {} } } EstablishConnectionError::Other(error.into()) } } impl EstablishConnectionError { pub fn other(error: impl Into + Send + Sync) -> Self { Self::Other(error.into()) } } #[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum Status { SignedOut, UpgradeRequired, Authenticating, Connecting, ConnectionError, Connected { connection_id: ConnectionId }, ConnectionLost, Reauthenticating, Reconnecting, ReconnectionError { next_reconnection: Instant }, } impl Status { pub fn is_connected(&self) -> bool { matches!(self, Self::Connected { .. }) } } struct ClientState { credentials: Option, status: (watch::Sender, watch::Receiver), entity_id_extractors: HashMap u64>, _reconnect_task: Option>, reconnect_interval: Duration, entities_by_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakEntityHandle>, models_by_message_type: HashMap, entity_types_by_message_type: HashMap, message_handlers: HashMap< TypeId, Arc< dyn Send + Sync + Fn( AnyEntityHandle, Box, &Arc, AsyncAppContext, ) -> LocalBoxFuture<'static, Result<()>>, >, >, } enum AnyWeakEntityHandle { Model(AnyWeakModelHandle), View(AnyWeakViewHandle), } enum AnyEntityHandle { Model(AnyModelHandle), View(AnyViewHandle), } #[derive(Clone, Debug)] pub struct Credentials { pub user_id: u64, pub access_token: String, } impl Default for ClientState { fn default() -> Self { Self { credentials: None, status: watch::channel_with(Status::SignedOut), entity_id_extractors: Default::default(), _reconnect_task: None, reconnect_interval: Duration::from_secs(5), models_by_message_type: Default::default(), entities_by_type_and_remote_id: Default::default(), entity_types_by_message_type: Default::default(), message_handlers: Default::default(), } } } pub enum Subscription { Entity { client: Weak, id: (TypeId, u64), }, Message { client: Weak, id: TypeId, }, } impl Drop for Subscription { fn drop(&mut self) { match self { Subscription::Entity { client, id } => { if let Some(client) = client.upgrade() { let mut state = client.state.write(); let _ = state.entities_by_type_and_remote_id.remove(id); } } Subscription::Message { client, id } => { if let Some(client) = client.upgrade() { let mut state = client.state.write(); let _ = state.entity_types_by_message_type.remove(id); let _ = state.message_handlers.remove(id); } } } } } impl Client { pub fn new(http: Arc) -> Arc { lazy_static! { static ref NEXT_CLIENT_ID: AtomicUsize = AtomicUsize::default(); } Arc::new(Self { id: NEXT_CLIENT_ID.fetch_add(1, Ordering::SeqCst), peer: Peer::new(), http, state: Default::default(), authenticate: None, establish_connection: None, }) } pub fn id(&self) -> usize { self.id } pub fn http_client(&self) -> Arc { self.http.clone() } #[cfg(any(test, feature = "test-support"))] pub fn tear_down(&self) { let mut state = self.state.write(); state._reconnect_task.take(); state.message_handlers.clear(); state.models_by_message_type.clear(); state.entities_by_type_and_remote_id.clear(); state.entity_id_extractors.clear(); self.peer.reset(); } #[cfg(any(test, feature = "test-support"))] pub fn override_authenticate(&mut self, authenticate: F) -> &mut Self where F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task>, { self.authenticate = Some(Box::new(authenticate)); self } #[cfg(any(test, feature = "test-support"))] pub fn override_establish_connection(&mut self, connect: F) -> &mut Self where F: 'static + Send + Sync + Fn(&Credentials, &AsyncAppContext) -> Task>, { self.establish_connection = Some(Box::new(connect)); self } pub fn user_id(&self) -> Option { self.state .read() .credentials .as_ref() .map(|credentials| credentials.user_id) } pub fn status(&self) -> watch::Receiver { self.state.read().status.1.clone() } fn set_status(self: &Arc, status: Status, cx: &AsyncAppContext) { log::info!("set status on client {}: {:?}", self.id, status); let mut state = self.state.write(); *state.status.0.borrow_mut() = status; match status { Status::Connected { .. } => { state._reconnect_task = None; } Status::ConnectionLost => { let this = self.clone(); let reconnect_interval = state.reconnect_interval; state._reconnect_task = Some(cx.spawn(|cx| async move { let mut rng = StdRng::from_entropy(); let mut delay = Duration::from_millis(100); while let Err(error) = this.authenticate_and_connect(true, &cx).await { log::error!("failed to connect {}", error); this.set_status( Status::ReconnectionError { next_reconnection: Instant::now() + delay, }, &cx, ); cx.background().timer(delay).await; delay = delay .mul_f32(rng.gen_range(1.0..=2.0)) .min(reconnect_interval); } })); } Status::SignedOut | Status::UpgradeRequired => { state._reconnect_task.take(); } _ => {} } } pub fn add_view_for_remote_entity( self: &Arc, remote_id: u64, cx: &mut ViewContext, ) -> Subscription { let id = (TypeId::of::(), remote_id); self.state .write() .entities_by_type_and_remote_id .insert(id, AnyWeakEntityHandle::View(cx.weak_handle().into())); Subscription::Entity { client: Arc::downgrade(self), id, } } pub fn add_model_for_remote_entity( self: &Arc, remote_id: u64, cx: &mut ModelContext, ) -> Subscription { let id = (TypeId::of::(), remote_id); self.state .write() .entities_by_type_and_remote_id .insert(id, AnyWeakEntityHandle::Model(cx.weak_handle().into())); Subscription::Entity { client: Arc::downgrade(self), id, } } pub fn add_message_handler( self: &Arc, model: ModelHandle, handler: H, ) -> Subscription where M: EnvelopedMessage, E: Entity, H: 'static + Send + Sync + Fn(ModelHandle, TypedEnvelope, Arc, AsyncAppContext) -> F, F: 'static + Future>, { let message_type_id = TypeId::of::(); let mut state = self.state.write(); state .models_by_message_type .insert(message_type_id, model.downgrade().into()); let prev_handler = state.message_handlers.insert( message_type_id, Arc::new(move |handle, envelope, client, cx| { let handle = if let AnyEntityHandle::Model(handle) = handle { handle } else { unreachable!(); }; let model = handle.downcast::().unwrap(); let envelope = envelope.into_any().downcast::>().unwrap(); handler(model, *envelope, client.clone(), cx).boxed_local() }), ); if prev_handler.is_some() { panic!("registered handler for the same message twice"); } Subscription::Message { client: Arc::downgrade(self), id: message_type_id, } } pub fn add_view_message_handler(self: &Arc, handler: H) where M: EntityMessage, E: View, H: 'static + Send + Sync + Fn(ViewHandle, TypedEnvelope, Arc, AsyncAppContext) -> F, F: 'static + Future>, { self.add_entity_message_handler::(move |handle, message, client, cx| { if let AnyEntityHandle::View(handle) = handle { handler(handle.downcast::().unwrap(), message, client, cx) } else { unreachable!(); } }) } pub fn add_model_message_handler(self: &Arc, handler: H) where M: EntityMessage, E: Entity, H: 'static + Send + Sync + Fn(ModelHandle, TypedEnvelope, Arc, AsyncAppContext) -> F, F: 'static + Future>, { self.add_entity_message_handler::(move |handle, message, client, cx| { if let AnyEntityHandle::Model(handle) = handle { handler(handle.downcast::().unwrap(), message, client, cx) } else { unreachable!(); } }) } fn add_entity_message_handler(self: &Arc, handler: H) where M: EntityMessage, E: Entity, H: 'static + Send + Sync + Fn(AnyEntityHandle, TypedEnvelope, Arc, AsyncAppContext) -> F, F: 'static + Future>, { let model_type_id = TypeId::of::(); let message_type_id = TypeId::of::(); let mut state = self.state.write(); state .entity_types_by_message_type .insert(message_type_id, model_type_id); state .entity_id_extractors .entry(message_type_id) .or_insert_with(|| { |envelope| { envelope .as_any() .downcast_ref::>() .unwrap() .payload .remote_entity_id() } }); let prev_handler = state.message_handlers.insert( message_type_id, Arc::new(move |handle, envelope, client, cx| { let envelope = envelope.into_any().downcast::>().unwrap(); handler(handle, *envelope, client.clone(), cx).boxed_local() }), ); if prev_handler.is_some() { panic!("registered handler for the same message twice"); } } pub fn add_model_request_handler(self: &Arc, handler: H) where M: EntityMessage + RequestMessage, E: Entity, H: 'static + Send + Sync + Fn(ModelHandle, TypedEnvelope, Arc, AsyncAppContext) -> F, F: 'static + Future>, { self.add_model_message_handler(move |entity, envelope, client, cx| { Self::respond_to_request::( envelope.receipt(), handler(entity, envelope, client.clone(), cx), client, ) }) } pub fn add_view_request_handler(self: &Arc, handler: H) where M: EntityMessage + RequestMessage, E: View, H: 'static + Send + Sync + Fn(ViewHandle, TypedEnvelope, Arc, AsyncAppContext) -> F, F: 'static + Future>, { self.add_view_message_handler(move |entity, envelope, client, cx| { Self::respond_to_request::( envelope.receipt(), handler(entity, envelope, client.clone(), cx), client, ) }) } async fn respond_to_request>>( receipt: Receipt, response: F, client: Arc, ) -> Result<()> { match response.await { Ok(response) => { client.respond(receipt, response)?; Ok(()) } Err(error) => { client.respond_with_error( receipt, proto::Error { message: error.to_string(), }, )?; Err(error) } } } pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool { read_credentials_from_keychain(cx).is_some() } #[async_recursion(?Send)] pub async fn authenticate_and_connect( self: &Arc, try_keychain: bool, cx: &AsyncAppContext, ) -> anyhow::Result<()> { let was_disconnected = match *self.status().borrow() { Status::SignedOut => true, Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => { false } Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } | Status::Authenticating | Status::Reauthenticating => return Ok(()), Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?, }; if was_disconnected { self.set_status(Status::Authenticating, cx); } else { self.set_status(Status::Reauthenticating, cx) } let mut read_from_keychain = false; let mut credentials = self.state.read().credentials.clone(); if credentials.is_none() && try_keychain { credentials = read_credentials_from_keychain(cx); read_from_keychain = credentials.is_some(); } if credentials.is_none() { credentials = Some(match self.authenticate(&cx).await { Ok(credentials) => credentials, Err(err) => { self.set_status(Status::ConnectionError, cx); return Err(err); } }); } let credentials = credentials.unwrap(); if was_disconnected { self.set_status(Status::Connecting, cx); } else { self.set_status(Status::Reconnecting, cx); } match self.establish_connection(&credentials, cx).await { Ok(conn) => { self.state.write().credentials = Some(credentials.clone()); if !read_from_keychain && IMPERSONATE_LOGIN.is_none() { write_credentials_to_keychain(&credentials, cx).log_err(); } self.set_connection(conn, cx).await; Ok(()) } Err(EstablishConnectionError::Unauthorized) => { self.state.write().credentials.take(); if read_from_keychain { cx.platform().delete_credentials(&ZED_SERVER_URL).log_err(); self.set_status(Status::SignedOut, cx); self.authenticate_and_connect(false, cx).await } else { self.set_status(Status::ConnectionError, cx); Err(EstablishConnectionError::Unauthorized)? } } Err(EstablishConnectionError::UpgradeRequired) => { self.set_status(Status::UpgradeRequired, cx); Err(EstablishConnectionError::UpgradeRequired)? } Err(error) => { self.set_status(Status::ConnectionError, cx); Err(error)? } } } async fn set_connection(self: &Arc, conn: Connection, cx: &AsyncAppContext) { let executor = cx.background(); log::info!("add connection to peer"); let (connection_id, handle_io, mut incoming) = self .peer .add_connection(conn, move |duration| executor.timer(duration)) .await; log::info!("set status to connected {}", connection_id); self.set_status(Status::Connected { connection_id }, cx); cx.foreground() .spawn({ let cx = cx.clone(); let this = self.clone(); async move { let mut message_id = 0_usize; while let Some(message) = incoming.next().await { let mut state = this.state.write(); message_id += 1; let type_name = message.payload_type_name(); let payload_type_id = message.payload_type_id(); let sender_id = message.original_sender_id().map(|id| id.0); let model = state .models_by_message_type .get(&payload_type_id) .and_then(|model| model.upgrade(&cx)) .map(AnyEntityHandle::Model) .or_else(|| { let entity_type_id = *state.entity_types_by_message_type.get(&payload_type_id)?; let entity_id = state .entity_id_extractors .get(&message.payload_type_id()) .map(|extract_entity_id| { (extract_entity_id)(message.as_ref()) })?; let entity = state .entities_by_type_and_remote_id .get(&(entity_type_id, entity_id))?; if let Some(entity) = entity.upgrade(&cx) { Some(entity) } else { state .entities_by_type_and_remote_id .remove(&(entity_type_id, entity_id)); None } }); let model = if let Some(model) = model { model } else { log::info!("unhandled message {}", type_name); continue; }; if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned() { drop(state); // Avoid deadlocks if the handler interacts with rpc::Client let future = handler(model, message, &this, cx.clone()); let client_id = this.id; log::debug!( "rpc message received. client_id:{}, message_id:{}, sender_id:{:?}, type:{}", client_id, message_id, sender_id, type_name ); cx.foreground() .spawn(async move { match future.await { Ok(()) => { log::debug!( "rpc message handled. client_id:{}, message_id:{}, sender_id:{:?}, type:{}", client_id, message_id, sender_id, type_name ); } Err(error) => { log::error!( "error handling message. client_id:{}, message_id:{}, sender_id:{:?}, type:{}, error:{:?}", client_id, message_id, sender_id, type_name, error ); } } }) .detach(); } else { log::info!("unhandled message {}", type_name); } // Don't starve the main thread when receiving lots of messages at once. smol::future::yield_now().await; } } }) .detach(); let handle_io = cx.background().spawn(handle_io); let this = self.clone(); let cx = cx.clone(); cx.foreground() .spawn(async move { match handle_io.await { Ok(()) => { if *this.status().borrow() == (Status::Connected { connection_id }) { this.set_status(Status::SignedOut, &cx); } } Err(err) => { log::error!("connection error: {:?}", err); this.set_status(Status::ConnectionLost, &cx); } } }) .detach(); } fn authenticate(self: &Arc, cx: &AsyncAppContext) -> Task> { if let Some(callback) = self.authenticate.as_ref() { callback(cx) } else { self.authenticate_with_browser(cx) } } fn establish_connection( self: &Arc, credentials: &Credentials, cx: &AsyncAppContext, ) -> Task> { if let Some(callback) = self.establish_connection.as_ref() { callback(credentials, cx) } else { self.establish_websocket_connection(credentials, cx) } } fn establish_websocket_connection( self: &Arc, credentials: &Credentials, cx: &AsyncAppContext, ) -> Task> { let request = Request::builder() .header( "Authorization", format!("{} {}", credentials.user_id, credentials.access_token), ) .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION); let http = self.http.clone(); cx.background().spawn(async move { let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL); let rpc_response = http.get(&rpc_url, Default::default(), false).await?; if rpc_response.status().is_redirection() { rpc_url = rpc_response .headers() .get("Location") .ok_or_else(|| anyhow!("missing location header in /rpc response"))? .to_str() .map_err(|error| EstablishConnectionError::other(error))? .to_string(); } // Until we switch the zed.dev domain to point to the new Next.js app, there // will be no redirect required, and the app will connect directly to // wss://zed.dev/rpc. else if rpc_response.status() != StatusCode::UPGRADE_REQUIRED { Err(anyhow!( "unexpected /rpc response status {}", rpc_response.status() ))? } let mut rpc_url = Url::parse(&rpc_url).context("invalid rpc url")?; let rpc_host = rpc_url .host_str() .zip(rpc_url.port_or_known_default()) .ok_or_else(|| anyhow!("missing host in rpc url"))?; let stream = smol::net::TcpStream::connect(rpc_host).await?; log::info!("connected to rpc endpoint {}", rpc_url); match rpc_url.scheme() { "https" => { rpc_url.set_scheme("wss").unwrap(); let request = request.uri(rpc_url.as_str()).body(())?; let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream).await?; Ok(Connection::new( stream .map_err(|error| anyhow!(error)) .sink_map_err(|error| anyhow!(error)), )) } "http" => { rpc_url.set_scheme("ws").unwrap(); let request = request.uri(rpc_url.as_str()).body(())?; let (stream, _) = async_tungstenite::client_async(request, stream).await?; Ok(Connection::new( stream .map_err(|error| anyhow!(error)) .sink_map_err(|error| anyhow!(error)), )) } _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?, } }) } pub fn authenticate_with_browser( self: &Arc, cx: &AsyncAppContext, ) -> Task> { let platform = cx.platform(); let executor = cx.background(); executor.clone().spawn(async move { // Generate a pair of asymmetric encryption keys. The public key will be used by the // zed server to encrypt the user's access token, so that it can'be intercepted by // any other app running on the user's device. let (public_key, private_key) = rpc::auth::keypair().expect("failed to generate keypair for auth"); let public_key_string = String::try_from(public_key).expect("failed to serialize public key for auth"); // Start an HTTP server to receive the redirect from Zed's sign-in page. let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port"); let port = server.server_addr().port(); // Open the Zed sign-in page in the user's browser, with query parameters that indicate // that the user is signing in from a Zed app running on the same device. let mut url = format!( "{}/native_app_signin?native_app_port={}&native_app_public_key={}", *ZED_SERVER_URL, port, public_key_string ); if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() { log::info!("impersonating user @{}", impersonate_login); write!(&mut url, "&impersonate={}", impersonate_login).unwrap(); } platform.open_url(&url); // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted // access token from the query params. // // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a // custom URL scheme instead of this local HTTP server. let (user_id, access_token) = executor .spawn(async move { if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? { let path = req.url(); let mut user_id = None; let mut access_token = None; let url = Url::parse(&format!("http://example.com{}", path)) .context("failed to parse login notification url")?; for (key, value) in url.query_pairs() { if key == "access_token" { access_token = Some(value.to_string()); } else if key == "user_id" { user_id = Some(value.to_string()); } } let post_auth_url = format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL); req.respond( tiny_http::Response::empty(302).with_header( tiny_http::Header::from_bytes( &b"Location"[..], post_auth_url.as_bytes(), ) .unwrap(), ), ) .context("failed to respond to login http request")?; Ok(( user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?, access_token .ok_or_else(|| anyhow!("missing access_token parameter"))?, )) } else { Err(anyhow!("didn't receive login redirect")) } }) .await?; let access_token = private_key .decrypt_string(&access_token) .context("failed to decrypt access token")?; platform.activate(true); Ok(Credentials { user_id: user_id.parse()?, access_token, }) }) } pub fn disconnect(self: &Arc, cx: &AsyncAppContext) -> Result<()> { let conn_id = self.connection_id()?; self.peer.disconnect(conn_id); self.set_status(Status::SignedOut, cx); Ok(()) } fn connection_id(&self) -> Result { if let Status::Connected { connection_id, .. } = *self.status().borrow() { Ok(connection_id) } else { Err(anyhow!("not connected")) } } pub fn send(&self, message: T) -> Result<()> { log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME); self.peer.send(self.connection_id()?, message) } pub fn request( &self, request: T, ) -> impl Future> { let client_id = self.id; log::debug!( "rpc request start. client_id:{}. name:{}", client_id, T::NAME ); let response = self .connection_id() .map(|conn_id| self.peer.request(conn_id, request)); async move { let response = response?.await; log::debug!( "rpc request finish. client_id:{}. name:{}", client_id, T::NAME ); response } } fn respond(&self, receipt: Receipt, response: T::Response) -> Result<()> { log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME); self.peer.respond(receipt, response) } fn respond_with_error( &self, receipt: Receipt, error: proto::Error, ) -> Result<()> { log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME); self.peer.respond_with_error(receipt, error) } } impl AnyWeakEntityHandle { fn upgrade(&self, cx: &AsyncAppContext) -> Option { match self { AnyWeakEntityHandle::Model(handle) => handle.upgrade(cx).map(AnyEntityHandle::Model), AnyWeakEntityHandle::View(handle) => handle.upgrade(cx).map(AnyEntityHandle::View), } } } fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option { if IMPERSONATE_LOGIN.is_some() { return None; } let (user_id, access_token) = cx .platform() .read_credentials(&ZED_SERVER_URL) .log_err() .flatten()?; Some(Credentials { user_id: user_id.parse().ok()?, access_token: String::from_utf8(access_token).ok()?, }) } fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> { cx.platform().write_credentials( &ZED_SERVER_URL, &credentials.user_id.to_string(), credentials.access_token.as_bytes(), ) } const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/"; pub fn encode_worktree_url(id: u64, access_token: &str) -> String { format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token) } pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> { let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?; let mut parts = path.split('/'); let id = parts.next()?.parse::().ok()?; let access_token = parts.next()?; if access_token.is_empty() { return None; } Some((id, access_token.to_string())) } #[cfg(test)] mod tests { use super::*; use crate::test::{FakeHttpClient, FakeServer}; use gpui::TestAppContext; #[gpui::test(iterations = 10)] async fn test_reconnection(cx: &mut TestAppContext) { cx.foreground().forbid_parking(); let user_id = 5; let mut client = Client::new(FakeHttpClient::with_404_response()); let server = FakeServer::for_client(user_id, &mut client, &cx).await; let mut status = client.status(); assert!(matches!( status.next().await, Some(Status::Connected { .. }) )); assert_eq!(server.auth_count(), 1); server.forbid_connections(); server.disconnect(); while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {} server.allow_connections(); cx.foreground().advance_clock(Duration::from_secs(10)); while !matches!(status.next().await, Some(Status::Connected { .. })) {} assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting server.forbid_connections(); server.disconnect(); while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {} // Clear cached credentials after authentication fails server.roll_access_token(); server.allow_connections(); cx.foreground().advance_clock(Duration::from_secs(10)); while !matches!(status.next().await, Some(Status::Connected { .. })) {} assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token } #[test] fn test_encode_and_decode_worktree_url() { let url = encode_worktree_url(5, "deadbeef"); assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string()))); assert_eq!( decode_worktree_url(&format!("\n {}\t", url)), Some((5, "deadbeef".to_string())) ); assert_eq!(decode_worktree_url("not://the-right-format"), None); } #[gpui::test] async fn test_subscribing_to_entity(cx: &mut TestAppContext) { cx.foreground().forbid_parking(); let user_id = 5; let mut client = Client::new(FakeHttpClient::with_404_response()); let server = FakeServer::for_client(user_id, &mut client, &cx).await; let (done_tx1, mut done_rx1) = smol::channel::unbounded(); let (done_tx2, mut done_rx2) = smol::channel::unbounded(); client.add_model_message_handler( move |model: ModelHandle, _: TypedEnvelope, _, cx| { match model.read_with(&cx, |model, _| model.id) { 1 => done_tx1.try_send(()).unwrap(), 2 => done_tx2.try_send(()).unwrap(), _ => unreachable!(), } async { Ok(()) } }, ); let model1 = cx.add_model(|_| Model { id: 1, subscription: None, }); let model2 = cx.add_model(|_| Model { id: 2, subscription: None, }); let model3 = cx.add_model(|_| Model { id: 3, subscription: None, }); let _subscription1 = model1.update(cx, |_, cx| client.add_model_for_remote_entity(1, cx)); let _subscription2 = model2.update(cx, |_, cx| client.add_model_for_remote_entity(2, cx)); // Ensure dropping a subscription for the same entity type still allows receiving of // messages for other entity IDs of the same type. let subscription3 = model3.update(cx, |_, cx| client.add_model_for_remote_entity(3, cx)); drop(subscription3); server.send(proto::JoinProject { project_id: 1 }); server.send(proto::JoinProject { project_id: 2 }); done_rx1.next().await.unwrap(); done_rx2.next().await.unwrap(); } #[gpui::test] async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) { cx.foreground().forbid_parking(); let user_id = 5; let mut client = Client::new(FakeHttpClient::with_404_response()); let server = FakeServer::for_client(user_id, &mut client, &cx).await; let model = cx.add_model(|_| Model::default()); let (done_tx1, _done_rx1) = smol::channel::unbounded(); let (done_tx2, mut done_rx2) = smol::channel::unbounded(); let subscription1 = client.add_message_handler( model.clone(), move |_, _: TypedEnvelope, _, _| { done_tx1.try_send(()).unwrap(); async { Ok(()) } }, ); drop(subscription1); let _subscription2 = client.add_message_handler(model, move |_, _: TypedEnvelope, _, _| { done_tx2.try_send(()).unwrap(); async { Ok(()) } }); server.send(proto::Ping {}); done_rx2.next().await.unwrap(); } #[gpui::test] async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) { cx.foreground().forbid_parking(); let user_id = 5; let mut client = Client::new(FakeHttpClient::with_404_response()); let server = FakeServer::for_client(user_id, &mut client, &cx).await; let model = cx.add_model(|_| Model::default()); let (done_tx, mut done_rx) = smol::channel::unbounded(); let subscription = client.add_message_handler( model.clone(), move |model, _: TypedEnvelope, _, mut cx| { model.update(&mut cx, |model, _| model.subscription.take()); done_tx.try_send(()).unwrap(); async { Ok(()) } }, ); model.update(cx, |model, _| { model.subscription = Some(subscription); }); server.send(proto::Ping {}); done_rx.next().await.unwrap(); } #[derive(Default)] struct Model { id: usize, subscription: Option, } impl Entity for Model { type Event = (); } }