Checkpoint

This commit is contained in:
Nathan Sobo
2025-04-20 17:28:44 -06:00
parent 101f3b100f
commit 3187f28405
6 changed files with 288 additions and 153 deletions

4
Cargo.lock generated
View File

@@ -133,9 +133,12 @@ name = "agent_thread"
version = "0.1.0"
dependencies = [
"anyhow",
"assistant_tool",
"assistant_tools",
"chrono",
"client",
"collections",
"env_logger 0.11.8",
"fs",
"futures 0.3.31",
"gpui",
@@ -143,6 +146,7 @@ dependencies = [
"language_model",
"language_models",
"parking_lot",
"project",
"reqwest_client",
"schemars",
"serde",

View File

@@ -13,6 +13,8 @@ workspace = true
[dependencies]
anyhow.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
chrono.workspace = true
collections.workspace = true
fs.workspace = true
@@ -21,6 +23,7 @@ gpui.workspace = true
language_model.workspace = true
language_models.workspace = true
parking_lot.workspace = true
project.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
@@ -31,9 +34,11 @@ util.workspace = true
[dev-dependencies]
client = { workspace = true, "features" = ["test-support"] }
env_logger.workspace = true
fs = { workspace = true, "features" = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] }
gpui_tokio.workspace = true
language_model = { workspace = true, "features" = ["test-support"] }
project = { workspace = true, "features" = ["test-support"] }
reqwest_client.workspace = true
settings = { workspace = true, "features" = ["test-support"] }

View File

@@ -1,196 +1,303 @@
use futures::channel::oneshot;
use gpui::{Context, Task};
#[cfg(test)]
mod tests;
use anyhow::Result;
use assistant_tool::{ActionLog, Tool};
use futures::{channel::mpsc, stream::FuturesUnordered};
use gpui::{Context, Entity, Task};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
MessageContent, Role,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
};
use project::Project;
use smol::stream::StreamExt;
use std::{future::Future, sync::Arc};
use std::{collections::BTreeMap, sync::Arc};
use util::ResultExt;
pub struct ThreadMessage {
#[derive(Debug)]
pub struct AgentMessage {
pub role: Role,
pub content: Vec<MessageContent>,
}
pub struct Thread {
messages: Vec<ThreadMessage>,
streaming_completion: Option<Task<Option<()>>>,
impl AgentMessage {
fn to_request_message(&self) -> LanguageModelRequestMessage {
LanguageModelRequestMessage {
role: self.role,
content: self.content.clone(),
cache: false, // TODO: Figure out caching
}
}
}
impl Thread {
pub fn new() -> Self {
pub type AgentResponseEvent = LanguageModelCompletionEvent;
pub struct AgentThread {
sent: Vec<AgentMessage>,
unsent: Vec<AgentMessage>,
streaming: Option<Task<Option<()>>>,
tools: BTreeMap<Arc<str>, Arc<dyn Tool>>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
}
impl AgentThread {
pub fn new(project: Entity<Project>, action_log: Entity<ActionLog>) -> Self {
Self {
messages: Vec::new(),
streaming_completion: None,
sent: Vec::new(),
unsent: Vec::new(),
streaming: None,
tools: BTreeMap::default(),
project,
action_log,
}
}
pub fn push_user_message(&mut self, text: impl Into<String>, cx: &mut Context<Self>) {
self.messages.push(ThreadMessage {
role: Role::User,
content: vec![MessageContent::Text(text.into())],
});
cx.notify();
pub fn add_tool(&mut self, tool: Arc<dyn Tool>) {
let name = Arc::from(tool.name());
self.tools.insert(name, tool);
}
pub fn stream_completion(
/// Cancels in-flight streaming, aborting any pending tool calls.
pub fn cancel_streaming(&mut self, cx: &mut Context<Self>) -> bool {
self.unsent.clear();
self.streaming.take().is_some()
}
/// Sending a message results in the model streaming a response, which could include tool calls.
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
pub fn send(
&mut self,
model: Arc<dyn LanguageModel>,
content: impl Into<MessageContent>,
cx: &mut Context<Self>,
) -> impl Future<Output = ()> {
let request = self.to_completion_request();
let (done_tx, done_rx) = futures::channel::oneshot::channel();
let mut done_tx = Some(done_tx);
self.streaming_completion = Some(
cx.spawn(async move |thread, cx| {
let mut events = model.stream_completion(request, cx).await?;
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
self.cancel_streaming(cx);
let events = mpsc::unbounded();
self.enqueue_unsent(model, content, events.0, cx);
events.1
}
while let Some(event) = events.next().await {
if let Some(event) = event.log_err() {
thread
.update(cx, |thread, cx| {
thread.handle_streamed_event(event, &mut done_tx, cx)
})
.ok();
/// Internal method which is called by send and also any tool call results.
/// If currently streaming a completion, these events will be sent when the streaming stops.
fn enqueue_unsent(
&mut self,
model: Arc<dyn LanguageModel>,
content: impl Into<MessageContent>,
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent>>,
cx: &mut Context<Self>,
) {
cx.notify();
self.unsent.push(AgentMessage {
role: Role::User,
content: vec![content.into()],
});
if !dbg!(self.streaming.is_some()) {
self.flush_unsent_messages(model, events_tx, cx)
}
}
fn flush_unsent_messages(
&mut self,
model: Arc<dyn LanguageModel>,
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent>>,
cx: &mut Context<Self>,
) {
cx.notify();
self.streaming = Some(
cx.spawn(async move |thread, cx| {
let mut subtasks = FuturesUnordered::new();
// Perform completion requests until the unsent messages are empty.
loop {
let unsent =
thread.update(cx, |thread, _cx| std::mem::take(&mut thread.unsent))?;
if unsent.is_empty() {
thread.update(cx, |thread, _cx| thread.streaming.take())?;
break;
}
let completion_request = thread.update(cx, |thread, _cx| {
thread.sent.extend(unsent);
thread.build_completion_request()
})?;
let mut events = model.stream_completion(completion_request, cx).await?;
while let Some(event) = events.next().await {
match event {
Ok(event) => {
thread
.update(cx, |thread, cx| {
let subtask = thread.handle_stream_event(
&model,
event,
events_tx.clone(),
cx,
);
subtasks.extend(subtask);
})
.ok();
}
Err(error) => {
events_tx.unbounded_send(Err(error)).ok();
break;
}
}
}
// Wait for any tasks we spawned to enqueue tool results before looping again.
subtasks.next().await;
}
anyhow::Ok(())
})
.log_err_in_task(cx),
);
}
cx.notify();
async move {
done_rx.await.ok();
fn handle_stream_event(
&mut self,
model: &Arc<dyn LanguageModel>,
event: LanguageModelCompletionEvent,
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent>>,
cx: &mut Context<Self>,
) -> Option<Task<()>> {
use LanguageModelCompletionEvent::*;
events_tx.unbounded_send(Ok(event.clone())).ok();
match dbg!(event) {
Text(new_text) => self.handle_text_event(new_text, cx),
Thinking { text, signature } => {
dbg!(text, signature);
}
ToolUse(tool_use) => {
return self.handle_tool_use_event(model.clone(), tool_use, events_tx, cx);
}
StartMessage { message_id, role } => {
self.sent.push(AgentMessage {
role,
content: Vec::new(),
});
}
UsageUpdate(token_usage) => {}
Stop(stop_reason) => {}
}
None
}
fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
if let Some(last_message) = self.sent.last_mut() {
debug_assert!(last_message.role == Role::Assistant);
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
text.push_str(&new_text);
} else {
last_message.content.push(MessageContent::Text(new_text));
}
cx.notify();
} else {
todo!("does this happen in practice?");
}
}
fn to_completion_request(&self) -> LanguageModelRequest {
fn handle_tool_use_event(
&mut self,
model: Arc<dyn LanguageModel>,
tool_use: LanguageModelToolUse,
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent>>,
cx: &mut Context<Self>,
) -> Option<Task<()>> {
if let Some(last_message) = self.sent.last_mut() {
debug_assert!(last_message.role == Role::Assistant);
last_message.content.push(tool_use.clone().into());
cx.notify();
} else {
todo!("does this happen in practice?");
}
if let Some(tool) = self.tools.get(&tool_use.name) {
let pending_tool_result = tool.clone().run(
tool_use.input,
&self.build_request_messages(),
self.project.clone(),
self.action_log.clone(),
cx,
);
Some(cx.spawn(async move |thread, cx| {
let tool_result = match pending_tool_result.output.await {
Ok(tool_output) => LanguageModelToolResult {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: false,
content: Arc::from(tool_output),
},
Err(error) => LanguageModelToolResult {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: true,
content: Arc::from(error.to_string()),
},
};
thread
.update(cx, |thread, cx| {
thread.enqueue_unsent(model, tool_result, events_tx, cx)
})
.ok();
}))
} else {
self.enqueue_unsent(
model,
LanguageModelToolResult {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: true,
content: Arc::from("tool does not exist"),
},
events_tx,
cx,
);
None
}
}
fn build_completion_request(&self) -> LanguageModelRequest {
LanguageModelRequest {
thread_id: None,
prompt_id: None,
messages: self
.messages
.iter()
.map(|message| LanguageModelRequestMessage {
role: message.role,
content: message.content.clone(),
cache: false,
messages: self.build_request_messages(),
tools: self
.tools
.values()
.filter_map(|tool| {
Some(LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema: tool
.input_schema(LanguageModelToolSchemaFormat::JsonSchema)
.log_err()?,
})
})
.collect(),
tools: Vec::new(),
stop: Vec::new(),
temperature: None,
}
}
fn handle_streamed_event(
&mut self,
event: LanguageModelCompletionEvent,
done_tx: &mut Option<oneshot::Sender<()>>,
cx: &mut Context<Self>,
) {
use LanguageModelCompletionEvent::*;
match event {
Stop(stop_reason) => {
done_tx.take().map(|tx| tx.send(()));
}
Text(new_text) => {
if let Some(last_message) = self.messages.last_mut() {
debug_assert!(last_message.role == Role::Assistant);
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
text.push_str(&new_text);
} else {
last_message.content.push(MessageContent::Text(new_text));
}
cx.notify();
} else {
todo!("does this happen in practice?")
}
}
Thinking { text, signature } => {
dbg!(text, signature);
}
ToolUse(language_model_tool_use) => {
dbg!(language_model_tool_use);
}
StartMessage { message_id, role } => {
dbg!(message_id, role);
self.messages.push(ThreadMessage {
role,
content: Vec::new(),
});
}
UsageUpdate(token_usage) => {
dbg!(token_usage);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use client::{Client, UserStore};
use fs::FakeFs;
use gpui::{App, AppContext, TestAppContext};
use language_model::LanguageModelRegistry;
use reqwest_client::ReqwestClient;
#[gpui::test]
async fn test_basic_threads(cx: &mut TestAppContext) {
let model = init_test(cx).await;
let thread = cx.new(|_cx| Thread::new());
thread
.update(cx, |thread, cx| {
thread.push_user_message("Testing: Reply with 'Hello'", cx);
thread.stream_completion(model, cx)
})
.await;
thread.update(cx, |thread, _cx| {
assert_eq!(
thread.messages.last().unwrap().content,
vec![MessageContent::Text("Hello".to_string())]
);
});
}
fn init_test(cx: &mut TestAppContext) -> Task<Arc<dyn LanguageModel>> {
cx.executor().allow_parking();
cx.update(|cx| {
gpui_tokio::init(cx);
let http_client = ReqwestClient::user_agent("agent thread tests").unwrap();
cx.set_http_client(Arc::new(http_client));
settings::init(cx);
client::init_settings(cx);
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let fs = FakeFs::new(cx.background_executor().clone());
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
let registry = LanguageModelRegistry::read_global(cx);
let model = registry
.available_models(cx)
.find(|model| model.id().0 == "claude-3-7-sonnet-latest")
.unwrap();
let provider = registry.provider(&model.provider_id()).unwrap();
let authenticated = provider.authenticate(cx);
cx.spawn(async move |_cx| {
authenticated.await.unwrap();
model
})
})
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
self.sent
.iter()
.map(|message| LanguageModelRequestMessage {
role: message.role,
content: message.content.clone(),
cache: false,
})
.collect()
}
}

View File

@@ -14,10 +14,10 @@ use gpui::Context;
use gpui::IntoElement;
use gpui::Window;
use gpui::{App, Entity, SharedString, Task};
use icons::IconName;
pub use icons::IconName;
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
use project::Project;
pub use project::Project;
pub use crate::action_log::*;
pub use crate::tool_registry::*;

View File

@@ -54,6 +54,7 @@ use crate::rename_tool::RenameTool;
use crate::symbol_info_tool::SymbolInfoTool;
use crate::terminal_tool::TerminalTool;
use crate::thinking_tool::ThinkingTool;
pub use schema::json_schema_for;
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
assistant_tool::init(cx);

View File

@@ -197,6 +197,24 @@ impl From<&str> for MessageContent {
}
}
impl From<LanguageModelToolUse> for MessageContent {
fn from(value: LanguageModelToolUse) -> Self {
MessageContent::ToolUse(value)
}
}
impl From<LanguageModelImage> for MessageContent {
fn from(value: LanguageModelImage) -> Self {
MessageContent::Image(value)
}
}
impl From<LanguageModelToolResult> for MessageContent {
fn from(value: LanguageModelToolResult) -> Self {
MessageContent::ToolResult(value)
}
}
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
pub struct LanguageModelRequestMessage {
pub role: Role,