Compare commits

...

4 Commits

Author SHA1 Message Date
Antonio Scandurra
de8f1c3c84 WIP 2025-06-30 21:01:37 +02:00
Antonio Scandurra
9351e959a5 WIP 2025-06-30 19:32:42 +02:00
Antonio Scandurra
dac0838a80 WIP 2025-06-30 14:24:24 +02:00
Antonio Scandurra
38d5f36d38 WIP 2025-06-30 12:15:19 +02:00
20 changed files with 3977 additions and 6505 deletions

1
Cargo.lock generated
View File

@@ -78,6 +78,7 @@ dependencies = [
"language",
"language_model",
"log",
"markdown",
"parking_lot",
"paths",
"postage",

View File

@@ -42,6 +42,7 @@ itertools.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
markdown.workspace = true
paths.workspace = true
postage.workspace = true
project.workspace = true

View File

@@ -1,3 +1,4 @@
mod agent2;
pub mod agent_profile;
pub mod context;
pub mod context_server_tool;
@@ -5,15 +6,17 @@ pub mod context_store;
pub mod history_store;
pub mod thread;
pub mod thread_store;
pub mod tool_use;
mod zed_agent;
pub use agent2::*;
pub use context::{AgentContext, ContextId, ContextLoadResult};
pub use context_store::ContextStore;
pub use thread::{
LastRestoreCheckpoint, Message, MessageCrease, MessageId, MessageSegment, Thread, ThreadError,
ThreadEvent, ThreadFeedback, ThreadId, ThreadSummary, TokenUsageRatio,
LastRestoreCheckpoint, Message, MessageCrease, MessageSegment, Thread, ThreadError,
ThreadEvent, ThreadFeedback, ThreadTitle, TokenUsageRatio,
};
pub use thread_store::{SerializedThread, TextThreadStore, ThreadStore};
pub use zed_agent::*;
pub fn init(cx: &mut gpui::App) {
thread_store::init(cx);

117
crates/agent/src/agent2.rs Normal file
View File

@@ -0,0 +1,117 @@
use anyhow::Result;
use assistant_tool::{Tool, ToolResultOutput};
use futures::{channel::oneshot, future::BoxFuture, stream::BoxStream};
use gpui::SharedString;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{
fmt::{self, Display},
sync::Arc,
};
#[derive(
Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, JsonSchema,
)]
pub struct ThreadId(SharedString);
impl ThreadId {
pub fn as_str(&self) -> &str {
&self.0
}
pub fn to_string(&self) -> String {
self.0.to_string()
}
}
impl From<&str> for ThreadId {
fn from(value: &str) -> Self {
ThreadId(SharedString::from(value.to_string()))
}
}
impl Display for ThreadId {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct MessageId(pub usize);
#[derive(Debug, Clone)]
pub struct AgentThreadToolCallId(SharedString);
pub enum AgentThreadResponseEvent {
Text(String),
Thinking(String),
ToolCallChunk {
id: AgentThreadToolCallId,
tool: Arc<dyn Tool>,
input: serde_json::Value,
},
ToolCall {
id: AgentThreadToolCallId,
tool: Arc<dyn Tool>,
input: serde_json::Value,
response_tx: oneshot::Sender<Result<ToolResultOutput>>,
},
}
pub enum AgentThreadMessage {
User {
id: MessageId,
chunks: Vec<AgentThreadUserMessageChunk>,
},
Assistant {
id: MessageId,
chunks: Vec<AgentThreadAssistantMessageChunk>,
},
}
pub enum AgentThreadUserMessageChunk {
Text(String),
// here's where we would put mentions, etc.
}
pub enum AgentThreadAssistantMessageChunk {
Text(String),
Thinking(String),
ToolResult {
id: AgentThreadToolCallId,
tool: Arc<dyn Tool>,
input: serde_json::Value,
output: Result<ToolResultOutput>,
},
}
pub struct AgentThreadResponse {
pub user_message_id: MessageId,
pub assistant_message_id: MessageId,
pub events: BoxStream<'static, Result<AgentThreadResponseEvent>>,
}
pub trait Agent {
fn create_thread() -> BoxFuture<'static, Result<Arc<dyn AgentThread>>>;
fn list_threads();
fn load_thread(&self, thread_id: ThreadId) -> BoxFuture<'static, Result<Arc<dyn AgentThread>>>;
}
pub trait AgentThread {
fn id(&self) -> ThreadId;
fn title(&self) -> BoxFuture<'static, Result<String>>;
fn summary(&self) -> BoxFuture<'static, Result<String>>;
fn messages(&self) -> BoxFuture<'static, Result<Vec<AgentThreadMessage>>>;
fn truncate(&self, message_id: MessageId) -> BoxFuture<'static, Result<()>>;
fn edit(
&self,
message_id: MessageId,
content: Vec<AgentThreadUserMessageChunk>,
max_iterations: usize,
) -> BoxFuture<'static, Result<AgentThreadResponse>>;
fn send(
&self,
content: Vec<AgentThreadUserMessageChunk>,
max_iterations: usize,
) -> BoxFuture<'static, Result<AgentThreadResponse>>;
}

View File

@@ -581,7 +581,7 @@ impl ThreadContextHandle {
}
pub fn title(&self, cx: &App) -> SharedString {
self.thread.read(cx).summary().or_default()
self.thread.read(cx).title().or_default()
}
fn load(self, cx: &App) -> Task<Option<(AgentContext, Vec<Entity<Buffer>>)>> {
@@ -589,7 +589,7 @@ impl ThreadContextHandle {
let text = Thread::wait_for_detailed_summary_or_text(&self.thread, cx).await?;
let title = self
.thread
.read_with(cx, |thread, _cx| thread.summary().or_default())
.read_with(cx, |thread, _cx| thread.title().or_default())
.ok()?;
let context = AgentContext::Thread(ThreadContext {
title,

View File

@@ -1,10 +1,11 @@
use crate::{
MessageId, ThreadId,
context::{
AgentContextHandle, AgentContextKey, ContextId, ContextKind, DirectoryContextHandle,
FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle,
SelectionContextHandle, SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle,
},
thread::{MessageId, Thread, ThreadId},
thread::Thread,
thread_store::ThreadStore,
};
use anyhow::{Context as _, Result, anyhow};
@@ -71,7 +72,8 @@ impl ContextStore {
) -> Vec<AgentContextHandle> {
let existing_context = thread
.messages()
.take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
.iter()
.take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id() != id))
.flat_map(|message| {
message
.loaded_context
@@ -441,7 +443,7 @@ impl ContextStore {
match context {
AgentContextHandle::Thread(thread_context) => {
self.context_thread_ids
.remove(thread_context.thread.read(cx).id());
.remove(&thread_context.thread.read(cx).id());
}
AgentContextHandle::TextThread(text_thread_context) => {
if let Some(path) = text_thread_context.context.read(cx).path() {

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,7 @@
use crate::{
MessageId, ThreadId,
context_server_tool::ContextServerTool,
thread::{
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
},
thread::{DetailedSummaryState, ExceededWindowError, ProjectSnapshot, Thread},
};
use agent_settings::{AgentProfileId, CompletionMode};
use anyhow::{Context as _, Result, anyhow};
@@ -400,35 +399,17 @@ impl ThreadStore {
self.threads.iter()
}
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
cx.new(|cx| {
Thread::new(
self.project.clone(),
self.tools.clone(),
self.prompt_builder.clone(),
self.project_context.clone(),
cx,
)
})
}
pub fn create_thread_from_serialized(
&mut self,
serialized: SerializedThread,
cx: &mut Context<Self>,
) -> Entity<Thread> {
cx.new(|cx| {
Thread::deserialize(
ThreadId::new(),
serialized,
self.project.clone(),
self.tools.clone(),
self.prompt_builder.clone(),
self.project_context.clone(),
None,
cx,
)
})
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread>>> {
todo!()
// cx.new(|cx| {
// Thread::new(
// self.project.clone(),
// self.tools.clone(),
// self.prompt_builder.clone(),
// self.project_context.clone(),
// cx,
// )
// })
}
pub fn open_thread(
@@ -447,51 +428,53 @@ impl ThreadStore {
.await?
.with_context(|| format!("no thread found with ID: {id:?}"))?;
let thread = this.update_in(cx, |this, window, cx| {
cx.new(|cx| {
Thread::deserialize(
id.clone(),
thread,
this.project.clone(),
this.tools.clone(),
this.prompt_builder.clone(),
this.project_context.clone(),
Some(window),
cx,
)
})
})?;
Ok(thread)
todo!();
// let thread = this.update_in(cx, |this, window, cx| {
// cx.new(|cx| {
// Thread::deserialize(
// id.clone(),
// thread,
// this.project.clone(),
// this.tools.clone(),
// this.prompt_builder.clone(),
// this.project_context.clone(),
// Some(window),
// cx,
// )
// })
// })?;
// Ok(thread)
})
}
pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
let (metadata, serialized_thread) =
thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
todo!()
// let (metadata, serialized_thread) =
// thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
let database_future = ThreadsDatabase::global_future(cx);
cx.spawn(async move |this, cx| {
let serialized_thread = serialized_thread.await?;
let database = database_future.await.map_err(|err| anyhow!(err))?;
database.save_thread(metadata, serialized_thread).await?;
// let database_future = ThreadsDatabase::global_future(cx);
// cx.spawn(async move |this, cx| {
// let serialized_thread = serialized_thread.await?;
// let database = database_future.await.map_err(|err| anyhow!(err))?;
// database.save_thread(metadata, serialized_thread).await?;
this.update(cx, |this, cx| this.reload(cx))?.await
})
// this.update(cx, |this, cx| this.reload(cx))?.await
// })
}
pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
let id = id.clone();
let database_future = ThreadsDatabase::global_future(cx);
cx.spawn(async move |this, cx| {
let database = database_future.await.map_err(|err| anyhow!(err))?;
database.delete_thread(id.clone()).await?;
todo!()
// let id = id.clone();
// let database_future = ThreadsDatabase::global_future(cx);
// cx.spawn(async move |this, cx| {
// let database = database_future.await.map_err(|err| anyhow!(err))?;
// database.delete_thread(id.clone()).await?;
this.update(cx, |this, cx| {
this.threads.retain(|thread| thread.id != id);
cx.notify();
})
})
// this.update(cx, |this, cx| {
// this.threads.retain(|thread| thread.id != id);
// cx.notify();
// })
// })
}
pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
@@ -1067,7 +1050,7 @@ impl ThreadsDatabase {
#[cfg(test)]
mod tests {
use super::*;
use crate::thread::{DetailedSummaryState, MessageId};
use crate::{MessageId, thread::DetailedSummaryState};
use chrono::Utc;
use language_model::{Role, TokenUsage};
use pretty_assertions::assert_eq;

View File

@@ -1,567 +0,0 @@
use crate::{
thread::{MessageId, PromptId, ThreadId},
thread_store::SerializedMessage,
};
use anyhow::Result;
use assistant_tool::{
AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
};
use collections::HashMap;
use futures::{FutureExt as _, future::Shared};
use gpui::{App, Entity, SharedString, Task, Window};
use icons::IconName;
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
};
use project::Project;
use std::sync::Arc;
use util::truncate_lines_to_byte_limit;
#[derive(Debug)]
pub struct ToolUse {
pub id: LanguageModelToolUseId,
pub name: SharedString,
pub ui_text: SharedString,
pub status: ToolUseStatus,
pub input: serde_json::Value,
pub icon: icons::IconName,
pub needs_confirmation: bool,
}
pub struct ToolUseState {
tools: Entity<ToolWorkingSet>,
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
}
impl ToolUseState {
pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
Self {
tools,
tool_uses_by_assistant_message: HashMap::default(),
tool_results: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(),
tool_result_cards: HashMap::default(),
tool_use_metadata_by_id: HashMap::default(),
}
}
/// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
///
/// Accepts a function to filter the tools that should be used to populate the state.
///
/// If `window` is `None` (e.g., when in headless mode or when running evals),
/// tool cards won't be deserialized
pub fn from_serialized_messages(
tools: Entity<ToolWorkingSet>,
messages: &[SerializedMessage],
project: Entity<Project>,
window: Option<&mut Window>, // None in headless mode
cx: &mut App,
) -> Self {
let mut this = Self::new(tools);
let mut tool_names_by_id = HashMap::default();
let mut window = window;
for message in messages {
match message.role {
Role::Assistant => {
if !message.tool_uses.is_empty() {
let tool_uses = message
.tool_uses
.iter()
.map(|tool_use| LanguageModelToolUse {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
raw_input: tool_use.input.to_string(),
input: tool_use.input.clone(),
is_input_complete: true,
})
.collect::<Vec<_>>();
tool_names_by_id.extend(
tool_uses
.iter()
.map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
);
this.tool_uses_by_assistant_message
.insert(message.id, tool_uses);
for tool_result in &message.tool_results {
let tool_use_id = tool_result.tool_use_id.clone();
let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
log::warn!("no tool name found for tool use: {tool_use_id:?}");
continue;
};
this.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name: tool_use.clone(),
is_error: tool_result.is_error,
content: tool_result.content.clone(),
output: tool_result.output.clone(),
},
);
if let Some(window) = &mut window {
if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
if let Some(output) = tool_result.output.clone() {
if let Some(card) = tool.deserialize_card(
output,
project.clone(),
window,
cx,
) {
this.tool_result_cards.insert(tool_use_id, card);
}
}
}
}
}
}
}
Role::System | Role::User => {}
}
}
this
}
pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
let mut cancelled_tool_uses = Vec::new();
self.pending_tool_uses_by_id
.retain(|tool_use_id, tool_use| {
if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
return true;
}
let content = "Tool canceled by user".into();
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name: tool_use.name.clone(),
content,
output: None,
is_error: true,
},
);
cancelled_tool_uses.push(tool_use.clone());
false
});
cancelled_tool_uses
}
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
self.pending_tool_uses_by_id.values().collect()
}
pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
return Vec::new();
};
let mut tool_uses = Vec::new();
for tool_use in tool_uses_for_message.iter() {
let tool_result = self.tool_results.get(&tool_use.id);
let status = (|| {
if let Some(tool_result) = tool_result {
let content = tool_result
.content
.to_str()
.map(|str| str.to_owned().into())
.unwrap_or_default();
return if tool_result.is_error {
ToolUseStatus::Error(content)
} else {
ToolUseStatus::Finished(content)
};
}
if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
match pending_tool_use.status {
PendingToolUseStatus::Idle => ToolUseStatus::Pending,
PendingToolUseStatus::NeedsConfirmation { .. } => {
ToolUseStatus::NeedsConfirmation
}
PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
PendingToolUseStatus::Error(ref err) => {
ToolUseStatus::Error(err.clone().into())
}
PendingToolUseStatus::InputStillStreaming => {
ToolUseStatus::InputStillStreaming
}
}
} else {
ToolUseStatus::Pending
}
})();
let (icon, needs_confirmation) =
if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
(tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
} else {
(IconName::Cog, false)
};
tool_uses.push(ToolUse {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
ui_text: self.tool_ui_label(
&tool_use.name,
&tool_use.input,
tool_use.is_input_complete,
cx,
),
input: tool_use.input.clone(),
status,
icon,
needs_confirmation,
})
}
tool_uses
}
pub fn tool_ui_label(
&self,
tool_name: &str,
input: &serde_json::Value,
is_input_complete: bool,
cx: &App,
) -> SharedString {
if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
if is_input_complete {
tool.ui_text(input).into()
} else {
tool.still_streaming_ui_text(input).into()
}
} else {
format!("Unknown tool {tool_name:?}").into()
}
}
pub fn tool_results_for_message(
&self,
assistant_message_id: MessageId,
) -> Vec<&LanguageModelToolResult> {
let Some(tool_uses) = self
.tool_uses_by_assistant_message
.get(&assistant_message_id)
else {
return Vec::new();
};
tool_uses
.iter()
.filter_map(|tool_use| self.tool_results.get(&tool_use.id))
.collect()
}
pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
self.tool_uses_by_assistant_message
.get(&assistant_message_id)
.map_or(false, |results| !results.is_empty())
}
pub fn tool_result(
&self,
tool_use_id: &LanguageModelToolUseId,
) -> Option<&LanguageModelToolResult> {
self.tool_results.get(tool_use_id)
}
pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
self.tool_result_cards.get(tool_use_id)
}
pub fn insert_tool_result_card(
&mut self,
tool_use_id: LanguageModelToolUseId,
card: AnyToolCard,
) {
self.tool_result_cards.insert(tool_use_id, card);
}
pub fn request_tool_use(
&mut self,
assistant_message_id: MessageId,
tool_use: LanguageModelToolUse,
metadata: ToolUseMetadata,
cx: &App,
) -> Arc<str> {
let tool_uses = self
.tool_uses_by_assistant_message
.entry(assistant_message_id)
.or_default();
let mut existing_tool_use_found = false;
for existing_tool_use in tool_uses.iter_mut() {
if existing_tool_use.id == tool_use.id {
*existing_tool_use = tool_use.clone();
existing_tool_use_found = true;
}
}
if !existing_tool_use_found {
tool_uses.push(tool_use.clone());
}
let status = if tool_use.is_input_complete {
self.tool_use_metadata_by_id
.insert(tool_use.id.clone(), metadata);
PendingToolUseStatus::Idle
} else {
PendingToolUseStatus::InputStillStreaming
};
let ui_text: Arc<str> = self
.tool_ui_label(
&tool_use.name,
&tool_use.input,
tool_use.is_input_complete,
cx,
)
.into();
let may_perform_edits = self
.tools
.read(cx)
.tool(&tool_use.name, cx)
.is_some_and(|tool| tool.may_perform_edits());
self.pending_tool_uses_by_id.insert(
tool_use.id.clone(),
PendingToolUse {
assistant_message_id,
id: tool_use.id,
name: tool_use.name.clone(),
ui_text: ui_text.clone(),
input: tool_use.input,
may_perform_edits,
status,
},
);
ui_text
}
pub fn run_pending_tool(
&mut self,
tool_use_id: LanguageModelToolUseId,
ui_text: SharedString,
task: Task<()>,
) {
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
tool_use.ui_text = ui_text.into();
tool_use.status = PendingToolUseStatus::Running {
_task: task.shared(),
};
}
}
pub fn confirm_tool_use(
&mut self,
tool_use_id: LanguageModelToolUseId,
ui_text: impl Into<Arc<str>>,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
tool: Arc<dyn Tool>,
) {
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
let ui_text = ui_text.into();
tool_use.ui_text = ui_text.clone();
let confirmation = Confirmation {
tool_use_id,
input,
request,
tool,
ui_text,
};
tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
}
}
pub fn insert_tool_output(
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
output: Result<ToolResultOutput>,
configured_model: Option<&ConfiguredModel>,
) -> Option<PendingToolUse> {
let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
telemetry::event!(
"Agent Tool Finished",
model = metadata
.as_ref()
.map(|metadata| metadata.model.telemetry_id()),
model_provider = metadata
.as_ref()
.map(|metadata| metadata.model.provider_id().to_string()),
thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
tool_name,
success = output.is_ok()
);
match output {
Ok(output) => {
let tool_result = output.content;
const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
// Protect from overly large output
let tool_output_limit = configured_model
.map(|model| model.model.max_token_count() as usize * BYTES_PER_TOKEN_ESTIMATE)
.unwrap_or(usize::MAX);
let content = match tool_result {
ToolResultContent::Text(text) => {
let text = if text.len() < tool_output_limit {
text
} else {
let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
format!(
"Tool result too long. The first {} bytes:\n\n{}",
truncated.len(),
truncated
)
};
LanguageModelToolResultContent::Text(text.into())
}
ToolResultContent::Image(language_model_image) => {
if language_model_image.estimate_tokens() < tool_output_limit {
LanguageModelToolResultContent::Image(language_model_image)
} else {
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name,
content: "Tool responded with an image that would exceeded the remaining tokens".into(),
is_error: true,
output: None,
},
);
return old_use;
}
}
};
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name,
content,
is_error: false,
output: output.output,
},
);
old_use
}
Err(err) => {
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name,
content: LanguageModelToolResultContent::Text(err.to_string().into()),
is_error: true,
output: None,
},
);
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
}
self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
}
}
}
pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
self.tool_uses_by_assistant_message
.contains_key(&assistant_message_id)
}
pub fn tool_results(
&self,
assistant_message_id: MessageId,
) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
self.tool_uses_by_assistant_message
.get(&assistant_message_id)
.into_iter()
.flatten()
.map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
}
}
#[derive(Debug, Clone)]
pub struct PendingToolUse {
pub id: LanguageModelToolUseId,
/// The ID of the Assistant message in which the tool use was requested.
#[allow(unused)]
pub assistant_message_id: MessageId,
pub name: Arc<str>,
pub ui_text: Arc<str>,
pub input: serde_json::Value,
pub status: PendingToolUseStatus,
pub may_perform_edits: bool,
}
#[derive(Debug, Clone)]
pub struct Confirmation {
pub tool_use_id: LanguageModelToolUseId,
pub input: serde_json::Value,
pub ui_text: Arc<str>,
pub request: Arc<LanguageModelRequest>,
pub tool: Arc<dyn Tool>,
}
#[derive(Debug, Clone)]
pub enum PendingToolUseStatus {
InputStillStreaming,
Idle,
NeedsConfirmation(Arc<Confirmation>),
Running { _task: Shared<Task<()>> },
Error(#[allow(unused)] Arc<str>),
}
impl PendingToolUseStatus {
pub fn is_idle(&self) -> bool {
matches!(self, PendingToolUseStatus::Idle)
}
pub fn is_error(&self) -> bool {
matches!(self, PendingToolUseStatus::Error(_))
}
pub fn needs_confirmation(&self) -> bool {
matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
}
}
#[derive(Clone)]
pub struct ToolUseMetadata {
pub model: Arc<dyn LanguageModel>,
pub thread_id: ThreadId,
pub prompt_id: PromptId,
}

View File

@@ -0,0 +1 @@
pub struct ZedAgentThread {}

File diff suppressed because it is too large Load Diff

View File

@@ -211,7 +211,7 @@ impl AgentDiffPane {
}
fn update_title(&mut self, cx: &mut Context<Self>) {
let new_title = self.thread.read(cx).summary().unwrap_or("Agent Changes");
let new_title = self.thread.read(cx).title().unwrap_or("Agent Changes");
if new_title != self.title {
self.title = new_title;
cx.emit(EditorEvent::TitleChanged);
@@ -461,7 +461,7 @@ impl Item for AgentDiffPane {
}
fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement {
let summary = self.thread.read(cx).summary().unwrap_or("Agent Changes");
let summary = self.thread.read(cx).title().unwrap_or("Agent Changes");
Label::new(format!("Review: {}", summary))
.color(if params.selected {
Color::Default
@@ -1369,8 +1369,6 @@ impl AgentDiff {
| ThreadEvent::MessageDeleted(_)
| ThreadEvent::SummaryGenerated
| ThreadEvent::SummaryChanged
| ThreadEvent::UsePendingTools { .. }
| ThreadEvent::ToolFinished { .. }
| ThreadEvent::CheckpointChanged
| ThreadEvent::ToolConfirmationNeeded
| ThreadEvent::ToolUseLimitReached
@@ -1801,7 +1799,10 @@ mod tests {
})
.await
.unwrap();
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let thread = thread_store
.update(cx, |store, cx| store.create_thread(cx))
.await
.unwrap();
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
let (workspace, cx) =
@@ -1966,7 +1967,10 @@ mod tests {
})
.await
.unwrap();
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let thread = thread_store
.update(cx, |store, cx| store.create_thread(cx))
.await
.unwrap();
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
let (workspace, cx) =

View File

@@ -45,7 +45,7 @@ impl AgentModelSelector {
let registry = LanguageModelRegistry::read_global(cx);
if let Some(provider) = registry.provider(&model.provider_id())
{
thread.set_configured_model(
thread.set_model(
Some(ConfiguredModel {
provider,
model: model.clone(),

View File

@@ -26,7 +26,7 @@ use crate::{
ui::AgentOnboardingModal,
};
use agent::{
Thread, ThreadError, ThreadEvent, ThreadId, ThreadSummary, TokenUsageRatio,
Thread, ThreadError, ThreadEvent, ThreadId, ThreadTitle, TokenUsageRatio,
context_store::ContextStore,
history_store::{HistoryEntryId, HistoryStore},
thread_store::{TextThreadStore, ThreadStore},
@@ -72,7 +72,7 @@ use zed_actions::{
agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding},
assistant::{OpenRulesLibrary, ToggleFocus},
};
use zed_llm_client::{CompletionIntent, UsageLimit};
use zed_llm_client::UsageLimit;
const AGENT_PANEL_KEY: &str = "agent_panel";
@@ -252,7 +252,7 @@ impl ActiveView {
thread.update(cx, |thread, cx| {
thread.thread().update(cx, |thread, cx| {
thread.set_summary(new_summary, cx);
thread.set_title(new_summary, cx);
});
})
}
@@ -278,7 +278,7 @@ impl ActiveView {
let editor = editor.clone();
move |_, thread, event, window, cx| match event {
ThreadEvent::SummaryGenerated => {
let summary = thread.read(cx).summary().or_default();
let summary = thread.read(cx).title().or_default();
editor.update(cx, |editor, cx| {
editor.set_text(summary, window, cx);
@@ -492,10 +492,15 @@ impl AgentPanel {
None
};
let thread = thread_store
.update(cx, |this, cx| this.create_thread(cx))?
.await?;
let panel = workspace.update_in(cx, |workspace, window, cx| {
let panel = cx.new(|cx| {
Self::new(
workspace,
thread,
thread_store,
context_store,
prompt_store,
@@ -518,13 +523,13 @@ impl AgentPanel {
fn new(
workspace: &Workspace,
thread: Entity<Thread>,
thread_store: Entity<ThreadStore>,
context_store: Entity<TextThreadStore>,
prompt_store: Option<Entity<PromptStore>>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let thread = thread_store.update(cx, |this, cx| this.create_thread(cx));
let fs = workspace.app_state().fs.clone();
let user_store = workspace.app_state().user_store.clone();
let project = workspace.project();
@@ -647,11 +652,12 @@ impl AgentPanel {
|this, _, event: &language_model::Event, cx| match event {
language_model::Event::DefaultModelChanged => match &this.active_view {
ActiveView::Thread { thread, .. } => {
thread
.read(cx)
.thread()
.clone()
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx));
// todo!(do we need this?);
// thread
// .read(cx)
// .thread()
// .clone()
// .update(cx, |thread, cx| thread.get_or_init_configured_model(cx));
}
ActiveView::TextThread { .. }
| ActiveView::History
@@ -784,46 +790,61 @@ impl AgentPanel {
.detach_and_log_err(cx);
}
let active_thread = cx.new(|cx| {
ActiveThread::new(
thread.clone(),
self.thread_store.clone(),
self.context_store.clone(),
context_store.clone(),
self.language_registry.clone(),
self.workspace.clone(),
window,
cx,
)
});
let fs = self.fs.clone();
let user_store = self.user_store.clone();
let thread_store = self.thread_store.clone();
let text_thread_store = self.context_store.clone();
let prompt_store = self.prompt_store.clone();
let language_registry = self.language_registry.clone();
let workspace = self.workspace.clone();
let message_editor = cx.new(|cx| {
MessageEditor::new(
self.fs.clone(),
self.workspace.clone(),
self.user_store.clone(),
context_store.clone(),
self.prompt_store.clone(),
self.thread_store.downgrade(),
self.context_store.downgrade(),
thread.clone(),
window,
cx,
)
});
cx.spawn_in(window, async move |this, cx| {
let thread = thread.await?;
let active_thread = cx.new_window_entity(|window, cx| {
ActiveThread::new(
thread.clone(),
thread_store.clone(),
text_thread_store.clone(),
context_store.clone(),
language_registry.clone(),
workspace.clone(),
window,
cx,
)
})?;
if let Some(text) = preserved_text {
message_editor.update(cx, |editor, cx| {
editor.set_text(text, window, cx);
});
}
let message_editor = cx.new_window_entity(|window, cx| {
MessageEditor::new(
fs.clone(),
workspace.clone(),
user_store.clone(),
context_store.clone(),
prompt_store.clone(),
thread_store.downgrade(),
text_thread_store.downgrade(),
thread.clone(),
window,
cx,
)
})?;
message_editor.focus_handle(cx).focus(window);
if let Some(text) = preserved_text {
message_editor.update_in(cx, |editor, window, cx| {
editor.set_text(text, window, cx);
});
}
let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx);
self.set_active_view(thread_view, window, cx);
this.update_in(cx, |this, window, cx| {
message_editor.focus_handle(cx).focus(window);
AgentDiff::set_active_thread(&self.workspace, &thread, window, cx);
let thread_view =
ActiveView::thread(active_thread.clone(), message_editor, window, cx);
this.set_active_view(thread_view, window, cx);
AgentDiff::set_active_thread(&this.workspace, &thread, window, cx);
})
})
.detach_and_log_err(cx);
}
fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -1254,23 +1275,11 @@ impl AgentPanel {
return;
}
let model = thread_state.configured_model().map(|cm| cm.model.clone());
if let Some(model) = model {
thread.update(cx, |active_thread, cx| {
active_thread.thread().update(cx, |thread, cx| {
thread.insert_invisible_continue_message(cx);
thread.advance_prompt_id();
thread.send_to_model(
model,
CompletionIntent::UserPrompt,
Some(window.window_handle()),
cx,
);
});
});
} else {
log::warn!("No configured model available for continuation");
}
thread.update(cx, |active_thread, cx| {
active_thread
.thread()
.update(cx, |thread, cx| thread.resume(window, cx))
});
}
fn toggle_burn_mode(
@@ -1552,24 +1561,24 @@ impl AgentPanel {
let state = {
let active_thread = active_thread.read(cx);
if active_thread.is_empty() {
&ThreadSummary::Pending
&ThreadTitle::Pending
} else {
active_thread.summary(cx)
}
};
match state {
ThreadSummary::Pending => Label::new(ThreadSummary::DEFAULT.clone())
ThreadTitle::Pending => Label::new(ThreadTitle::DEFAULT.clone())
.truncate()
.into_any_element(),
ThreadSummary::Generating => Label::new(LOADING_SUMMARY_PLACEHOLDER)
ThreadTitle::Generating => Label::new(LOADING_SUMMARY_PLACEHOLDER)
.truncate()
.into_any_element(),
ThreadSummary::Ready(_) => div()
ThreadTitle::Ready(_) => div()
.w_full()
.child(change_title_editor.clone())
.into_any_element(),
ThreadSummary::Error => h_flex()
ThreadTitle::Error => h_flex()
.w_full()
.child(change_title_editor.clone())
.child(
@@ -2024,7 +2033,7 @@ impl AgentPanel {
.read(cx)
.thread()
.read(cx)
.configured_model()
.model()
.map_or(false, |model| {
model.provider.id().0 == ZED_CLOUD_PROVIDER_ID
});
@@ -2629,7 +2638,7 @@ impl AgentPanel {
return None;
}
let model = thread.configured_model()?.model;
let model = thread.model()?.model;
let focus_handle = self.focus_handle(cx);

View File

@@ -121,7 +121,7 @@ pub(crate) enum ModelUsageContext {
impl ModelUsageContext {
pub fn configured_model(&self, cx: &App) -> Option<ConfiguredModel> {
match self {
Self::Thread(thread) => thread.read(cx).configured_model(),
Self::Thread(thread) => thread.read(cx).model(),
Self::InlineAssistant => {
LanguageModelRegistry::read_global(cx).inline_assistant_model()
}

View File

@@ -670,7 +670,7 @@ fn recent_context_picker_entries(
let mut threads = unordered_thread_entries(thread_store, text_thread_store, cx)
.filter(|(_, thread)| match thread {
ThreadContextEntry::Thread { id, .. } => {
Some(id) != active_thread_id && !current_threads.contains(id)
Some(id) != active_thread_id.as_ref() && !current_threads.contains(id)
}
ThreadContextEntry::Context { .. } => true,
})

View File

@@ -169,13 +169,13 @@ impl ContextStrip {
if self
.context_store
.read(cx)
.includes_thread(active_thread.id())
.includes_thread(&active_thread.id())
{
return None;
}
Some(SuggestedContext::Thread {
name: active_thread.summary().or_default(),
name: active_thread.title().or_default(),
thread: weak_active_thread,
})
} else if let Some(active_context_editor) = panel.active_context_editor() {

View File

@@ -387,25 +387,14 @@ impl MessageEditor {
thread
.update(cx, |thread, cx| {
thread.insert_user_message(
user_message,
loaded_context,
checkpoint.ok(),
user_message_creases,
cx,
);
})
.log_err();
thread
.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.send_to_model(
model,
CompletionIntent::UserPrompt,
Some(window_handle),
cx,
);
todo!();
// thread.send(
// user_message,
// loaded_context,
// checkpoint.ok(),
// user_message_creases,
// cx,
// );
})
.log_err();
})

View File

@@ -156,7 +156,7 @@ impl Render for ProfileSelector {
.map(|profile| profile.name.clone())
.unwrap_or_else(|| "Unknown".into());
let configured_model = self.thread.read(cx).configured_model().or_else(|| {
let configured_model = self.thread.read(cx).model().or_else(|| {
let model_registry = LanguageModelRegistry::read_global(cx);
model_registry.default_model()
});

View File

@@ -498,7 +498,7 @@ impl AddedContext {
render_hover: {
let thread = handle.thread.clone();
Some(Rc::new(move |_, cx| {
let text = thread.read(cx).latest_detailed_summary_or_text();
let text = thread.read(cx).latest_detailed_summary_or_text(cx);
ContextPillHover::new_text(text.clone(), cx).into()
}))
},