Maintain server state consistency when removing a connection

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Max Brunsfeld 2021-09-20 15:45:33 -07:00
parent 8de9c362c9
commit 8f578e7521

View file

@ -3,6 +3,7 @@ use super::{
db::{ChannelId, MessageId, UserId},
AppState,
};
use crate::errors::TideResultExt;
use anyhow::anyhow;
use async_std::{sync::RwLock, task};
use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
@ -49,7 +50,7 @@ pub struct Server {
struct ServerState {
connections: HashMap<ConnectionId, ConnectionState>,
connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
pub worktrees: HashMap<u64, Worktree>,
worktrees: HashMap<u64, Worktree>,
visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
channels: HashMap<ChannelId, Channel>,
next_worktree_id: u64,
@ -707,15 +708,19 @@ impl Server {
{
let worktree = &state.worktrees[worktree_id];
let mut participants = HashSet::new();
let mut guests = HashSet::new();
if let Ok(share) = worktree.share() {
for guest_connection_id in share.guest_connection_ids.keys() {
let user_id = state.user_id_for_connection(*guest_connection_id)?;
participants.insert(user_id.to_proto());
let user_id = state
.user_id_for_connection(*guest_connection_id)
.context("stale worktree guest connection")?;
guests.insert(user_id.to_proto());
}
}
let host_user_id = state.user_id_for_connection(worktree.host_connection_id)?;
let host_user_id = state
.user_id_for_connection(worktree.host_connection_id)
.context("stale worktree host connection")?;
let host =
collaborators
.entry(host_user_id)
@ -726,7 +731,7 @@ impl Server {
host.worktrees.push(proto::WorktreeMetadata {
root_name: worktree.root_name.clone(),
is_shared: worktree.share().is_ok(),
participants: participants.into_iter().collect(),
participants: guests.into_iter().collect(),
});
}
@ -1137,7 +1142,14 @@ impl ServerState {
.insert(worktree_id);
}
self.next_worktree_id += 1;
if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
connection.worktrees.insert(worktree_id);
}
self.worktrees.insert(worktree_id, worktree);
#[cfg(test)]
self.check_invariants();
worktree_id
}
@ -1161,6 +1173,89 @@ impl ServerState {
visible_worktrees.remove(&worktree_id);
}
}
#[cfg(test)]
self.check_invariants();
}
#[cfg(test)]
fn check_invariants(&self) {
for (connection_id, connection) in &self.connections {
for worktree_id in &connection.worktrees {
let worktree = &self.worktrees.get(&worktree_id).unwrap();
if worktree.host_connection_id != *connection_id {
assert!(worktree
.share()
.unwrap()
.guest_connection_ids
.contains_key(connection_id));
}
}
for channel_id in &connection.channels {
let channel = self.channels.get(channel_id).unwrap();
assert!(channel.connection_ids.contains(connection_id));
}
assert!(self
.connections_by_user_id
.get(&connection.user_id)
.unwrap()
.contains(connection_id));
}
for (user_id, connection_ids) in &self.connections_by_user_id {
for connection_id in connection_ids {
assert_eq!(
self.connections.get(connection_id).unwrap().user_id,
*user_id
);
}
}
for (worktree_id, worktree) in &self.worktrees {
let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
assert!(host_connection.worktrees.contains(worktree_id));
for collaborator_id in &worktree.collaborator_user_ids {
let visible_worktree_ids = self
.visible_worktrees_by_user_id
.get(collaborator_id)
.unwrap();
assert!(visible_worktree_ids.contains(worktree_id));
}
if let Some(share) = &worktree.share {
for guest_connection_id in share.guest_connection_ids.keys() {
let guest_connection = self.connections.get(guest_connection_id).unwrap();
assert!(guest_connection.worktrees.contains(worktree_id));
}
assert_eq!(
share.active_replica_ids.len(),
share.guest_connection_ids.len(),
);
assert_eq!(
share.active_replica_ids,
share
.guest_connection_ids
.values()
.copied()
.collect::<HashSet<_>>(),
);
}
}
for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
for worktree_id in visible_worktree_ids {
let worktree = self.worktrees.get(worktree_id).unwrap();
assert!(worktree.collaborator_user_ids.contains(user_id));
}
}
for (channel_id, channel) in &self.channels {
for connection_id in &channel.connection_ids {
let connection = self.connections.get(connection_id).unwrap();
assert!(connection.channels.contains(channel_id));
}
}
}
}