Add Server::{state,state_mut} to catch most deadlocks statically

This commit is contained in:
Antonio Scandurra
2021-09-21 12:19:52 +02:00
parent 0b11192fe3
commit d381020a60

View File

@@ -112,12 +112,11 @@ impl Server {
addr: String,
user_id: UserId,
) -> impl Future<Output = ()> {
let this = self.clone();
let mut this = self.clone();
async move {
let (connection_id, handle_io, mut incoming_rx) =
this.peer.add_connection(connection).await;
this.store
.write()
this.state_mut()
.await
.add_connection(connection_id, user_id);
if let Err(err) = this.update_collaborators_for_users(&[user_id]).await {
@@ -167,9 +166,9 @@ impl Server {
}
}
async fn sign_out(self: &Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
self.peer.disconnect(connection_id).await;
let removed_connection = self.store.write().await.remove_connection(connection_id)?;
let removed_connection = self.state_mut().await.remove_connection(connection_id)?;
for (worktree_id, worktree) in removed_connection.hosted_worktrees {
if let Some(share) = worktree.share {
@@ -210,13 +209,12 @@ impl Server {
}
async fn open_worktree(
self: Arc<Server>,
mut self: Arc<Server>,
request: TypedEnvelope<proto::OpenWorktree>,
) -> tide::Result<()> {
let receipt = request.receipt();
let host_user_id = self
.store
.read()
.state()
.await
.user_id_for_connection(request.sender_id)?;
@@ -238,7 +236,7 @@ impl Server {
}
let collaborator_user_ids = collaborator_user_ids.into_iter().collect::<Vec<_>>();
let worktree_id = self.store.write().await.add_worktree(Worktree {
let worktree_id = self.state_mut().await.add_worktree(Worktree {
host_connection_id: request.sender_id,
collaborator_user_ids: collaborator_user_ids.clone(),
root_name: request.payload.root_name,
@@ -255,13 +253,12 @@ impl Server {
}
async fn close_worktree(
self: Arc<Server>,
mut self: Arc<Server>,
request: TypedEnvelope<proto::CloseWorktree>,
) -> tide::Result<()> {
let worktree_id = request.payload.worktree_id;
let worktree = self
.store
.write()
.state_mut()
.await
.remove_worktree(worktree_id, request.sender_id)?;
@@ -282,7 +279,7 @@ impl Server {
}
async fn share_worktree(
self: Arc<Server>,
mut self: Arc<Server>,
mut request: TypedEnvelope<proto::ShareWorktree>,
) -> tide::Result<()> {
let worktree = request
@@ -296,8 +293,7 @@ impl Server {
.collect();
let collaborator_user_ids =
self.store
.write()
self.state_mut()
.await
.share_worktree(worktree.id, request.sender_id, entries);
if let Some(collaborator_user_ids) = collaborator_user_ids {
@@ -320,13 +316,12 @@ impl Server {
}
async fn unshare_worktree(
self: Arc<Server>,
mut self: Arc<Server>,
request: TypedEnvelope<proto::UnshareWorktree>,
) -> tide::Result<()> {
let worktree_id = request.payload.worktree_id;
let worktree = self
.store
.write()
.state_mut()
.await
.unshare_worktree(worktree_id, request.sender_id)?;
@@ -342,20 +337,16 @@ impl Server {
}
async fn join_worktree(
self: Arc<Server>,
mut self: Arc<Server>,
request: TypedEnvelope<proto::JoinWorktree>,
) -> tide::Result<()> {
let worktree_id = request.payload.worktree_id;
let user_id = self
.store
.read()
.state()
.await
.user_id_for_connection(request.sender_id)?;
let response;
let connection_ids;
let collaborator_user_ids;
let mut state = self.store.write().await;
let mut state = self.state_mut().await;
match state.join_worktree(request.sender_id, user_id, worktree_id) {
Ok(JoinedWorktree {
replica_id,
@@ -376,7 +367,7 @@ impl Server {
});
}
}
response = proto::JoinWorktreeResponse {
let response = proto::JoinWorktreeResponse {
worktree: Some(proto::Worktree {
id: worktree_id,
root_name: worktree.root_name.clone(),
@@ -385,10 +376,29 @@ impl Server {
replica_id: replica_id as u32,
peers,
};
connection_ids = worktree.connection_ids();
collaborator_user_ids = worktree.collaborator_user_ids.clone();
let connection_ids = worktree.connection_ids();
let collaborator_user_ids = worktree.collaborator_user_ids.clone();
drop(state);
broadcast(request.sender_id, connection_ids, |conn_id| {
self.peer.send(
conn_id,
proto::AddPeer {
worktree_id,
peer: Some(proto::Peer {
peer_id: request.sender_id.0,
replica_id: response.replica_id,
}),
},
)
})
.await?;
self.peer.respond(request.receipt(), response).await?;
self.update_collaborators_for_users(&collaborator_user_ids)
.await?;
}
Err(error) => {
drop(state);
self.peer
.respond_with_error(
request.receipt(),
@@ -397,44 +407,23 @@ impl Server {
},
)
.await?;
return Ok(());
}
}
drop(state);
broadcast(request.sender_id, connection_ids, |conn_id| {
self.peer.send(
conn_id,
proto::AddPeer {
worktree_id,
peer: Some(proto::Peer {
peer_id: request.sender_id.0,
replica_id: response.replica_id,
}),
},
)
})
.await?;
self.peer.respond(request.receipt(), response).await?;
self.update_collaborators_for_users(&collaborator_user_ids)
.await?;
Ok(())
}
async fn leave_worktree(
self: Arc<Server>,
mut self: Arc<Server>,
request: TypedEnvelope<proto::LeaveWorktree>,
) -> tide::Result<()> {
let sender_id = request.sender_id;
let worktree_id = request.payload.worktree_id;
if let Some(worktree) = self
.store
.write()
let worktree = self
.state_mut()
.await
.leave_worktree(sender_id, worktree_id)
{
.leave_worktree(sender_id, worktree_id);
if let Some(worktree) = worktree {
broadcast(sender_id, worktree.connection_ids, |conn_id| {
self.peer.send(
conn_id,
@@ -452,10 +441,10 @@ impl Server {
}
async fn update_worktree(
self: Arc<Server>,
mut self: Arc<Server>,
request: TypedEnvelope<proto::UpdateWorktree>,
) -> tide::Result<()> {
let connection_ids = self.store.write().await.update_worktree(
let connection_ids = self.state_mut().await.update_worktree(
request.sender_id,
request.payload.worktree_id,
&request.payload.removed_entries,
@@ -477,8 +466,7 @@ impl Server {
) -> tide::Result<()> {
let receipt = request.receipt();
let host_connection_id = self
.store
.read()
.state()
.await
.worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
let response = self
@@ -494,8 +482,7 @@ impl Server {
request: TypedEnvelope<proto::CloseBuffer>,
) -> tide::Result<()> {
let host_connection_id = self
.store
.read()
.state()
.await
.worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
self.peer
@@ -511,7 +498,7 @@ impl Server {
let host;
let guests;
{
let state = self.store.read().await;
let state = self.state().await;
host = state
.worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
guests = state
@@ -547,8 +534,7 @@ impl Server {
) -> tide::Result<()> {
broadcast(
request.sender_id,
self.store
.read()
self.state()
.await
.worktree_connection_ids(request.sender_id, request.payload.worktree_id)?,
|connection_id| {
@@ -585,8 +571,7 @@ impl Server {
request: TypedEnvelope<proto::GetChannels>,
) -> tide::Result<()> {
let user_id = self
.store
.read()
.state()
.await
.user_id_for_connection(request.sender_id)?;
let channels = self.app_state.db.get_accessible_channels(user_id).await?;
@@ -637,7 +622,7 @@ impl Server {
) -> tide::Result<()> {
let mut send_futures = Vec::new();
let state = self.store.read().await;
let state = self.state().await;
for user_id in user_ids {
let collaborators = state.collaborators_for_user(*user_id);
for connection_id in state.connection_ids_for_user(*user_id) {
@@ -657,12 +642,11 @@ impl Server {
}
async fn join_channel(
self: Arc<Self>,
mut self: Arc<Self>,
request: TypedEnvelope<proto::JoinChannel>,
) -> tide::Result<()> {
let user_id = self
.store
.read()
.state()
.await
.user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id);
@@ -675,8 +659,7 @@ impl Server {
Err(anyhow!("access denied"))?;
}
self.store
.write()
self.state_mut()
.await
.join_channel(request.sender_id, channel_id);
let messages = self
@@ -706,12 +689,11 @@ impl Server {
}
async fn leave_channel(
self: Arc<Self>,
mut self: Arc<Self>,
request: TypedEnvelope<proto::LeaveChannel>,
) -> tide::Result<()> {
let user_id = self
.store
.read()
.state()
.await
.user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id);
@@ -724,8 +706,7 @@ impl Server {
Err(anyhow!("access denied"))?;
}
self.store
.write()
self.state_mut()
.await
.leave_channel(request.sender_id, channel_id);
@@ -741,7 +722,7 @@ impl Server {
let user_id;
let connection_ids;
{
let state = self.store.read().await;
let state = self.state().await;
user_id = state.user_id_for_connection(request.sender_id)?;
if let Some(ids) = state.channel_connection_ids(channel_id) {
connection_ids = ids;
@@ -829,8 +810,7 @@ impl Server {
request: TypedEnvelope<proto::GetChannelMessages>,
) -> tide::Result<()> {
let user_id = self
.store
.read()
.state()
.await
.user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id);
@@ -872,6 +852,18 @@ impl Server {
.await?;
Ok(())
}
fn state<'a>(
self: &'a Arc<Self>,
) -> impl Future<Output = async_std::sync::RwLockReadGuard<'a, Store>> {
self.store.read()
}
fn state_mut<'a>(
self: &'a mut Arc<Self>,
) -> impl Future<Output = async_std::sync::RwLockWriteGuard<'a, Store>> {
self.store.write()
}
}
pub async fn broadcast<F, T>(