From 3187f284053105a807f6600a92aca6202adfdc84 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Sun, 20 Apr 2025 17:28:44 -0600 Subject: [PATCH] Checkpoint --- Cargo.lock | 4 + crates/agent_thread/Cargo.toml | 5 + crates/agent_thread/src/agent_thread.rs | 409 +++++++++++------- crates/assistant_tool/src/assistant_tool.rs | 4 +- crates/assistant_tools/src/assistant_tools.rs | 1 + crates/language_model/src/request.rs | 18 + 6 files changed, 288 insertions(+), 153 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0f1c32abe5..9e90c7a38f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -133,9 +133,12 @@ name = "agent_thread" version = "0.1.0" dependencies = [ "anyhow", + "assistant_tool", + "assistant_tools", "chrono", "client", "collections", + "env_logger 0.11.8", "fs", "futures 0.3.31", "gpui", @@ -143,6 +146,7 @@ dependencies = [ "language_model", "language_models", "parking_lot", + "project", "reqwest_client", "schemars", "serde", diff --git a/crates/agent_thread/Cargo.toml b/crates/agent_thread/Cargo.toml index b50450c497..0d396ff5ea 100644 --- a/crates/agent_thread/Cargo.toml +++ b/crates/agent_thread/Cargo.toml @@ -13,6 +13,8 @@ workspace = true [dependencies] anyhow.workspace = true +assistant_tool.workspace = true +assistant_tools.workspace = true chrono.workspace = true collections.workspace = true fs.workspace = true @@ -21,6 +23,7 @@ gpui.workspace = true language_model.workspace = true language_models.workspace = true parking_lot.workspace = true +project.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true @@ -31,9 +34,11 @@ util.workspace = true [dev-dependencies] client = { workspace = true, "features" = ["test-support"] } +env_logger.workspace = true fs = { workspace = true, "features" = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] } gpui_tokio.workspace = true language_model = { workspace = true, "features" = ["test-support"] } +project = { workspace = true, "features" = ["test-support"] } reqwest_client.workspace = true settings = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent_thread/src/agent_thread.rs b/crates/agent_thread/src/agent_thread.rs index eafa4bf5ae..1dcbb75a12 100644 --- a/crates/agent_thread/src/agent_thread.rs +++ b/crates/agent_thread/src/agent_thread.rs @@ -1,196 +1,303 @@ -use futures::channel::oneshot; -use gpui::{Context, Task}; +#[cfg(test)] +mod tests; + +use anyhow::Result; +use assistant_tool::{ActionLog, Tool}; +use futures::{channel::mpsc, stream::FuturesUnordered}; +use gpui::{Context, Entity, Task}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, - MessageContent, Role, + LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolSchemaFormat, + LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason, }; +use project::Project; use smol::stream::StreamExt; -use std::{future::Future, sync::Arc}; +use std::{collections::BTreeMap, sync::Arc}; use util::ResultExt; -pub struct ThreadMessage { +#[derive(Debug)] +pub struct AgentMessage { pub role: Role, pub content: Vec, } -pub struct Thread { - messages: Vec, - streaming_completion: Option>>, +impl AgentMessage { + fn to_request_message(&self) -> LanguageModelRequestMessage { + LanguageModelRequestMessage { + role: self.role, + content: self.content.clone(), + cache: false, // TODO: Figure out caching + } + } } -impl Thread { - pub fn new() -> Self { +pub type AgentResponseEvent = LanguageModelCompletionEvent; + +pub struct AgentThread { + sent: Vec, + unsent: Vec, + streaming: Option>>, + tools: BTreeMap, Arc>, + project: Entity, + action_log: Entity, +} + +impl AgentThread { + pub fn new(project: Entity, action_log: Entity) -> Self { Self { - messages: Vec::new(), - streaming_completion: None, + sent: Vec::new(), + unsent: Vec::new(), + streaming: None, + tools: BTreeMap::default(), + project, + action_log, } } - pub fn push_user_message(&mut self, text: impl Into, cx: &mut Context) { - self.messages.push(ThreadMessage { - role: Role::User, - content: vec![MessageContent::Text(text.into())], - }); - - cx.notify(); + pub fn add_tool(&mut self, tool: Arc) { + let name = Arc::from(tool.name()); + self.tools.insert(name, tool); } - pub fn stream_completion( + /// Cancels in-flight streaming, aborting any pending tool calls. + pub fn cancel_streaming(&mut self, cx: &mut Context) -> bool { + self.unsent.clear(); + self.streaming.take().is_some() + } + + /// Sending a message results in the model streaming a response, which could include tool calls. + /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent. + /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn. + pub fn send( &mut self, model: Arc, + content: impl Into, cx: &mut Context, - ) -> impl Future { - let request = self.to_completion_request(); - let (done_tx, done_rx) = futures::channel::oneshot::channel(); - let mut done_tx = Some(done_tx); - self.streaming_completion = Some( - cx.spawn(async move |thread, cx| { - let mut events = model.stream_completion(request, cx).await?; + ) -> mpsc::UnboundedReceiver> { + self.cancel_streaming(cx); + let events = mpsc::unbounded(); + self.enqueue_unsent(model, content, events.0, cx); + events.1 + } - while let Some(event) = events.next().await { - if let Some(event) = event.log_err() { - thread - .update(cx, |thread, cx| { - thread.handle_streamed_event(event, &mut done_tx, cx) - }) - .ok(); + /// Internal method which is called by send and also any tool call results. + /// If currently streaming a completion, these events will be sent when the streaming stops. + fn enqueue_unsent( + &mut self, + model: Arc, + content: impl Into, + events_tx: mpsc::UnboundedSender>, + cx: &mut Context, + ) { + cx.notify(); + + self.unsent.push(AgentMessage { + role: Role::User, + content: vec![content.into()], + }); + if !dbg!(self.streaming.is_some()) { + self.flush_unsent_messages(model, events_tx, cx) + } + } + + fn flush_unsent_messages( + &mut self, + model: Arc, + events_tx: mpsc::UnboundedSender>, + cx: &mut Context, + ) { + cx.notify(); + self.streaming = Some( + cx.spawn(async move |thread, cx| { + let mut subtasks = FuturesUnordered::new(); + + // Perform completion requests until the unsent messages are empty. + loop { + let unsent = + thread.update(cx, |thread, _cx| std::mem::take(&mut thread.unsent))?; + + if unsent.is_empty() { + thread.update(cx, |thread, _cx| thread.streaming.take())?; + break; } + + let completion_request = thread.update(cx, |thread, _cx| { + thread.sent.extend(unsent); + thread.build_completion_request() + })?; + + let mut events = model.stream_completion(completion_request, cx).await?; + while let Some(event) = events.next().await { + match event { + Ok(event) => { + thread + .update(cx, |thread, cx| { + let subtask = thread.handle_stream_event( + &model, + event, + events_tx.clone(), + cx, + ); + subtasks.extend(subtask); + }) + .ok(); + } + Err(error) => { + events_tx.unbounded_send(Err(error)).ok(); + break; + } + } + } + + // Wait for any tasks we spawned to enqueue tool results before looping again. + subtasks.next().await; } anyhow::Ok(()) }) .log_err_in_task(cx), ); + } - cx.notify(); - async move { - done_rx.await.ok(); + fn handle_stream_event( + &mut self, + model: &Arc, + event: LanguageModelCompletionEvent, + events_tx: mpsc::UnboundedSender>, + cx: &mut Context, + ) -> Option> { + use LanguageModelCompletionEvent::*; + events_tx.unbounded_send(Ok(event.clone())).ok(); + + match dbg!(event) { + Text(new_text) => self.handle_text_event(new_text, cx), + Thinking { text, signature } => { + dbg!(text, signature); + } + ToolUse(tool_use) => { + return self.handle_tool_use_event(model.clone(), tool_use, events_tx, cx); + } + StartMessage { message_id, role } => { + self.sent.push(AgentMessage { + role, + content: Vec::new(), + }); + } + UsageUpdate(token_usage) => {} + Stop(stop_reason) => {} + } + + None + } + + fn handle_text_event(&mut self, new_text: String, cx: &mut Context) { + if let Some(last_message) = self.sent.last_mut() { + debug_assert!(last_message.role == Role::Assistant); + if let Some(MessageContent::Text(text)) = last_message.content.last_mut() { + text.push_str(&new_text); + } else { + last_message.content.push(MessageContent::Text(new_text)); + } + + cx.notify(); + } else { + todo!("does this happen in practice?"); } } - fn to_completion_request(&self) -> LanguageModelRequest { + fn handle_tool_use_event( + &mut self, + model: Arc, + tool_use: LanguageModelToolUse, + events_tx: mpsc::UnboundedSender>, + cx: &mut Context, + ) -> Option> { + if let Some(last_message) = self.sent.last_mut() { + debug_assert!(last_message.role == Role::Assistant); + last_message.content.push(tool_use.clone().into()); + cx.notify(); + } else { + todo!("does this happen in practice?"); + } + + if let Some(tool) = self.tools.get(&tool_use.name) { + let pending_tool_result = tool.clone().run( + tool_use.input, + &self.build_request_messages(), + self.project.clone(), + self.action_log.clone(), + cx, + ); + + Some(cx.spawn(async move |thread, cx| { + let tool_result = match pending_tool_result.output.await { + Ok(tool_output) => LanguageModelToolResult { + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: false, + content: Arc::from(tool_output), + }, + Err(error) => LanguageModelToolResult { + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + content: Arc::from(error.to_string()), + }, + }; + + thread + .update(cx, |thread, cx| { + thread.enqueue_unsent(model, tool_result, events_tx, cx) + }) + .ok(); + })) + } else { + self.enqueue_unsent( + model, + LanguageModelToolResult { + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + content: Arc::from("tool does not exist"), + }, + events_tx, + cx, + ); + None + } + } + + fn build_completion_request(&self) -> LanguageModelRequest { LanguageModelRequest { thread_id: None, prompt_id: None, - messages: self - .messages - .iter() - .map(|message| LanguageModelRequestMessage { - role: message.role, - content: message.content.clone(), - cache: false, + messages: self.build_request_messages(), + tools: self + .tools + .values() + .filter_map(|tool| { + Some(LanguageModelRequestTool { + name: tool.name(), + description: tool.description(), + input_schema: tool + .input_schema(LanguageModelToolSchemaFormat::JsonSchema) + .log_err()?, + }) }) .collect(), - tools: Vec::new(), stop: Vec::new(), temperature: None, } } - fn handle_streamed_event( - &mut self, - event: LanguageModelCompletionEvent, - done_tx: &mut Option>, - cx: &mut Context, - ) { - use LanguageModelCompletionEvent::*; - - match event { - Stop(stop_reason) => { - done_tx.take().map(|tx| tx.send(())); - } - Text(new_text) => { - if let Some(last_message) = self.messages.last_mut() { - debug_assert!(last_message.role == Role::Assistant); - if let Some(MessageContent::Text(text)) = last_message.content.last_mut() { - text.push_str(&new_text); - } else { - last_message.content.push(MessageContent::Text(new_text)); - } - - cx.notify(); - } else { - todo!("does this happen in practice?") - } - } - Thinking { text, signature } => { - dbg!(text, signature); - } - ToolUse(language_model_tool_use) => { - dbg!(language_model_tool_use); - } - StartMessage { message_id, role } => { - dbg!(message_id, role); - - self.messages.push(ThreadMessage { - role, - content: Vec::new(), - }); - } - UsageUpdate(token_usage) => { - dbg!(token_usage); - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use client::{Client, UserStore}; - use fs::FakeFs; - use gpui::{App, AppContext, TestAppContext}; - use language_model::LanguageModelRegistry; - use reqwest_client::ReqwestClient; - - #[gpui::test] - async fn test_basic_threads(cx: &mut TestAppContext) { - let model = init_test(cx).await; - let thread = cx.new(|_cx| Thread::new()); - - thread - .update(cx, |thread, cx| { - thread.push_user_message("Testing: Reply with 'Hello'", cx); - thread.stream_completion(model, cx) - }) - .await; - - thread.update(cx, |thread, _cx| { - assert_eq!( - thread.messages.last().unwrap().content, - vec![MessageContent::Text("Hello".to_string())] - ); - }); - } - - fn init_test(cx: &mut TestAppContext) -> Task> { - cx.executor().allow_parking(); - cx.update(|cx| { - gpui_tokio::init(cx); - let http_client = ReqwestClient::user_agent("agent thread tests").unwrap(); - cx.set_http_client(Arc::new(http_client)); - - settings::init(cx); - client::init_settings(cx); - let client = Client::production(cx); - let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let fs = FakeFs::new(cx.background_executor().clone()); - language_model::init(client.clone(), cx); - language_models::init(user_store.clone(), client.clone(), fs.clone(), cx); - - let registry = LanguageModelRegistry::read_global(cx); - let model = registry - .available_models(cx) - .find(|model| model.id().0 == "claude-3-7-sonnet-latest") - .unwrap(); - - let provider = registry.provider(&model.provider_id()).unwrap(); - let authenticated = provider.authenticate(cx); - - cx.spawn(async move |_cx| { - authenticated.await.unwrap(); - model - }) - }) + fn build_request_messages(&self) -> Vec { + self.sent + .iter() + .map(|message| LanguageModelRequestMessage { + role: message.role, + content: message.content.clone(), + cache: false, + }) + .collect() } } diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs index cb7f0ff518..f948851493 100644 --- a/crates/assistant_tool/src/assistant_tool.rs +++ b/crates/assistant_tool/src/assistant_tool.rs @@ -14,10 +14,10 @@ use gpui::Context; use gpui::IntoElement; use gpui::Window; use gpui::{App, Entity, SharedString, Task}; -use icons::IconName; +pub use icons::IconName; use language_model::LanguageModelRequestMessage; use language_model::LanguageModelToolSchemaFormat; -use project::Project; +pub use project::Project; pub use crate::action_log::*; pub use crate::tool_registry::*; diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index 92e388407c..78521518b8 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -54,6 +54,7 @@ use crate::rename_tool::RenameTool; use crate::symbol_info_tool::SymbolInfoTool; use crate::terminal_tool::TerminalTool; use crate::thinking_tool::ThinkingTool; +pub use schema::json_schema_for; pub fn init(http_client: Arc, cx: &mut App) { assistant_tool::init(cx); diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 0f1e97af5a..f6d1b302fb 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -197,6 +197,24 @@ impl From<&str> for MessageContent { } } +impl From for MessageContent { + fn from(value: LanguageModelToolUse) -> Self { + MessageContent::ToolUse(value) + } +} + +impl From for MessageContent { + fn from(value: LanguageModelImage) -> Self { + MessageContent::Image(value) + } +} + +impl From for MessageContent { + fn from(value: LanguageModelToolResult) -> Self { + MessageContent::ToolResult(value) + } +} + #[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)] pub struct LanguageModelRequestMessage { pub role: Role,