Add Server::{state,state_mut} to catch most deadlocks statically
This commit is contained in:
@@ -112,12 +112,11 @@ impl Server {
|
||||
addr: String,
|
||||
user_id: UserId,
|
||||
) -> impl Future<Output = ()> {
|
||||
let this = self.clone();
|
||||
let mut this = self.clone();
|
||||
async move {
|
||||
let (connection_id, handle_io, mut incoming_rx) =
|
||||
this.peer.add_connection(connection).await;
|
||||
this.store
|
||||
.write()
|
||||
this.state_mut()
|
||||
.await
|
||||
.add_connection(connection_id, user_id);
|
||||
if let Err(err) = this.update_collaborators_for_users(&[user_id]).await {
|
||||
@@ -167,9 +166,9 @@ impl Server {
|
||||
}
|
||||
}
|
||||
|
||||
async fn sign_out(self: &Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
|
||||
async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
|
||||
self.peer.disconnect(connection_id).await;
|
||||
let removed_connection = self.store.write().await.remove_connection(connection_id)?;
|
||||
let removed_connection = self.state_mut().await.remove_connection(connection_id)?;
|
||||
|
||||
for (worktree_id, worktree) in removed_connection.hosted_worktrees {
|
||||
if let Some(share) = worktree.share {
|
||||
@@ -210,13 +209,12 @@ impl Server {
|
||||
}
|
||||
|
||||
async fn open_worktree(
|
||||
self: Arc<Server>,
|
||||
mut self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::OpenWorktree>,
|
||||
) -> tide::Result<()> {
|
||||
let receipt = request.receipt();
|
||||
let host_user_id = self
|
||||
.store
|
||||
.read()
|
||||
.state()
|
||||
.await
|
||||
.user_id_for_connection(request.sender_id)?;
|
||||
|
||||
@@ -238,7 +236,7 @@ impl Server {
|
||||
}
|
||||
|
||||
let collaborator_user_ids = collaborator_user_ids.into_iter().collect::<Vec<_>>();
|
||||
let worktree_id = self.store.write().await.add_worktree(Worktree {
|
||||
let worktree_id = self.state_mut().await.add_worktree(Worktree {
|
||||
host_connection_id: request.sender_id,
|
||||
collaborator_user_ids: collaborator_user_ids.clone(),
|
||||
root_name: request.payload.root_name,
|
||||
@@ -255,13 +253,12 @@ impl Server {
|
||||
}
|
||||
|
||||
async fn close_worktree(
|
||||
self: Arc<Server>,
|
||||
mut self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::CloseWorktree>,
|
||||
) -> tide::Result<()> {
|
||||
let worktree_id = request.payload.worktree_id;
|
||||
let worktree = self
|
||||
.store
|
||||
.write()
|
||||
.state_mut()
|
||||
.await
|
||||
.remove_worktree(worktree_id, request.sender_id)?;
|
||||
|
||||
@@ -282,7 +279,7 @@ impl Server {
|
||||
}
|
||||
|
||||
async fn share_worktree(
|
||||
self: Arc<Server>,
|
||||
mut self: Arc<Server>,
|
||||
mut request: TypedEnvelope<proto::ShareWorktree>,
|
||||
) -> tide::Result<()> {
|
||||
let worktree = request
|
||||
@@ -296,8 +293,7 @@ impl Server {
|
||||
.collect();
|
||||
|
||||
let collaborator_user_ids =
|
||||
self.store
|
||||
.write()
|
||||
self.state_mut()
|
||||
.await
|
||||
.share_worktree(worktree.id, request.sender_id, entries);
|
||||
if let Some(collaborator_user_ids) = collaborator_user_ids {
|
||||
@@ -320,13 +316,12 @@ impl Server {
|
||||
}
|
||||
|
||||
async fn unshare_worktree(
|
||||
self: Arc<Server>,
|
||||
mut self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::UnshareWorktree>,
|
||||
) -> tide::Result<()> {
|
||||
let worktree_id = request.payload.worktree_id;
|
||||
let worktree = self
|
||||
.store
|
||||
.write()
|
||||
.state_mut()
|
||||
.await
|
||||
.unshare_worktree(worktree_id, request.sender_id)?;
|
||||
|
||||
@@ -342,20 +337,16 @@ impl Server {
|
||||
}
|
||||
|
||||
async fn join_worktree(
|
||||
self: Arc<Server>,
|
||||
mut self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::JoinWorktree>,
|
||||
) -> tide::Result<()> {
|
||||
let worktree_id = request.payload.worktree_id;
|
||||
let user_id = self
|
||||
.store
|
||||
.read()
|
||||
.state()
|
||||
.await
|
||||
.user_id_for_connection(request.sender_id)?;
|
||||
|
||||
let response;
|
||||
let connection_ids;
|
||||
let collaborator_user_ids;
|
||||
let mut state = self.store.write().await;
|
||||
let mut state = self.state_mut().await;
|
||||
match state.join_worktree(request.sender_id, user_id, worktree_id) {
|
||||
Ok(JoinedWorktree {
|
||||
replica_id,
|
||||
@@ -376,7 +367,7 @@ impl Server {
|
||||
});
|
||||
}
|
||||
}
|
||||
response = proto::JoinWorktreeResponse {
|
||||
let response = proto::JoinWorktreeResponse {
|
||||
worktree: Some(proto::Worktree {
|
||||
id: worktree_id,
|
||||
root_name: worktree.root_name.clone(),
|
||||
@@ -385,10 +376,29 @@ impl Server {
|
||||
replica_id: replica_id as u32,
|
||||
peers,
|
||||
};
|
||||
connection_ids = worktree.connection_ids();
|
||||
collaborator_user_ids = worktree.collaborator_user_ids.clone();
|
||||
let connection_ids = worktree.connection_ids();
|
||||
let collaborator_user_ids = worktree.collaborator_user_ids.clone();
|
||||
drop(state);
|
||||
|
||||
broadcast(request.sender_id, connection_ids, |conn_id| {
|
||||
self.peer.send(
|
||||
conn_id,
|
||||
proto::AddPeer {
|
||||
worktree_id,
|
||||
peer: Some(proto::Peer {
|
||||
peer_id: request.sender_id.0,
|
||||
replica_id: response.replica_id,
|
||||
}),
|
||||
},
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
self.peer.respond(request.receipt(), response).await?;
|
||||
self.update_collaborators_for_users(&collaborator_user_ids)
|
||||
.await?;
|
||||
}
|
||||
Err(error) => {
|
||||
drop(state);
|
||||
self.peer
|
||||
.respond_with_error(
|
||||
request.receipt(),
|
||||
@@ -397,44 +407,23 @@ impl Server {
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
drop(state);
|
||||
broadcast(request.sender_id, connection_ids, |conn_id| {
|
||||
self.peer.send(
|
||||
conn_id,
|
||||
proto::AddPeer {
|
||||
worktree_id,
|
||||
peer: Some(proto::Peer {
|
||||
peer_id: request.sender_id.0,
|
||||
replica_id: response.replica_id,
|
||||
}),
|
||||
},
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
self.peer.respond(request.receipt(), response).await?;
|
||||
self.update_collaborators_for_users(&collaborator_user_ids)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn leave_worktree(
|
||||
self: Arc<Server>,
|
||||
mut self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::LeaveWorktree>,
|
||||
) -> tide::Result<()> {
|
||||
let sender_id = request.sender_id;
|
||||
let worktree_id = request.payload.worktree_id;
|
||||
|
||||
if let Some(worktree) = self
|
||||
.store
|
||||
.write()
|
||||
let worktree = self
|
||||
.state_mut()
|
||||
.await
|
||||
.leave_worktree(sender_id, worktree_id)
|
||||
{
|
||||
.leave_worktree(sender_id, worktree_id);
|
||||
if let Some(worktree) = worktree {
|
||||
broadcast(sender_id, worktree.connection_ids, |conn_id| {
|
||||
self.peer.send(
|
||||
conn_id,
|
||||
@@ -452,10 +441,10 @@ impl Server {
|
||||
}
|
||||
|
||||
async fn update_worktree(
|
||||
self: Arc<Server>,
|
||||
mut self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::UpdateWorktree>,
|
||||
) -> tide::Result<()> {
|
||||
let connection_ids = self.store.write().await.update_worktree(
|
||||
let connection_ids = self.state_mut().await.update_worktree(
|
||||
request.sender_id,
|
||||
request.payload.worktree_id,
|
||||
&request.payload.removed_entries,
|
||||
@@ -477,8 +466,7 @@ impl Server {
|
||||
) -> tide::Result<()> {
|
||||
let receipt = request.receipt();
|
||||
let host_connection_id = self
|
||||
.store
|
||||
.read()
|
||||
.state()
|
||||
.await
|
||||
.worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
|
||||
let response = self
|
||||
@@ -494,8 +482,7 @@ impl Server {
|
||||
request: TypedEnvelope<proto::CloseBuffer>,
|
||||
) -> tide::Result<()> {
|
||||
let host_connection_id = self
|
||||
.store
|
||||
.read()
|
||||
.state()
|
||||
.await
|
||||
.worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
|
||||
self.peer
|
||||
@@ -511,7 +498,7 @@ impl Server {
|
||||
let host;
|
||||
let guests;
|
||||
{
|
||||
let state = self.store.read().await;
|
||||
let state = self.state().await;
|
||||
host = state
|
||||
.worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
|
||||
guests = state
|
||||
@@ -547,8 +534,7 @@ impl Server {
|
||||
) -> tide::Result<()> {
|
||||
broadcast(
|
||||
request.sender_id,
|
||||
self.store
|
||||
.read()
|
||||
self.state()
|
||||
.await
|
||||
.worktree_connection_ids(request.sender_id, request.payload.worktree_id)?,
|
||||
|connection_id| {
|
||||
@@ -585,8 +571,7 @@ impl Server {
|
||||
request: TypedEnvelope<proto::GetChannels>,
|
||||
) -> tide::Result<()> {
|
||||
let user_id = self
|
||||
.store
|
||||
.read()
|
||||
.state()
|
||||
.await
|
||||
.user_id_for_connection(request.sender_id)?;
|
||||
let channels = self.app_state.db.get_accessible_channels(user_id).await?;
|
||||
@@ -637,7 +622,7 @@ impl Server {
|
||||
) -> tide::Result<()> {
|
||||
let mut send_futures = Vec::new();
|
||||
|
||||
let state = self.store.read().await;
|
||||
let state = self.state().await;
|
||||
for user_id in user_ids {
|
||||
let collaborators = state.collaborators_for_user(*user_id);
|
||||
for connection_id in state.connection_ids_for_user(*user_id) {
|
||||
@@ -657,12 +642,11 @@ impl Server {
|
||||
}
|
||||
|
||||
async fn join_channel(
|
||||
self: Arc<Self>,
|
||||
mut self: Arc<Self>,
|
||||
request: TypedEnvelope<proto::JoinChannel>,
|
||||
) -> tide::Result<()> {
|
||||
let user_id = self
|
||||
.store
|
||||
.read()
|
||||
.state()
|
||||
.await
|
||||
.user_id_for_connection(request.sender_id)?;
|
||||
let channel_id = ChannelId::from_proto(request.payload.channel_id);
|
||||
@@ -675,8 +659,7 @@ impl Server {
|
||||
Err(anyhow!("access denied"))?;
|
||||
}
|
||||
|
||||
self.store
|
||||
.write()
|
||||
self.state_mut()
|
||||
.await
|
||||
.join_channel(request.sender_id, channel_id);
|
||||
let messages = self
|
||||
@@ -706,12 +689,11 @@ impl Server {
|
||||
}
|
||||
|
||||
async fn leave_channel(
|
||||
self: Arc<Self>,
|
||||
mut self: Arc<Self>,
|
||||
request: TypedEnvelope<proto::LeaveChannel>,
|
||||
) -> tide::Result<()> {
|
||||
let user_id = self
|
||||
.store
|
||||
.read()
|
||||
.state()
|
||||
.await
|
||||
.user_id_for_connection(request.sender_id)?;
|
||||
let channel_id = ChannelId::from_proto(request.payload.channel_id);
|
||||
@@ -724,8 +706,7 @@ impl Server {
|
||||
Err(anyhow!("access denied"))?;
|
||||
}
|
||||
|
||||
self.store
|
||||
.write()
|
||||
self.state_mut()
|
||||
.await
|
||||
.leave_channel(request.sender_id, channel_id);
|
||||
|
||||
@@ -741,7 +722,7 @@ impl Server {
|
||||
let user_id;
|
||||
let connection_ids;
|
||||
{
|
||||
let state = self.store.read().await;
|
||||
let state = self.state().await;
|
||||
user_id = state.user_id_for_connection(request.sender_id)?;
|
||||
if let Some(ids) = state.channel_connection_ids(channel_id) {
|
||||
connection_ids = ids;
|
||||
@@ -829,8 +810,7 @@ impl Server {
|
||||
request: TypedEnvelope<proto::GetChannelMessages>,
|
||||
) -> tide::Result<()> {
|
||||
let user_id = self
|
||||
.store
|
||||
.read()
|
||||
.state()
|
||||
.await
|
||||
.user_id_for_connection(request.sender_id)?;
|
||||
let channel_id = ChannelId::from_proto(request.payload.channel_id);
|
||||
@@ -872,6 +852,18 @@ impl Server {
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn state<'a>(
|
||||
self: &'a Arc<Self>,
|
||||
) -> impl Future<Output = async_std::sync::RwLockReadGuard<'a, Store>> {
|
||||
self.store.read()
|
||||
}
|
||||
|
||||
fn state_mut<'a>(
|
||||
self: &'a mut Arc<Self>,
|
||||
) -> impl Future<Output = async_std::sync::RwLockWriteGuard<'a, Store>> {
|
||||
self.store.write()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn broadcast<F, T>(
|
||||
|
||||
Reference in New Issue
Block a user