Compare commits

...

10 Commits

Author SHA1 Message Date
Nathan Sobo
447eb8e1c9 Checkpoint 2025-04-21 07:08:03 -06:00
Nathan Sobo
e434117018 Checkpoint: Still a failing test for concurrent tool calls.
Seems like I'm surfacing a bug in Anthropic.
2025-04-20 20:50:20 -06:00
Nathan Sobo
36271b79b3 Failing test proving we need to batch tools per message 2025-04-20 19:04:37 -06:00
Nathan Sobo
41644a53cc Checkpoint 2025-04-20 17:56:42 -06:00
Nathan Sobo
08a9c4af09 Checkpoint 2025-04-20 17:54:33 -06:00
Nathan Sobo
3187f28405 Checkpoint 2025-04-20 17:28:44 -06:00
Nathan Sobo
101f3b100f Get a basic request/reply tested in AgentThread 2025-04-20 00:41:03 -06:00
Nathan Sobo
39c8b7bf5f Add agent_thread crate
Experimental for now, I want to try really integration testing it
against the real APIs in a more "eval style", meaning embrace the
stochastic nature of it.
2025-04-20 00:17:38 -06:00
Nathan Sobo
08b41252f6 Include role in start message 2025-04-20 00:16:49 -06:00
Nathan Sobo
152bbca238 Add gpui helpers 2025-04-20 00:16:08 -06:00
15 changed files with 692 additions and 4 deletions

30
Cargo.lock generated
View File

@@ -128,6 +128,36 @@ dependencies = [
"zed_llm_client",
]
[[package]]
name = "agent2"
version = "0.1.0"
dependencies = [
"anyhow",
"assistant_tool",
"assistant_tools",
"chrono",
"client",
"collections",
"ctor",
"env_logger 0.11.8",
"fs",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"language_model",
"language_models",
"parking_lot",
"project",
"reqwest_client",
"schemars",
"serde",
"serde_json",
"settings",
"smol",
"thiserror 2.0.12",
"util",
]
[[package]]
name = "ahash"
version = "0.7.8"

View File

@@ -3,6 +3,7 @@ resolver = "2"
members = [
"crates/activity_indicator",
"crates/agent",
"crates/agent2",
"crates/anthropic",
"crates/askpass",
"crates/assets",

45
crates/agent2/Cargo.toml Normal file
View File

@@ -0,0 +1,45 @@
[package]
name = "agent2"
version = "0.1.0"
edition = "2021"
license = "GPL-3.0-or-later"
publish = false
[lib]
path = "src/agent2.rs"
[lints]
workspace = true
[dependencies]
anyhow.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
chrono.workspace = true
collections.workspace = true
fs.workspace = true
futures.workspace = true
gpui.workspace = true
language_model.workspace = true
language_models.workspace = true
parking_lot.workspace = true
project.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
thiserror.workspace = true
util.workspace = true
[dev-dependencies]
ctor.workspace = true
client = { workspace = true, "features" = ["test-support"] }
env_logger.workspace = true
fs = { workspace = true, "features" = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] }
gpui_tokio.workspace = true
language_model = { workspace = true, "features" = ["test-support"] }
project = { workspace = true, "features" = ["test-support"] }
reqwest_client.workspace = true
settings = { workspace = true, "features" = ["test-support"] }

1
crates/agent2/LICENSE-GPL Symbolic link
View File

@@ -0,0 +1 @@
../../LICENSE-GPL

278
crates/agent2/src/agent2.rs Normal file
View File

