Compare commits
10 Commits
logging
...
test-drive
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
447eb8e1c9 | ||
|
|
e434117018 | ||
|
|
36271b79b3 | ||
|
|
41644a53cc | ||
|
|
08a9c4af09 | ||
|
|
3187f28405 | ||
|
|
101f3b100f | ||
|
|
39c8b7bf5f | ||
|
|
08b41252f6 | ||
|
|
152bbca238 |
30
Cargo.lock
generated
30
Cargo.lock
generated
@@ -128,6 +128,36 @@ dependencies = [
|
||||
"zed_llm_client",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"assistant_tools",
|
||||
"chrono",
|
||||
"client",
|
||||
"collections",
|
||||
"ctor",
|
||||
"env_logger 0.11.8",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"language_model",
|
||||
"language_models",
|
||||
"parking_lot",
|
||||
"project",
|
||||
"reqwest_client",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"thiserror 2.0.12",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.7.8"
|
||||
|
||||
@@ -3,6 +3,7 @@ resolver = "2"
|
||||
members = [
|
||||
"crates/activity_indicator",
|
||||
"crates/agent",
|
||||
"crates/agent2",
|
||||
"crates/anthropic",
|
||||
"crates/askpass",
|
||||
"crates/assets",
|
||||
|
||||
45
crates/agent2/Cargo.toml
Normal file
45
crates/agent2/Cargo.toml
Normal file
@@ -0,0 +1,45 @@
|
||||
[package]
|
||||
name = "agent2"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "GPL-3.0-or-later"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/agent2.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
assistant_tools.workspace = true
|
||||
chrono.workspace = true
|
||||
collections.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
thiserror.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
ctor.workspace = true
|
||||
client = { workspace = true, "features" = ["test-support"] }
|
||||
env_logger.workspace = true
|
||||
fs = { workspace = true, "features" = ["test-support"] }
|
||||
gpui = { workspace = true, "features" = ["test-support"] }
|
||||
gpui_tokio.workspace = true
|
||||
language_model = { workspace = true, "features" = ["test-support"] }
|
||||
project = { workspace = true, "features" = ["test-support"] }
|
||||
reqwest_client.workspace = true
|
||||
settings = { workspace = true, "features" = ["test-support"] }
|
||||
1
crates/agent2/LICENSE-GPL
Symbolic link
1
crates/agent2/LICENSE-GPL
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-GPL
|
||||
278
crates/agent2/src/agent2.rs
Normal file
278
crates/agent2/src/agent2.rs
Normal file
@@ -0,0 +1,278 @@
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use anyhow::Result;
|
||||
use assistant_tool::{ActionLog, Tool};
|
||||
use futures::{channel::mpsc, future};
|
||||
use gpui::{Context, Entity, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, MessageContent, Role, StopReason,
|
||||
};
|
||||
use project::Project;
|
||||
use smol::stream::StreamExt;
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AgentMessage {
|
||||
pub role: Role,
|
||||
pub content: Vec<MessageContent>,
|
||||
}
|
||||
|
||||
impl AgentMessage {
|
||||
fn to_request_message(&self) -> LanguageModelRequestMessage {
|
||||
LanguageModelRequestMessage {
|
||||
role: self.role,
|
||||
content: self.content.clone(),
|
||||
cache: false, // TODO: Figure out caching
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type AgentResponseEvent = LanguageModelCompletionEvent;
|
||||
|
||||
pub struct Agent {
|
||||
messages: Vec<AgentMessage>,
|
||||
/// Holds the task that handles agent interaction until the end of the turn.
|
||||
/// Survives across multiple requests as the model performs tool calls and
|
||||
/// we run tools, report their results.
|
||||
running_turn: Option<Task<()>>,
|
||||
tools: BTreeMap<Arc<str>, Arc<dyn Tool>>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
pub fn new(project: Entity<Project>, action_log: Entity<ActionLog>) -> Self {
|
||||
Self {
|
||||
messages: Vec::new(),
|
||||
running_turn: None,
|
||||
tools: BTreeMap::default(),
|
||||
project,
|
||||
action_log,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_tool(&mut self, tool: Arc<dyn Tool>) {
|
||||
let name = Arc::from(tool.name());
|
||||
self.tools.insert(name, tool);
|
||||
}
|
||||
|
||||
pub fn remove_tool(&mut self, name: &str) -> bool {
|
||||
self.tools.remove(name).is_some()
|
||||
}
|
||||
|
||||
/// Sending a message results in the model streaming a response, which could include tool calls.
|
||||
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
|
||||
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
|
||||
pub fn send(
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
content: impl Into<MessageContent>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
|
||||
cx.notify();
|
||||
let (events_tx, events_rx) = mpsc::unbounded();
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::User,
|
||||
content: vec![content.into()],
|
||||
});
|
||||
self.running_turn = Some(cx.spawn(async move |thread, cx| {
|
||||
let turn_result = async {
|
||||
// Perform one request, then keep looping if the model makes tool calls.
|
||||
loop {
|
||||
let request =
|
||||
thread.update(cx, |thread, _cx| thread.build_completion_request())?;
|
||||
|
||||
println!(
|
||||
"request: {}",
|
||||
serde_json::to_string_pretty(&request).unwrap()
|
||||
);
|
||||
|
||||
// Stream events, appending to messages and collecting up tool uses.
|
||||
let mut events = model.stream_completion(request, cx).await?;
|
||||
let mut tool_uses = Vec::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
tool_uses.extend(thread.handle_response_event(
|
||||
event,
|
||||
events_tx.clone(),
|
||||
cx,
|
||||
));
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
Err(error) => {
|
||||
events_tx.unbounded_send(Err(error)).ok();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no tool uses, the turn is done.
|
||||
if tool_uses.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// If there are tool uses, wait for their results to be
|
||||
// computed, then send them together in a single message on
|
||||
// the next loop iteration.
|
||||
let tool_results = future::join_all(tool_uses).await;
|
||||
thread
|
||||
.update(cx, |thread, _cx| {
|
||||
thread.messages.push(AgentMessage {
|
||||
role: Role::User,
|
||||
content: tool_results.into_iter().map(Into::into).collect(),
|
||||
});
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
if let Err(error) = turn_result {
|
||||
events_tx.unbounded_send(Err(error)).ok();
|
||||
}
|
||||
}));
|
||||
events_rx
|
||||
}
|
||||
|
||||
fn handle_response_event(
|
||||
&mut self,
|
||||
event: LanguageModelCompletionEvent,
|
||||
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
use LanguageModelCompletionEvent::*;
|
||||
events_tx.unbounded_send(Ok(event.clone())).ok();
|
||||
|
||||
match event {
|
||||
Text(new_text) => self.handle_text_event(new_text, cx),
|
||||
Thinking { text, signature } => {}
|
||||
ToolUse(tool_use) => {
|
||||
return Some(self.handle_tool_use_event(tool_use, cx));
|
||||
}
|
||||
StartMessage { message_id, role } => {
|
||||
self.messages.push(AgentMessage {
|
||||
role,
|
||||
content: Vec::new(),
|
||||
});
|
||||
}
|
||||
UsageUpdate(token_usage) => {}
|
||||
Stop(stop_reason) => self.handle_stop_event(stop_reason),
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn handle_stop_event(&mut self, stop_reason: StopReason) {
|
||||
match stop_reason {
|
||||
StopReason::EndTurn | StopReason::ToolUse => {}
|
||||
StopReason::MaxTokens => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
|
||||
if let Some(last_message) = self.messages.last_mut() {
|
||||
debug_assert!(last_message.role == Role::Assistant);
|
||||
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
|
||||
text.push_str(&new_text);
|
||||
} else {
|
||||
last_message.content.push(MessageContent::Text(new_text));
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
} else {
|
||||
todo!("does this happen in practice?");
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_tool_use_event(
|
||||
&mut self,
|
||||
tool_use: LanguageModelToolUse,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<LanguageModelToolResult> {
|
||||
if let Some(last_message) = self.messages.last_mut() {
|
||||
debug_assert!(last_message.role == Role::Assistant);
|
||||
last_message.content.push(tool_use.clone().into());
|
||||
cx.notify();
|
||||
} else {
|
||||
todo!("does this happen in practice?");
|
||||
}
|
||||
|
||||
if let Some(tool) = self.tools.get(&tool_use.name) {
|
||||
let pending_tool_result = tool.clone().run(
|
||||
tool_use.input,
|
||||
&self.build_request_messages(),
|
||||
self.project.clone(),
|
||||
self.action_log.clone(),
|
||||
cx,
|
||||
);
|
||||
|
||||
cx.foreground_executor().spawn(async move {
|
||||
match pending_tool_result.output.await {
|
||||
Ok(tool_output) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: false,
|
||||
content: Arc::from(tool_output),
|
||||
},
|
||||
Err(error) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
content: Arc::from(error.to_string()),
|
||||
},
|
||||
}
|
||||
})
|
||||
} else {
|
||||
Task::ready(LanguageModelToolResult {
|
||||
content: Arc::from(format!("No tool named {} exists", tool_use.name)),
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn build_completion_request(&self) -> LanguageModelRequest {
|
||||
LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
messages: self.build_request_messages(),
|
||||
tools: self
|
||||
.tools
|
||||
.values()
|
||||
.filter_map(|tool| {
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool.name(),
|
||||
description: tool.description(),
|
||||
input_schema: tool
|
||||
.input_schema(LanguageModelToolSchemaFormat::JsonSchema)
|
||||
.log_err()?,
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
|
||||
self.messages
|
||||
.iter()
|
||||
.map(|message| LanguageModelRequestMessage {
|
||||
role: message.role,
|
||||
content: message.content.clone(),
|
||||
cache: false,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
191
crates/agent2/src/tests.rs
Normal file
191
crates/agent2/src/tests.rs
Normal file
@@ -0,0 +1,191 @@
|
||||
use super::*;
|
||||
use assistant_tool::{IconName, Project, ToolResult};
|
||||
use client::{Client, UserStore};
|
||||
use fs::FakeFs;
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
mod tools;
|
||||
use tools::*;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_echo(cx: &mut TestAppContext) {
|
||||
let AgentTest { model, agent, .. } = agent_test(cx).await;
|
||||
|
||||
let events = agent
|
||||
.update(cx, |agent, cx| {
|
||||
agent.send(model.clone(), "Testing: Reply with 'Hello'", cx)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
agent.update(cx, |agent, _cx| {
|
||||
assert_eq!(
|
||||
agent.messages.last().unwrap().content,
|
||||
vec![MessageContent::Text("Hello".to_string())]
|
||||
);
|
||||
});
|
||||
assert_eq!(stop_events(events), vec![StopReason::EndTurn]);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_tool_calls(cx: &mut TestAppContext) {
|
||||
let AgentTest { model, agent, .. } = agent_test(cx).await;
|
||||
|
||||
// Test a tool calls that's likely to complete before streaming stops.
|
||||
let events = agent
|
||||
.update(cx, |agent, cx| {
|
||||
agent.add_tool(Arc::new(EchoTool));
|
||||
agent.send(
|
||||
model.clone(),
|
||||
"Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
assert_eq!(
|
||||
stop_events(events),
|
||||
vec![StopReason::ToolUse, StopReason::EndTurn]
|
||||
);
|
||||
|
||||
// Test a tool calls that's likely to complete after streaming stops.
|
||||
let events = agent
|
||||
.update(cx, |agent, cx| {
|
||||
agent.remove_tool(&EchoTool.name());
|
||||
agent.add_tool(Arc::new(DelayTool));
|
||||
agent.send(
|
||||
model.clone(),
|
||||
"Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
assert_eq!(
|
||||
stop_events(events),
|
||||
vec![StopReason::ToolUse, StopReason::EndTurn]
|
||||
);
|
||||
agent.update(cx, |agent, _cx| {
|
||||
assert!(agent
|
||||
.messages
|
||||
.last()
|
||||
.unwrap()
|
||||
.content
|
||||
.iter()
|
||||
.any(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
text.contains("Ding")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||
let AgentTest { model, agent, .. } = agent_test(cx).await;
|
||||
|
||||
// Test concurrent tool calls with different delay times
|
||||
let events = agent
|
||||
.update(cx, |agent, cx| {
|
||||
agent.add_tool(Arc::new(DelayTool));
|
||||
agent.send(
|
||||
model.clone(),
|
||||
"Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.map(|event| dbg!(event))
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
let stop_reasons = stop_events(events);
|
||||
assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]);
|
||||
|
||||
agent.update(cx, |agent, _cx| {
|
||||
let last_message = agent.messages.last().unwrap();
|
||||
let text = last_message
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<String>();
|
||||
|
||||
assert!(text.contains("Ding"));
|
||||
});
|
||||
}
|
||||
|
||||
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<StopReason> {
|
||||
result_events
|
||||
.into_iter()
|
||||
.filter_map(|event| match event.unwrap() {
|
||||
LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
struct AgentTest {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
agent: Entity<Agent>,
|
||||
}
|
||||
|
||||
async fn agent_test(cx: &mut TestAppContext) -> AgentTest {
|
||||
cx.executor().allow_parking();
|
||||
cx.update(settings::init);
|
||||
let fs = FakeFs::new(cx.executor().clone());
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let agent = cx.new(|_| Agent::new(project.clone(), action_log.clone()));
|
||||
|
||||
let model = cx
|
||||
.update(|cx| {
|
||||
gpui_tokio::init(cx);
|
||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
client::init_settings(cx);
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
||||
|
||||
let models = LanguageModelRegistry::read_global(cx);
|
||||
let model = models
|
||||
.available_models(cx)
|
||||
.find(|model| model.id().0 == "claude-3-7-sonnet-latest")
|
||||
.unwrap();
|
||||
|
||||
let provider = models.provider(&model.provider_id()).unwrap();
|
||||
let authenticated = provider.authenticate(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
authenticated.await.unwrap();
|
||||
model
|
||||
})
|
||||
})
|
||||
.await;
|
||||
|
||||
AgentTest {
|
||||
model,
|
||||
agent: agent,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[ctor::ctor]
|
||||
fn init_logger() {
|
||||
if std::env::var("RUST_LOG").is_ok() {
|
||||
env_logger::init();
|
||||
}
|
||||
}
|
||||
102
crates/agent2/src/tests/tools.rs
Normal file
102
crates/agent2/src/tests/tools.rs
Normal file
@@ -0,0 +1,102 @@
|
||||
use super::*;
|
||||
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct EchoToolInput {
|
||||
text: String,
|
||||
}
|
||||
|
||||
pub struct EchoTool;
|
||||
|
||||
impl Tool for EchoTool {
|
||||
fn name(&self) -> String {
|
||||
"echo".to_string()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"A tool that echoes its input".to_string()
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::Ai
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &gpui::App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||
"Echo".to_string()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: gpui::Entity<Project>,
|
||||
_action_log: gpui::Entity<assistant_tool::ActionLog>,
|
||||
cx: &mut gpui::App,
|
||||
) -> ToolResult {
|
||||
ToolResult {
|
||||
output: cx.foreground_executor().spawn(async move {
|
||||
let input: EchoToolInput = serde_json::from_value(input)?;
|
||||
Ok(input.text)
|
||||
}),
|
||||
card: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
assistant_tools::json_schema_for::<EchoToolInput>(format)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct DelayToolInput {
|
||||
ms: u64,
|
||||
}
|
||||
|
||||
pub struct DelayTool;
|
||||
|
||||
impl Tool for DelayTool {
|
||||
fn name(&self) -> String {
|
||||
"delay".to_string()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"A tool that waits for a specified delay".to_string()
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::Cog
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &gpui::App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||
"Delay".to_string()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: gpui::Entity<Project>,
|
||||
_action_log: gpui::Entity<assistant_tool::ActionLog>,
|
||||
cx: &mut gpui::App,
|
||||
) -> ToolResult {
|
||||
ToolResult {
|
||||
output: cx.foreground_executor().spawn(async move {
|
||||
let input: DelayToolInput = serde_json::from_value(input)?;
|
||||
smol::Timer::after(Duration::from_millis(input.ms)).await;
|
||||
Ok("Ding".to_string())
|
||||
}),
|
||||
card: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
assistant_tools::json_schema_for::<DelayToolInput>(format)
|
||||
}
|
||||
}
|
||||
@@ -416,6 +416,7 @@ pub async fn stream_completion_with_rate_limit_info(
|
||||
let beta_headers = Model::from_id(&request.base.model)
|
||||
.map(|model| model.beta_headers())
|
||||
.unwrap_or_else(|_err| Model::DEFAULT_BETA_HEADERS.join(","));
|
||||
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
@@ -423,6 +424,7 @@ pub async fn stream_completion_with_rate_limit_info(
|
||||
.header("Anthropic-Beta", beta_headers)
|
||||
.header("X-Api-Key", api_key)
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
let serialized_request =
|
||||
serde_json::to_string(&request).context("failed to serialize request")?;
|
||||
let request = request_builder
|
||||
|
||||
@@ -14,10 +14,10 @@ use gpui::Context;
|
||||
use gpui::IntoElement;
|
||||
use gpui::Window;
|
||||
use gpui::{App, Entity, SharedString, Task};
|
||||
use icons::IconName;
|
||||
pub use icons::IconName;
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
use project::Project;
|
||||
pub use project::Project;
|
||||
|
||||
pub use crate::action_log::*;
|
||||
pub use crate::tool_registry::*;
|
||||
|
||||
@@ -54,6 +54,7 @@ use crate::rename_tool::RenameTool;
|
||||
use crate::symbol_info_tool::SymbolInfoTool;
|
||||
use crate::terminal_tool::TerminalTool;
|
||||
use crate::thinking_tool::ThinkingTool;
|
||||
pub use schema::json_schema_for;
|
||||
|
||||
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
assistant_tool::init(cx);
|
||||
|
||||
@@ -178,7 +178,14 @@ impl TestAppContext {
|
||||
&self.foreground_executor
|
||||
}
|
||||
|
||||
fn new<T: 'static>(&mut self, build_entity: impl FnOnce(&mut Context<T>) -> T) -> Entity<T> {
|
||||
/// Builds an entity that is owned by the application.
|
||||
///
|
||||
/// The given function will be invoked with a [`Context`] and must return an object representing the entity. An
|
||||
/// [`Entity`] handle will be returned, which can be used to access the entity in a context.
|
||||
pub fn new<T: 'static>(
|
||||
&mut self,
|
||||
build_entity: impl FnOnce(&mut Context<T>) -> T,
|
||||
) -> Entity<T> {
|
||||
let mut cx = self.app.borrow_mut();
|
||||
cx.new(build_entity)
|
||||
}
|
||||
|
||||
@@ -95,6 +95,13 @@ where
|
||||
.spawn(self.log_tracked_err(*location))
|
||||
.detach();
|
||||
}
|
||||
|
||||
/// Convert a Task<Result<T, E>> to a Task<()> that logs all errors.
|
||||
pub fn log_err_in_task(self, cx: &App) -> Task<Option<T>> {
|
||||
let location = core::panic::Location::caller();
|
||||
cx.foreground_executor()
|
||||
.spawn(async move { self.log_tracked_err(*location).await })
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Future for Task<T> {
|
||||
|
||||
@@ -72,6 +72,7 @@ pub enum LanguageModelCompletionEvent {
|
||||
ToolUse(LanguageModelToolUse),
|
||||
StartMessage {
|
||||
message_id: String,
|
||||
role: Role,
|
||||
},
|
||||
UsageUpdate(TokenUsage),
|
||||
}
|
||||
@@ -288,7 +289,7 @@ pub trait LanguageModel: Send + Sync {
|
||||
|
||||
if let Some(first_event) = events.next().await {
|
||||
match first_event {
|
||||
Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
|
||||
Ok(LanguageModelCompletionEvent::StartMessage { message_id: id, .. }) => {
|
||||
message_id = Some(id.clone());
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Text(text)) => {
|
||||
|
||||
@@ -197,6 +197,24 @@ impl From<&str> for MessageContent {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LanguageModelToolUse> for MessageContent {
|
||||
fn from(value: LanguageModelToolUse) -> Self {
|
||||
MessageContent::ToolUse(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LanguageModelImage> for MessageContent {
|
||||
fn from(value: LanguageModelImage) -> Self {
|
||||
MessageContent::Image(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LanguageModelToolResult> for MessageContent {
|
||||
fn from(value: LanguageModelToolResult) -> Self {
|
||||
MessageContent::ToolResult(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
|
||||
pub struct LanguageModelRequestMessage {
|
||||
pub role: Role,
|
||||
|
||||
@@ -750,6 +750,10 @@ pub fn map_to_language_model_completion_events(
|
||||
))),
|
||||
Ok(LanguageModelCompletionEvent::StartMessage {
|
||||
message_id: message.id,
|
||||
role: match message.role {
|
||||
anthropic::Role::User => Role::User,
|
||||
anthropic::Role::Assistant => Role::Assistant,
|
||||
},
|
||||
}),
|
||||
],
|
||||
state,
|
||||
|
||||
Reference in New Issue
Block a user