From db9f0795986e28fe41d89361c81bd478ecb486e4 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Sun, 9 Jun 2024 21:06:57 -0600 Subject: [PATCH] Move MessageStream to its own module Split the MessageStream type into a separate module to keep the proto module focused on protobuf message definitions. The new message_stream module contains the MessageStream type and its implementation for reading and writing protobuf messages over a WebSocket connection. Refactored imports and updated references to MessageStream in the peer and proto modules to use the new path. --- crates/rpc/src/message_stream.rs | 136 ++++++++++++++++++++++++++++++ crates/rpc/src/peer.rs | 4 +- crates/rpc/src/proto.rs | 137 +------------------------------ crates/rpc/src/rpc.rs | 2 + 4 files changed, 144 insertions(+), 135 deletions(-) create mode 100644 crates/rpc/src/message_stream.rs diff --git a/crates/rpc/src/message_stream.rs b/crates/rpc/src/message_stream.rs new file mode 100644 index 0000000000..5ae2def0b3 --- /dev/null +++ b/crates/rpc/src/message_stream.rs @@ -0,0 +1,136 @@ +use crate::proto::{Envelope, Message}; +use anyhow::{anyhow, Result}; +use async_tungstenite::tungstenite::Message as WebSocketMessage; +use futures::{SinkExt, StreamExt}; +use prost::Message as _; +use std::{io, time::Instant}; + +const KIB: usize = 1024; +const MIB: usize = KIB * 1024; +const MAX_BUFFER_LEN: usize = MIB; + +/// A stream of protobuf messages. +pub struct MessageStream { + stream: S, + encoding_buffer: Vec, +} + +impl MessageStream { + pub fn new(stream: S) -> Self { + Self { + stream, + encoding_buffer: Vec::new(), + } + } + + pub fn inner_mut(&mut self) -> &mut S { + &mut self.stream + } +} + +impl MessageStream +where + S: futures::Sink + Unpin, +{ + pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> { + #[cfg(any(test, feature = "test-support"))] + const COMPRESSION_LEVEL: i32 = -7; + + #[cfg(not(any(test, feature = "test-support")))] + const COMPRESSION_LEVEL: i32 = 4; + + match message { + Message::Envelope(message) => { + self.encoding_buffer.reserve(message.encoded_len()); + message + .encode(&mut self.encoding_buffer) + .map_err(io::Error::from)?; + let buffer = + zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL) + .unwrap(); + + self.encoding_buffer.clear(); + self.encoding_buffer.shrink_to(MAX_BUFFER_LEN); + self.stream.send(WebSocketMessage::Binary(buffer)).await?; + } + Message::Ping => { + self.stream + .send(WebSocketMessage::Ping(Default::default())) + .await?; + } + Message::Pong => { + self.stream + .send(WebSocketMessage::Pong(Default::default())) + .await?; + } + } + + Ok(()) + } +} + +impl MessageStream +where + S: futures::Stream> + Unpin, +{ + pub async fn read(&mut self) -> Result<(Message, Instant), anyhow::Error> { + while let Some(bytes) = self.stream.next().await { + let received_at = Instant::now(); + match bytes? { + WebSocketMessage::Binary(bytes) => { + zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap(); + let envelope = Envelope::decode(self.encoding_buffer.as_slice()) + .map_err(io::Error::from)?; + + self.encoding_buffer.clear(); + self.encoding_buffer.shrink_to(MAX_BUFFER_LEN); + return Ok((Message::Envelope(envelope), received_at)); + } + WebSocketMessage::Ping(_) => return Ok((Message::Ping, received_at)), + WebSocketMessage::Pong(_) => return Ok((Message::Pong, received_at)), + WebSocketMessage::Close(_) => break, + _ => {} + } + } + Err(anyhow!("connection closed")) + } +} + +#[cfg(test)] +mod tests { + use crate::proto::{envelope, Envelope, UpdateWorktree}; + + use super::*; + + #[gpui::test] + async fn test_buffer_size() { + let (tx, rx) = futures::channel::mpsc::unbounded(); + let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!(""))); + sink.write(Message::Envelope(Envelope { + payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree { + root_name: "abcdefg".repeat(10), + ..Default::default() + })), + ..Default::default() + })) + .await + .unwrap(); + assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN); + sink.write(Message::Envelope(Envelope { + payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree { + root_name: "abcdefg".repeat(1000000), + ..Default::default() + })), + ..Default::default() + })) + .await + .unwrap(); + assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN); + + let mut stream = MessageStream::new(rx.map(anyhow::Ok)); + stream.read().await.unwrap(); + assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN); + stream.read().await.unwrap(); + assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN); + } +} diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index 8e026953c1..3e347e62e3 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -1,8 +1,8 @@ use crate::{ErrorCode, ErrorCodeExt, ErrorExt, RpcError}; use super::{ - proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage}, - Connection, + proto::{self, AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage}, + Connection, MessageStream, }; use anyhow::{anyhow, Context, Result}; use collections::HashMap; diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 4844ea6aba..6cbd59d394 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -1,21 +1,15 @@ #![allow(non_snake_case)] use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope}; -use anyhow::{anyhow, Result}; -use async_tungstenite::tungstenite::Message as WebSocketMessage; use collections::HashMap; -use futures::{SinkExt as _, StreamExt as _}; -use prost::Message as _; use serde::Serialize; -use std::any::{Any, TypeId}; -use std::time::Instant; use std::{ + any::{Any, TypeId}, cmp, - fmt::Debug, - io, iter, - time::{Duration, SystemTime, UNIX_EPOCH}, + fmt::{self, Debug}, + iter, mem, + time::{Duration, Instant, SystemTime, UNIX_EPOCH}, }; -use std::{fmt, mem}; include!(concat!(env!("OUT_DIR"), "/zed.messages.rs")); @@ -516,16 +510,6 @@ entity_messages!( UpdateChannelBufferCollaborators, ); -const KIB: usize = 1024; -const MIB: usize = KIB * 1024; -const MAX_BUFFER_LEN: usize = MIB; - -/// A stream of protobuf messages. -pub struct MessageStream { - stream: S, - encoding_buffer: Vec, -} - #[allow(clippy::large_enum_variant)] #[derive(Debug)] pub enum Message { @@ -534,87 +518,6 @@ pub enum Message { Pong, } -impl MessageStream { - pub fn new(stream: S) -> Self { - Self { - stream, - encoding_buffer: Vec::new(), - } - } - - pub fn inner_mut(&mut self) -> &mut S { - &mut self.stream - } -} - -impl MessageStream -where - S: futures::Sink + Unpin, -{ - pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> { - #[cfg(any(test, feature = "test-support"))] - const COMPRESSION_LEVEL: i32 = -7; - - #[cfg(not(any(test, feature = "test-support")))] - const COMPRESSION_LEVEL: i32 = 4; - - match message { - Message::Envelope(message) => { - self.encoding_buffer.reserve(message.encoded_len()); - message - .encode(&mut self.encoding_buffer) - .map_err(io::Error::from)?; - let buffer = - zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL) - .unwrap(); - - self.encoding_buffer.clear(); - self.encoding_buffer.shrink_to(MAX_BUFFER_LEN); - self.stream.send(WebSocketMessage::Binary(buffer)).await?; - } - Message::Ping => { - self.stream - .send(WebSocketMessage::Ping(Default::default())) - .await?; - } - Message::Pong => { - self.stream - .send(WebSocketMessage::Pong(Default::default())) - .await?; - } - } - - Ok(()) - } -} - -impl MessageStream -where - S: futures::Stream> + Unpin, -{ - pub async fn read(&mut self) -> Result<(Message, Instant), anyhow::Error> { - while let Some(bytes) = self.stream.next().await { - let received_at = Instant::now(); - match bytes? { - WebSocketMessage::Binary(bytes) => { - zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap(); - let envelope = Envelope::decode(self.encoding_buffer.as_slice()) - .map_err(io::Error::from)?; - - self.encoding_buffer.clear(); - self.encoding_buffer.shrink_to(MAX_BUFFER_LEN); - return Ok((Message::Envelope(envelope), received_at)); - } - WebSocketMessage::Ping(_) => return Ok((Message::Ping, received_at)), - WebSocketMessage::Pong(_) => return Ok((Message::Pong, received_at)), - WebSocketMessage::Close(_) => break, - _ => {} - } - } - Err(anyhow!("connection closed")) - } -} - impl From for SystemTime { fn from(val: Timestamp) -> Self { UNIX_EPOCH @@ -722,38 +625,6 @@ pub fn split_worktree_update( mod tests { use super::*; - #[gpui::test] - async fn test_buffer_size() { - let (tx, rx) = futures::channel::mpsc::unbounded(); - let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!(""))); - sink.write(Message::Envelope(Envelope { - payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree { - root_name: "abcdefg".repeat(10), - ..Default::default() - })), - ..Default::default() - })) - .await - .unwrap(); - assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN); - sink.write(Message::Envelope(Envelope { - payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree { - root_name: "abcdefg".repeat(1000000), - ..Default::default() - })), - ..Default::default() - })) - .await - .unwrap(); - assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN); - - let mut stream = MessageStream::new(rx.map(anyhow::Ok)); - stream.read().await.unwrap(); - assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN); - stream.read().await.unwrap(); - assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN); - } - #[gpui::test] fn test_converting_peer_id_from_and_to_u64() { let peer_id = PeerId { diff --git a/crates/rpc/src/rpc.rs b/crates/rpc/src/rpc.rs index 880102e8d3..c0b23233b3 100644 --- a/crates/rpc/src/rpc.rs +++ b/crates/rpc/src/rpc.rs @@ -2,6 +2,7 @@ pub mod auth; mod conn; mod error; mod extension; +mod message_stream; mod notification; mod peer; pub mod proto; @@ -9,6 +10,7 @@ pub mod proto; pub use conn::Connection; pub use error::*; pub use extension::*; +pub use message_stream::*; pub use notification::*; pub use peer::*; mod macros;