diff --git a/Cargo.lock b/Cargo.lock index 2b51a98764..e1cfb2c518 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9119,6 +9119,7 @@ name = "remote" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "collections", "fs", "futures 0.3.30", diff --git a/crates/collab/src/tests/remote_editing_collaboration_tests.rs b/crates/collab/src/tests/remote_editing_collaboration_tests.rs index dae3345755..504b2e11d1 100644 --- a/crates/collab/src/tests/remote_editing_collaboration_tests.rs +++ b/crates/collab/src/tests/remote_editing_collaboration_tests.rs @@ -26,7 +26,7 @@ async fn test_sharing_an_ssh_remote_project( .await; // Set up project on remote FS - let (client_ssh, server_ssh) = SshRemoteClient::fake(cx_a, server_cx); + let (port, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx); let remote_fs = FakeFs::new(server_cx.executor()); remote_fs .insert_tree( @@ -67,6 +67,7 @@ async fn test_sharing_an_ssh_remote_project( ) }); + let client_ssh = SshRemoteClient::fake_client(port, cx_a).await; let (project_a, worktree_id) = client_a .build_ssh_project("/code/project1", client_ssh, cx_a) .await; diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 70d5962647..8ea9e78cb7 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -1243,6 +1243,10 @@ impl Project { self.client.clone() } + pub fn ssh_client(&self) -> Option> { + self.ssh_client.clone() + } + pub fn user_store(&self) -> Model { self.user_store.clone() } diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 226b2474b2..7ba144a73a 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -12,6 +12,7 @@ message Envelope { uint32 id = 1; optional uint32 responding_to = 2; optional PeerId original_sender_id = 3; + optional uint32 ack_id = 266; oneof payload { Hello hello = 4; @@ -295,7 +296,9 @@ message Envelope { OpenServerSettings open_server_settings = 263; GetPermalinkToLine get_permalink_to_line = 264; - GetPermalinkToLineResponse get_permalink_to_line_response = 265; // current max + GetPermalinkToLineResponse get_permalink_to_line_response = 265; + + FlushBufferedMessages flush_buffered_messages = 267; } reserved 87 to 88; @@ -2522,3 +2525,6 @@ message GetPermalinkToLine { message GetPermalinkToLineResponse { string permalink = 1; } + +message FlushBufferedMessages {} +message FlushBufferedMessagesResponse {} diff --git a/crates/proto/src/macros.rs b/crates/proto/src/macros.rs index 4fdbfff81b..2ce0c0df25 100644 --- a/crates/proto/src/macros.rs +++ b/crates/proto/src/macros.rs @@ -32,6 +32,7 @@ macro_rules! messages { responding_to, original_sender_id, payload: Some(envelope::Payload::$name(self)), + ack_id: None, } } diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index ffbbeb49c2..8179473fea 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -372,6 +372,7 @@ messages!( (OpenServerSettings, Foreground), (GetPermalinkToLine, Foreground), (GetPermalinkToLineResponse, Foreground), + (FlushBufferedMessages, Foreground), ); request_messages!( @@ -498,6 +499,7 @@ request_messages!( (RemoveWorktree, Ack), (OpenServerSettings, OpenBufferResponse), (GetPermalinkToLine, GetPermalinkToLineResponse), + (FlushBufferedMessages, Ack), ); entity_messages!( diff --git a/crates/remote/Cargo.toml b/crates/remote/Cargo.toml index b8c5f34cc5..937a69ee59 100644 --- a/crates/remote/Cargo.toml +++ b/crates/remote/Cargo.toml @@ -19,6 +19,7 @@ test-support = ["fs/test-support"] [dependencies] anyhow.workspace = true +async-trait.workspace = true collections.workspace = true fs.workspace = true futures.workspace = true diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index 1d8006a060..5926a0b896 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -6,6 +6,7 @@ use crate::{ proxy::ProxyLaunchError, }; use anyhow::{anyhow, Context as _, Result}; +use async_trait::async_trait; use collections::HashMap; use futures::{ channel::{ @@ -13,7 +14,7 @@ use futures::{ oneshot, }, future::BoxFuture, - select_biased, AsyncReadExt as _, Future, FutureExt as _, SinkExt, StreamExt as _, + select_biased, AsyncReadExt as _, Future, FutureExt as _, StreamExt as _, }; use gpui::{ AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task, @@ -30,13 +31,14 @@ use smol::{ }; use std::{ any::TypeId, + collections::VecDeque, ffi::OsStr, fmt, ops::ControlFlow, path::{Path, PathBuf}, sync::{ atomic::{AtomicU32, Ordering::SeqCst}, - Arc, + Arc, Weak, }, time::{Duration, Instant}, }; @@ -275,68 +277,6 @@ async fn run_cmd(command: &mut process::Command) -> Result { } } -struct ChannelForwarder { - quit_tx: UnboundedSender<()>, - forwarding_task: Task<(UnboundedSender, UnboundedReceiver)>, -} - -impl ChannelForwarder { - fn new( - mut incoming_tx: UnboundedSender, - mut outgoing_rx: UnboundedReceiver, - cx: &AsyncAppContext, - ) -> (Self, UnboundedSender, UnboundedReceiver) { - let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>(); - - let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::(); - let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::(); - - let forwarding_task = cx.background_executor().spawn(async move { - loop { - select_biased! { - _ = quit_rx.next().fuse() => { - break; - }, - incoming_envelope = proxy_incoming_rx.next().fuse() => { - if let Some(envelope) = incoming_envelope { - if incoming_tx.send(envelope).await.is_err() { - break; - } - } else { - break; - } - } - outgoing_envelope = outgoing_rx.next().fuse() => { - if let Some(envelope) = outgoing_envelope { - if proxy_outgoing_tx.send(envelope).await.is_err() { - break; - } - } else { - break; - } - } - } - } - - (incoming_tx, outgoing_rx) - }); - - ( - Self { - forwarding_task, - quit_tx, - }, - proxy_incoming_tx, - proxy_outgoing_rx, - ) - } - - async fn into_channels(mut self) -> (UnboundedSender, UnboundedReceiver) { - let _ = self.quit_tx.send(()).await; - self.forwarding_task.await - } -} - const MAX_MISSED_HEARTBEATS: usize = 5; const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5); @@ -346,9 +286,8 @@ const MAX_RECONNECT_ATTEMPTS: usize = 3; enum State { Connecting, Connected { - ssh_connection: SshRemoteConnection, + ssh_connection: Box, delegate: Arc, - forwarder: ChannelForwarder, multiplex_task: Task>, heartbeat_task: Task>, @@ -356,18 +295,16 @@ enum State { HeartbeatMissed { missed_heartbeats: usize, - ssh_connection: SshRemoteConnection, + ssh_connection: Box, delegate: Arc, - forwarder: ChannelForwarder, multiplex_task: Task>, heartbeat_task: Task>, }, Reconnecting, ReconnectFailed { - ssh_connection: SshRemoteConnection, + ssh_connection: Box, delegate: Arc, - forwarder: ChannelForwarder, error: anyhow::Error, attempts: usize, @@ -391,11 +328,11 @@ impl fmt::Display for State { } impl State { - fn ssh_connection(&self) -> Option<&SshRemoteConnection> { + fn ssh_connection(&self) -> Option<&dyn SshRemoteProcess> { match self { - Self::Connected { ssh_connection, .. } => Some(ssh_connection), - Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection), - Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection), + Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()), + Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()), + Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()), _ => None, } } @@ -429,14 +366,12 @@ impl State { Self::HeartbeatMissed { ssh_connection, delegate, - forwarder, multiplex_task, heartbeat_task, .. } => Self::Connected { ssh_connection, delegate, - forwarder, multiplex_task, heartbeat_task, }, @@ -449,14 +384,12 @@ impl State { Self::Connected { ssh_connection, delegate, - forwarder, multiplex_task, heartbeat_task, } => Self::HeartbeatMissed { missed_heartbeats: 1, ssh_connection, delegate, - forwarder, multiplex_task, heartbeat_task, }, @@ -464,14 +397,12 @@ impl State { missed_heartbeats, ssh_connection, delegate, - forwarder, multiplex_task, heartbeat_task, } => Self::HeartbeatMissed { missed_heartbeats: missed_heartbeats + 1, ssh_connection, delegate, - forwarder, multiplex_task, heartbeat_task, }, @@ -529,7 +460,8 @@ impl SshRemoteClient { let (incoming_tx, incoming_rx) = mpsc::unbounded::(); let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1); - let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?; + let client = + cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?; let this = cx.new_model(|_| Self { client: client.clone(), unique_identifier: unique_identifier.clone(), @@ -537,26 +469,19 @@ impl SshRemoteClient { state: Arc::new(Mutex::new(Some(State::Connecting))), })?; - let (proxy, proxy_incoming_tx, proxy_outgoing_rx) = - ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx); - - let (ssh_connection, ssh_proxy_process) = Self::establish_connection( + let (ssh_connection, io_task) = Self::establish_connection( unique_identifier, false, connection_options, + incoming_tx, + outgoing_rx, + connection_activity_tx, delegate.clone(), &mut cx, ) .await?; - let multiplex_task = Self::multiplex( - this.downgrade(), - ssh_proxy_process, - proxy_incoming_tx, - proxy_outgoing_rx, - connection_activity_tx, - &mut cx, - ); + let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx); if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await { log::error!("failed to establish connection: {}", error); @@ -570,7 +495,6 @@ impl SshRemoteClient { *this.state.lock() = Some(State::Connected { ssh_connection, delegate, - forwarder: proxy, multiplex_task, heartbeat_task, }); @@ -592,7 +516,6 @@ impl SshRemoteClient { heartbeat_task, ssh_connection, delegate, - forwarder, } = state else { return None; @@ -616,7 +539,6 @@ impl SshRemoteClient { drop(heartbeat_task); drop(ssh_connection); drop(delegate); - drop(forwarder); }) } @@ -638,33 +560,30 @@ impl SshRemoteClient { } let state = lock.take().unwrap(); - let (attempts, mut ssh_connection, delegate, forwarder) = match state { + let (attempts, mut ssh_connection, delegate) = match state { State::Connected { ssh_connection, delegate, - forwarder, multiplex_task, heartbeat_task, } | State::HeartbeatMissed { ssh_connection, delegate, - forwarder, multiplex_task, heartbeat_task, .. } => { drop(multiplex_task); drop(heartbeat_task); - (0, ssh_connection, delegate, forwarder) + (0, ssh_connection, delegate) } State::ReconnectFailed { attempts, ssh_connection, delegate, - forwarder, .. - } => (attempts, ssh_connection, delegate, forwarder), + } => (attempts, ssh_connection, delegate), State::Connecting | State::Reconnecting | State::ReconnectExhausted @@ -691,41 +610,37 @@ impl SshRemoteClient { let client = self.client.clone(); let reconnect_task = cx.spawn(|this, mut cx| async move { macro_rules! failed { - ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => { + ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => { return State::ReconnectFailed { error: anyhow!($error), attempts: $attempts, ssh_connection: $ssh_connection, delegate: $delegate, - forwarder: $forwarder, }; }; } - if let Err(error) = ssh_connection.master_process.kill() { - failed!(error, attempts, ssh_connection, delegate, forwarder); - }; - if let Err(error) = ssh_connection - .master_process - .status() + .kill() .await .context("Failed to kill ssh process") { - failed!(error, attempts, ssh_connection, delegate, forwarder); - } + failed!(error, attempts, ssh_connection, delegate); + }; - let connection_options = ssh_connection.socket.connection_options.clone(); + let connection_options = ssh_connection.connection_options(); - let (incoming_tx, outgoing_rx) = forwarder.into_channels().await; - let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) = - ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx); + let (outgoing_tx, outgoing_rx) = mpsc::unbounded::(); + let (incoming_tx, incoming_rx) = mpsc::unbounded::(); let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1); - let (ssh_connection, ssh_process) = match Self::establish_connection( + let (ssh_connection, io_task) = match Self::establish_connection( identifier, true, connection_options, + incoming_tx, + outgoing_rx, + connection_activity_tx, delegate.clone(), &mut cx, ) @@ -733,27 +648,20 @@ impl SshRemoteClient { { Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process), Err(error) => { - failed!(error, attempts, ssh_connection, delegate, forwarder); + failed!(error, attempts, ssh_connection, delegate); } }; - let multiplex_task = Self::multiplex( - this.clone(), - ssh_process, - proxy_incoming_tx, - proxy_outgoing_rx, - connection_activity_tx, - &mut cx, - ); + let multiplex_task = Self::monitor(this.clone(), io_task, &cx); + client.reconnect(incoming_rx, outgoing_tx, &cx); - if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await { - failed!(error, attempts, ssh_connection, delegate, forwarder); + if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await { + failed!(error, attempts, ssh_connection, delegate); }; State::Connected { ssh_connection, delegate, - forwarder, multiplex_task, heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx), } @@ -797,7 +705,7 @@ impl SshRemoteClient { cx.emit(SshRemoteEvent::Disconnected); Ok(()) } else { - log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state."); + log::debug!("State has transition from Reconnecting into new state while attempting reconnect."); Ok(()) } }) @@ -910,13 +818,12 @@ impl SshRemoteClient { } fn multiplex( - this: WeakModel, mut ssh_proxy_process: Child, incoming_tx: UnboundedSender, mut outgoing_rx: UnboundedReceiver, mut connection_activity_tx: Sender<()>, cx: &AsyncAppContext, - ) -> Task> { + ) -> Task> { let mut child_stderr = ssh_proxy_process.stderr.take().unwrap(); let mut child_stdout = ssh_proxy_process.stdout.take().unwrap(); let mut child_stdin = ssh_proxy_process.stdin.take().unwrap(); @@ -988,7 +895,7 @@ impl SshRemoteClient { } }); - cx.spawn(|mut cx| async move { + cx.spawn(|_| async move { let result = futures::select! { result = stdin_task.fuse() => { result.context("stdin") @@ -1002,9 +909,22 @@ impl SshRemoteClient { }; match result { - Ok(_) => { - let exit_code = ssh_proxy_process.status().await?.code().unwrap_or(1); + Ok(_) => Ok(ssh_proxy_process.status().await?.code().unwrap_or(1)), + Err(error) => Err(error), + } + }) + } + fn monitor( + this: WeakModel, + io_task: Task>, + cx: &AsyncAppContext, + ) -> Task> { + cx.spawn(|mut cx| async move { + let result = io_task.await; + + match result { + Ok(exit_code) => { if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) { match error { ProxyLaunchError::ServerNotRunning => { @@ -1058,21 +978,40 @@ impl SshRemoteClient { cx.notify(); } + #[allow(clippy::too_many_arguments)] async fn establish_connection( unique_identifier: String, reconnect: bool, connection_options: SshConnectionOptions, + incoming_tx: UnboundedSender, + outgoing_rx: UnboundedReceiver, + connection_activity_tx: Sender<()>, delegate: Arc, cx: &mut AsyncAppContext, - ) -> Result<(SshRemoteConnection, Child)> { + ) -> Result<(Box, Task>)> { + #[cfg(any(test, feature = "test-support"))] + if let Some(fake) = fake::SshRemoteConnection::new(&connection_options) { + let io_task = fake::SshRemoteConnection::multiplex( + fake.connection_options(), + incoming_tx, + outgoing_rx, + connection_activity_tx, + cx, + ) + .await; + return Ok((fake, io_task)); + } + let ssh_connection = SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?; let platform = ssh_connection.query_platform().await?; let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?; - ssh_connection - .ensure_server_binary(&delegate, &remote_binary_path, platform, cx) - .await?; + if !reconnect { + ssh_connection + .ensure_server_binary(&delegate, &remote_binary_path, platform, cx) + .await?; + } let socket = ssh_connection.socket.clone(); run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?; @@ -1097,7 +1036,15 @@ impl SshRemoteClient { .spawn() .context("failed to spawn remote server")?; - Ok((ssh_connection, ssh_proxy_process)) + let io_task = Self::multiplex( + ssh_proxy_process, + incoming_tx, + outgoing_rx, + connection_activity_tx, + &cx, + ); + + Ok((Box::new(ssh_connection), io_task)) } pub fn subscribe_to_entity(&self, remote_id: u64, entity: &Model) { @@ -1109,7 +1056,7 @@ impl SshRemoteClient { .lock() .as_ref() .and_then(|state| state.ssh_connection()) - .map(|ssh_connection| ssh_connection.socket.ssh_args()) + .map(|ssh_connection| ssh_connection.ssh_args()) } pub fn proto_client(&self) -> AnyProtoClient { @@ -1124,7 +1071,6 @@ impl SshRemoteClient { self.connection_options.clone() } - #[cfg(not(any(test, feature = "test-support")))] pub fn connection_state(&self) -> ConnectionState { self.state .lock() @@ -1133,37 +1079,59 @@ impl SshRemoteClient { .unwrap_or(ConnectionState::Disconnected) } - #[cfg(any(test, feature = "test-support"))] - pub fn connection_state(&self) -> ConnectionState { - ConnectionState::Connected - } - pub fn is_disconnected(&self) -> bool { self.connection_state() == ConnectionState::Disconnected } #[cfg(any(test, feature = "test-support"))] - pub fn fake( + pub fn simulate_disconnect(&self, client_cx: &mut AppContext) -> Task<()> { + let port = self.connection_options().port.unwrap(); + client_cx.spawn(|cx| async move { + let (channel, server_cx) = cx + .update_global(|c: &mut fake::ServerConnections, _| c.get(port)) + .unwrap(); + + let (outgoing_tx, _) = mpsc::unbounded::(); + let (_, incoming_rx) = mpsc::unbounded::(); + channel.reconnect(incoming_rx, outgoing_tx, &server_cx); + }) + } + + #[cfg(any(test, feature = "test-support"))] + pub fn fake_server( client_cx: &mut gpui::TestAppContext, server_cx: &mut gpui::TestAppContext, - ) -> (Model, Arc) { - use gpui::Context; + ) -> (u16, Arc) { + use gpui::BorrowAppContext; + let (outgoing_tx, _) = mpsc::unbounded::(); + let (_, incoming_rx) = mpsc::unbounded::(); + let server_client = + server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server")); + let port = client_cx.update(|cx| { + cx.update_default_global(|c: &mut fake::ServerConnections, _| { + c.push(server_client.clone(), server_cx.to_async()) + }) + }); + (port, server_client) + } - let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded(); - let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded(); - - ( - client_cx.update(|cx| { - let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx); - cx.new_model(|_| Self { - client, - unique_identifier: "fake".to_string(), - connection_options: SshConnectionOptions::default(), - state: Arc::new(Mutex::new(None)), - }) - }), - server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)), - ) + #[cfg(any(test, feature = "test-support"))] + pub async fn fake_client(port: u16, client_cx: &mut gpui::TestAppContext) -> Model { + client_cx + .update(|cx| { + Self::new( + "fake".to_string(), + SshConnectionOptions { + host: "".to_string(), + port: Some(port), + ..Default::default() + }, + Arc::new(fake::Delegate), + cx, + ) + }) + .await + .unwrap() } } @@ -1173,6 +1141,13 @@ impl From for AnyProtoClient { } } +#[async_trait] +trait SshRemoteProcess: Send + Sync { + async fn kill(&mut self) -> Result<()>; + fn ssh_args(&self) -> Vec; + fn connection_options(&self) -> SshConnectionOptions; +} + struct SshRemoteConnection { socket: SshSocket, master_process: process::Child, @@ -1187,6 +1162,25 @@ impl Drop for SshRemoteConnection { } } +#[async_trait] +impl SshRemoteProcess for SshRemoteConnection { + async fn kill(&mut self) -> Result<()> { + self.master_process.kill()?; + + self.master_process.status().await?; + + Ok(()) + } + + fn ssh_args(&self) -> Vec { + self.socket.ssh_args() + } + + fn connection_options(&self) -> SshConnectionOptions { + self.socket.connection_options.clone() + } +} + impl SshRemoteConnection { #[cfg(not(unix))] async fn new( @@ -1469,9 +1463,13 @@ type ResponseChannels = Mutex, - response_channels: ResponseChannels, // Lock - message_handlers: Mutex, // Lock + outgoing_tx: Mutex>, + buffer: Mutex>, + response_channels: ResponseChannels, + message_handlers: Mutex, + max_received: AtomicU32, + name: &'static str, + task: Mutex>>, } impl ChannelClient { @@ -1479,32 +1477,59 @@ impl ChannelClient { incoming_rx: mpsc::UnboundedReceiver, outgoing_tx: mpsc::UnboundedSender, cx: &AppContext, + name: &'static str, ) -> Arc { - let this = Arc::new(Self { - outgoing_tx, + Arc::new_cyclic(|this| Self { + outgoing_tx: Mutex::new(outgoing_tx), next_message_id: AtomicU32::new(0), + max_received: AtomicU32::new(0), response_channels: ResponseChannels::default(), message_handlers: Default::default(), - }); - - Self::start_handling_messages(this.clone(), incoming_rx, cx); - - this + buffer: Mutex::new(VecDeque::new()), + name, + task: Mutex::new(Self::start_handling_messages( + this.clone(), + incoming_rx, + &cx.to_async(), + )), + }) } fn start_handling_messages( - this: Arc, + this: Weak, mut incoming_rx: mpsc::UnboundedReceiver, - cx: &AppContext, - ) { + cx: &AsyncAppContext, + ) -> Task> { cx.spawn(|cx| { - let this = Arc::downgrade(&this); async move { let peer_id = PeerId { owner_id: 0, id: 0 }; while let Some(incoming) = incoming_rx.next().await { let Some(this) = this.upgrade() else { return anyhow::Ok(()); }; + if let Some(ack_id) = incoming.ack_id { + let mut buffer = this.buffer.lock(); + while buffer.front().is_some_and(|msg| msg.id <= ack_id) { + buffer.pop_front(); + } + } + if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = + &incoming.payload + { + log::debug!("{}:ssh message received. name:FlushBufferedMessages", this.name); + { + let buffer = this.buffer.lock(); + for envelope in buffer.iter() { + this.outgoing_tx.lock().unbounded_send(envelope.clone()).ok(); + } + } + let mut envelope = proto::Ack{}.into_envelope(0, Some(incoming.id), None); + envelope.id = this.next_message_id.fetch_add(1, SeqCst); + this.outgoing_tx.lock().unbounded_send(envelope).ok(); + continue; + } + + this.max_received.store(incoming.id, SeqCst); if let Some(request_id) = incoming.responding_to { let request_id = MessageId(request_id); @@ -1526,26 +1551,37 @@ impl ChannelClient { this.clone().into(), cx.clone(), ) { - log::debug!("ssh message received. name:{type_name}"); - match future.await { - Ok(_) => { - log::debug!("ssh message handled. name:{type_name}"); + log::debug!("{}:ssh message received. name:{type_name}", this.name); + cx.foreground_executor().spawn(async move { + match future.await { + Ok(_) => { + log::debug!("{}:ssh message handled. name:{type_name}", this.name); + } + Err(error) => { + log::error!( + "{}:error handling message. type:{type_name}, error:{error}", this.name, + ); + } } - Err(error) => { - log::error!( - "error handling message. type:{type_name}, error:{error}", - ); - } - } + }).detach() } else { - log::error!("unhandled ssh message name:{type_name}"); + log::error!("{}:unhandled ssh message name:{type_name}", this.name); } } } anyhow::Ok(()) } }) - .detach(); + } + + pub fn reconnect( + self: &Arc, + incoming_rx: UnboundedReceiver, + outgoing_tx: UnboundedSender, + cx: &AsyncAppContext, + ) { + *self.outgoing_tx.lock() = outgoing_tx; + *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx); } pub fn subscribe_to_entity(&self, remote_id: u64, entity: &Model) { @@ -1581,6 +1617,26 @@ impl ChannelClient { } } + pub async fn resync(&self, timeout: Duration) -> Result<()> { + smol::future::or( + async { + self.request(proto::FlushBufferedMessages {}).await?; + for envelope in self.buffer.lock().iter() { + self.outgoing_tx + .lock() + .unbounded_send(envelope.clone()) + .ok(); + } + Ok(()) + }, + async { + smol::Timer::after(timeout).await; + Err(anyhow!("Timeout detected")) + }, + ) + .await + } + pub async fn ping(&self, timeout: Duration) -> Result<()> { smol::future::or( async { @@ -1610,7 +1666,8 @@ impl ChannelClient { let mut response_channels_lock = self.response_channels.lock(); response_channels_lock.insert(MessageId(envelope.id), tx); drop(response_channels_lock); - let result = self.outgoing_tx.unbounded_send(envelope); + + let result = self.send_buffered(envelope); async move { if let Err(error) = &result { log::error!("failed to send message: {}", error); @@ -1627,7 +1684,15 @@ impl ChannelClient { pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> { envelope.id = self.next_message_id.fetch_add(1, SeqCst); - self.outgoing_tx.unbounded_send(envelope)?; + self.send_buffered(envelope) + } + + pub fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> { + envelope.ack_id = Some(self.max_received.load(SeqCst)); + self.buffer.lock().push_back(envelope.clone()); + // ignore errors on send (happen while we're reconnecting) + // assume that the global "disconnected" overlay is sufficient. + self.outgoing_tx.lock().unbounded_send(envelope).ok(); Ok(()) } } @@ -1657,3 +1722,148 @@ impl ProtoClient for ChannelClient { false } } + +#[cfg(any(test, feature = "test-support"))] +mod fake { + use std::{path::PathBuf, sync::Arc}; + + use anyhow::Result; + use async_trait::async_trait; + use futures::{ + channel::{ + mpsc::{self, Sender}, + oneshot, + }, + select_biased, FutureExt, SinkExt, StreamExt, + }; + use gpui::{AsyncAppContext, BorrowAppContext, Global, SemanticVersion, Task}; + use rpc::proto::Envelope; + + use super::{ + ChannelClient, SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteProcess, + }; + + pub(super) struct SshRemoteConnection { + connection_options: SshConnectionOptions, + } + + impl SshRemoteConnection { + pub(super) fn new( + connection_options: &SshConnectionOptions, + ) -> Option> { + if connection_options.host == "" { + return Some(Box::new(Self { + connection_options: connection_options.clone(), + })); + } + return None; + } + pub(super) async fn multiplex( + connection_options: SshConnectionOptions, + mut client_incoming_tx: mpsc::UnboundedSender, + mut client_outgoing_rx: mpsc::UnboundedReceiver, + mut connection_activity_tx: Sender<()>, + cx: &mut AsyncAppContext, + ) -> Task> { + let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::(); + let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::(); + + let (channel, server_cx) = cx + .update(|cx| { + cx.update_global(|conns: &mut ServerConnections, _| { + conns.get(connection_options.port.unwrap()) + }) + }) + .unwrap(); + channel.reconnect(server_incoming_rx, server_outgoing_tx, &server_cx); + + // send to proxy_tx to get to the server. + // receive from + + cx.background_executor().spawn(async move { + loop { + select_biased! { + server_to_client = server_outgoing_rx.next().fuse() => { + let Some(server_to_client) = server_to_client else { + return Ok(1) + }; + connection_activity_tx.try_send(()).ok(); + client_incoming_tx.send(server_to_client).await.ok(); + } + client_to_server = client_outgoing_rx.next().fuse() => { + let Some(client_to_server) = client_to_server else { + return Ok(1) + }; + server_incoming_tx.send(client_to_server).await.ok(); + } + } + } + }) + } + } + + #[async_trait] + impl SshRemoteProcess for SshRemoteConnection { + async fn kill(&mut self) -> Result<()> { + Ok(()) + } + + fn ssh_args(&self) -> Vec { + Vec::new() + } + + fn connection_options(&self) -> SshConnectionOptions { + self.connection_options.clone() + } + } + + #[derive(Default)] + pub(super) struct ServerConnections(Vec<(Arc, AsyncAppContext)>); + impl Global for ServerConnections {} + + impl ServerConnections { + pub(super) fn push(&mut self, server: Arc, cx: AsyncAppContext) -> u16 { + self.0.push((server.clone(), cx)); + self.0.len() as u16 - 1 + } + + pub(super) fn get(&mut self, port: u16) -> (Arc, AsyncAppContext) { + self.0 + .get(port as usize) + .expect("no fake server for port") + .clone() + } + } + + pub(super) struct Delegate; + + impl SshClientDelegate for Delegate { + fn ask_password( + &self, + _: String, + _: &mut AsyncAppContext, + ) -> oneshot::Receiver> { + unreachable!() + } + fn remote_server_binary_path( + &self, + _: SshPlatform, + _: &mut AsyncAppContext, + ) -> Result { + unreachable!() + } + fn get_server_binary( + &self, + _: SshPlatform, + _: &mut AsyncAppContext, + ) -> oneshot::Receiver> { + unreachable!() + } + fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) { + unreachable!() + } + fn set_error(&self, _: String, _: &mut AsyncAppContext) { + unreachable!() + } + } +} diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index 41065ad550..51d8d67887 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -641,6 +641,47 @@ async fn test_open_server_settings(cx: &mut TestAppContext, server_cx: &mut Test }) } +#[gpui::test(iterations = 20)] +async fn test_reconnect(cx: &mut TestAppContext, server_cx: &mut TestAppContext) { + let (project, _headless, fs) = init_test(cx, server_cx).await; + + let (worktree, _) = project + .update(cx, |project, cx| { + project.find_or_create_worktree("/code/project1", true, cx) + }) + .await + .unwrap(); + + let worktree_id = worktree.read_with(cx, |worktree, _| worktree.id()); + let buffer = project + .update(cx, |project, cx| { + project.open_buffer((worktree_id, Path::new("src/lib.rs")), cx) + }) + .await + .unwrap(); + + buffer.update(cx, |buffer, cx| { + assert_eq!(buffer.text(), "fn one() -> usize { 1 }"); + let ix = buffer.text().find('1').unwrap(); + buffer.edit([(ix..ix + 1, "100")], None, cx); + }); + + let client = cx.read(|cx| project.read(cx).ssh_client().unwrap()); + client + .update(cx, |client, cx| client.simulate_disconnect(cx)) + .detach(); + + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .unwrap(); + + assert_eq!( + fs.load("/code/project1/src/lib.rs".as_ref()).await.unwrap(), + "fn one() -> usize { 100 }" + ); +} + fn init_logger() { if std::env::var("RUST_LOG").is_ok() { env_logger::try_init().ok(); @@ -651,9 +692,9 @@ async fn init_test( cx: &mut TestAppContext, server_cx: &mut TestAppContext, ) -> (Model, Model, Arc) { - let (ssh_remote_client, ssh_server_client) = SshRemoteClient::fake(cx, server_cx); init_logger(); + let (forwarder, ssh_server_client) = SshRemoteClient::fake_server(cx, server_cx); let fs = FakeFs::new(server_cx.executor()); fs.insert_tree( "/code", @@ -694,8 +735,9 @@ async fn init_test( cx, ) }); - let project = build_project(ssh_remote_client, cx); + let ssh = SshRemoteClient::fake_client(forwarder, cx).await; + let project = build_project(ssh, cx); project .update(cx, { let headless = headless.clone(); diff --git a/crates/remote_server/src/unix.rs b/crates/remote_server/src/unix.rs index 30b2bacd0a..abfdb798f4 100644 --- a/crates/remote_server/src/unix.rs +++ b/crates/remote_server/src/unix.rs @@ -279,7 +279,7 @@ fn start_server( }) .detach(); - ChannelClient::new(incoming_rx, outgoing_tx, cx) + ChannelClient::new(incoming_rx, outgoing_tx, cx, "server") } fn init_paths() -> anyhow::Result<()> {