Checkpoint
This commit is contained in:
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -133,9 +133,12 @@ name = "agent_thread"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"assistant_tools",
|
||||
"chrono",
|
||||
"client",
|
||||
"collections",
|
||||
"env_logger 0.11.8",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
@@ -143,6 +146,7 @@ dependencies = [
|
||||
"language_model",
|
||||
"language_models",
|
||||
"parking_lot",
|
||||
"project",
|
||||
"reqwest_client",
|
||||
"schemars",
|
||||
"serde",
|
||||
|
||||
@@ -13,6 +13,8 @@ workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
assistant_tools.workspace = true
|
||||
chrono.workspace = true
|
||||
collections.workspace = true
|
||||
fs.workspace = true
|
||||
@@ -21,6 +23,7 @@ gpui.workspace = true
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
@@ -31,9 +34,11 @@ util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
client = { workspace = true, "features" = ["test-support"] }
|
||||
env_logger.workspace = true
|
||||
fs = { workspace = true, "features" = ["test-support"] }
|
||||
gpui = { workspace = true, "features" = ["test-support"] }
|
||||
gpui_tokio.workspace = true
|
||||
language_model = { workspace = true, "features" = ["test-support"] }
|
||||
project = { workspace = true, "features" = ["test-support"] }
|
||||
reqwest_client.workspace = true
|
||||
settings = { workspace = true, "features" = ["test-support"] }
|
||||
|
||||
@@ -1,196 +1,303 @@
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{Context, Task};
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use anyhow::Result;
|
||||
use assistant_tool::{ActionLog, Tool};
|
||||
use futures::{channel::mpsc, stream::FuturesUnordered};
|
||||
use gpui::{Context, Entity, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
MessageContent, Role,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
|
||||
};
|
||||
use project::Project;
|
||||
use smol::stream::StreamExt;
|
||||
use std::{future::Future, sync::Arc};
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
pub struct ThreadMessage {
|
||||
#[derive(Debug)]
|
||||
pub struct AgentMessage {
|
||||
pub role: Role,
|
||||
pub content: Vec<MessageContent>,
|
||||
}
|
||||
|
||||
pub struct Thread {
|
||||
messages: Vec<ThreadMessage>,
|
||||
streaming_completion: Option<Task<Option<()>>>,
|
||||
impl AgentMessage {
|
||||
fn to_request_message(&self) -> LanguageModelRequestMessage {
|
||||
LanguageModelRequestMessage {
|
||||
role: self.role,
|
||||
content: self.content.clone(),
|
||||
cache: false, // TODO: Figure out caching
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
pub fn new() -> Self {
|
||||
pub type AgentResponseEvent = LanguageModelCompletionEvent;
|
||||
|
||||
pub struct AgentThread {
|
||||
sent: Vec<AgentMessage>,
|
||||
unsent: Vec<AgentMessage>,
|
||||
streaming: Option<Task<Option<()>>>,
|
||||
tools: BTreeMap<Arc<str>, Arc<dyn Tool>>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
}
|
||||
|
||||
impl AgentThread {
|
||||
pub fn new(project: Entity<Project>, action_log: Entity<ActionLog>) -> Self {
|
||||
Self {
|
||||
messages: Vec::new(),
|
||||
streaming_completion: None,
|
||||
sent: Vec::new(),
|
||||
unsent: Vec::new(),
|
||||
streaming: None,
|
||||
tools: BTreeMap::default(),
|
||||
project,
|
||||
action_log,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_user_message(&mut self, text: impl Into<String>, cx: &mut Context<Self>) {
|
||||
self.messages.push(ThreadMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::Text(text.into())],
|
||||
});
|
||||
|
||||
cx.notify();
|
||||
pub fn add_tool(&mut self, tool: Arc<dyn Tool>) {
|
||||
let name = Arc::from(tool.name());
|
||||
self.tools.insert(name, tool);
|
||||
}
|
||||
|
||||
pub fn stream_completion(
|
||||
/// Cancels in-flight streaming, aborting any pending tool calls.
|
||||
pub fn cancel_streaming(&mut self, cx: &mut Context<Self>) -> bool {
|
||||
self.unsent.clear();
|
||||
self.streaming.take().is_some()
|
||||
}
|
||||
|
||||
/// Sending a message results in the model streaming a response, which could include tool calls.
|
||||
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
|
||||
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
|
||||
pub fn send(
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
content: impl Into<MessageContent>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl Future<Output = ()> {
|
||||
let request = self.to_completion_request();
|
||||
let (done_tx, done_rx) = futures::channel::oneshot::channel();
|
||||
let mut done_tx = Some(done_tx);
|
||||
self.streaming_completion = Some(
|
||||
cx.spawn(async move |thread, cx| {
|
||||
let mut events = model.stream_completion(request, cx).await?;
|
||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
|
||||
self.cancel_streaming(cx);
|
||||
let events = mpsc::unbounded();
|
||||
self.enqueue_unsent(model, content, events.0, cx);
|
||||
events.1
|
||||
}
|
||||
|
||||
while let Some(event) = events.next().await {
|
||||
if let Some(event) = event.log_err() {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.handle_streamed_event(event, &mut done_tx, cx)
|
||||
})
|
||||
.ok();
|
||||
/// Internal method which is called by send and also any tool call results.
|
||||
/// If currently streaming a completion, these events will be sent when the streaming stops.
|
||||
fn enqueue_unsent(
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
content: impl Into<MessageContent>,
|
||||
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent>>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
cx.notify();
|
||||
|
||||
self.unsent.push(AgentMessage {
|
||||
role: Role::User,
|
||||
content: vec![content.into()],
|
||||
});
|
||||
if !dbg!(self.streaming.is_some()) {
|
||||
self.flush_unsent_messages(model, events_tx, cx)
|
||||
}
|
||||
}
|
||||
|
||||
fn flush_unsent_messages(
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent>>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
cx.notify();
|
||||
self.streaming = Some(
|
||||
cx.spawn(async move |thread, cx| {
|
||||
let mut subtasks = FuturesUnordered::new();
|
||||
|
||||
// Perform completion requests until the unsent messages are empty.
|
||||
loop {
|
||||
let unsent =
|
||||
thread.update(cx, |thread, _cx| std::mem::take(&mut thread.unsent))?;
|
||||
|
||||
if unsent.is_empty() {
|
||||
thread.update(cx, |thread, _cx| thread.streaming.take())?;
|
||||
break;
|
||||
}
|
||||
|
||||
let completion_request = thread.update(cx, |thread, _cx| {
|
||||
thread.sent.extend(unsent);
|
||||
thread.build_completion_request()
|
||||
})?;
|
||||
|
||||
let mut events = model.stream_completion(completion_request, cx).await?;
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
let subtask = thread.handle_stream_event(
|
||||
&model,
|
||||
event,
|
||||
events_tx.clone(),
|
||||
cx,
|
||||
);
|
||||
subtasks.extend(subtask);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
Err(error) => {
|
||||
events_tx.unbounded_send(Err(error)).ok();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for any tasks we spawned to enqueue tool results before looping again.
|
||||
subtasks.next().await;
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.log_err_in_task(cx),
|
||||
);
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
async move {
|
||||
done_rx.await.ok();
|
||||
fn handle_stream_event(
|
||||
&mut self,
|
||||
model: &Arc<dyn LanguageModel>,
|
||||
event: LanguageModelCompletionEvent,
|
||||
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<()>> {
|
||||
use LanguageModelCompletionEvent::*;
|
||||
events_tx.unbounded_send(Ok(event.clone())).ok();
|
||||
|
||||
match dbg!(event) {
|
||||
Text(new_text) => self.handle_text_event(new_text, cx),
|
||||
Thinking { text, signature } => {
|
||||
dbg!(text, signature);
|
||||
}
|
||||
ToolUse(tool_use) => {
|
||||
return self.handle_tool_use_event(model.clone(), tool_use, events_tx, cx);
|
||||
}
|
||||
StartMessage { message_id, role } => {
|
||||
self.sent.push(AgentMessage {
|
||||
role,
|
||||
content: Vec::new(),
|
||||
});
|
||||
}
|
||||
UsageUpdate(token_usage) => {}
|
||||
Stop(stop_reason) => {}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
|
||||
if let Some(last_message) = self.sent.last_mut() {
|
||||
debug_assert!(last_message.role == Role::Assistant);
|
||||
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
|
||||
text.push_str(&new_text);
|
||||
} else {
|
||||
last_message.content.push(MessageContent::Text(new_text));
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
} else {
|
||||
todo!("does this happen in practice?");
|
||||
}
|
||||
}
|
||||
|
||||
fn to_completion_request(&self) -> LanguageModelRequest {
|
||||
fn handle_tool_use_event(
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
tool_use: LanguageModelToolUse,
|
||||
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<()>> {
|
||||
if let Some(last_message) = self.sent.last_mut() {
|
||||
debug_assert!(last_message.role == Role::Assistant);
|
||||
last_message.content.push(tool_use.clone().into());
|
||||
cx.notify();
|
||||
} else {
|
||||
todo!("does this happen in practice?");
|
||||
}
|
||||
|
||||
if let Some(tool) = self.tools.get(&tool_use.name) {
|
||||
let pending_tool_result = tool.clone().run(
|
||||
tool_use.input,
|
||||
&self.build_request_messages(),
|
||||
self.project.clone(),
|
||||
self.action_log.clone(),
|
||||
cx,
|
||||
);
|
||||
|
||||
Some(cx.spawn(async move |thread, cx| {
|
||||
let tool_result = match pending_tool_result.output.await {
|
||||
Ok(tool_output) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: false,
|
||||
content: Arc::from(tool_output),
|
||||
},
|
||||
Err(error) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
content: Arc::from(error.to_string()),
|
||||
},
|
||||
};
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.enqueue_unsent(model, tool_result, events_tx, cx)
|
||||
})
|
||||
.ok();
|
||||
}))
|
||||
} else {
|
||||
self.enqueue_unsent(
|
||||
model,
|
||||
LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
content: Arc::from("tool does not exist"),
|
||||
},
|
||||
events_tx,
|
||||
cx,
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn build_completion_request(&self) -> LanguageModelRequest {
|
||||
LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
messages: self
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| LanguageModelRequestMessage {
|
||||
role: message.role,
|
||||
content: message.content.clone(),
|
||||
cache: false,
|
||||
messages: self.build_request_messages(),
|
||||
tools: self
|
||||
.tools
|
||||
.values()
|
||||
.filter_map(|tool| {
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool.name(),
|
||||
description: tool.description(),
|
||||
input_schema: tool
|
||||
.input_schema(LanguageModelToolSchemaFormat::JsonSchema)
|
||||
.log_err()?,
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_streamed_event(
|
||||
&mut self,
|
||||
event: LanguageModelCompletionEvent,
|
||||
done_tx: &mut Option<oneshot::Sender<()>>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
use LanguageModelCompletionEvent::*;
|
||||
|
||||
match event {
|
||||
Stop(stop_reason) => {
|
||||
done_tx.take().map(|tx| tx.send(()));
|
||||
}
|
||||
Text(new_text) => {
|
||||
if let Some(last_message) = self.messages.last_mut() {
|
||||
debug_assert!(last_message.role == Role::Assistant);
|
||||
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
|
||||
text.push_str(&new_text);
|
||||
} else {
|
||||
last_message.content.push(MessageContent::Text(new_text));
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
} else {
|
||||
todo!("does this happen in practice?")
|
||||
}
|
||||
}
|
||||
Thinking { text, signature } => {
|
||||
dbg!(text, signature);
|
||||
}
|
||||
ToolUse(language_model_tool_use) => {
|
||||
dbg!(language_model_tool_use);
|
||||
}
|
||||
StartMessage { message_id, role } => {
|
||||
dbg!(message_id, role);
|
||||
|
||||
self.messages.push(ThreadMessage {
|
||||
role,
|
||||
content: Vec::new(),
|
||||
});
|
||||
}
|
||||
UsageUpdate(token_usage) => {
|
||||
dbg!(token_usage);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use client::{Client, UserStore};
|
||||
use fs::FakeFs;
|
||||
use gpui::{App, AppContext, TestAppContext};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use reqwest_client::ReqwestClient;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_basic_threads(cx: &mut TestAppContext) {
|
||||
let model = init_test(cx).await;
|
||||
let thread = cx.new(|_cx| Thread::new());
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.push_user_message("Testing: Reply with 'Hello'", cx);
|
||||
thread.stream_completion(model, cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.messages.last().unwrap().content,
|
||||
vec![MessageContent::Text("Hello".to_string())]
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) -> Task<Arc<dyn LanguageModel>> {
|
||||
cx.executor().allow_parking();
|
||||
cx.update(|cx| {
|
||||
gpui_tokio::init(cx);
|
||||
let http_client = ReqwestClient::user_agent("agent thread tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
settings::init(cx);
|
||||
client::init_settings(cx);
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
let fs = FakeFs::new(cx.background_executor().clone());
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
||||
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
let model = registry
|
||||
.available_models(cx)
|
||||
.find(|model| model.id().0 == "claude-3-7-sonnet-latest")
|
||||
.unwrap();
|
||||
|
||||
let provider = registry.provider(&model.provider_id()).unwrap();
|
||||
let authenticated = provider.authenticate(cx);
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
authenticated.await.unwrap();
|
||||
model
|
||||
})
|
||||
})
|
||||
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
|
||||
self.sent
|
||||
.iter()
|
||||
.map(|message| LanguageModelRequestMessage {
|
||||
role: message.role,
|
||||
content: message.content.clone(),
|
||||
cache: false,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,10 +14,10 @@ use gpui::Context;
|
||||
use gpui::IntoElement;
|
||||
use gpui::Window;
|
||||
use gpui::{App, Entity, SharedString, Task};
|
||||
use icons::IconName;
|
||||
pub use icons::IconName;
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
use project::Project;
|
||||
pub use project::Project;
|
||||
|
||||
pub use crate::action_log::*;
|
||||
pub use crate::tool_registry::*;
|
||||
|
||||
@@ -54,6 +54,7 @@ use crate::rename_tool::RenameTool;
|
||||
use crate::symbol_info_tool::SymbolInfoTool;
|
||||
use crate::terminal_tool::TerminalTool;
|
||||
use crate::thinking_tool::ThinkingTool;
|
||||
pub use schema::json_schema_for;
|
||||
|
||||
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
assistant_tool::init(cx);
|
||||
|
||||
@@ -197,6 +197,24 @@ impl From<&str> for MessageContent {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LanguageModelToolUse> for MessageContent {
|
||||
fn from(value: LanguageModelToolUse) -> Self {
|
||||
MessageContent::ToolUse(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LanguageModelImage> for MessageContent {
|
||||
fn from(value: LanguageModelImage) -> Self {
|
||||
MessageContent::Image(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LanguageModelToolResult> for MessageContent {
|
||||
fn from(value: LanguageModelToolResult) -> Self {
|
||||
MessageContent::ToolResult(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
|
||||
pub struct LanguageModelRequestMessage {
|
||||
pub role: Role,
|
||||
|
||||
Reference in New Issue
Block a user