diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index c97c82c656..6cb5373881 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1464,6 +1464,21 @@ where // projects + pub async fn project_count_excluding_admins(&self) -> Result { + self.transact(|mut tx| async move { + Ok(sqlx::query_scalar::<_, i32>( + " + SELECT COUNT(*) + FROM projects, users + WHERE projects.host_user_id = users.id AND users.admin IS FALSE + ", + ) + .fetch_one(&mut tx) + .await? as usize) + }) + .await + } + pub async fn share_project( &self, expected_room_id: RoomId, diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index dc98a2ee68..20fae38c16 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -9,7 +9,6 @@ mod db_tests; #[cfg(test)] mod integration_tests; -use crate::rpc::ResultExt as _; use anyhow::anyhow; use axum::{routing::get, Router}; use collab::{Error, Result}; @@ -20,9 +19,7 @@ use std::{ net::{SocketAddr, TcpListener}, path::{Path, PathBuf}, sync::Arc, - time::Duration, }; -use tokio::signal; use tracing_log::LogTracer; use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer}; use util::ResultExt; @@ -129,7 +126,6 @@ async fn main() -> Result<()> { axum::Server::from_tcp(listener)? .serve(app.into_make_service_with_connect_info::()) - .with_graceful_shutdown(graceful_shutdown(rpc_server, state)) .await?; } _ => { @@ -174,52 +170,3 @@ pub fn init_tracing(config: &Config) -> Option<()> { None } - -async fn graceful_shutdown(rpc_server: Arc, state: Arc) { - let ctrl_c = async { - signal::ctrl_c() - .await - .expect("failed to install Ctrl+C handler"); - }; - - #[cfg(unix)] - let terminate = async { - signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("failed to install signal handler") - .recv() - .await; - }; - - #[cfg(not(unix))] - let terminate = std::future::pending::<()>(); - - tokio::select! { - _ = ctrl_c => {}, - _ = terminate => {}, - } - - if let Some(live_kit) = state.live_kit_client.as_ref() { - let deletions = rpc_server - .store() - .await - .rooms() - .values() - .map(|room| { - let name = room.live_kit_room.clone(); - async { - live_kit.delete_room(name).await.trace_err(); - } - }) - .collect::>(); - - tracing::info!("deleting all live-kit rooms"); - if let Err(_) = tokio::time::timeout( - Duration::from_secs(10), - futures::future::join_all(deletions), - ) - .await - { - tracing::error!("timed out waiting for live-kit room deletion"); - } - } -} diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 0c559239f5..58870163f5 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -49,7 +49,7 @@ use std::{ }, time::Duration, }; -pub use store::{Store, Worktree}; +pub use store::Store; use tokio::{ sync::{Mutex, MutexGuard}, time::Sleep, @@ -437,7 +437,7 @@ impl Server { let decline_calls = { let mut store = self.store().await; store.remove_connection(connection_id)?; - let mut connections = store.connection_ids_for_user(user_id); + let mut connections = store.user_connection_ids(user_id); connections.next().is_none() }; @@ -470,7 +470,7 @@ impl Server { if let Some(code) = &user.invite_code { let store = self.store().await; let invitee_contact = store.contact_for_user(invitee_id, true, false); - for connection_id in store.connection_ids_for_user(inviter_id) { + for connection_id in store.user_connection_ids(inviter_id) { self.peer.send( connection_id, proto::UpdateContacts { @@ -495,7 +495,7 @@ impl Server { if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? { if let Some(invite_code) = &user.invite_code { let store = self.store().await; - for connection_id in store.connection_ids_for_user(user_id) { + for connection_id in store.user_connection_ids(user_id) { self.peer.send( connection_id, proto::UpdateInviteInfo { @@ -582,7 +582,7 @@ impl Server { session.connection_id, ) .await?; - for connection_id in self.store().await.connection_ids_for_user(session.user_id) { + for connection_id in self.store().await.user_connection_ids(session.user_id) { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); @@ -674,7 +674,7 @@ impl Server { { let store = self.store().await; for canceled_user_id in left_room.canceled_calls_to_user_ids { - for connection_id in store.connection_ids_for_user(canceled_user_id) { + for connection_id in store.user_connection_ids(canceled_user_id) { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); @@ -744,7 +744,7 @@ impl Server { let mut calls = self .store() .await - .connection_ids_for_user(called_user_id) + .user_connection_ids(called_user_id) .map(|connection_id| self.peer.request(connection_id, incoming_call.clone())) .collect::>(); @@ -784,7 +784,7 @@ impl Server { .db .cancel_call(Some(room_id), session.connection_id, called_user_id) .await?; - for connection_id in self.store().await.connection_ids_for_user(called_user_id) { + for connection_id in self.store().await.user_connection_ids(called_user_id) { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); @@ -807,7 +807,7 @@ impl Server { .db .decline_call(Some(room_id), session.user_id) .await?; - for connection_id in self.store().await.connection_ids_for_user(session.user_id) { + for connection_id in self.store().await.user_connection_ids(session.user_id) { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); @@ -905,7 +905,7 @@ impl Server { .. } = contact { - for contact_conn_id in store.connection_ids_for_user(contact_user_id) { + for contact_conn_id in store.user_connection_ids(contact_user_id) { self.peer .send( contact_conn_id, @@ -1522,7 +1522,7 @@ impl Server { // Update outgoing contact requests of requester let mut update = proto::UpdateContacts::default(); update.outgoing_requests.push(responder_id.to_proto()); - for connection_id in self.store().await.connection_ids_for_user(requester_id) { + for connection_id in self.store().await.user_connection_ids(requester_id) { self.peer.send(connection_id, update.clone())?; } @@ -1534,7 +1534,7 @@ impl Server { requester_id: requester_id.to_proto(), should_notify: true, }); - for connection_id in self.store().await.connection_ids_for_user(responder_id) { + for connection_id in self.store().await.user_connection_ids(responder_id) { self.peer.send(connection_id, update.clone())?; } @@ -1574,7 +1574,7 @@ impl Server { update .remove_incoming_requests .push(requester_id.to_proto()); - for connection_id in store.connection_ids_for_user(responder_id) { + for connection_id in store.user_connection_ids(responder_id) { self.peer.send(connection_id, update.clone())?; } @@ -1588,7 +1588,7 @@ impl Server { update .remove_outgoing_requests .push(responder_id.to_proto()); - for connection_id in store.connection_ids_for_user(requester_id) { + for connection_id in store.user_connection_ids(requester_id) { self.peer.send(connection_id, update.clone())?; } } @@ -1615,7 +1615,7 @@ impl Server { update .remove_outgoing_requests .push(responder_id.to_proto()); - for connection_id in self.store().await.connection_ids_for_user(requester_id) { + for connection_id in self.store().await.user_connection_ids(requester_id) { self.peer.send(connection_id, update.clone())?; } @@ -1624,7 +1624,7 @@ impl Server { update .remove_incoming_requests .push(requester_id.to_proto()); - for connection_id in self.store().await.connection_ids_for_user(responder_id) { + for connection_id in self.store().await.user_connection_ids(responder_id) { self.peer.send(connection_id, update.clone())?; } @@ -1819,21 +1819,25 @@ pub async fn handle_websocket_request( }) } -pub async fn handle_metrics(Extension(server): Extension>) -> axum::response::Response { - let metrics = server.store().await.metrics(); - METRIC_CONNECTIONS.set(metrics.connections as _); - METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _); +pub async fn handle_metrics(Extension(server): Extension>) -> Result { + let connections = server + .store() + .await + .connections() + .filter(|connection| !connection.admin) + .count(); + + METRIC_CONNECTIONS.set(connections as _); + + let shared_projects = server.app_state.db.project_count_excluding_admins().await?; + METRIC_SHARED_PROJECTS.set(shared_projects as _); let encoder = prometheus::TextEncoder::new(); let metric_families = prometheus::gather(); - match encoder.encode_to_string(&metric_families) { - Ok(string) => (StatusCode::OK, string).into_response(), - Err(error) => ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("failed to encode metrics {:?}", error), - ) - .into_response(), - } + let encoded_metrics = encoder + .encode_to_string(&metric_families) + .map_err(|err| anyhow!("{}", err))?; + Ok(encoded_metrics) } fn to_axum_message(message: TungsteniteMessage) -> AxumMessage { diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index 1aa9c709b7..2bb6d89f40 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -1,111 +1,32 @@ -use crate::db::{self, ProjectId, UserId}; +use crate::db::{self, UserId}; use anyhow::{anyhow, Result}; -use collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use collections::{BTreeMap, HashSet}; use rpc::{proto, ConnectionId}; use serde::Serialize; -use std::path::PathBuf; use tracing::instrument; -pub type RoomId = u64; - #[derive(Default, Serialize)] pub struct Store { - connections: BTreeMap, + connections: BTreeMap, connected_users: BTreeMap, - next_room_id: RoomId, - rooms: BTreeMap, - projects: BTreeMap, } #[derive(Default, Serialize)] struct ConnectedUser { connection_ids: HashSet, - active_call: Option, } #[derive(Serialize)] -struct ConnectionState { - user_id: UserId, - admin: bool, - projects: BTreeSet, -} - -#[derive(Copy, Clone, Eq, PartialEq, Serialize)] -pub struct Call { - pub calling_user_id: UserId, - pub room_id: RoomId, - pub connection_id: Option, - pub initial_project_id: Option, -} - -#[derive(Serialize)] -pub struct Project { - pub id: ProjectId, - pub room_id: RoomId, - pub host_connection_id: ConnectionId, - pub host: Collaborator, - pub guests: HashMap, - pub active_replica_ids: HashSet, - pub worktrees: BTreeMap, - pub language_servers: Vec, -} - -#[derive(Serialize)] -pub struct Collaborator { - pub replica_id: ReplicaId, +pub struct Connection { pub user_id: UserId, pub admin: bool, } -#[derive(Default, Serialize)] -pub struct Worktree { - pub abs_path: PathBuf, - pub root_name: String, - pub visible: bool, - #[serde(skip)] - pub entries: BTreeMap, - #[serde(skip)] - pub diagnostic_summaries: BTreeMap, - pub scan_id: u64, - pub is_complete: bool, -} - -pub type ReplicaId = u16; - -#[derive(Copy, Clone)] -pub struct Metrics { - pub connections: usize, - pub shared_projects: usize, -} - impl Store { - pub fn metrics(&self) -> Metrics { - let connections = self.connections.values().filter(|c| !c.admin).count(); - let mut shared_projects = 0; - for project in self.projects.values() { - if let Some(connection) = self.connections.get(&project.host_connection_id) { - if !connection.admin { - shared_projects += 1; - } - } - } - - Metrics { - connections, - shared_projects, - } - } - #[instrument(skip(self))] pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) { - self.connections.insert( - connection_id, - ConnectionState { - user_id, - admin, - projects: Default::default(), - }, - ); + self.connections + .insert(connection_id, Connection { user_id, admin }); let connected_user = self.connected_users.entry(user_id).or_default(); connected_user.connection_ids.insert(connection_id); } @@ -127,10 +48,11 @@ impl Store { Ok(()) } - pub fn connection_ids_for_user( - &self, - user_id: UserId, - ) -> impl Iterator + '_ { + pub fn connections(&self) -> impl Iterator { + self.connections.values() + } + + pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator + '_ { self.connected_users .get(&user_id) .into_iter() @@ -197,35 +119,9 @@ impl Store { } } - pub fn rooms(&self) -> &BTreeMap { - &self.rooms - } - #[cfg(test)] pub fn check_invariants(&self) { for (connection_id, connection) in &self.connections { - for project_id in &connection.projects { - let project = &self.projects.get(project_id).unwrap(); - if project.host_connection_id != *connection_id { - assert!(project.guests.contains_key(connection_id)); - } - - for (worktree_id, worktree) in project.worktrees.iter() { - let mut paths = HashMap::default(); - for entry in worktree.entries.values() { - let prev_entry = paths.insert(&entry.path, entry); - assert_eq!( - prev_entry, - None, - "worktree {:?}, duplicate path for entries {:?} and {:?}", - worktree_id, - prev_entry.unwrap(), - entry - ); - } - } - } - assert!(self .connected_users .get(&connection.user_id) @@ -241,85 +137,6 @@ impl Store { *user_id ); } - - if let Some(active_call) = state.active_call.as_ref() { - if let Some(active_call_connection_id) = active_call.connection_id { - assert!( - state.connection_ids.contains(&active_call_connection_id), - "call is active on a dead connection" - ); - assert!( - state.connection_ids.contains(&active_call_connection_id), - "call is active on a dead connection" - ); - } - } - } - - for (room_id, room) in &self.rooms { - // for pending_user_id in &room.pending_participant_user_ids { - // assert!( - // self.connected_users - // .contains_key(&UserId::from_proto(*pending_user_id)), - // "call is active on a user that has disconnected" - // ); - // } - - for participant in &room.participants { - assert!( - self.connections - .contains_key(&ConnectionId(participant.peer_id)), - "room {} contains participant {:?} that has disconnected", - room_id, - participant - ); - - for participant_project in &participant.projects { - let project = &self.projects[&ProjectId::from_proto(participant_project.id)]; - assert_eq!( - project.room_id, *room_id, - "project was shared on a different room" - ); - } - } - - // assert!( - // !room.pending_participant_user_ids.is_empty() || !room.participants.is_empty(), - // "room can't be empty" - // ); - } - - for (project_id, project) in &self.projects { - let host_connection = self.connections.get(&project.host_connection_id).unwrap(); - assert!(host_connection.projects.contains(project_id)); - - for guest_connection_id in project.guests.keys() { - let guest_connection = self.connections.get(guest_connection_id).unwrap(); - assert!(guest_connection.projects.contains(project_id)); - } - assert_eq!(project.active_replica_ids.len(), project.guests.len()); - assert_eq!( - project.active_replica_ids, - project - .guests - .values() - .map(|guest| guest.replica_id) - .collect::>(), - ); - - let room = &self.rooms[&project.room_id]; - let room_participant = room - .participants - .iter() - .find(|participant| participant.peer_id == project.host_connection_id.0) - .unwrap(); - assert!( - room_participant - .projects - .iter() - .any(|project| project.id == project_id.to_proto()), - "project was not shared in room" - ); } } }