From 36271b79b3ff5a8fee495a674601ddfec7c66848 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Sun, 20 Apr 2025 19:04:37 -0600 Subject: [PATCH] Failing test proving we need to batch tools per message --- Cargo.lock | 3 +- Cargo.toml | 2 +- crates/agent_thread/Cargo.toml | 44 --- crates/agent_thread/src/agent_thread.rs | 340 ------------------------ 4 files changed, 3 insertions(+), 386 deletions(-) delete mode 100644 crates/agent_thread/Cargo.toml delete mode 100644 crates/agent_thread/src/agent_thread.rs diff --git a/Cargo.lock b/Cargo.lock index 9e90c7a38f..620d877260 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -129,7 +129,7 @@ dependencies = [ ] [[package]] -name = "agent_thread" +name = "agent2" version = "0.1.0" dependencies = [ "anyhow", @@ -138,6 +138,7 @@ dependencies = [ "chrono", "client", "collections", + "ctor", "env_logger 0.11.8", "fs", "futures 0.3.31", diff --git a/Cargo.toml b/Cargo.toml index 990ca52f11..5163893735 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ resolver = "2" members = [ "crates/activity_indicator", "crates/agent", - "crates/agent_thread", + "crates/agent2", "crates/anthropic", "crates/askpass", "crates/assets", diff --git a/crates/agent_thread/Cargo.toml b/crates/agent_thread/Cargo.toml deleted file mode 100644 index 0d396ff5ea..0000000000 --- a/crates/agent_thread/Cargo.toml +++ /dev/null @@ -1,44 +0,0 @@ -[package] -name = "agent_thread" -version = "0.1.0" -edition = "2021" -license = "GPL-3.0-or-later" -publish = false - -[lib] -path = "src/agent_thread.rs" - -[lints] -workspace = true - -[dependencies] -anyhow.workspace = true -assistant_tool.workspace = true -assistant_tools.workspace = true -chrono.workspace = true -collections.workspace = true -fs.workspace = true -futures.workspace = true -gpui.workspace = true -language_model.workspace = true -language_models.workspace = true -parking_lot.workspace = true -project.workspace = true -schemars.workspace = true -serde.workspace = true -serde_json.workspace = true -settings.workspace = true -smol.workspace = true -thiserror.workspace = true -util.workspace = true - -[dev-dependencies] -client = { workspace = true, "features" = ["test-support"] } -env_logger.workspace = true -fs = { workspace = true, "features" = ["test-support"] } -gpui = { workspace = true, "features" = ["test-support"] } -gpui_tokio.workspace = true -language_model = { workspace = true, "features" = ["test-support"] } -project = { workspace = true, "features" = ["test-support"] } -reqwest_client.workspace = true -settings = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent_thread/src/agent_thread.rs b/crates/agent_thread/src/agent_thread.rs deleted file mode 100644 index 89024c9086..0000000000 --- a/crates/agent_thread/src/agent_thread.rs +++ /dev/null @@ -1,340 +0,0 @@ -#[cfg(test)] -mod tests; - -use anyhow::Result; -use assistant_tool::{ActionLog, Tool}; -use futures::{channel::mpsc, stream::FuturesUnordered}; -use gpui::{Context, Entity, Task}; -use language_model::{ - LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolSchemaFormat, - LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason, -}; -use project::Project; -use smol::stream::StreamExt; -use std::{collections::BTreeMap, sync::Arc}; -use util::ResultExt; - -#[derive(Debug)] -pub struct AgentMessage { - pub role: Role, - pub content: Vec, -} - -impl AgentMessage { - fn to_request_message(&self) -> LanguageModelRequestMessage { - LanguageModelRequestMessage { - role: self.role, - content: self.content.clone(), - cache: false, // TODO: Figure out caching - } - } -} - -pub type AgentResponseEvent = LanguageModelCompletionEvent; - -pub struct AgentThread { - sent: Vec, - unsent: Vec, - /// Holds the task that handles agent interaction until the end of the turn. - /// Survives across multiple requests as the model performs tool calls and - /// we run tools, report their results. - in_progress_turn: Option>>, - /// True when we're actively streaming data from a model completion request. - /// Within the same turn, this could be true, then false while we run a - /// tool, then true again as we relay the tool result and continue the turn. - streaming: bool, - tools: BTreeMap, Arc>, - project: Entity, - action_log: Entity, -} - -impl AgentThread { - pub fn new(project: Entity, action_log: Entity) -> Self { - Self { - sent: Vec::new(), - unsent: Vec::new(), - in_progress_turn: None, - streaming: false, - tools: BTreeMap::default(), - project, - action_log, - } - } - - pub fn add_tool(&mut self, tool: Arc) { - let name = Arc::from(tool.name()); - self.tools.insert(name, tool); - } - - pub fn remove_tool(&mut self, name: &str) -> bool { - self.tools.remove(name).is_some() - } - - /// Cancels in-flight streaming, aborting any pending tool calls. - pub fn cancel_turn(&mut self, cx: &mut Context) -> bool { - cx.notify(); - self.streaming = false; - self.in_progress_turn.take().is_some() - } - - /// Sending a message results in the model streaming a response, which could include tool calls. - /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent. - /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn. - pub fn send( - &mut self, - model: Arc, - content: impl Into, - cx: &mut Context, - ) -> mpsc::UnboundedReceiver> { - self.cancel_turn(cx); - let events = mpsc::unbounded(); - self.enqueue_unsent(model, content, events.0, cx); - events.1 - } - - /// Internal method which is called by send and also any tool call results. - /// If currently streaming a completion, these events will be sent when the streaming stops. - fn enqueue_unsent( - &mut self, - model: Arc, - content: impl Into, - events_tx: mpsc::UnboundedSender>, - cx: &mut Context, - ) { - cx.notify(); - - self.unsent.push(AgentMessage { - role: Role::User, - content: vec![content.into()], - }); - if !dbg!(self.streaming) { - self.resume_turn(model, events_tx, cx) - } - } - - fn resume_turn( - &mut self, - model: Arc, - events_tx: mpsc::UnboundedSender>, - cx: &mut Context, - ) { - cx.notify(); - self.in_progress_turn = Some( - cx.spawn(async move |thread, cx| { - let mut subtasks = FuturesUnordered::new(); - - // Perform completion requests until the unsent messages are empty. - loop { - let unsent = - thread.update(cx, |thread, _cx| std::mem::take(&mut thread.unsent))?; - - if unsent.is_empty() { - break; - } - - let completion_request = thread.update(cx, |thread, _cx| { - thread.sent.extend(unsent); - thread.streaming = true; - thread.build_completion_request() - })?; - - let mut events = model.stream_completion(completion_request, cx).await?; - while let Some(event) = events.next().await { - match event { - Ok(event) => { - thread - .update(cx, |thread, cx| { - let subtask = thread.handle_stream_event( - &model, - event, - events_tx.clone(), - cx, - ); - subtasks.extend(subtask); - }) - .ok(); - } - Err(error) => { - events_tx.unbounded_send(Err(error)).ok(); - break; - } - } - } - - thread - .update(cx, |thread, cx| { - thread.streaming = false; - }) - .ok(); - - // Wait for any tasks we spawned to enqueue tool results before looping again. - subtasks.next().await; - } - - anyhow::Ok(()) - }) - .log_err_in_task(cx), - ); - } - - fn handle_response_event( - &mut self, - model: &Arc, - event: LanguageModelCompletionEvent, - events_tx: mpsc::UnboundedSender>, - cx: &mut Context, - ) -> Option> { - use LanguageModelCompletionEvent::*; - events_tx.unbounded_send(Ok(event.clone())).ok(); - - match dbg!(event) { - Text(new_text) => self.handle_text_event(new_text, cx), - Thinking { text, signature } => { - dbg!(text, signature); - } - ToolUse(tool_use) => { - return self.handle_tool_use_event(model.clone(), tool_use, events_tx, cx); - } - StartMessage { message_id, role } => { - self.sent.push(AgentMessage { - role, - content: Vec::new(), - }); - } - UsageUpdate(token_usage) => {} - Stop(stop_reason) => self.handle_stop_event(stop_reason), - } - - None - } - - fn handle_stop_event(&mut self, stop_reason: StopReason) { - match stop_reason { - StopReason::EndTurn | StopReason::ToolUse => {} - StopReason::MaxTokens => todo!(), - StopReason::ToolUse => todo!(), - } - } - - fn handle_stream_event( - &mut self, - model: &Arc, - event: LanguageModelCompletionEvent, - events_tx: mpsc::UnboundedSender>, - cx: &mut Context, - ) -> Option> { - self.handle_response_event(model, event, events_tx, cx) - } - - fn handle_text_event(&mut self, new_text: String, cx: &mut Context) { - if let Some(last_message) = self.sent.last_mut() { - debug_assert!(last_message.role == Role::Assistant); - if let Some(MessageContent::Text(text)) = last_message.content.last_mut() { - text.push_str(&new_text); - } else { - last_message.content.push(MessageContent::Text(new_text)); - } - - cx.notify(); - } else { - todo!("does this happen in practice?"); - } - } - - fn handle_tool_use_event( - &mut self, - model: Arc, - tool_use: LanguageModelToolUse, - events_tx: mpsc::UnboundedSender>, - cx: &mut Context, - ) -> Option> { - if let Some(last_message) = self.sent.last_mut() { - debug_assert!(last_message.role == Role::Assistant); - last_message.content.push(tool_use.clone().into()); - cx.notify(); - } else { - todo!("does this happen in practice?"); - } - - if let Some(tool) = self.tools.get(&tool_use.name) { - let pending_tool_result = tool.clone().run( - tool_use.input, - &self.build_request_messages(), - self.project.clone(), - self.action_log.clone(), - cx, - ); - - Some(cx.spawn(async move |thread, cx| { - let tool_result = match pending_tool_result.output.await { - Ok(tool_output) => LanguageModelToolResult { - tool_use_id: tool_use.id, - tool_name: tool_use.name, - is_error: false, - content: Arc::from(tool_output), - }, - Err(error) => LanguageModelToolResult { - tool_use_id: tool_use.id, - tool_name: tool_use.name, - is_error: true, - content: Arc::from(error.to_string()), - }, - }; - - thread - .update(cx, |thread, cx| { - thread.enqueue_unsent(model, tool_result, events_tx, cx) - }) - .ok(); - })) - } else { - self.enqueue_unsent( - model, - LanguageModelToolResult { - tool_use_id: tool_use.id, - tool_name: tool_use.name, - is_error: true, - content: Arc::from("tool does not exist"), - }, - events_tx, - cx, - ); - None - } - } - - fn build_completion_request(&self) -> LanguageModelRequest { - LanguageModelRequest { - thread_id: None, - prompt_id: None, - messages: self.build_request_messages(), - tools: self - .tools - .values() - .filter_map(|tool| { - Some(LanguageModelRequestTool { - name: tool.name(), - description: tool.description(), - input_schema: tool - .input_schema(LanguageModelToolSchemaFormat::JsonSchema) - .log_err()?, - }) - }) - .collect(), - stop: Vec::new(), - temperature: None, - } - } - - fn build_request_messages(&self) -> Vec { - self.sent - .iter() - .map(|message| LanguageModelRequestMessage { - role: message.role, - content: message.content.clone(), - cache: false, - }) - .collect() - } -}