Compare commits

...

1 Commits

Author SHA1 Message Date
Conrad Irwin
e1aeda24ba WIP connection pool 2024-10-23 16:42:09 -06:00
3 changed files with 324 additions and 157 deletions

View File

@@ -0,0 +1,297 @@
use std::{path::PathBuf, sync::Arc};
use anyhow::{anyhow, Result};
use collections::HashMap;
use futures::{channel::{mpsc::{Sender, UnboundedReceiver, UnboundedSender}, oneshot}, AsyncReadExt, FutureExt, StreamExt};
use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, Task, WeakModel};
use smol::process::Child;
use rpc::{proto::Envelope, ErrorExt};
use crate::{
protocol::{
message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
}, ssh_session::{run_cmd, SshRemoteConnection, SshRemoteProcess, SshSocket}, SshClientDelegate, SshConnectionOptions
};
pub(crate) struct ConnectionPool {
connections: HashMap<SshConnectionOptions, WeakModel<ConnectionState>>,
}
struct ConnectionState {
refcount: usize,
options: SshConnectionOptions,
connecting: Task<()>,
connected: Option<Connected>,
waiters: Vec<oneshot::Sender<Result<()>>>,
};
struct Connected {
connection: SshRemoteConnection,
remote_binary_path: PathBuf,
}
impl ConnectionState {
pub(crate) async fn connect(
&mut self,
unique_identifier: String,
reconnect: bool,
incoming_tx: UnboundedSender<rpc::proto::Envelope>,
outgoing_rx: UnboundedReceiver<Envelope>,
connection_activity_tx: Sender<()>,
delegate: Arc<dyn SshClientDelegate>,
cx: &mut AppContext,
) -> Result<(Box<dyn SshRemoteProcess>, Task<Result<i32>>)> {
let Some(Connected { connection, remote_binary_path }) = connection.connected.as_ref() else {
let (tx, rx) = oneshot::channel();
self.waiters.push(tx);
return cx.spawn(|this, cx| async move {
rx.await?;
this.update(|this, cx| this.connect(
unique_identifier,
reconnect,
incoming_tx,
outgoing_rx,
connection_activity_tx,
delegate,
cx,
))?
})
};
delegate.set_status(Some("Starting proxy"), cx);
let mut start_proxy_command = format!(
"RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
std::env::var("RUST_LOG").unwrap_or_default(),
std::env::var("RUST_BACKTRACE").unwrap_or_default(),
remote_binary_path,
unique_identifier,
);
if reconnect {
start_proxy_command.push_str(" --reconnect");
}
let ssh_proxy_process = connection.socket
.ssh_command(start_proxy_command)
// IMPORTANT: we kill this process when we drop the task that uses it.
.kill_on_drop(true)
.spawn()
.context("failed to spawn remote server")?;
let io_task = Self::multiplex(
ssh_proxy_process,
incoming_tx,
outgoing_rx,
connection_activity_tx,
&cx,
);
Ok((Box::new(handle) as _, io_task))
}
}
impl Global for ConnectionPool {}
impl ConnectionPool {
pub(crate) fn connection(&mut self, opts: SshConnectionOptions, delegate: &Arc<dyn SshClientDelegate>, cx: &mut AppContext) -> Model<ConnectionState> {
if let Some(connection) = self.connections.get(&opts).and_then(|connection| connection.upgrade()) {
return connection
}
let connection = cx.new_model(|cx| {
ConnectionState {
refcount: 0,
options: opts.clone(),
connecting: Self::create_master_process(opts.clone(), delegate.clone(), &mut cx.to_async()),
connected: None,
waiters: vec![],
}
});
cx.observe_release(&connection, |c, cx| {
cx.update_global(|pool: &mut Self, _| {
pool.connections.remove(&c.options);
});
});
self.connections.insert(opts, connection.downgrade());
connection
}
}
pub(crate) async fn connect(
&mut self,
unique_identifier: String,
reconnect: bool,
connection_options: SshConnectionOptions,
incoming_tx: UnboundedSender<rpc::proto::Envelope>,
outgoing_rx: UnboundedReceiver<Envelope>,
connection_activity_tx: Sender<()>,
delegate: Arc<dyn SshClientDelegate>,
cx: &mut AppContext,
) -> Task<Result<(Box<dyn SshRemoteProcess>, Task<Result<i32>>)>> {
let connection = self.connections.entry(connection_options.clone()).or_insert_with(|| {
cx.new_model(|cx| {
ConnectionState {
refcount: 0,
options: connection_options.clone(),
connecting: Self::create_master_process(connection_options, delegate.clone(), cx),
connected: None,
waiters: vec![],
}
})
});
}
fn create_master_process(
connection_options: SshConnectionOptions,
delegate: Arc<dyn SshClientDelegate>,
cx: &mut AsyncAppContext,
) -> Task<()> {
let task: Task<Result<Connected>> = cx.spawn({
let connection_options = connection_options.clone();
|mut cx| async move {
let ssh_connection = SshRemoteConnection::new(connection_options, delegate.clone(), &mut cx).await?;
let platform = ssh_connection.query_platform().await?;
let remote_binary_path = delegate.remote_server_binary_path(platform, &mut cx)?;
ssh_connection
.ensure_server_binary(&delegate, &remote_binary_path, platform, &mut cx)
.await?;
let socket = ssh_connection.socket.clone();
// do this as part of ensure server binary?
run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
Ok(Connected{
connection: ssh_connection,
remote_binary_path,
})
}});
cx.spawn(|cx| async move {
let result = task.await;
cx.update_global(|connection_pool: &mut Self, _| {
let Some(connection_state) = connection_pool.connections.get_mut(&connection_options) else {
log::error!("connection dropped while connecting");
return;
};
match result {
Ok(connection) => {
connection_state.connected = Some(connection);
for tx in connection_state.waiters.drain(..) {
tx.send(Ok(())).ok();
}
},
Err(e) => {
for tx in connection_state.waiters.drain(..) {
tx.send(Err(e.cloned())).ok();
}
connection_pool.connections.remove(&connection_options);
}
}
}).ok();
})
}
fn multiplex(
mut ssh_proxy_process: Child,
incoming_tx: UnboundedSender<Envelope>,
mut outgoing_rx: UnboundedReceiver<Envelope>,
mut connection_activity_tx: Sender<()>,
cx: &AsyncAppContext,
) -> Task<Result<i32>> {
let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
let mut stdin_buffer = Vec::new();
let mut stdout_buffer = Vec::new();
let mut stderr_buffer = Vec::new();
let mut stderr_offset = 0;
let stdin_task = cx.background_executor().spawn(async move {
while let Some(outgoing) = outgoing_rx.next().await {
write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
}
anyhow::Ok(())
});
let stdout_task = cx.background_executor().spawn({
let mut connection_activity_tx = connection_activity_tx.clone();
async move {
loop {
stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
let len = child_stdout.read(&mut stdout_buffer).await?;
if len == 0 {
return anyhow::Ok(());
}
if len < MESSAGE_LEN_SIZE {
child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
}
let message_len = message_len_from_buffer(&stdout_buffer);
let envelope =
read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
.await?;
connection_activity_tx.try_send(()).ok();
incoming_tx.unbounded_send(envelope).ok();
}
}
});
let stderr_task: Task<anyhow::Result<()>> = cx.background_executor().spawn(async move {
loop {
stderr_buffer.resize(stderr_offset + 1024, 0);
let len = child_stderr
.read(&mut stderr_buffer[stderr_offset..])
.await?;
if len == 0 {
return anyhow::Ok(());
}
stderr_offset += len;
let mut start_ix = 0;
while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
.iter()
.position(|b| b == &b'\n')
{
let line_ix = start_ix + ix;
let content = &stderr_buffer[start_ix..line_ix];
start_ix = line_ix + 1;
if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
record.log(log::logger())
} else {
eprintln!("(remote) {}", String::from_utf8_lossy(content));
}
}
stderr_buffer.drain(0..start_ix);
stderr_offset -= start_ix;
connection_activity_tx.try_send(()).ok();
}
});
cx.spawn(|_| async move {
let result = futures::select! {
result = stdin_task.fuse() => {
result.context("stdin")
}
result = stdout_task.fuse() => {
result.context("stdout")
}
result = stderr_task.fuse() => {
result.context("stderr")
}
};
let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
drop(handle);
match result {
Ok(_) => Ok(status),
Err(error) => Err(error),
}
})
}
}