@@ -0,0 +1,278 @@
#[cfg(test)]
mod tests;
use anyhow::Result;
use assistant_tool::{ActionLog, Tool};
use futures::{channel::mpsc, future};
use gpui::{Context, Entity, Task};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolSchemaFormat,
LanguageModelToolUse, MessageContent, Role, StopReason,
};
use project::Project;
use smol::stream::StreamExt;
use std::{collections::BTreeMap, sync::Arc};
use util::ResultExt;
#[derive(Debug)]
pub struct AgentMessage {
pub role: Role,
pub content: Vec<MessageContent>,
}
impl AgentMessage {
fn to_request_message(&self) -> LanguageModelRequestMessage {
LanguageModelRequestMessage {
role: self.role,
content: self.content.clone(),
cache: false, // TODO: Figure out caching
}
}
}
pub type AgentResponseEvent = LanguageModelCompletionEvent;
pub struct Agent {
messages: Vec<AgentMessage>,
/// Holds the task that handles agent interaction until the end of the turn.
/// Survives across multiple requests as the model performs tool calls and
/// we run tools, report their results.
running_turn: Option<Task<()>>,
tools: BTreeMap<Arc<str>, Arc<dyn Tool>>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
}
impl Agent {
pub fn new(project: Entity<Project>, action_log: Entity<ActionLog>) -> Self {
Self {
messages: Vec::new(),
running_turn: None,
tools: BTreeMap::default(),
project,
action_log,
}
}
pub fn add_tool(&mut self, tool: Arc<dyn Tool>) {
let name = Arc::from(tool.name());
self.tools.insert(name, tool);
}
pub fn remove_tool(&mut self, name: &str) -> bool {
self.tools.remove(name).is_some()
}
/// Sending a message results in the model streaming a response, which could include tool calls.
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
pub fn send(
&mut self,
model: Arc<dyn LanguageModel>,
content: impl Into<MessageContent>,
cx: &mut Context<Self>,
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
cx.notify();
let (events_tx, events_rx) = mpsc::unbounded();
self.messages.push(AgentMessage {
role: Role::User,
content: vec![content.into()],
});
self.running_turn = Some(cx.spawn(async move |thread, cx| {
let turn_result = async {
// Perform one request, then keep looping if the model makes tool calls.
loop {
let request =
thread.update(cx, |thread, _cx| thread.build_completion_request())?;
println!(
"request: {}",
serde_json::to_string_pretty(&request).unwrap()
);
// Stream events, appending to messages and collecting up tool uses.
let mut events = model.stream_completion(request, cx).await?;
let mut tool_uses = Vec::new();
while let Some(event) = events.next().await {
match event {
Ok(event) => {
thread
.update(cx, |thread, cx| {
tool_uses.extend(thread.handle_response_event(
event,
events_tx.clone(),
cx,
));
})
.ok();
}
Err(error) => {
events_tx.unbounded_send(Err(error)).ok();
break;
}
}
}
// If there are no tool uses, the turn is done.
if tool_uses.is_empty() {
break;
}
// If there are tool uses, wait for their results to be
// computed, then send them together in a single message on
// the next loop iteration.
let tool_results = future::join_all(tool_uses).await;
thread
.update(cx, |thread, _cx| {
thread.messages.push(AgentMessage {
role: Role::User,
content: tool_results.into_iter().map(Into::into).collect(),
});
})
.ok();
}
anyhow::Ok(())
}
.await;
if let Err(error) = turn_result {
events_tx.unbounded_send(Err(error)).ok();
}
}));
events_rx
}
fn handle_response_event(
&mut self,
event: LanguageModelCompletionEvent,
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent>>,
cx: &mut Context<Self>,
) -> Option<Task<LanguageModelToolResult>> {
use LanguageModelCompletionEvent::*;
events_tx.unbounded_send(Ok(event.clone())).ok();
match event {
Text(new_text) => self.handle_text_event(new_text, cx),
Thinking { text, signature } => {}
ToolUse(tool_use) => {
return Some(self.handle_tool_use_event(tool_use, cx));
}
StartMessage { message_id, role } => {
self.messages.push(AgentMessage {
role,
content: Vec::new(),
});
}
UsageUpdate(token_usage) => {}
Stop(stop_reason) => self.handle_stop_event(stop_reason),
}
None
}
fn handle_stop_event(&mut self, stop_reason: StopReason) {
match stop_reason {
StopReason::EndTurn | StopReason::ToolUse => {}
StopReason::MaxTokens => todo!(),
}
}
fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
if let Some(last_message) = self.messages.last_mut() {
debug_assert!(last_message.role == Role::Assistant);
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
text.push_str(&new_text);
} else {
last_message.content.push(MessageContent::Text(new_text));
}
cx.notify();
} else {
todo!("does this happen in practice?");
}
}
fn handle_tool_use_event(
&mut self,
tool_use: LanguageModelToolUse,
cx: &mut Context<Self>,
) -> Task<LanguageModelToolResult> {
if let Some(last_message) = self.messages.last_mut() {
debug_assert!(last_message.role == Role::Assistant);
last_message.content.push(tool_use.clone().into());
cx.notify();
} else {
todo!("does this happen in practice?");
}
if let Some(tool) = self.tools.get(&tool_use.name) {
let pending_tool_result = tool.clone().run(
tool_use.input,
&self.build_request_messages(),
self.project.clone(),
self.action_log.clone(),
cx,
);
cx.foreground_executor().spawn(async move {
match pending_tool_result.output.await {
Ok(tool_output) => LanguageModelToolResult {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: false,
content: Arc::from(tool_output),
},
Err(error) => LanguageModelToolResult {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: true,
content: Arc::from(error.to_string()),
},
}
})
} else {
Task::ready(LanguageModelToolResult {
content: Arc::from(format!("No tool named {} exists", tool_use.name)),
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: true,
})
}
}
fn build_completion_request(&self) -> LanguageModelRequest {
LanguageModelRequest {
thread_id: None,
prompt_id: None,
messages: self.build_request_messages(),
tools: self
.tools
.values()
.filter_map(|tool| {
Some(LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema: tool
.input_schema(LanguageModelToolSchemaFormat::JsonSchema)
.log_err()?,
})
})
.collect(),
stop: Vec::new(),
temperature: None,
}
}
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
self.messages
.iter()
.map(|message| LanguageModelRequestMessage {
role: message.role,
content: message.content.clone(),
cache: false,
})
.collect()
}
}

