This commit is contained in:
Max Brunsfeld
2021-09-20 18:05:46 -07:00
parent 8b1a2c8cd2
commit 5dfd4be174
5 changed files with 752 additions and 573 deletions

View File

@@ -1,3 +1,5 @@
mod store;
use super::{
auth,
db::{ChannelId, MessageId, UserId},
@@ -8,16 +10,17 @@ use anyhow::anyhow;
use async_std::{sync::RwLock, task};
use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use futures::{future::BoxFuture, FutureExt};
use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
use postage::{broadcast, mpsc, prelude::Sink as _, prelude::Stream as _};
use sha1::{Digest as _, Sha1};
use std::{
any::TypeId,
collections::{hash_map, HashMap, HashSet},
collections::{HashMap, HashSet},
future::Future,
mem,
sync::Arc,
time::Instant,
};
use store::{ReplicaId, Store, Worktree};
use surf::StatusCode;
use tide::log;
use tide::{
@@ -30,8 +33,6 @@ use zrpc::{
Connection, ConnectionId, Peer, TypedEnvelope,
};
type ReplicaId = u16;
type MessageHandler = Box<
dyn Send
+ Sync
@@ -40,46 +41,12 @@ type MessageHandler = Box<
pub struct Server {
peer: Arc<Peer>,
state: RwLock<ServerState>,
store: RwLock<Store>,
app_state: Arc<AppState>,
handlers: HashMap<TypeId, MessageHandler>,
notifications: Option<mpsc::Sender<()>>,
}
#[derive(Default)]
struct ServerState {
connections: HashMap<ConnectionId, ConnectionState>,
connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
worktrees: HashMap<u64, Worktree>,
visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
channels: HashMap<ChannelId, Channel>,
next_worktree_id: u64,
}
struct ConnectionState {
user_id: UserId,
worktrees: HashSet<u64>,
channels: HashSet<ChannelId>,
}
struct Worktree {
host_connection_id: ConnectionId,
collaborator_user_ids: Vec<UserId>,
root_name: String,
share: Option<WorktreeShare>,
}
struct WorktreeShare {
guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
active_replica_ids: HashSet<ReplicaId>,
entries: HashMap<u64, proto::Entry>,
}
#[derive(Default)]
struct Channel {
connection_ids: HashSet<ConnectionId>,
}
const MESSAGE_COUNT_PER_PAGE: usize = 100;
const MAX_MESSAGE_LEN: usize = 1024;
@@ -92,7 +59,7 @@ impl Server {
let mut server = Self {
peer,
app_state,
state: Default::default(),
store: Default::default(),
handlers: Default::default(),
notifications,
};
@@ -100,7 +67,7 @@ impl Server {
server
.add_handler(Server::ping)
.add_handler(Server::open_worktree)
.add_handler(Server::handle_close_worktree)
.add_handler(Server::close_worktree)
.add_handler(Server::share_worktree)
.add_handler(Server::unshare_worktree)
.add_handler(Server::join_worktree)
@@ -149,7 +116,10 @@ impl Server {
async move {
let (connection_id, handle_io, mut incoming_rx) =
this.peer.add_connection(connection).await;
this.add_connection(connection_id, user_id).await;
this.store
.write()
.await
.add_connection(connection_id, user_id);
if let Err(err) = this.update_collaborators_for_users(&[user_id]).await {
log::error!("error updating collaborators for {:?}: {}", user_id, err);
}
@@ -197,61 +167,40 @@ impl Server {
}
}
async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
async fn sign_out(self: &Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
self.peer.disconnect(connection_id).await;
self.remove_connection(connection_id).await?;
Ok(())
}
let removed_connection = self.store.write().await.remove_connection(connection_id)?;
// Add a new connection associated with a given user.
async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
let mut state = self.state.write().await;
state.connections.insert(
connection_id,
ConnectionState {
user_id,
worktrees: Default::default(),
channels: Default::default(),
},
);
state
.connections_by_user_id
.entry(user_id)
.or_default()
.insert(connection_id);
}
// Remove the given connection and its association with any worktrees.
async fn remove_connection(
self: &Arc<Server>,
connection_id: ConnectionId,
) -> tide::Result<()> {
let mut worktree_ids = Vec::new();
let mut state = self.state.write().await;
if let Some(connection) = state.connections.remove(&connection_id) {
worktree_ids = connection.worktrees.into_iter().collect();
for channel_id in connection.channels {
if let Some(channel) = state.channels.get_mut(&channel_id) {
channel.connection_ids.remove(&connection_id);
}
}
let user_connections = state
.connections_by_user_id
.get_mut(&connection.user_id)
.unwrap();
user_connections.remove(&connection_id);
if user_connections.is_empty() {
state.connections_by_user_id.remove(&connection.user_id);
for (worktree_id, worktree) in removed_connection.hosted_worktrees {
if let Some(share) = worktree.share {
broadcast(
connection_id,
share.guest_connection_ids.keys().copied().collect(),
|conn_id| {
self.peer
.send(conn_id, proto::UnshareWorktree { worktree_id })
},
)
.await?;
}
}
drop(state);
for worktree_id in worktree_ids {
self.close_worktree(worktree_id, connection_id).await?;
for (worktree_id, peer_ids) in removed_connection.guest_worktree_ids {
broadcast(connection_id, peer_ids, |conn_id| {
self.peer.send(
conn_id,
proto::RemovePeer {
worktree_id,
peer_id: connection_id.0,
},
)
})
.await?;
}
self.update_collaborators_for_users(removed_connection.collaborator_ids.iter())
.await;
Ok(())
}
@@ -266,7 +215,7 @@ impl Server {
) -> tide::Result<()> {
let receipt = request.receipt();
let host_user_id = self
.state
.store
.read()
.await
.user_id_for_connection(request.sender_id)?;
@@ -289,7 +238,7 @@ impl Server {
}
let collaborator_user_ids = collaborator_user_ids.into_iter().collect::<Vec<_>>();
let worktree_id = self.state.write().await.add_worktree(Worktree {
let worktree_id = self.store.write().await.add_worktree(Worktree {
host_connection_id: request.sender_id,
collaborator_user_ids: collaborator_user_ids.clone(),
root_name: request.payload.root_name,
@@ -305,6 +254,33 @@ impl Server {
Ok(())
}
async fn close_worktree(
self: Arc<Server>,
request: TypedEnvelope<proto::CloseWorktree>,
) -> tide::Result<()> {
let worktree_id = request.payload.worktree_id;
let worktree = self
.store
.write()
.await
.remove_worktree(worktree_id, request.sender_id)?;
if let Some(share) = worktree.share {
broadcast(
request.sender_id,
share.guest_connection_ids.keys().copied().collect(),
|conn_id| {
self.peer
.send(conn_id, proto::UnshareWorktree { worktree_id })
},
)
.await?;
}
self.update_collaborators_for_users(&worktree.collaborator_user_ids)
.await?;
Ok(())
}
async fn share_worktree(
self: Arc<Server>,
mut request: TypedEnvelope<proto::ShareWorktree>,
@@ -319,16 +295,12 @@ impl Server {
.map(|entry| (entry.id, entry))
.collect();
let mut state = self.state.write().await;
if let Some(worktree) = state.worktrees.get_mut(&worktree.id) {
worktree.share = Some(WorktreeShare {
guest_connection_ids: Default::default(),
active_replica_ids: Default::default(),
entries,
});
let collaborator_user_ids = worktree.collaborator_user_ids.clone();
drop(state);
if let Some(collaborator_user_ids) =
self.store
.write()
.await
.share_worktree(worktree.id, request.sender_id, entries)
{
self.peer
.respond(request.receipt(), proto::ShareWorktreeResponse {})
.await?;
@@ -352,26 +324,11 @@ impl Server {
request: TypedEnvelope<proto::UnshareWorktree>,
) -> tide::Result<()> {
let worktree_id = request.payload.worktree_id;
let connection_ids;
let collaborator_user_ids;
{
let mut state = self.state.write().await;
let worktree = state.write_worktree(worktree_id, request.sender_id)?;
if worktree.host_connection_id != request.sender_id {
return Err(anyhow!("no such worktree"))?;
}
connection_ids = worktree.connection_ids();
collaborator_user_ids = worktree.collaborator_user_ids.clone();
worktree.share.take();
for connection_id in &connection_ids {
if let Some(connection) = state.connections.get_mut(connection_id) {
connection.worktrees.remove(&worktree_id);
}
}
}
let (connection_ids, collaborator_user_ids) = self
.store
.write()
.await
.unshare_worktree(worktree_id, request.sender_id)?;
broadcast(request.sender_id, connection_ids, |conn_id| {
self.peer
@@ -390,7 +347,7 @@ impl Server {
) -> tide::Result<()> {
let worktree_id = request.payload.worktree_id;
let user_id = self
.state
.store
.read()
.await
.user_id_for_connection(request.sender_id)?;
@@ -398,7 +355,7 @@ impl Server {
let response;
let connection_ids;
let collaborator_user_ids;
let mut state = self.state.write().await;
let mut state = self.store.write().await;
match state.join_worktree(request.sender_id, user_id, worktree_id) {
Ok((peer_replica_id, worktree)) => {
let share = worktree.share()?;
@@ -462,48 +419,17 @@ impl Server {
Ok(())
}
async fn handle_close_worktree(
self: Arc<Server>,
request: TypedEnvelope<proto::CloseWorktree>,
) -> tide::Result<()> {
self.close_worktree(request.payload.worktree_id, request.sender_id)
.await
}
async fn close_worktree(
async fn leave_worktree(
self: &Arc<Server>,
worktree_id: u64,
sender_conn_id: ConnectionId,
) -> tide::Result<()> {
let connection_ids;
let collaborator_user_ids;
let mut is_host = false;
let mut is_guest = false;
if let Some((connection_ids, collaborator_ids)) = self
.store
.write()
.await
.leave_worktree(sender_conn_id, worktree_id)
{
let mut state = self.state.write().await;
let worktree = state.write_worktree(worktree_id, sender_conn_id)?;
connection_ids = worktree.connection_ids();
collaborator_user_ids = worktree.collaborator_user_ids.clone();
if worktree.host_connection_id == sender_conn_id {
is_host = true;
state.remove_worktree(worktree_id);
} else {
let share = worktree.share_mut()?;
if let Some(replica_id) = share.guest_connection_ids.remove(&sender_conn_id) {
is_guest = true;
share.active_replica_ids.remove(&replica_id);
}
}
}
if is_host {
broadcast(sender_conn_id, connection_ids, |conn_id| {
self.peer
.send(conn_id, proto::UnshareWorktree { worktree_id })
})
.await?;
} else if is_guest {
broadcast(sender_conn_id, connection_ids, |conn_id| {
self.peer.send(
conn_id,
@@ -513,10 +439,10 @@ impl Server {
},
)
})
.await?
}
self.update_collaborators_for_users(&collaborator_user_ids)
.await?;
self.update_collaborators_for_users(&collaborator_ids)
.await?;
}
Ok(())
}
@@ -524,22 +450,19 @@ impl Server {
self: Arc<Server>,
request: TypedEnvelope<proto::UpdateWorktree>,
) -> tide::Result<()> {
{
let mut state = self.state.write().await;
let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
let share = worktree.share_mut()?;
let connection_ids = self.store.write().await.update_worktree(
request.sender_id,
request.payload.worktree_id,
&request.payload.removed_entries,
&request.payload.updated_entries,
)?;
for entry_id in &request.payload.removed_entries {
share.entries.remove(&entry_id);
}
broadcast(request.sender_id, connection_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
})
.await?;
for entry in &request.payload.updated_entries {
share.entries.insert(entry.id, entry.clone());
}
}
self.broadcast_in_worktree(request.payload.worktree_id, &request)
.await?;
Ok(())
}
@@ -548,14 +471,11 @@ impl Server {
request: TypedEnvelope<proto::OpenBuffer>,
) -> tide::Result<()> {
let receipt = request.receipt();
let worktree_id = request.payload.worktree_id;
let host_connection_id = self
.state
.store
.read()
.await
.read_worktree(worktree_id, request.sender_id)?
.host_connection_id;
.worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
let response = self
.peer
.forward_request(request.sender_id, host_connection_id, request.payload)
@@ -569,16 +489,13 @@ impl Server {
request: TypedEnvelope<proto::CloseBuffer>,
) -> tide::Result<()> {
let host_connection_id = self
.state
.store
.read()
.await
.read_worktree(request.payload.worktree_id, request.sender_id)?
.host_connection_id;
.worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
self.peer
.forward_send(request.sender_id, host_connection_id, request.payload)
.await?;
Ok(())
}
@@ -589,15 +506,11 @@ impl Server {
let host;
let guests;
{
let state = self.state.read().await;
let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
host = worktree.host_connection_id;
guests = worktree
.share()?
.guest_connection_ids
.keys()
.copied()
.collect::<Vec<_>>();
let state = self.store.read().await;
host = state
.worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
guests = state
.worktree_guest_connection_ids(request.sender_id, request.payload.worktree_id)?;
}
let sender = request.sender_id;
@@ -627,8 +540,18 @@ impl Server {
self: Arc<Server>,
request: TypedEnvelope<proto::UpdateBuffer>,
) -> tide::Result<()> {
self.broadcast_in_worktree(request.payload.worktree_id, &request)
.await?;
broadcast(
request.sender_id,
self.store
.read()
.await
.worktree_connection_ids(request.sender_id, request.payload.worktree_id)?,
|connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
},
)
.await?;
self.peer.respond(request.receipt(), proto::Ack {}).await?;
Ok(())
}
@@ -637,8 +560,19 @@ impl Server {
self: Arc<Server>,
request: TypedEnvelope<proto::BufferSaved>,
) -> tide::Result<()> {
self.broadcast_in_worktree(request.payload.worktree_id, &request)
.await
broadcast(
request.sender_id,
self.store
.read()
.await
.worktree_connection_ids(request.sender_id, request.payload.worktree_id)?,
|connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
},
)
.await?;
Ok(())
}
async fn get_channels(
@@ -646,7 +580,7 @@ impl Server {
request: TypedEnvelope<proto::GetChannels>,
) -> tide::Result<()> {
let user_id = self
.state
.store
.read()
.await
.user_id_for_connection(request.sender_id)?;
@@ -698,45 +632,10 @@ impl Server {
) -> tide::Result<()> {
let mut send_futures = Vec::new();
let state = self.state.read().await;
let state = self.store.read().await;
for user_id in user_ids {
let mut collaborators = HashMap::new();
for worktree_id in state
.visible_worktrees_by_user_id
.get(&user_id)
.unwrap_or(&HashSet::new())
{
let worktree = &state.worktrees[worktree_id];
let mut guests = HashSet::new();
if let Ok(share) = worktree.share() {
for guest_connection_id in share.guest_connection_ids.keys() {
let user_id = state
.user_id_for_connection(*guest_connection_id)
.context("stale worktree guest connection")?;
guests.insert(user_id.to_proto());
}
}
let host_user_id = state
.user_id_for_connection(worktree.host_connection_id)
.context("stale worktree host connection")?;
let host =
collaborators
.entry(host_user_id)
.or_insert_with(|| proto::Collaborator {
user_id: host_user_id.to_proto(),
worktrees: Vec::new(),
});
host.worktrees.push(proto::WorktreeMetadata {
root_name: worktree.root_name.clone(),
is_shared: worktree.share().is_ok(),
participants: guests.into_iter().collect(),
});
}
let collaborators = collaborators.into_values().collect::<Vec<_>>();
for connection_id in state.user_connection_ids(*user_id) {
let collaborators = state.collaborators_for_user(*user_id);
for connection_id in state.connection_ids_for_user(*user_id) {
send_futures.push(self.peer.send(
connection_id,
proto::UpdateCollaborators {
@@ -757,7 +656,7 @@ impl Server {
request: TypedEnvelope<proto::JoinChannel>,
) -> tide::Result<()> {
let user_id = self
.state
.store
.read()
.await
.user_id_for_connection(request.sender_id)?;
@@ -771,7 +670,7 @@ impl Server {
Err(anyhow!("access denied"))?;
}
self.state
self.store
.write()
.await
.join_channel(request.sender_id, channel_id);
@@ -806,7 +705,7 @@ impl Server {
request: TypedEnvelope<proto::LeaveChannel>,
) -> tide::Result<()> {
let user_id = self
.state
.store
.read()
.await
.user_id_for_connection(request.sender_id)?;
@@ -820,7 +719,7 @@ impl Server {
Err(anyhow!("access denied"))?;
}
self.state
self.store
.write()
.await
.leave_channel(request.sender_id, channel_id);
@@ -837,10 +736,10 @@ impl Server {
let user_id;
let connection_ids;
{
let state = self.state.read().await;
let state = self.store.read().await;
user_id = state.user_id_for_connection(request.sender_id)?;
if let Some(channel) = state.channels.get(&channel_id) {
connection_ids = channel.connection_ids();
if let Some(ids) = state.channel_connection_ids(channel_id) {
connection_ids = ids;
} else {
return Ok(());
}
@@ -925,7 +824,7 @@ impl Server {
request: TypedEnvelope<proto::GetChannelMessages>,
) -> tide::Result<()> {
let user_id = self
.state
.store
.read()
.await
.user_id_for_connection(request.sender_id)?;
@@ -968,27 +867,6 @@ impl Server {
.await?;
Ok(())
}
async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
&self,
worktree_id: u64,
message: &TypedEnvelope<T>,
) -> tide::Result<()> {
let connection_ids = self
.state
.read()
.await
.read_worktree(worktree_id, message.sender_id)?
.connection_ids();
broadcast(message.sender_id, connection_ids, |conn_id| {
self.peer
.forward_send(message.sender_id, conn_id, message.payload.clone())
})
.await?;
Ok(())
}
}
pub async fn broadcast<F, T>(
@@ -1008,292 +886,6 @@ where
Ok(())
}
impl ServerState {
fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
if let Some(connection) = self.connections.get_mut(&connection_id) {
connection.channels.insert(channel_id);
self.channels
.entry(channel_id)
.or_default()
.connection_ids
.insert(connection_id);
}
}
fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
if let Some(connection) = self.connections.get_mut(&connection_id) {
connection.channels.remove(&channel_id);
if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
entry.get_mut().connection_ids.remove(&connection_id);
if entry.get_mut().connection_ids.is_empty() {
entry.remove();
}
}
}
}
fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
Ok(self
.connections
.get(&connection_id)
.ok_or_else(|| anyhow!("unknown connection"))?
.user_id)
}
fn user_connection_ids<'a>(
&'a self,
user_id: UserId,
) -> impl 'a + Iterator<Item = ConnectionId> {
self.connections_by_user_id
.get(&user_id)
.into_iter()
.flatten()
.copied()
}
// Add the given connection as a guest of the given worktree
fn join_worktree(
&mut self,
connection_id: ConnectionId,
user_id: UserId,
worktree_id: u64,
) -> tide::Result<(ReplicaId, &Worktree)> {
let connection = self
.connections
.get_mut(&connection_id)
.ok_or_else(|| anyhow!("no such connection"))?;
let worktree = self
.worktrees
.get_mut(&worktree_id)
.ok_or_else(|| anyhow!("no such worktree"))?;
if !worktree.collaborator_user_ids.contains(&user_id) {
Err(anyhow!("no such worktree"))?;
}
let share = worktree.share_mut()?;
connection.worktrees.insert(worktree_id);
let mut replica_id = 1;
while share.active_replica_ids.contains(&replica_id) {
replica_id += 1;
}
share.active_replica_ids.insert(replica_id);
share.guest_connection_ids.insert(connection_id, replica_id);
return Ok((replica_id, worktree));
}
fn read_worktree(
&self,
worktree_id: u64,
connection_id: ConnectionId,
) -> tide::Result<&Worktree> {
let worktree = self
.worktrees
.get(&worktree_id)
.ok_or_else(|| anyhow!("worktree not found"))?;
if worktree.host_connection_id == connection_id
|| worktree
.share()?
.guest_connection_ids
.contains_key(&connection_id)
{
Ok(worktree)
} else {
Err(anyhow!(
"{} is not a member of worktree {}",
connection_id,
worktree_id
))?
}
}
fn write_worktree(
&mut self,
worktree_id: u64,
connection_id: ConnectionId,
) -> tide::Result<&mut Worktree> {
let worktree = self
.worktrees
.get_mut(&worktree_id)
.ok_or_else(|| anyhow!("worktree not found"))?;
if worktree.host_connection_id == connection_id
|| worktree.share.as_ref().map_or(false, |share| {
share.guest_connection_ids.contains_key(&connection_id)
})
{
Ok(worktree)
} else {
Err(anyhow!(
"{} is not a member of worktree {}",
connection_id,
worktree_id
))?
}
}
fn add_worktree(&mut self, worktree: Worktree) -> u64 {
let worktree_id = self.next_worktree_id;
for collaborator_user_id in &worktree.collaborator_user_ids {
self.visible_worktrees_by_user_id
.entry(*collaborator_user_id)
.or_default()
.insert(worktree_id);
}
self.next_worktree_id += 1;
if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
connection.worktrees.insert(worktree_id);
}
self.worktrees.insert(worktree_id, worktree);
#[cfg(test)]
self.check_invariants();
worktree_id
}
fn remove_worktree(&mut self, worktree_id: u64) {
let worktree = self.worktrees.remove(&worktree_id).unwrap();
if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
connection.worktrees.remove(&worktree_id);
}
if let Some(share) = worktree.share {
for connection_id in share.guest_connection_ids.keys() {
if let Some(connection) = self.connections.get_mut(connection_id) {
connection.worktrees.remove(&worktree_id);
}
}
}
for collaborator_user_id in worktree.collaborator_user_ids {
if let Some(visible_worktrees) = self
.visible_worktrees_by_user_id
.get_mut(&collaborator_user_id)
{
visible_worktrees.remove(&worktree_id);
}
}
#[cfg(test)]
self.check_invariants();
}
#[cfg(test)]
fn check_invariants(&self) {
for (connection_id, connection) in &self.connections {
for worktree_id in &connection.worktrees {
let worktree = &self.worktrees.get(&worktree_id).unwrap();
if worktree.host_connection_id != *connection_id {
assert!(worktree
.share()
.unwrap()
.guest_connection_ids
.contains_key(connection_id));
}
}
for channel_id in &connection.channels {
let channel = self.channels.get(channel_id).unwrap();
assert!(channel.connection_ids.contains(connection_id));
}
assert!(self
.connections_by_user_id
.get(&connection.user_id)
.unwrap()
.contains(connection_id));
}
for (user_id, connection_ids) in &self.connections_by_user_id {
for connection_id in connection_ids {
assert_eq!(
self.connections.get(connection_id).unwrap().user_id,
*user_id
);
}
}
for (worktree_id, worktree) in &self.worktrees {
let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
assert!(host_connection.worktrees.contains(worktree_id));
for collaborator_id in &worktree.collaborator_user_ids {
let visible_worktree_ids = self
.visible_worktrees_by_user_id
.get(collaborator_id)
.unwrap();
assert!(visible_worktree_ids.contains(worktree_id));
}
if let Some(share) = &worktree.share {
for guest_connection_id in share.guest_connection_ids.keys() {
let guest_connection = self.connections.get(guest_connection_id).unwrap();
assert!(guest_connection.worktrees.contains(worktree_id));
}
assert_eq!(
share.active_replica_ids.len(),
share.guest_connection_ids.len(),
);
assert_eq!(
share.active_replica_ids,
share
.guest_connection_ids
.values()
.copied()
.collect::<HashSet<_>>(),
);
}
}
for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
for worktree_id in visible_worktree_ids {
let worktree = self.worktrees.get(worktree_id).unwrap();
assert!(worktree.collaborator_user_ids.contains(user_id));
}
}
for (channel_id, channel) in &self.channels {
for connection_id in &channel.connection_ids {
let connection = self.connections.get(connection_id).unwrap();
assert!(connection.channels.contains(channel_id));
}
}
}
}
impl Worktree {
pub fn connection_ids(&self) -> Vec<ConnectionId> {
if let Some(share) = &self.share {
share
.guest_connection_ids
.keys()
.copied()
.chain(Some(self.host_connection_id))
.collect()
} else {
vec![self.host_connection_id]
}
}
fn share(&self) -> tide::Result<&WorktreeShare> {
Ok(self
.share
.as_ref()
.ok_or_else(|| anyhow!("worktree is not shared"))?)
}
fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
Ok(self
.share
.as_mut()
.ok_or_else(|| anyhow!("worktree is not shared"))?)
}
}
impl Channel {
fn connection_ids(&self) -> Vec<ConnectionId> {
self.connection_ids.iter().copied().collect()
}
}
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let server = Server::new(app.state().clone(), rpc.clone(), None);
app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
@@ -2477,16 +2069,16 @@ mod tests {
})
}
async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
self.server.state.read().await
async fn state<'a>(&'a self) -> RwLockReadGuard<'a, Store> {
self.server.store.read().await
}
async fn condition<F>(&mut self, mut predicate: F)
where
F: FnMut(&ServerState) -> bool,
F: FnMut(&Store) -> bool,
{
async_std::future::timeout(Duration::from_millis(500), async {
while !(predicate)(&*self.server.state.read().await) {
while !(predicate)(&*self.server.store.read().await) {
self.notifications.recv().await;
}
})

574
server/src/rpc/store.rs Normal file
View File

@@ -0,0 +1,574 @@
use crate::db::{ChannelId, MessageId, UserId};
use crate::errors::TideResultExt;
use anyhow::anyhow;
use std::collections::{hash_map, HashMap, HashSet};
use zrpc::{proto, ConnectionId};
#[derive(Default)]
pub struct Store {
connections: HashMap<ConnectionId, ConnectionState>,
connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
worktrees: HashMap<u64, Worktree>,
visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
channels: HashMap<ChannelId, Channel>,
next_worktree_id: u64,
}
struct ConnectionState {
user_id: UserId,
worktrees: HashSet<u64>,
channels: HashSet<ChannelId>,
}
pub struct Worktree {
pub host_connection_id: ConnectionId,
pub collaborator_user_ids: Vec<UserId>,
pub root_name: String,
pub share: Option<WorktreeShare>,
}
struct WorktreeShare {
pub guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
pub active_replica_ids: HashSet<ReplicaId>,
pub entries: HashMap<u64, proto::Entry>,
}
#[derive(Default)]
struct Channel {
connection_ids: HashSet<ConnectionId>,
}
pub type ReplicaId = u16;
#[derive(Default)]
pub struct RemovedConnectionState {
pub hosted_worktrees: HashMap<u64, Worktree>,
pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
pub collaborator_ids: HashSet<UserId>,
}
impl Store {
pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
self.connections.insert(
connection_id,
ConnectionState {
user_id,
worktrees: Default::default(),
channels: Default::default(),
},
);
self.connections_by_user_id
.entry(user_id)
.or_default()
.insert(connection_id);
}
pub fn remove_connection(
&mut self,
connection_id: ConnectionId,
) -> tide::Result<RemovedConnectionState> {
let connection = if let Some(connection) = self.connections.get(&connection_id) {
connection
} else {
return Err(anyhow!("no such connection"))?;
};
for channel_id in connection.channels {
if let Some(channel) = self.channels.get_mut(&channel_id) {
channel.connection_ids.remove(&connection_id);
}
}
let user_connections = self
.connections_by_user_id
.get_mut(&connection.user_id)
.unwrap();
user_connections.remove(&connection_id);
if user_connections.is_empty() {
self.connections_by_user_id.remove(&connection.user_id);
}
let mut result = RemovedConnectionState::default();
for worktree_id in connection.worktrees {
if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
result.hosted_worktrees.insert(worktree_id, worktree);
result
.collaborator_ids
.extend(worktree.collaborator_user_ids.iter().copied());
} else {
if let Some(worktree) = self.worktrees.get(&worktree_id) {
result
.guest_worktree_ids
.insert(worktree_id, worktree.connection_ids());
result
.collaborator_ids
.extend(worktree.collaborator_user_ids.iter().copied());
}
}
}
Ok(result)
}
pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
if let Some(connection) = self.connections.get_mut(&connection_id) {
connection.channels.insert(channel_id);
self.channels
.entry(channel_id)
.or_default()
.connection_ids
.insert(connection_id);
}
}
pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
if let Some(connection) = self.connections.get_mut(&connection_id) {
connection.channels.remove(&channel_id);
if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
entry.get_mut().connection_ids.remove(&connection_id);
if entry.get_mut().connection_ids.is_empty() {
entry.remove();
}
}
}
}
pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
Ok(self
.connections
.get(&connection_id)
.ok_or_else(|| anyhow!("unknown connection"))?
.user_id)
}
pub fn connection_ids_for_user<'a>(
&'a self,
user_id: UserId,
) -> impl 'a + Iterator<Item = ConnectionId> {
self.connections_by_user_id
.get(&user_id)
.into_iter()
.flatten()
.copied()
}
pub fn collaborators_for_user(&self, user_id: UserId) -> Vec<proto::Collaborator> {
let mut collaborators = HashMap::new();
for worktree_id in self
.visible_worktrees_by_user_id
.get(&user_id)
.unwrap_or(&HashSet::new())
{
let worktree = &self.worktrees[worktree_id];
let mut guests = HashSet::new();
if let Ok(share) = worktree.share() {
for guest_connection_id in share.guest_connection_ids.keys() {
if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
guests.insert(user_id.to_proto());
}
}
}
if let Ok(host_user_id) = self
.user_id_for_connection(worktree.host_connection_id)
.context("stale worktree host connection")
{
let host =
collaborators
.entry(host_user_id)
.or_insert_with(|| proto::Collaborator {
user_id: host_user_id.to_proto(),
worktrees: Vec::new(),
});
host.worktrees.push(proto::WorktreeMetadata {
root_name: worktree.root_name.clone(),
is_shared: worktree.share().is_ok(),
participants: guests.into_iter().collect(),
});
}
}
collaborators.into_values().collect()
}
pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
let worktree_id = self.next_worktree_id;
for collaborator_user_id in &worktree.collaborator_user_ids {
self.visible_worktrees_by_user_id
.entry(*collaborator_user_id)
.or_default()
.insert(worktree_id);
}
self.next_worktree_id += 1;
if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
connection.worktrees.insert(worktree_id);
}
self.worktrees.insert(worktree_id, worktree);
#[cfg(test)]
self.check_invariants();
worktree_id
}
pub fn remove_worktree(
&mut self,
worktree_id: u64,
acting_connection_id: ConnectionId,
) -> tide::Result<Worktree> {
let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
if e.get().host_connection_id != acting_connection_id {
Err(anyhow!("not your worktree"))?;
}
e.remove()
} else {
return Err(anyhow!("no such worktree"))?;
};
if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
connection.worktrees.remove(&worktree_id);
}
if let Some(share) = worktree.share {
for connection_id in share.guest_connection_ids.keys() {
if let Some(connection) = self.connections.get_mut(connection_id) {
connection.worktrees.remove(&worktree_id);
}
}
}
for collaborator_user_id in worktree.collaborator_user_ids {
if let Some(visible_worktrees) = self
.visible_worktrees_by_user_id
.get_mut(&collaborator_user_id)
{
visible_worktrees.remove(&worktree_id);
}
}
#[cfg(test)]
self.check_invariants();
Ok(worktree)
}
pub fn share_worktree(
&mut self,
worktree_id: u64,
connection_id: ConnectionId,
entries: HashMap<u64, proto::Entry>,
) -> Option<Vec<UserId>> {
if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
if worktree.host_connection_id == connection_id {
worktree.share = Some(WorktreeShare {
guest_connection_ids: Default::default(),
active_replica_ids: Default::default(),
entries,
});
return Some(worktree.collaborator_user_ids.clone());
}
}
None
}
pub fn unshare_worktree(
&mut self,
worktree_id: u64,
acting_connection_id: ConnectionId,
) -> tide::Result<(Vec<ConnectionId>, Vec<UserId>)> {
let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
worktree
} else {
return Err(anyhow!("no such worktree"))?;
};
if worktree.host_connection_id != acting_connection_id {
return Err(anyhow!("not your worktree"))?;
}
let connection_ids = worktree.connection_ids();
if let Some(share) = worktree.share.take() {
for connection_id in &connection_ids {
if let Some(connection) = self.connections.get_mut(connection_id) {
connection.worktrees.remove(&worktree_id);
}
}
Ok((connection_ids, worktree.collaborator_user_ids.clone()))
} else {
Err(anyhow!("worktree is not shared"))?
}
}
pub fn join_worktree(
&mut self,
connection_id: ConnectionId,
user_id: UserId,
worktree_id: u64,
) -> tide::Result<(ReplicaId, &Worktree)> {
let connection = self
.connections
.get_mut(&connection_id)
.ok_or_else(|| anyhow!("no such connection"))?;
let worktree = self
.worktrees
.get_mut(&worktree_id)
.and_then(|worktree| {
if worktree.collaborator_user_ids.contains(&user_id) {
Some(worktree)
} else {
None
}
})
.ok_or_else(|| anyhow!("no such worktree"))?;
let share = worktree.share_mut()?;
connection.worktrees.insert(worktree_id);
let mut replica_id = 1;
while share.active_replica_ids.contains(&replica_id) {
replica_id += 1;
}
share.active_replica_ids.insert(replica_id);
share.guest_connection_ids.insert(connection_id, replica_id);
return Ok((replica_id, worktree));
}
pub fn leave_worktree(
&mut self,
connection_id: ConnectionId,
worktree_id: u64,
) -> Option<(Vec<ConnectionId>, Vec<UserId>)> {
let worktree = self.worktrees.get_mut(&worktree_id)?;
let share = worktree.share.as_mut()?;
let replica_id = share.guest_connection_ids.remove(&connection_id)?;
share.active_replica_ids.remove(&replica_id);
Some((
worktree.connection_ids(),
worktree.collaborator_user_ids.clone(),
))
}
pub fn update_worktree(
&mut self,
connection_id: ConnectionId,
worktree_id: u64,
removed_entries: &[u64],
updated_entries: &[proto::Entry],
) -> tide::Result<Vec<ConnectionId>> {
let worktree = self.write_worktree(worktree_id, connection_id)?;
let share = worktree.share_mut()?;
for entry_id in removed_entries {
share.entries.remove(&entry_id);
}
for entry in updated_entries {
share.entries.insert(entry.id, entry.clone());
}
Ok(worktree.connection_ids())
}
pub fn worktree_host_connection_id(
&self,
connection_id: ConnectionId,
worktree_id: u64,
) -> tide::Result<ConnectionId> {
Ok(self
.read_worktree(worktree_id, connection_id)?
.host_connection_id)
}
pub fn worktree_guest_connection_ids(
&self,
connection_id: ConnectionId,
worktree_id: u64,
) -> tide::Result<Vec<ConnectionId>> {
Ok(self
.read_worktree(worktree_id, connection_id)?
.share()?
.guest_connection_ids
.keys()
.copied()
.collect())
}
pub fn worktree_connection_ids(
&self,
connection_id: ConnectionId,
worktree_id: u64,
) -> tide::Result<Vec<ConnectionId>> {
Ok(self
.read_worktree(worktree_id, connection_id)?
.connection_ids())
}
pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
Some(self.channels.get(&channel_id)?.connection_ids())
}
fn read_worktree(
&self,
worktree_id: u64,
connection_id: ConnectionId,
) -> tide::Result<&Worktree> {
let worktree = self
.worktrees
.get(&worktree_id)
.ok_or_else(|| anyhow!("worktree not found"))?;
if worktree.host_connection_id == connection_id
|| worktree
.share()?
.guest_connection_ids
.contains_key(&connection_id)
{
Ok(worktree)
} else {
Err(anyhow!(
"{} is not a member of worktree {}",
connection_id,
worktree_id
))?
}
}
fn write_worktree(
&mut self,
worktree_id: u64,
connection_id: ConnectionId,
) -> tide::Result<&mut Worktree> {
let worktree = self
.worktrees
.get_mut(&worktree_id)
.ok_or_else(|| anyhow!("worktree not found"))?;
if worktree.host_connection_id == connection_id
|| worktree.share.as_ref().map_or(false, |share| {
share.guest_connection_ids.contains_key(&connection_id)
})
{
Ok(worktree)
} else {
Err(anyhow!(
"{} is not a member of worktree {}",
connection_id,
worktree_id
))?
}
}
#[cfg(test)]
fn check_invariants(&self) {
for (connection_id, connection) in &self.connections {
for worktree_id in &connection.worktrees {
let worktree = &self.worktrees.get(&worktree_id).unwrap();
if worktree.host_connection_id != *connection_id {
assert!(worktree
.share()
.unwrap()
.guest_connection_ids
.contains_key(connection_id));
}
}
for channel_id in &connection.channels {
let channel = self.channels.get(channel_id).unwrap();
assert!(channel.connection_ids.contains(connection_id));
}
assert!(self
.connections_by_user_id
.get(&connection.user_id)
.unwrap()
.contains(connection_id));
}
for (user_id, connection_ids) in &self.connections_by_user_id {
for connection_id in connection_ids {
assert_eq!(
self.connections.get(connection_id).unwrap().user_id,
*user_id
);
}
}
for (worktree_id, worktree) in &self.worktrees {
let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
assert!(host_connection.worktrees.contains(worktree_id));
for collaborator_id in &worktree.collaborator_user_ids {
let visible_worktree_ids = self
.visible_worktrees_by_user_id
.get(collaborator_id)
.unwrap();
assert!(visible_worktree_ids.contains(worktree_id));
}
if let Some(share) = &worktree.share {
for guest_connection_id in share.guest_connection_ids.keys() {
let guest_connection = self.connections.get(guest_connection_id).unwrap();
assert!(guest_connection.worktrees.contains(worktree_id));
}
assert_eq!(
share.active_replica_ids.len(),
share.guest_connection_ids.len(),
);
assert_eq!(
share.active_replica_ids,
share
.guest_connection_ids
.values()
.copied()
.collect::<HashSet<_>>(),
);
}
}
for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
for worktree_id in visible_worktree_ids {
let worktree = self.worktrees.get(worktree_id).unwrap();
assert!(worktree.collaborator_user_ids.contains(user_id));
}
}
for (channel_id, channel) in &self.channels {
for connection_id in &channel.connection_ids {
let connection = self.connections.get(connection_id).unwrap();
assert!(connection.channels.contains(channel_id));
}
}
}
}
impl Worktree {
pub fn connection_ids(&self) -> Vec<ConnectionId> {
if let Some(share) = &self.share {
share
.guest_connection_ids
.keys()
.copied()
.chain(Some(self.host_connection_id))
.collect()
} else {
vec![self.host_connection_id]
}
}
pub fn share(&self) -> tide::Result<&WorktreeShare> {
Ok(self
.share
.as_ref()
.ok_or_else(|| anyhow!("worktree is not shared"))?)
}
fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
Ok(self
.share
.as_mut()
.ok_or_else(|| anyhow!("worktree is not shared"))?)
}
}
impl Channel {
fn connection_ids(&self) -> Vec<ConnectionId> {
self.connection_ids.iter().copied().collect()
}
}

View File

@@ -66,21 +66,28 @@ impl Entity for Worktree {
type Event = ();
fn release(&mut self, cx: &mut MutableAppContext) {
let rpc = match self {
Self::Local(tree) => tree
.remote_id
.borrow()
.map(|remote_id| (tree.rpc.clone(), remote_id)),
Self::Remote(tree) => Some((tree.rpc.clone(), tree.remote_id)),
};
if let Some((rpc, worktree_id)) = rpc {
cx.spawn(|_| async move {
if let Err(err) = rpc.send(proto::CloseWorktree { worktree_id }).await {
log::error!("error closing worktree {}: {}", worktree_id, err);
match self {
Self::Local(tree) => {
if let Some(worktree_id) = *tree.remote_id.borrow() {
let rpc = tree.rpc.clone();
cx.spawn(|_| async move {
if let Err(err) = rpc.send(proto::CloseWorktree { worktree_id }).await {
log::error!("error closing worktree: {}", err);
}
})
.detach();
}
})
.detach();
}
Self::Remote(tree) => {
let rpc = tree.rpc.clone();
let worktree_id = tree.remote_id;
cx.spawn(|_| async move {
if let Err(err) = rpc.send(proto::LeaveWorktree { worktree_id }).await {
log::error!("error closing worktree: {}", err);
}
})
.detach();
}
}
}
}

View File

@@ -39,6 +39,7 @@ message Envelope {
OpenWorktreeResponse open_worktree_response = 34;
UnshareWorktree unshare_worktree = 35;
UpdateCollaborators update_collaborators = 36;
LeaveWorktree leave_worktree = 37;
}
}
@@ -75,6 +76,10 @@ message JoinWorktree {
uint64 worktree_id = 1;
}
message LeaveWorktree {
uint64 worktree_id = 1;
}
message JoinWorktreeResponse {
Worktree worktree = 2;
uint32 replica_id = 3;

View File

@@ -139,6 +139,7 @@ messages!(
JoinWorktree,
JoinWorktreeResponse,
LeaveChannel,
LeaveWorktree,
OpenBuffer,
OpenBufferResponse,
OpenWorktree,