View File

@@ -1,3 +1,4 @@
pub mod connection_pool;
pub mod json_log;
pub mod protocol;
pub mod proxy;

View File

@@ -1,4 +1,5 @@
use crate::{
connection_pool::ConnectionPool,
json_log::LogRecord,
protocol::{
message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
@@ -56,7 +57,7 @@ pub struct SshSocket {
socket_path: PathBuf,
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
pub struct SshConnectionOptions {
pub host: String,
pub username: Option<String>,
@@ -241,7 +242,7 @@ pub trait SshClientDelegate: Send + Sync {
}
impl SshSocket {
fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
pub(crate) fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
let mut command = process::Command::new("ssh");
self.ssh_options(&mut command)
.arg(self.connection_options.ssh_url())
@@ -258,7 +259,7 @@ impl SshSocket {
.arg(format!("ControlPath={}", self.socket_path.display()))
}
fn ssh_args(&self) -> Vec<String> {
pub(crate) fn ssh_args(&self) -> Vec<String> {
vec![
"-o".to_string(),
"ControlMaster=no".to_string(),
@@ -269,7 +270,7 @@ impl SshSocket {
}
}
async fn run_cmd(command: &mut process::Command) -> Result<String> {
pub(crate) async fn run_cmd(command: &mut process::Command) -> Result<String> {
let output = command.output().await?;
if output.status.success() {
Ok(String::from_utf8_lossy(&output.stdout).to_string())
@@ -834,108 +835,6 @@ impl SshRemoteClient {
}
}
fn multiplex(
mut ssh_proxy_process: Child,
incoming_tx: UnboundedSender<Envelope>,
mut outgoing_rx: UnboundedReceiver<Envelope>,
mut connection_activity_tx: Sender<()>,
cx: &AsyncAppContext,
) -> Task<Result<i32>> {
let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
let mut stdin_buffer = Vec::new();
let mut stdout_buffer = Vec::new();
let mut stderr_buffer = Vec::new();
let mut stderr_offset = 0;
let stdin_task = cx.background_executor().spawn(async move {
while let Some(outgoing) = outgoing_rx.next().await {
write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
}
anyhow::Ok(())
});
let stdout_task = cx.background_executor().spawn({
let mut connection_activity_tx = connection_activity_tx.clone();
async move {
loop {
stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
let len = child_stdout.read(&mut stdout_buffer).await?;
if len == 0 {
return anyhow::Ok(());
}
if len < MESSAGE_LEN_SIZE {
child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
}
let message_len = message_len_from_buffer(&stdout_buffer);
let envelope =
read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
.await?;
connection_activity_tx.try_send(()).ok();
incoming_tx.unbounded_send(envelope).ok();
}
}
});
let stderr_task: Task<anyhow::Result<()>> = cx.background_executor().spawn(async move {
loop {
stderr_buffer.resize(stderr_offset + 1024, 0);
let len = child_stderr
.read(&mut stderr_buffer[stderr_offset..])
.await?;
if len == 0 {
return anyhow::Ok(());
}
stderr_offset += len;
let mut start_ix = 0;
while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
.iter()
.position(|b| b == &b'\n')
{
let line_ix = start_ix + ix;
let content = &stderr_buffer[start_ix..line_ix];
start_ix = line_ix + 1;
if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
record.log(log::logger())
} else {
eprintln!("(remote) {}", String::from_utf8_lossy(content));
}
}
stderr_buffer.drain(0..start_ix);
stderr_offset -= start_ix;
connection_activity_tx.try_send(()).ok();
}
});
cx.spawn(|_| async move {
let result = futures::select! {
result = stdin_task.fuse() => {
result.context("stdin")
}
result = stdout_task.fuse() => {
result.context("stdout")
}
result = stderr_task.fuse() => {
result.context("stderr")
}
};
let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
match result {
Ok(_) => Ok(status),
Err(error) => Err(error),
}
})
}
fn monitor(
this: WeakModel<Self>,
io_task: Task<Result<i32>>,
@@ -1029,49 +928,19 @@ impl SshRemoteClient {
return Ok((fake, io_task));
}
let ssh_connection =
SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
let platform = ssh_connection.query_platform().await?;
let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
if !reconnect {
ssh_connection
.ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
.await?;
}
let socket = ssh_connection.socket.clone();
run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
delegate.set_status(Some("Starting proxy"), cx);
let mut start_proxy_command = format!(
"RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
std::env::var("RUST_LOG").unwrap_or_default(),
std::env::var("RUST_BACKTRACE").unwrap_or_default(),
remote_binary_path,
unique_identifier,
);
if reconnect {
start_proxy_command.push_str(" --reconnect");
}
let ssh_proxy_process = socket
.ssh_command(start_proxy_command)
// IMPORTANT: we kill this process when we drop the task that uses it.
.kill_on_drop(true)
.spawn()
.context("failed to spawn remote server")?;
let io_task = Self::multiplex(
ssh_proxy_process,
incoming_tx,
outgoing_rx,
connection_activity_tx,
&cx,
);
Ok((Box::new(ssh_connection), io_task))
cx.update_global(|pool: &mut ConnectionPool, _| {
pool.connect(
unique_identifier,
reconnect,
connection_options,
incoming_tx,
outgoing_rx,
connection_activity_tx,
delegate,
cx,
)
})?
.await
}
pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
@@ -1172,16 +1041,16 @@ impl From<SshRemoteClient> for AnyProtoClient {
}
#[async_trait]
trait SshRemoteProcess: Send + Sync {
pub(crate) trait SshRemoteProcess {
async fn kill(&mut self) -> Result<()>;
fn ssh_args(&self) -> Vec<String>;
fn connection_options(&self) -> SshConnectionOptions;
}
struct SshRemoteConnection {
socket: SshSocket,
master_process: process::Child,
_temp_dir: TempDir,
pub(crate) struct SshRemoteConnection {
pub(crate) socket: SshSocket,
pub(crate) master_process: process::Child,
pub(crate) _temp_dir: TempDir,
}
impl Drop for SshRemoteConnection {
@@ -1222,7 +1091,7 @@ impl SshRemoteConnection {
}
#[cfg(unix)]
async fn new(
pub(crate) async fn new(
connection_options: SshConnectionOptions,
delegate: Arc<dyn SshClientDelegate>,
cx: &mut AsyncAppContext,
@@ -1358,7 +1227,7 @@ impl SshRemoteConnection {
})
}
async fn ensure_server_binary(
pub(crate) async fn ensure_server_binary(
&self,
delegate: &Arc<dyn SshClientDelegate>,
dst_path: &Path,
@@ -1621,7 +1490,7 @@ impl SshRemoteConnection {
Ok(())
}
async fn query_platform(&self) -> Result<SshPlatform> {
pub(crate) async fn query_platform(&self) -> Result<SshPlatform> {
let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;