191
crates/agent2/src/tests.rs Normal file
View File

@@ -0,0 +1,191 @@
use super::*;
use assistant_tool::{IconName, Project, ToolResult};
use client::{Client, UserStore};
use fs::FakeFs;
use gpui::{AppContext, TestAppContext};
use language_model::LanguageModelRegistry;
use reqwest_client::ReqwestClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::time::Duration;
mod tools;
use tools::*;
#[gpui::test]
async fn test_echo(cx: &mut TestAppContext) {
let AgentTest { model, agent, .. } = agent_test(cx).await;
let events = agent
.update(cx, |agent, cx| {
agent.send(model.clone(), "Testing: Reply with 'Hello'", cx)
})
.collect()
.await;
agent.update(cx, |agent, _cx| {
assert_eq!(
agent.messages.last().unwrap().content,
vec![MessageContent::Text("Hello".to_string())]
);
});
assert_eq!(stop_events(events), vec![StopReason::EndTurn]);
}
#[gpui::test]
async fn test_tool_calls(cx: &mut TestAppContext) {
let AgentTest { model, agent, .. } = agent_test(cx).await;
// Test a tool calls that's likely to complete before streaming stops.
let events = agent
.update(cx, |agent, cx| {
agent.add_tool(Arc::new(EchoTool));
agent.send(
model.clone(),
"Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
cx,
)
})
.collect()
.await;
assert_eq!(
stop_events(events),
vec![StopReason::ToolUse, StopReason::EndTurn]
);
// Test a tool calls that's likely to complete after streaming stops.
let events = agent
.update(cx, |agent, cx| {
agent.remove_tool(&EchoTool.name());
agent.add_tool(Arc::new(DelayTool));
agent.send(
model.clone(),
"Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
cx,
)
})
.collect()
.await;
assert_eq!(
stop_events(events),
vec![StopReason::ToolUse, StopReason::EndTurn]
);
agent.update(cx, |agent, _cx| {
assert!(agent
.messages
.last()
.unwrap()
.content
.iter()
.any(|content| {
if let MessageContent::Text(text) = content {
text.contains("Ding")
} else {
false
}
}));
});
}
#[gpui::test]
async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
let AgentTest { model, agent, .. } = agent_test(cx).await;
// Test concurrent tool calls with different delay times
let events = agent
.update(cx, |agent, cx| {
agent.add_tool(Arc::new(DelayTool));
agent.send(
model.clone(),
"Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
cx,
)
})
.map(|event| dbg!(event))
.collect()
.await;
let stop_reasons = stop_events(events);
assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]);
agent.update(cx, |agent, _cx| {
let last_message = agent.messages.last().unwrap();
let text = last_message
.content
.iter()
.filter_map(|content| {
if let MessageContent::Text(text) = content {
Some(text.as_str())
} else {
None
}
})
.collect::<String>();
assert!(text.contains("Ding"));
});
}
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<StopReason> {
result_events
.into_iter()
.filter_map(|event| match event.unwrap() {
LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason),
_ => None,
})
.collect()
}
struct AgentTest {
model: Arc<dyn LanguageModel>,
agent: Entity<Agent>,
}
async fn agent_test(cx: &mut TestAppContext) -> AgentTest {
cx.executor().allow_parking();
cx.update(settings::init);
let fs = FakeFs::new(cx.executor().clone());
let project = Project::test(fs.clone(), [], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let agent = cx.new(|_| Agent::new(project.clone(), action_log.clone()));
let model = cx
.update(|cx| {
gpui_tokio::init(cx);
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
cx.set_http_client(Arc::new(http_client));
client::init_settings(cx);
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
let models = LanguageModelRegistry::read_global(cx);
let model = models
.available_models(cx)
.find(|model| model.id().0 == "claude-3-7-sonnet-latest")
.unwrap();
let provider = models.provider(&model.provider_id()).unwrap();
let authenticated = provider.authenticate(cx);
cx.spawn(async move |cx| {
authenticated.await.unwrap();
model
})
})
.await;
AgentTest {
model,
agent: agent,
}
}
#[cfg(test)]
#[ctor::ctor]
fn init_logger() {
if std::env::var("RUST_LOG").is_ok() {
env_logger::init();
}
}

View File

@@ -0,0 +1,102 @@
use super::*;
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct EchoToolInput {
text: String,
}
pub struct EchoTool;
impl Tool for EchoTool {
fn name(&self) -> String {
"echo".to_string()
}
fn description(&self) -> String {
"A tool that echoes its input".to_string()
}
fn icon(&self) -> IconName {
IconName::Ai
}
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &gpui::App) -> bool {
false
}
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Echo".to_string()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
_project: gpui::Entity<Project>,
_action_log: gpui::Entity<assistant_tool::ActionLog>,
cx: &mut gpui::App,
) -> ToolResult {
ToolResult {
output: cx.foreground_executor().spawn(async move {
let input: EchoToolInput = serde_json::from_value(input)?;
Ok(input.text)
}),
card: None,
}
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
assistant_tools::json_schema_for::<EchoToolInput>(format)
}
}
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct DelayToolInput {
ms: u64,
}
pub struct DelayTool;
impl Tool for DelayTool {
fn name(&self) -> String {
"delay".to_string()
}
fn description(&self) -> String {
"A tool that waits for a specified delay".to_string()
}
fn icon(&self) -> IconName {
IconName::Cog
}
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &gpui::App) -> bool {
false
}
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Delay".to_string()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
_project: gpui::Entity<Project>,
_action_log: gpui::Entity<assistant_tool::ActionLog>,
cx: &mut gpui::App,
) -> ToolResult {
ToolResult {
output: cx.foreground_executor().spawn(async move {
let input: DelayToolInput = serde_json::from_value(input)?;
smol::Timer::after(Duration::from_millis(input.ms)).await;
Ok("Ding".to_string())
}),
card: None,
}
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
assistant_tools::json_schema_for::<DelayToolInput>(format)
}
}

