diff --git a/Cargo.lock b/Cargo.lock index a24b7b81ed..817a0e7c91 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4699,6 +4699,7 @@ dependencies = [ "sqlx-rt 0.5.5", "stringprep", "thiserror", + "time 0.2.25", "url", "webpki", "webpki-roots", @@ -5866,6 +5867,7 @@ dependencies = [ "surf", "tide", "tide-compress", + "time 0.2.25", "toml 0.5.8", "zed", "zrpc", diff --git a/server/Cargo.toml b/server/Cargo.toml index 6d26f66054..aad43e5b6e 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -31,6 +31,7 @@ sha-1 = "0.9" surf = "2.2.0" tide = "0.16.0" tide-compress = "0.9.0" +time = "0.2" toml = "0.5.8" zrpc = { path = "../zrpc" } @@ -41,7 +42,7 @@ default-features = false [dependencies.sqlx] version = "0.5.2" -features = ["runtime-async-std-rustls", "postgres"] +features = ["runtime-async-std-rustls", "postgres", "time"] [dev-dependencies] gpui = { path = "../gpui" } diff --git a/server/src/auth.rs b/server/src/auth.rs index 9dde8212ff..d61428fa37 100644 --- a/server/src/auth.rs +++ b/server/src/auth.rs @@ -257,7 +257,7 @@ async fn get_auth_callback(mut request: Request) -> tide::Result { // When signing in from the native app, generate a new access token for the current user. Return // a redirect so that the user's browser sends this access token to the locally-running app. if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) { - let access_token = create_access_token(request.db(), user.id()).await?; + let access_token = create_access_token(request.db(), user.id).await?; let native_app_public_key = zed_auth::PublicKey::try_from(app_sign_in_params.native_app_public_key.clone()) .context("failed to parse app public key")?; @@ -267,9 +267,7 @@ async fn get_auth_callback(mut request: Request) -> tide::Result { return Ok(tide::Redirect::new(&format!( "http://127.0.0.1:{}?user_id={}&access_token={}", - app_sign_in_params.native_app_port, - user.id().0, - encrypted_access_token, + app_sign_in_params.native_app_port, user.id.0, encrypted_access_token, )) .into()); } diff --git a/server/src/db.rs b/server/src/db.rs index 300c8de6d5..de196766e1 100644 --- a/server/src/db.rs +++ b/server/src/db.rs @@ -1,5 +1,6 @@ use serde::Serialize; use sqlx::{FromRow, Result}; +use time::OffsetDateTime; pub use async_sqlx_session::PostgresSessionStore as SessionStore; pub use sqlx::postgres::PgPoolOptions as DbOptions; @@ -8,14 +9,14 @@ pub struct Db(pub sqlx::PgPool); #[derive(Debug, FromRow, Serialize)] pub struct User { - id: i32, + pub id: UserId, pub github_login: String, pub admin: bool, } #[derive(Debug, FromRow, Serialize)] pub struct Signup { - id: i32, + pub id: SignupId, pub github_login: String, pub email_address: String, pub about: String, @@ -23,33 +24,18 @@ pub struct Signup { #[derive(Debug, FromRow, Serialize)] pub struct Channel { - id: i32, + pub id: ChannelId, pub name: String, } #[derive(Debug, FromRow)] pub struct ChannelMessage { - id: i32, - sender_id: i32, - body: String, - sent_at: i64, + pub id: MessageId, + pub sender_id: UserId, + pub body: String, + pub sent_at: OffsetDateTime, } -#[derive(Clone, Copy)] -pub struct UserId(pub i32); - -#[derive(Clone, Copy)] -pub struct OrgId(pub i32); - -#[derive(Clone, Copy)] -pub struct ChannelId(pub i32); - -#[derive(Clone, Copy)] -pub struct SignupId(pub i32); - -#[derive(Clone, Copy)] -pub struct MessageId(pub i32); - impl Db { // signups @@ -108,6 +94,33 @@ impl Db { sqlx::query_as(query).fetch_all(&self.0).await } + pub async fn get_users_by_ids( + &self, + requester_id: UserId, + ids: impl Iterator, + ) -> Result> { + // Only return users that are in a common channel with the requesting user. + let query = " + SELECT users.* + FROM + users, channel_memberships + WHERE + users.id IN $1 AND + channel_memberships.user_id = users.id AND + channel_memberships.channel_id IN ( + SELECT channel_id + FROM channel_memberships + WHERE channel_memberships.user_id = $2 + ) + "; + + sqlx::query_as(query) + .bind(&ids.map(|id| id.0).collect::>()) + .bind(requester_id) + .fetch_all(&self.0) + .await + } + pub async fn get_user_by_github_login(&self, github_login: &str) -> Result> { let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1"; sqlx::query_as(query) @@ -147,7 +160,7 @@ impl Db { VALUES ($1, $2) "; sqlx::query(query) - .bind(user_id.0 as i32) + .bind(user_id.0) .bind(access_token_hash) .execute(&self.0) .await @@ -156,8 +169,8 @@ impl Db { pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { let query = "SELECT hash FROM access_tokens WHERE user_id = $1"; - sqlx::query_scalar::<_, String>(query) - .bind(user_id.0 as i32) + sqlx::query_scalar(query) + .bind(user_id.0) .fetch_all(&self.0) .await } @@ -180,14 +193,20 @@ impl Db { } #[cfg(test)] - pub async fn add_org_member(&self, org_id: OrgId, user_id: UserId) -> Result<()> { + pub async fn add_org_member( + &self, + org_id: OrgId, + user_id: UserId, + is_admin: bool, + ) -> Result<()> { let query = " - INSERT INTO org_memberships (org_id, user_id) - VALUES ($1, $2) + INSERT INTO org_memberships (org_id, user_id, admin) + VALUES ($1, $2, $3) "; sqlx::query(query) .bind(org_id.0) .bind(user_id.0) + .bind(is_admin) .execute(&self.0) .await .map(drop) @@ -272,16 +291,18 @@ impl Db { channel_id: ChannelId, sender_id: UserId, body: &str, + timestamp: OffsetDateTime, ) -> Result { let query = " INSERT INTO channel_messages (channel_id, sender_id, body, sent_at) - VALUES ($1, $2, $3, NOW()::timestamp) + VALUES ($1, $2, $3, $4) RETURNING id "; sqlx::query_scalar(query) .bind(channel_id.0) .bind(sender_id.0) .bind(body) + .bind(timestamp) .fetch_one(&self.0) .await .map(MessageId) @@ -292,12 +313,15 @@ impl Db { channel_id: ChannelId, count: usize, ) -> Result> { - let query = " - SELECT id, sender_id, body, sent_at - FROM channel_messages - WHERE channel_id = $1 + let query = r#" + SELECT + id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at + FROM + channel_messages + WHERE + channel_id = $1 LIMIT $2 - "; + "#; sqlx::query_as(query) .bind(channel_id.0) .bind(count as i64) @@ -314,14 +338,29 @@ impl std::ops::Deref for Db { } } -impl Channel { - pub fn id(&self) -> ChannelId { - ChannelId(self.id) - } +macro_rules! id_type { + ($name:ident) => { + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)] + #[sqlx(transparent)] + #[serde(transparent)] + pub struct $name(pub i32); + + impl $name { + #[allow(unused)] + pub fn from_proto(value: u64) -> Self { + Self(value as i32) + } + + #[allow(unused)] + pub fn to_proto(&self) -> u64 { + self.0 as u64 + } + } + }; } -impl User { - pub fn id(&self) -> UserId { - UserId(self.id) - } -} +id_type!(UserId); +id_type!(OrgId); +id_type!(ChannelId); +id_type!(SignupId); +id_type!(MessageId); diff --git a/server/src/rpc.rs b/server/src/rpc.rs index f1ebf605a2..8696f03691 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -23,6 +23,7 @@ use tide::{ http::headers::{HeaderName, CONNECTION, UPGRADE}, Request, Response, }; +use time::OffsetDateTime; use zrpc::{ auth::random_token, proto::{self, EnvelopedMessage}, @@ -33,17 +34,19 @@ type ReplicaId = u16; #[derive(Default)] pub struct State { - connections: HashMap, - pub worktrees: HashMap, + connections: HashMap, + pub worktrees: HashMap, + channels: HashMap, next_worktree_id: u64, } -struct ConnectionState { +struct Connection { user_id: UserId, worktrees: HashSet, + channels: HashSet, } -pub struct WorktreeState { +pub struct Worktree { host_connection_id: Option, guest_connection_ids: HashMap, active_replica_ids: HashSet, @@ -52,7 +55,12 @@ pub struct WorktreeState { entries: HashMap, } -impl WorktreeState { +#[derive(Default)] +struct Channel { + connection_ids: HashSet, +} + +impl Worktree { pub fn connection_ids(&self) -> Vec { self.guest_connection_ids .keys() @@ -68,14 +76,21 @@ impl WorktreeState { } } +impl Channel { + fn connection_ids(&self) -> Vec { + self.connection_ids.iter().copied().collect() + } +} + impl State { // Add a new connection associated with a given user. pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) { self.connections.insert( connection_id, - ConnectionState { + Connection { user_id, worktrees: Default::default(), + channels: Default::default(), }, ); } @@ -83,8 +98,13 @@ impl State { // Remove the given connection and its association with any worktrees. pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Vec { let mut worktree_ids = Vec::new(); - if let Some(connection_state) = self.connections.remove(&connection_id) { - for worktree_id in connection_state.worktrees { + if let Some(connection) = self.connections.remove(&connection_id) { + for channel_id in connection.channels { + if let Some(channel) = self.channels.get_mut(&channel_id) { + channel.connection_ids.remove(&connection_id); + } + } + for worktree_id in connection.worktrees { if let Some(worktree) = self.worktrees.get_mut(&worktree_id) { if worktree.host_connection_id == Some(connection_id) { worktree_ids.push(worktree_id); @@ -100,28 +120,39 @@ impl State { worktree_ids } + fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) { + if let Some(connection) = self.connections.get_mut(&connection_id) { + connection.channels.insert(channel_id); + self.channels + .entry(channel_id) + .or_default() + .connection_ids + .insert(connection_id); + } + } + // Add the given connection as a guest of the given worktree pub fn join_worktree( &mut self, connection_id: ConnectionId, worktree_id: u64, access_token: &str, - ) -> Option<(ReplicaId, &WorktreeState)> { - if let Some(worktree_state) = self.worktrees.get_mut(&worktree_id) { - if access_token == worktree_state.access_token { - if let Some(connection_state) = self.connections.get_mut(&connection_id) { - connection_state.worktrees.insert(worktree_id); + ) -> Option<(ReplicaId, &Worktree)> { + if let Some(worktree) = self.worktrees.get_mut(&worktree_id) { + if access_token == worktree.access_token { + if let Some(connection) = self.connections.get_mut(&connection_id) { + connection.worktrees.insert(worktree_id); } let mut replica_id = 1; - while worktree_state.active_replica_ids.contains(&replica_id) { + while worktree.active_replica_ids.contains(&replica_id) { replica_id += 1; } - worktree_state.active_replica_ids.insert(replica_id); - worktree_state + worktree.active_replica_ids.insert(replica_id); + worktree .guest_connection_ids .insert(connection_id, replica_id); - Some((replica_id, worktree_state)) + Some((replica_id, worktree)) } else { None } @@ -142,7 +173,7 @@ impl State { &self, worktree_id: u64, connection_id: ConnectionId, - ) -> tide::Result<&WorktreeState> { + ) -> tide::Result<&Worktree> { let worktree = self .worktrees .get(&worktree_id) @@ -165,7 +196,7 @@ impl State { &mut self, worktree_id: u64, connection_id: ConnectionId, - ) -> tide::Result<&mut WorktreeState> { + ) -> tide::Result<&mut Worktree> { let worktree = self .worktrees .get_mut(&worktree_id) @@ -263,7 +294,9 @@ pub fn add_rpc_routes(router: &mut Router, state: &Arc, rpc: &Arc>, rpc: &Arc) { @@ -373,7 +406,7 @@ async fn share_worktree( .collect(); state.worktrees.insert( worktree_id, - WorktreeState { + Worktree { host_connection_id: Some(request.sender_id), guest_connection_ids: Default::default(), active_replica_ids: Default::default(), @@ -627,7 +660,7 @@ async fn get_channels( channels: channels .into_iter() .map(|chan| proto::Channel { - id: chan.id().0 as u64, + id: chan.id.to_proto(), name: chan.name, }) .collect(), @@ -637,6 +670,34 @@ async fn get_channels( Ok(()) } +async fn get_users( + request: TypedEnvelope, + rpc: &Arc, + state: &Arc, +) -> tide::Result<()> { + let user_id = state + .rpc + .read() + .await + .user_id_for_connection(request.sender_id)?; + let receipt = request.receipt(); + let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto); + let users = state + .db + .get_users_by_ids(user_id, user_ids) + .await? + .into_iter() + .map(|user| proto::User { + id: user.id.to_proto(), + github_login: user.github_login, + avatar_url: String::new(), + }) + .collect(); + rpc.respond(receipt, proto::GetUsersResponse { users }) + .await?; + Ok(()) +} + async fn join_channel( request: TypedEnvelope, rpc: &Arc, @@ -647,14 +708,74 @@ async fn join_channel( .read() .await .user_id_for_connection(request.sender_id)?; + let channel_id = ChannelId::from_proto(request.payload.channel_id); if !state .db - .can_user_access_channel(user_id, ChannelId(request.payload.channel_id as i32)) + .can_user_access_channel(user_id, channel_id) .await? { Err(anyhow!("access denied"))?; } + state + .rpc + .write() + .await + .join_channel(request.sender_id, channel_id); + let messages = state + .db + .get_recent_channel_messages(channel_id, 50) + .await? + .into_iter() + .map(|msg| proto::ChannelMessage { + id: msg.id.to_proto(), + body: msg.body, + timestamp: msg.sent_at.unix_timestamp() as u64, + sender_id: msg.sender_id.to_proto(), + }) + .collect(); + rpc.respond(request.receipt(), proto::JoinChannelResponse { messages }) + .await?; + Ok(()) +} + +async fn send_channel_message( + request: TypedEnvelope, + peer: &Arc, + app: &Arc, +) -> tide::Result<()> { + let channel_id = ChannelId::from_proto(request.payload.channel_id); + let user_id; + let connection_ids; + { + let state = app.rpc.read().await; + user_id = state.user_id_for_connection(request.sender_id)?; + if let Some(channel) = state.channels.get(&channel_id) { + connection_ids = channel.connection_ids(); + } else { + return Ok(()); + } + } + + let timestamp = OffsetDateTime::now_utc(); + let message_id = app + .db + .create_channel_message(channel_id, user_id, &request.payload.body, timestamp) + .await?; + let message = proto::ChannelMessageSent { + channel_id: channel_id.to_proto(), + message: Some(proto::ChannelMessage { + sender_id: user_id.to_proto(), + id: message_id.to_proto(), + body: request.payload.body, + timestamp: timestamp.unix_timestamp() as u64, + }), + }; + broadcast(request.sender_id, connection_ids, |conn_id| { + peer.send(conn_id, message.clone()) + }) + .await?; + Ok(()) } diff --git a/server/src/tests.rs b/server/src/tests.rs index 653e2ae59a..86980e8673 100644 --- a/server/src/tests.rs +++ b/server/src/tests.rs @@ -11,9 +11,10 @@ use rand::prelude::*; use serde_json::json; use sqlx::{ migrate::{MigrateDatabase, Migrator}, + types::time::OffsetDateTime, Executor as _, Postgres, }; -use std::{path::Path, sync::Arc}; +use std::{path::Path, sync::Arc, time::SystemTime}; use zed::{ editor::Editor, fs::{FakeFs, Fs as _}, @@ -485,10 +486,15 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) { let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await; let (user_id_b, client_b) = server.create_client(&mut cx_a, "user_b").await; - // Create a channel that includes these 2 users and 1 other user. + // Create an org that includes these 2 users and 1 other user. let db = &server.app_state.db; let user_id_c = db.create_user("user_c", false).await.unwrap(); let org_id = db.create_org("Test Org", "test-org").await.unwrap(); + db.add_org_member(org_id, user_id_a, false).await.unwrap(); + db.add_org_member(org_id, user_id_b, false).await.unwrap(); + db.add_org_member(org_id, user_id_c, false).await.unwrap(); + + // Create a channel that includes all the users. let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap(); db.add_channel_member(channel_id, user_id_a, false) .await @@ -499,11 +505,21 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) { db.add_channel_member(channel_id, user_id_c, false) .await .unwrap(); - db.create_channel_message(channel_id, user_id_c, "first message!") - .await - .unwrap(); - - // let chatroom_a = ChatRoom:: + db.create_channel_message( + channel_id, + user_id_c, + "first message!", + OffsetDateTime::now_utc(), + ) + .await + .unwrap(); + assert_eq!( + db.get_recent_channel_messages(channel_id, 50) + .await + .unwrap()[0] + .body, + "first message!" + ); } struct TestServer { diff --git a/zrpc/proto/zed.proto b/zrpc/proto/zed.proto index 65360ae885..7d8e3ff742 100644 --- a/zrpc/proto/zed.proto +++ b/zrpc/proto/zed.proto @@ -24,10 +24,10 @@ message Envelope { RemovePeer remove_peer = 19; GetChannels get_channels = 20; GetChannelsResponse get_channels_response = 21; - JoinChannel join_channel = 22; - JoinChannelResponse join_channel_response = 23; - GetUsers get_users = 24; - GetUsersResponse get_users_response = 25; + GetUsers get_users = 22; + GetUsersResponse get_users_response = 23; + JoinChannel join_channel = 24; + JoinChannelResponse join_channel_response = 25; SendChannelMessage send_channel_message = 26; ChannelMessageSent channel_message_sent = 27; }