Compare commits

...

4 Commits

Author SHA1 Message Date
Antonio Scandurra
de8f1c3c84 WIP 2025-06-30 21:01:37 +02:00
Antonio Scandurra
9351e959a5 WIP 2025-06-30 19:32:42 +02:00
Antonio Scandurra
dac0838a80 WIP 2025-06-30 14:24:24 +02:00
Antonio Scandurra
38d5f36d38 WIP 2025-06-30 12:15:19 +02:00
20 changed files with 3977 additions and 6505 deletions

1
Cargo.lock generated
View File

@@ -78,6 +78,7 @@ dependencies = [
"language", "language",
"language_model", "language_model",
"log", "log",
"markdown",
"parking_lot", "parking_lot",
"paths", "paths",
"postage", "postage",

View File

@@ -42,6 +42,7 @@ itertools.workspace = true
language.workspace = true language.workspace = true
language_model.workspace = true language_model.workspace = true
log.workspace = true log.workspace = true
markdown.workspace = true
paths.workspace = true paths.workspace = true
postage.workspace = true postage.workspace = true
project.workspace = true project.workspace = true

View File

@@ -1,3 +1,4 @@
mod agent2;
pub mod agent_profile; pub mod agent_profile;
pub mod context; pub mod context;
pub mod context_server_tool; pub mod context_server_tool;
@@ -5,15 +6,17 @@ pub mod context_store;
pub mod history_store; pub mod history_store;
pub mod thread; pub mod thread;
pub mod thread_store; pub mod thread_store;
pub mod tool_use; mod zed_agent;
pub use agent2::*;
pub use context::{AgentContext, ContextId, ContextLoadResult}; pub use context::{AgentContext, ContextId, ContextLoadResult};
pub use context_store::ContextStore; pub use context_store::ContextStore;
pub use thread::{ pub use thread::{
LastRestoreCheckpoint, Message, MessageCrease, MessageId, MessageSegment, Thread, ThreadError, LastRestoreCheckpoint, Message, MessageCrease, MessageSegment, Thread, ThreadError,
ThreadEvent, ThreadFeedback, ThreadId, ThreadSummary, TokenUsageRatio, ThreadEvent, ThreadFeedback, ThreadTitle, TokenUsageRatio,
}; };
pub use thread_store::{SerializedThread, TextThreadStore, ThreadStore}; pub use thread_store::{SerializedThread, TextThreadStore, ThreadStore};
pub use zed_agent::*;
pub fn init(cx: &mut gpui::App) { pub fn init(cx: &mut gpui::App) {
thread_store::init(cx); thread_store::init(cx);

117
crates/agent/src/agent2.rs Normal file
View File

@@ -0,0 +1,117 @@
use anyhow::Result;
use assistant_tool::{Tool, ToolResultOutput};
use futures::{channel::oneshot, future::BoxFuture, stream::BoxStream};
use gpui::SharedString;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{
fmt::{self, Display},
sync::Arc,
};
#[derive(
Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, JsonSchema,
)]
pub struct ThreadId(SharedString);
impl ThreadId {
pub fn as_str(&self) -> &str {
&self.0
}
pub fn to_string(&self) -> String {
self.0.to_string()
}
}
impl From<&str> for ThreadId {
fn from(value: &str) -> Self {
ThreadId(SharedString::from(value.to_string()))
}
}
impl Display for ThreadId {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct MessageId(pub usize);
#[derive(Debug, Clone)]
pub struct AgentThreadToolCallId(SharedString);
pub enum AgentThreadResponseEvent {
Text(String),
Thinking(String),
ToolCallChunk {
id: AgentThreadToolCallId,
tool: Arc<dyn Tool>,
input: serde_json::Value,
},
ToolCall {
id: AgentThreadToolCallId,
tool: Arc<dyn Tool>,
input: serde_json::Value,
response_tx: oneshot::Sender<Result<ToolResultOutput>>,
},
}
pub enum AgentThreadMessage {
User {
id: MessageId,
chunks: Vec<AgentThreadUserMessageChunk>,
},
Assistant {
id: MessageId,
chunks: Vec<AgentThreadAssistantMessageChunk>,
},
}
pub enum AgentThreadUserMessageChunk {
Text(String),
// here's where we would put mentions, etc.
}
pub enum AgentThreadAssistantMessageChunk {
Text(String),
Thinking(String),
ToolResult {
id: AgentThreadToolCallId,
tool: Arc<dyn Tool>,
input: serde_json::Value,
output: Result<ToolResultOutput>,
},
}
pub struct AgentThreadResponse {
pub user_message_id: MessageId,
pub assistant_message_id: MessageId,
pub events: BoxStream<'static, Result<AgentThreadResponseEvent>>,
}
pub trait Agent {
fn create_thread() -> BoxFuture<'static, Result<Arc<dyn AgentThread>>>;
fn list_threads();
fn load_thread(&self, thread_id: ThreadId) -> BoxFuture<'static, Result<Arc<dyn AgentThread>>>;
}
pub trait AgentThread {
fn id(&self) -> ThreadId;
fn title(&self) -> BoxFuture<'static, Result<String>>;
fn summary(&self) -> BoxFuture<'static, Result<String>>;
fn messages(&self) -> BoxFuture<'static, Result<Vec<AgentThreadMessage>>>;
fn truncate(&self, message_id: MessageId) -> BoxFuture<'static, Result<()>>;
fn edit(
&self,
message_id: MessageId,
content: Vec<AgentThreadUserMessageChunk>,
max_iterations: usize,
) -> BoxFuture<'static, Result<AgentThreadResponse>>;
fn send(
&self,
content: Vec<AgentThreadUserMessageChunk>,
max_iterations: usize,
) -> BoxFuture<'static, Result<AgentThreadResponse>>;
}

View File

@@ -581,7 +581,7 @@ impl ThreadContextHandle {
} }
pub fn title(&self, cx: &App) -> SharedString { pub fn title(&self, cx: &App) -> SharedString {
self.thread.read(cx).summary().or_default() self.thread.read(cx).title().or_default()
} }
fn load(self, cx: &App) -> Task<Option<(AgentContext, Vec<Entity<Buffer>>)>> { fn load(self, cx: &App) -> Task<Option<(AgentContext, Vec<Entity<Buffer>>)>> {
@@ -589,7 +589,7 @@ impl ThreadContextHandle {
let text = Thread::wait_for_detailed_summary_or_text(&self.thread, cx).await?; let text = Thread::wait_for_detailed_summary_or_text(&self.thread, cx).await?;
let title = self let title = self
.thread .thread
.read_with(cx, |thread, _cx| thread.summary().or_default()) .read_with(cx, |thread, _cx| thread.title().or_default())
.ok()?; .ok()?;
let context = AgentContext::Thread(ThreadContext { let context = AgentContext::Thread(ThreadContext {
title, title,

View File

@@ -1,10 +1,11 @@
use crate::{ use crate::{
MessageId, ThreadId,
context::{ context::{
AgentContextHandle, AgentContextKey, ContextId, ContextKind, DirectoryContextHandle, AgentContextHandle, AgentContextKey, ContextId, ContextKind, DirectoryContextHandle,
FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle, FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle,
SelectionContextHandle, SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle, SelectionContextHandle, SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle,
}, },
thread::{MessageId, Thread, ThreadId}, thread::Thread,
thread_store::ThreadStore, thread_store::ThreadStore,
}; };
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
@@ -71,7 +72,8 @@ impl ContextStore {
) -> Vec<AgentContextHandle> { ) -> Vec<AgentContextHandle> {
let existing_context = thread let existing_context = thread
.messages() .messages()
.take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id)) .iter()
.take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id() != id))
.flat_map(|message| { .flat_map(|message| {
message message
.loaded_context .loaded_context
@@ -441,7 +443,7 @@ impl ContextStore {
match context { match context {
AgentContextHandle::Thread(thread_context) => { AgentContextHandle::Thread(thread_context) => {
self.context_thread_ids self.context_thread_ids
.remove(thread_context.thread.read(cx).id()); .remove(&thread_context.thread.read(cx).id());
} }
AgentContextHandle::TextThread(text_thread_context) => { AgentContextHandle::TextThread(text_thread_context) => {
if let Some(path) = text_thread_context.context.read(cx).path() { if let Some(path) = text_thread_context.context.read(cx).path() {

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,7 @@
use crate::{ use crate::{
MessageId, ThreadId,
context_server_tool::ContextServerTool, context_server_tool::ContextServerTool,
thread::{ thread::{DetailedSummaryState, ExceededWindowError, ProjectSnapshot, Thread},
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
},
}; };
use agent_settings::{AgentProfileId, CompletionMode}; use agent_settings::{AgentProfileId, CompletionMode};
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
@@ -400,35 +399,17 @@ impl ThreadStore {
self.threads.iter() self.threads.iter()
} }
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> { pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread>>> {
cx.new(|cx| { todo!()
Thread::new( // cx.new(|cx| {
self.project.clone(), // Thread::new(
self.tools.clone(), // self.project.clone(),
self.prompt_builder.clone(), // self.tools.clone(),
self.project_context.clone(), // self.prompt_builder.clone(),
cx, // self.project_context.clone(),
) // cx,
}) // )
} // })
pub fn create_thread_from_serialized(
&mut self,
serialized: SerializedThread,
cx: &mut Context<Self>,
) -> Entity<Thread> {
cx.new(|cx| {
Thread::deserialize(
ThreadId::new(),
serialized,
self.project.clone(),
self.tools.clone(),
self.prompt_builder.clone(),
self.project_context.clone(),
None,
cx,
)
})
} }
pub fn open_thread( pub fn open_thread(
@@ -447,51 +428,53 @@ impl ThreadStore {
.await? .await?
.with_context(|| format!("no thread found with ID: {id:?}"))?; .with_context(|| format!("no thread found with ID: {id:?}"))?;
let thread = this.update_in(cx, |this, window, cx| { todo!();
cx.new(|cx| { // let thread = this.update_in(cx, |this, window, cx| {
Thread::deserialize( // cx.new(|cx| {
id.clone(), // Thread::deserialize(
thread, // id.clone(),
this.project.clone(), // thread,
this.tools.clone(), // this.project.clone(),
this.prompt_builder.clone(), // this.tools.clone(),
this.project_context.clone(), // this.prompt_builder.clone(),
Some(window), // this.project_context.clone(),
cx, // Some(window),
) // cx,
}) // )
})?; // })
// })?;
Ok(thread) // Ok(thread)
}) })
} }
pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> { pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
let (metadata, serialized_thread) = todo!()
thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx))); // let (metadata, serialized_thread) =
// thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
let database_future = ThreadsDatabase::global_future(cx); // let database_future = ThreadsDatabase::global_future(cx);
cx.spawn(async move |this, cx| { // cx.spawn(async move |this, cx| {
let serialized_thread = serialized_thread.await?; // let serialized_thread = serialized_thread.await?;
let database = database_future.await.map_err(|err| anyhow!(err))?; // let database = database_future.await.map_err(|err| anyhow!(err))?;
database.save_thread(metadata, serialized_thread).await?; // database.save_thread(metadata, serialized_thread).await?;
this.update(cx, |this, cx| this.reload(cx))?.await // this.update(cx, |this, cx| this.reload(cx))?.await
}) // })
} }
pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> { pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
let id = id.clone(); todo!()
let database_future = ThreadsDatabase::global_future(cx); // let id = id.clone();
cx.spawn(async move |this, cx| { // let database_future = ThreadsDatabase::global_future(cx);
let database = database_future.await.map_err(|err| anyhow!(err))?; // cx.spawn(async move |this, cx| {
database.delete_thread(id.clone()).await?; // let database = database_future.await.map_err(|err| anyhow!(err))?;
// database.delete_thread(id.clone()).await?;
this.update(cx, |this, cx| { // this.update(cx, |this, cx| {
this.threads.retain(|thread| thread.id != id); // this.threads.retain(|thread| thread.id != id);
cx.notify(); // cx.notify();
}) // })
}) // })
} }
pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> { pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
@@ -1067,7 +1050,7 @@ impl ThreadsDatabase {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::thread::{DetailedSummaryState, MessageId}; use crate::{MessageId, thread::DetailedSummaryState};
use chrono::Utc; use chrono::Utc;
use language_model::{Role, TokenUsage}; use language_model::{Role, TokenUsage};
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;

View File

@@ -1,567 +0,0 @@
use crate::{
thread::{MessageId, PromptId, ThreadId},
thread_store::SerializedMessage,
};
use anyhow::Result;
use assistant_tool::{
AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
};
use collections::HashMap;
use futures::{FutureExt as _, future::Shared};
use gpui::{App, Entity, SharedString, Task, Window};
use icons::IconName;
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
};
use project::Project;
use std::sync::Arc;
use util::truncate_lines_to_byte_limit;
#[derive(Debug)]
pub struct ToolUse {
pub id: LanguageModelToolUseId,
pub name: SharedString,
pub ui_text: SharedString,
pub status: ToolUseStatus,
pub input: serde_json::Value,
pub icon: icons::IconName,
pub needs_confirmation: bool,
}
pub struct ToolUseState {
tools: Entity<ToolWorkingSet>,
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
}
impl ToolUseState {
pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
Self {
tools,
tool_uses_by_assistant_message: HashMap::default(),
tool_results: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(),
tool_result_cards: HashMap::default(),
tool_use_metadata_by_id: HashMap::default(),
}
}
/// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
///
/// Accepts a function to filter the tools that should be used to populate the state.
///
/// If `window` is `None` (e.g., when in headless mode or when running evals),
/// tool cards won't be deserialized
pub fn from_serialized_messages(
tools: Entity<ToolWorkingSet>,
messages: &[SerializedMessage],
project: Entity<Project>,
window: Option<&mut Window>, // None in headless mode
cx: &mut App,
) -> Self {
let mut this = Self::new(tools);
let mut tool_names_by_id = HashMap::default();
let mut window = window;
for message in messages {
match message.role {
Role::Assistant => {
if !message.tool_uses.is_empty() {
let tool_uses = message
.tool_uses
.iter()
.map(|tool_use| LanguageModelToolUse {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
raw_input: tool_use.input.to_string(),
input: tool_use.input.clone(),
is_input_complete: true,
})
.collect::<Vec<_>>();
tool_names_by_id.extend(
tool_uses
.iter()
.map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
);
this.tool_uses_by_assistant_message
.insert(message.id, tool_uses);
for tool_result in &message.tool_results {
let tool_use_id = tool_result.tool_use_id.clone();
let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
log::warn!("no tool name found for tool use: {tool_use_id:?}");
continue;
};
this.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name: tool_use.clone(),
is_error: tool_result.is_error,
content: tool_result.content.clone(),
output: tool_result.output.clone(),
},
);
if let Some(window) = &mut window {
if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
if let Some(output) = tool_result.output.clone() {
if let Some(card) = tool.deserialize_card(
output,
project.clone(),
window,
cx,
) {
this.tool_result_cards.insert(tool_use_id, card);
}
}
}
}
}
}
}
Role::System | Role::User => {}
}
}
this
}
pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
let mut cancelled_tool_uses = Vec::new();
self.pending_tool_uses_by_id
.retain(|tool_use_id, tool_use| {
if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
return true;
}
let content = "Tool canceled by user".into();
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name: tool_use.name.clone(),
content,
output: None,
is_error: true,
},
);
cancelled_tool_uses.push(tool_use.clone());
false
});
cancelled_tool_uses
}
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
self.pending_tool_uses_by_id.values().collect()
}
pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
return Vec::new();
};
let mut tool_uses = Vec::new();
for tool_use in tool_uses_for_message.iter() {
let tool_result = self.tool_results.get(&tool_use.id);
let status = (|| {
if let Some(tool_result) = tool_result {
let content = tool_result
.content
.to_str()
.map(|str| str.to_owned().into())
.unwrap_or_default();
return if tool_result.is_error {
ToolUseStatus::Error(content)
} else {
ToolUseStatus::Finished(content)
};
}
if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
match pending_tool_use.status {
PendingToolUseStatus::Idle => ToolUseStatus::Pending,
PendingToolUseStatus::NeedsConfirmation { .. } => {
ToolUseStatus::NeedsConfirmation
}
PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
PendingToolUseStatus::Error(ref err) => {
ToolUseStatus::Error(err.clone().into())
}
PendingToolUseStatus::InputStillStreaming => {
ToolUseStatus::InputStillStreaming
}
}
} else {
ToolUseStatus::Pending
}
})();
let (icon, needs_confirmation) =
if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
(tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
} else {
(IconName::Cog, false)
};
tool_uses.push(ToolUse {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
ui_text: self.tool_ui_label(
&tool_use.name,
&tool_use.input,
tool_use.is_input_complete,
cx,
),
input: tool_use.input.clone(),
status,
icon,
needs_confirmation,
})
}
tool_uses
}
pub fn tool_ui_label(
&self,
tool_name: &str,
input: &serde_json::Value,
is_input_complete: bool,
cx: &App,
) -> SharedString {
if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
if is_input_complete {
tool.ui_text(input).into()
} else {
tool.still_streaming_ui_text(input).into()
}
} else {
format!("Unknown tool {tool_name:?}").into()
}
}
pub fn tool_results_for_message(
&self,
assistant_message_id: MessageId,
) -> Vec<&LanguageModelToolResult> {
let Some(tool_uses) = self
.tool_uses_by_assistant_message
.get(&assistant_message_id)
else {
return Vec::new();
};
tool_uses
.iter()
.filter_map(|tool_use| self.tool_results.get(&tool_use.id))
.collect()
}
pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
self.tool_uses_by_assistant_message
.get(&assistant_message_id)
.map_or(false, |results| !results.is_empty())
}
pub fn tool_result(
&self,
tool_use_id: &LanguageModelToolUseId,
) -> Option<&LanguageModelToolResult> {
self.tool_results.get(tool_use_id)
}
pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
self.tool_result_cards.get(tool_use_id)
}
pub fn insert_tool_result_card(
&mut self,
tool_use_id: LanguageModelToolUseId,
card: AnyToolCard,
) {
self.tool_result_cards.insert(tool_use_id, card);
}
pub fn request_tool_use(
&mut self,
assistant_message_id: MessageId,
tool_use: LanguageModelToolUse,
metadata: ToolUseMetadata,
cx: &App,
) -> Arc<str> {
let tool_uses = self
.tool_uses_by_assistant_message
.entry(assistant_message_id)
.or_default();
let mut existing_tool_use_found = false;
for existing_tool_use in tool_uses.iter_mut() {
if existing_tool_use.id == tool_use.id {
*existing_tool_use = tool_use.clone();
existing_tool_use_found = true;
}
}
if !existing_tool_use_found {
tool_uses.push(tool_use.clone());
}
let status = if tool_use.is_input_complete {
self.tool_use_metadata_by_id
.insert(tool_use.id.clone(), metadata);
PendingToolUseStatus::Idle
} else {
PendingToolUseStatus::InputStillStreaming
};
let ui_text: Arc<str> = self
.tool_ui_label(
&tool_use.name,
&tool_use.input,
tool_use.is_input_complete,
cx,
)
.into();
let may_perform_edits = self
.tools
.read(cx)
.tool(&tool_use.name, cx)
.is_some_and(|tool| tool.may_perform_edits());
self.pending_tool_uses_by_id.insert(
tool_use.id.clone(),
PendingToolUse {
assistant_message_id,
id: tool_use.id,
name: tool_use.name.clone(),
ui_text: ui_text.clone(),
input: tool_use.input,
may_perform_edits,
status,
},
);
ui_text
}
pub fn run_pending_tool(
&mut self,
tool_use_id: LanguageModelToolUseId,
ui_text: SharedString,
task: Task<()>,
) {
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
tool_use.ui_text = ui_text.into();
tool_use.status = PendingToolUseStatus::Running {
_task: task.shared(),
};
}
}
pub fn confirm_tool_use(
&mut self,
tool_use_id: LanguageModelToolUseId,
ui_text: impl Into<Arc<str>>,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
tool: Arc<dyn Tool>,
) {
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
let ui_text = ui_text.into();
tool_use.ui_text = ui_text.clone();
let confirmation = Confirmation {
tool_use_id,
input,
request,
tool,
ui_text,
};
tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
}
}
pub fn insert_tool_output(
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
output: Result<ToolResultOutput>,
configured_model: Option<&ConfiguredModel>,
) -> Option<PendingToolUse> {
let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
telemetry::event!(
"Agent Tool Finished",
model = metadata
.as_ref()
.map(|metadata| metadata.model.telemetry_id()),
model_provider = metadata
.as_ref()
.map(|metadata| metadata.model.provider_id().to_string()),
thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
tool_name,
success = output.is_ok()
);
match output {
Ok(output) => {
let tool_result = output.content;
const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
// Protect from overly large output
let tool_output_limit = configured_model
.map(|model| model.model.max_token_count() as usize * BYTES_PER_TOKEN_ESTIMATE)
.unwrap_or(usize::MAX);
let content = match tool_result {
ToolResultContent::Text(text) => {
let text = if text.len() < tool_output_limit {
text
} else {
let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
format!(
"Tool result too long. The first {} bytes:\n\n{}",
truncated.len(),
truncated
)
};
LanguageModelToolResultContent::Text(text.into())
}
ToolResultContent::Image(language_model_image) => {
if language_model_image.estimate_tokens() < tool_output_limit {
LanguageModelToolResultContent::Image(language_model_image)
} else {
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name,
content: "Tool responded with an image that would exceeded the remaining tokens".into(),
is_error: true,
output: None,
},
);
return old_use;
}
}
};
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name,
content,
is_error: false,
output: output.output,
},
);
old_use
}
Err(err) => {
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name,
content: LanguageModelToolResultContent::Text(err.to_string().into()),
is_error: true,
output: None,
},
);
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
}
self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
}
}
}
pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
self.tool_uses_by_assistant_message
.contains_key(&assistant_message_id)
}
pub fn tool_results(
&self,
assistant_message_id: MessageId,
) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
self.tool_uses_by_assistant_message
.get(&assistant_message_id)
.into_iter()
.flatten()
.map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
}
}
#[derive(Debug, Clone)]
pub struct PendingToolUse {
pub id: LanguageModelToolUseId,
/// The ID of the Assistant message in which the tool use was requested.
#[allow(unused)]
pub assistant_message_id: MessageId,
pub name: Arc<str>,
pub ui_text: Arc<str>,
pub input: serde_json::Value,
pub status: PendingToolUseStatus,
pub may_perform_edits: bool,
}
#[derive(Debug, Clone)]
pub struct Confirmation {
pub tool_use_id: LanguageModelToolUseId,
pub input: serde_json::Value,
pub ui_text: Arc<str>,
pub request: Arc<LanguageModelRequest>,
pub tool: Arc<dyn Tool>,
}
#[derive(Debug, Clone)]
pub enum PendingToolUseStatus {
InputStillStreaming,
Idle,
NeedsConfirmation(Arc<Confirmation>),
Running { _task: Shared<Task<()>> },
Error(#[allow(unused)] Arc<str>),
}
impl PendingToolUseStatus {
pub fn is_idle(&self) -> bool {
matches!(self, PendingToolUseStatus::Idle)
}
pub fn is_error(&self) -> bool {
matches!(self, PendingToolUseStatus::Error(_))
}
pub fn needs_confirmation(&self) -> bool {
matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
}
}
#[derive(Clone)]
pub struct ToolUseMetadata {
pub model: Arc<dyn LanguageModel>,
pub thread_id: ThreadId,
pub prompt_id: PromptId,
}

View File

@@ -0,0 +1 @@
pub struct ZedAgentThread {}

File diff suppressed because it is too large Load Diff

View File

@@ -211,7 +211,7 @@ impl AgentDiffPane {
} }
fn update_title(&mut self, cx: &mut Context<Self>) { fn update_title(&mut self, cx: &mut Context<Self>) {
let new_title = self.thread.read(cx).summary().unwrap_or("Agent Changes"); let new_title = self.thread.read(cx).title().unwrap_or("Agent Changes");
if new_title != self.title { if new_title != self.title {
self.title = new_title; self.title = new_title;
cx.emit(EditorEvent::TitleChanged); cx.emit(EditorEvent::TitleChanged);
@@ -461,7 +461,7 @@ impl Item for AgentDiffPane {
} }
fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement { fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement {
let summary = self.thread.read(cx).summary().unwrap_or("Agent Changes"); let summary = self.thread.read(cx).title().unwrap_or("Agent Changes");
Label::new(format!("Review: {}", summary)) Label::new(format!("Review: {}", summary))
.color(if params.selected { .color(if params.selected {
Color::Default Color::Default
@@ -1369,8 +1369,6 @@ impl AgentDiff {
| ThreadEvent::MessageDeleted(_) | ThreadEvent::MessageDeleted(_)
| ThreadEvent::SummaryGenerated | ThreadEvent::SummaryGenerated
| ThreadEvent::SummaryChanged | ThreadEvent::SummaryChanged
| ThreadEvent::UsePendingTools { .. }
| ThreadEvent::ToolFinished { .. }
| ThreadEvent::CheckpointChanged | ThreadEvent::CheckpointChanged
| ThreadEvent::ToolConfirmationNeeded | ThreadEvent::ToolConfirmationNeeded
| ThreadEvent::ToolUseLimitReached | ThreadEvent::ToolUseLimitReached
@@ -1801,7 +1799,10 @@ mod tests {
}) })
.await .await
.unwrap(); .unwrap();
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); let thread = thread_store
.update(cx, |store, cx| store.create_thread(cx))
.await
.unwrap();
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
let (workspace, cx) = let (workspace, cx) =
@@ -1966,7 +1967,10 @@ mod tests {
}) })
.await .await
.unwrap(); .unwrap();
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); let thread = thread_store
.update(cx, |store, cx| store.create_thread(cx))
.await
.unwrap();
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
let (workspace, cx) = let (workspace, cx) =

View File

@@ -45,7 +45,7 @@ impl AgentModelSelector {
let registry = LanguageModelRegistry::read_global(cx); let registry = LanguageModelRegistry::read_global(cx);
if let Some(provider) = registry.provider(&model.provider_id()) if let Some(provider) = registry.provider(&model.provider_id())
{ {
thread.set_configured_model( thread.set_model(
Some(ConfiguredModel { Some(ConfiguredModel {
provider, provider,
model: model.clone(), model: model.clone(),

View File

@@ -26,7 +26,7 @@ use crate::{
ui::AgentOnboardingModal, ui::AgentOnboardingModal,
}; };
use agent::{ use agent::{
Thread, ThreadError, ThreadEvent, ThreadId, ThreadSummary, TokenUsageRatio, Thread, ThreadError, ThreadEvent, ThreadId, ThreadTitle, TokenUsageRatio,
context_store::ContextStore, context_store::ContextStore,
history_store::{HistoryEntryId, HistoryStore}, history_store::{HistoryEntryId, HistoryStore},
thread_store::{TextThreadStore, ThreadStore}, thread_store::{TextThreadStore, ThreadStore},
@@ -72,7 +72,7 @@ use zed_actions::{
agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding}, agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding},
assistant::{OpenRulesLibrary, ToggleFocus}, assistant::{OpenRulesLibrary, ToggleFocus},
}; };
use zed_llm_client::{CompletionIntent, UsageLimit}; use zed_llm_client::UsageLimit;
const AGENT_PANEL_KEY: &str = "agent_panel"; const AGENT_PANEL_KEY: &str = "agent_panel";
@@ -252,7 +252,7 @@ impl ActiveView {
thread.update(cx, |thread, cx| { thread.update(cx, |thread, cx| {
thread.thread().update(cx, |thread, cx| { thread.thread().update(cx, |thread, cx| {
thread.set_summary(new_summary, cx); thread.set_title(new_summary, cx);
}); });
}) })
} }
@@ -278,7 +278,7 @@ impl ActiveView {
let editor = editor.clone(); let editor = editor.clone();
move |_, thread, event, window, cx| match event { move |_, thread, event, window, cx| match event {
ThreadEvent::SummaryGenerated => { ThreadEvent::SummaryGenerated => {
let summary = thread.read(cx).summary().or_default(); let summary = thread.read(cx).title().or_default();
editor.update(cx, |editor, cx| { editor.update(cx, |editor, cx| {
editor.set_text(summary, window, cx); editor.set_text(summary, window, cx);
@@ -492,10 +492,15 @@ impl AgentPanel {
None None
}; };
let thread = thread_store
.update(cx, |this, cx| this.create_thread(cx))?
.await?;
let panel = workspace.update_in(cx, |workspace, window, cx| { let panel = workspace.update_in(cx, |workspace, window, cx| {
let panel = cx.new(|cx| { let panel = cx.new(|cx| {
Self::new( Self::new(
workspace, workspace,
thread,
thread_store, thread_store,
context_store, context_store,
prompt_store, prompt_store,
@@ -518,13 +523,13 @@ impl AgentPanel {
fn new( fn new(
workspace: &Workspace, workspace: &Workspace,
thread: Entity<Thread>,
thread_store: Entity<ThreadStore>, thread_store: Entity<ThreadStore>,
context_store: Entity<TextThreadStore>, context_store: Entity<TextThreadStore>,
prompt_store: Option<Entity<PromptStore>>, prompt_store: Option<Entity<PromptStore>>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let thread = thread_store.update(cx, |this, cx| this.create_thread(cx));
let fs = workspace.app_state().fs.clone(); let fs = workspace.app_state().fs.clone();
let user_store = workspace.app_state().user_store.clone(); let user_store = workspace.app_state().user_store.clone();
let project = workspace.project(); let project = workspace.project();
@@ -647,11 +652,12 @@ impl AgentPanel {
|this, _, event: &language_model::Event, cx| match event { |this, _, event: &language_model::Event, cx| match event {
language_model::Event::DefaultModelChanged => match &this.active_view { language_model::Event::DefaultModelChanged => match &this.active_view {
ActiveView::Thread { thread, .. } => { ActiveView::Thread { thread, .. } => {
thread // todo!(do we need this?);
.read(cx) // thread
.thread() // .read(cx)
.clone() // .thread()
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx)); // .clone()
// .update(cx, |thread, cx| thread.get_or_init_configured_model(cx));
} }
ActiveView::TextThread { .. } ActiveView::TextThread { .. }
| ActiveView::History | ActiveView::History
@@ -784,46 +790,61 @@ impl AgentPanel {
.detach_and_log_err(cx); .detach_and_log_err(cx);
} }
let active_thread = cx.new(|cx| { let fs = self.fs.clone();
ActiveThread::new( let user_store = self.user_store.clone();
thread.clone(), let thread_store = self.thread_store.clone();
self.thread_store.clone(), let text_thread_store = self.context_store.clone();
self.context_store.clone(), let prompt_store = self.prompt_store.clone();
context_store.clone(), let language_registry = self.language_registry.clone();
self.language_registry.clone(), let workspace = self.workspace.clone();
self.workspace.clone(),
window,
cx,
)
});
let message_editor = cx.new(|cx| { cx.spawn_in(window, async move |this, cx| {
MessageEditor::new( let thread = thread.await?;
self.fs.clone(), let active_thread = cx.new_window_entity(|window, cx| {
self.workspace.clone(), ActiveThread::new(
self.user_store.clone(), thread.clone(),
context_store.clone(), thread_store.clone(),
self.prompt_store.clone(), text_thread_store.clone(),
self.thread_store.downgrade(), context_store.clone(),
self.context_store.downgrade(), language_registry.clone(),
thread.clone(), workspace.clone(),
window, window,
cx, cx,
) )
}); })?;
if let Some(text) = preserved_text { let message_editor = cx.new_window_entity(|window, cx| {
message_editor.update(cx, |editor, cx| { MessageEditor::new(
editor.set_text(text, window, cx); fs.clone(),
}); workspace.clone(),
} user_store.clone(),
context_store.clone(),
prompt_store.clone(),
thread_store.downgrade(),
text_thread_store.downgrade(),
thread.clone(),
window,
cx,
)
})?;
message_editor.focus_handle(cx).focus(window); if let Some(text) = preserved_text {
message_editor.update_in(cx, |editor, window, cx| {
editor.set_text(text, window, cx);
});
}
let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx); this.update_in(cx, |this, window, cx| {
self.set_active_view(thread_view, window, cx); message_editor.focus_handle(cx).focus(window);
AgentDiff::set_active_thread(&self.workspace, &thread, window, cx); let thread_view =
ActiveView::thread(active_thread.clone(), message_editor, window, cx);
this.set_active_view(thread_view, window, cx);
AgentDiff::set_active_thread(&this.workspace, &thread, window, cx);
})
})
.detach_and_log_err(cx);
} }
fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context<Self>) { fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -1254,23 +1275,11 @@ impl AgentPanel {
return; return;
} }
let model = thread_state.configured_model().map(|cm| cm.model.clone()); thread.update(cx, |active_thread, cx| {
if let Some(model) = model { active_thread
thread.update(cx, |active_thread, cx| { .thread()
active_thread.thread().update(cx, |thread, cx| { .update(cx, |thread, cx| thread.resume(window, cx))
thread.insert_invisible_continue_message(cx); });
thread.advance_prompt_id();
thread.send_to_model(
model,
CompletionIntent::UserPrompt,
Some(window.window_handle()),
cx,
);
});
});
} else {
log::warn!("No configured model available for continuation");
}
} }
fn toggle_burn_mode( fn toggle_burn_mode(
@@ -1552,24 +1561,24 @@ impl AgentPanel {
let state = { let state = {
let active_thread = active_thread.read(cx); let active_thread = active_thread.read(cx);
if active_thread.is_empty() { if active_thread.is_empty() {
&ThreadSummary::Pending &ThreadTitle::Pending
} else { } else {
active_thread.summary(cx) active_thread.summary(cx)
} }
}; };
match state { match state {
ThreadSummary::Pending => Label::new(ThreadSummary::DEFAULT.clone()) ThreadTitle::Pending => Label::new(ThreadTitle::DEFAULT.clone())
.truncate() .truncate()
.into_any_element(), .into_any_element(),
ThreadSummary::Generating => Label::new(LOADING_SUMMARY_PLACEHOLDER) ThreadTitle::Generating => Label::new(LOADING_SUMMARY_PLACEHOLDER)
.truncate() .truncate()
.into_any_element(), .into_any_element(),
ThreadSummary::Ready(_) => div() ThreadTitle::Ready(_) => div()
.w_full() .w_full()
.child(change_title_editor.clone()) .child(change_title_editor.clone())
.into_any_element(), .into_any_element(),
ThreadSummary::Error => h_flex() ThreadTitle::Error => h_flex()
.w_full() .w_full()
.child(change_title_editor.clone()) .child(change_title_editor.clone())
.child( .child(
@@ -2024,7 +2033,7 @@ impl AgentPanel {
.read(cx) .read(cx)
.thread() .thread()
.read(cx) .read(cx)
.configured_model() .model()
.map_or(false, |model| { .map_or(false, |model| {
model.provider.id().0 == ZED_CLOUD_PROVIDER_ID model.provider.id().0 == ZED_CLOUD_PROVIDER_ID
}); });
@@ -2629,7 +2638,7 @@ impl AgentPanel {
return None; return None;
} }
let model = thread.configured_model()?.model; let model = thread.model()?.model;
let focus_handle = self.focus_handle(cx); let focus_handle = self.focus_handle(cx);

View File

@@ -121,7 +121,7 @@ pub(crate) enum ModelUsageContext {
impl ModelUsageContext { impl ModelUsageContext {
pub fn configured_model(&self, cx: &App) -> Option<ConfiguredModel> { pub fn configured_model(&self, cx: &App) -> Option<ConfiguredModel> {
match self { match self {
Self::Thread(thread) => thread.read(cx).configured_model(), Self::Thread(thread) => thread.read(cx).model(),
Self::InlineAssistant => { Self::InlineAssistant => {
LanguageModelRegistry::read_global(cx).inline_assistant_model() LanguageModelRegistry::read_global(cx).inline_assistant_model()
} }

View File

@@ -670,7 +670,7 @@ fn recent_context_picker_entries(
let mut threads = unordered_thread_entries(thread_store, text_thread_store, cx) let mut threads = unordered_thread_entries(thread_store, text_thread_store, cx)
.filter(|(_, thread)| match thread { .filter(|(_, thread)| match thread {
ThreadContextEntry::Thread { id, .. } => { ThreadContextEntry::Thread { id, .. } => {
Some(id) != active_thread_id && !current_threads.contains(id) Some(id) != active_thread_id.as_ref() && !current_threads.contains(id)
} }
ThreadContextEntry::Context { .. } => true, ThreadContextEntry::Context { .. } => true,
}) })

View File

@@ -169,13 +169,13 @@ impl ContextStrip {
if self if self
.context_store .context_store
.read(cx) .read(cx)
.includes_thread(active_thread.id()) .includes_thread(&active_thread.id())
{ {
return None; return None;
} }
Some(SuggestedContext::Thread { Some(SuggestedContext::Thread {
name: active_thread.summary().or_default(), name: active_thread.title().or_default(),
thread: weak_active_thread, thread: weak_active_thread,
}) })
} else if let Some(active_context_editor) = panel.active_context_editor() { } else if let Some(active_context_editor) = panel.active_context_editor() {

View File

@@ -387,25 +387,14 @@ impl MessageEditor {
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.insert_user_message( todo!();
user_message, // thread.send(
loaded_context, // user_message,
checkpoint.ok(), // loaded_context,
user_message_creases, // checkpoint.ok(),
cx, // user_message_creases,
); // cx,
}) // );
.log_err();
thread
.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.send_to_model(
model,
CompletionIntent::UserPrompt,
Some(window_handle),
cx,
);
}) })
.log_err(); .log_err();
}) })

View File

@@ -156,7 +156,7 @@ impl Render for ProfileSelector {
.map(|profile| profile.name.clone()) .map(|profile| profile.name.clone())
.unwrap_or_else(|| "Unknown".into()); .unwrap_or_else(|| "Unknown".into());
let configured_model = self.thread.read(cx).configured_model().or_else(|| { let configured_model = self.thread.read(cx).model().or_else(|| {
let model_registry = LanguageModelRegistry::read_global(cx); let model_registry = LanguageModelRegistry::read_global(cx);
model_registry.default_model() model_registry.default_model()
}); });

View File

@@ -498,7 +498,7 @@ impl AddedContext {
render_hover: { render_hover: {
let thread = handle.thread.clone(); let thread = handle.thread.clone();
Some(Rc::new(move |_, cx| { Some(Rc::new(move |_, cx| {
let text = thread.read(cx).latest_detailed_summary_or_text(); let text = thread.read(cx).latest_detailed_summary_or_text(cx);
ContextPillHover::new_text(text.clone(), cx).into() ContextPillHover::new_text(text.clone(), cx).into()
})) }))
}, },