View File

@@ -416,6 +416,7 @@ pub async fn stream_completion_with_rate_limit_info(
let beta_headers = Model::from_id(&request.base.model)
.map(|model| model.beta_headers())
.unwrap_or_else(|_err| Model::DEFAULT_BETA_HEADERS.join(","));
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
@@ -423,6 +424,7 @@ pub async fn stream_completion_with_rate_limit_info(
.header("Anthropic-Beta", beta_headers)
.header("X-Api-Key", api_key)
.header("Content-Type", "application/json");
let serialized_request =
serde_json::to_string(&request).context("failed to serialize request")?;
let request = request_builder

View File

@@ -14,10 +14,10 @@ use gpui::Context;
use gpui::IntoElement;
use gpui::Window;
use gpui::{App, Entity, SharedString, Task};
use icons::IconName;
pub use icons::IconName;
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
use project::Project;
pub use project::Project;
pub use crate::action_log::*;
pub use crate::tool_registry::*;

View File

@@ -54,6 +54,7 @@ use crate::rename_tool::RenameTool;
use crate::symbol_info_tool::SymbolInfoTool;
use crate::terminal_tool::TerminalTool;
use crate::thinking_tool::ThinkingTool;
pub use schema::json_schema_for;
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
assistant_tool::init(cx);

View File

@@ -178,7 +178,14 @@ impl TestAppContext {
&self.foreground_executor
}
fn new<T: 'static>(&mut self, build_entity: impl FnOnce(&mut Context<T>) -> T) -> Entity<T> {
/// Builds an entity that is owned by the application.
///
/// The given function will be invoked with a [`Context`] and must return an object representing the entity. An
/// [`Entity`] handle will be returned, which can be used to access the entity in a context.
pub fn new<T: 'static>(
&mut self,
build_entity: impl FnOnce(&mut Context<T>) -> T,
) -> Entity<T> {
let mut cx = self.app.borrow_mut();
cx.new(build_entity)
}

View File

@@ -95,6 +95,13 @@ where
.spawn(self.log_tracked_err(*location))
.detach();
}
/// Convert a Task<Result<T, E>> to a Task<()> that logs all errors.
pub fn log_err_in_task(self, cx: &App) -> Task<Option<T>> {
let location = core::panic::Location::caller();
cx.foreground_executor()
.spawn(async move { self.log_tracked_err(*location).await })
}
}
impl<T> Future for Task<T> {

View File

@@ -72,6 +72,7 @@ pub enum LanguageModelCompletionEvent {
ToolUse(LanguageModelToolUse),
StartMessage {
message_id: String,
role: Role,
},
UsageUpdate(TokenUsage),
}
@@ -288,7 +289,7 @@ pub trait LanguageModel: Send + Sync {
if let Some(first_event) = events.next().await {
match first_event {
Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
Ok(LanguageModelCompletionEvent::StartMessage { message_id: id, .. }) => {
message_id = Some(id.clone());
}
Ok(LanguageModelCompletionEvent::Text(text)) => {

View File

@@ -197,6 +197,24 @@ impl From<&str> for MessageContent {
}
}
impl From<LanguageModelToolUse> for MessageContent {
fn from(value: LanguageModelToolUse) -> Self {
MessageContent::ToolUse(value)
}
}
impl From<LanguageModelImage> for MessageContent {
fn from(value: LanguageModelImage) -> Self {
MessageContent::Image(value)
}
}
impl From<LanguageModelToolResult> for MessageContent {
fn from(value: LanguageModelToolResult) -> Self {
MessageContent::ToolResult(value)
}
}
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
pub struct LanguageModelRequestMessage {
pub role: Role,

View File

@@ -750,6 +750,10 @@ pub fn map_to_language_model_completion_events(
))),
Ok(LanguageModelCompletionEvent::StartMessage {
message_id: message.id,
role: match message.role {
anthropic::Role::User => Role::User,
anthropic::Role::Assistant => Role::Assistant,
},
}),
],
state,