Compare commits

...

42 Commits

Author SHA1 Message Date
Antonio Scandurra
587ed1e314 WIP 2025-07-01 12:52:08 +02:00
Conrad Irwin
1cf7a0f97b Rename WIPity WIP
Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
Co-authored-by: Nathan Sobo <nathan@zed.dev>
2025-06-30 17:48:49 -06:00
Conrad Irwin
f9b43cbd1f Re-merge ZedAgent and Thread
Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
Co-authored-by: Nathan Sobo <nathan@zed.dev>
2025-06-30 17:43:53 -06:00
Conrad Irwin
dab7ca4a84 WIP
Co-authored-by: Nathan Sobo <nathan@zed.dev>
2025-06-30 17:01:05 -06:00
Agus Zubiaga
e061fbefae Replace edit_message + delete_messages with new truncate method 2025-06-30 15:46:59 -03:00
Agus Zubiaga
64d19c44e4 Remove Message is_hidden 2025-06-30 15:05:44 -03:00
Agus Zubiaga
e51a0852e1 Replace insert_invisible_continue_message with send_continue_message 2025-06-30 14:39:42 -03:00
Agus Zubiaga
2ea1488aca Replace insert_user_message fully 2025-06-30 13:22:47 -03:00
Agus Zubiaga
c76361d213 Test new retry 2025-06-30 13:16:39 -03:00
Agus Zubiaga
9d7c94a16e Rename send_to_model2 to send_message and fix is_generating 2025-06-30 11:37:48 -03:00
Agus Zubiaga
2af70370e9 Trigger summary generation from send_to_model2 2025-06-30 11:14:51 -03:00
Agus Zubiaga
7725b95571 Replace more insert_user_message usages 2025-06-30 10:58:08 -03:00
Ben Brandt
be3a295ae4 Refactor tool use deserialization
Extract the tool use deserialization logic from `ZedAgent::new` into a
new `DeserializedToolUse` helper struct, so we don't have to clone
messages
2025-06-30 12:09:12 +02:00
Ben Brandt
269f73ab7c Report tool output in finished status 2025-06-30 11:53:24 +02:00
Ben Brandt
90899465a2 Cleanup from merge and clippy warnings 2025-06-30 11:31:56 +02:00
Ben Brandt
34a2d23134 Merge branch 'main' into split-agent-from-thread 2025-06-30 11:11:29 +02:00
Agus Zubiaga
3e2bcb05fb Start using send_to_model2 in message editor
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-06-27 19:35:53 -03:00
Agus Zubiaga
f32af6ab52 Checkpoint: Rendering tool uses
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-06-27 19:09:34 -03:00
Agus Zubiaga
eef7c07061 Remove MessageSegment::RedactedThinking
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-06-27 18:37:38 -03:00
Agus Zubiaga
b1a7812232 BASE_RETRY_DELAY_SECS -> BASE_RETRY_DELAY
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-06-27 18:18:21 -03:00
Agus Zubiaga
2f8fa209bc Test send_to_model2
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-06-27 18:07:23 -03:00
Max Brunsfeld
5e0f3e0ead Start writing assistant messages + tool calls to thread in ZedAgent
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-06-27 13:00:19 -07:00
Agus Zubiaga
8776548b02 Use build_request in tests 2025-06-27 12:46:15 -03:00
Agus Zubiaga
82b243e4ea Add user messages to agent request 2025-06-27 12:42:35 -03:00
Agus Zubiaga
b2434e7fef Checkpoint: Handle all retryable errors 2025-06-27 12:05:24 -03:00
Antonio Scandurra
6036c09c1a Checkpoint
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-06-27 16:00:08 +02:00
Antonio Scandurra
865970d42b WIP
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-06-27 15:11:50 +02:00
Antonio Scandurra
b9c4f2c7a8 WIP
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-06-27 14:26:15 +02:00
Antonio Scandurra
e458ba2293 WIP
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-06-27 13:12:45 +02:00
Antonio Scandurra
04c842a7c2 WIP: actually run tools
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-06-27 13:02:40 +02:00
Antonio Scandurra
7a055b4865 WIP: start reworking tool use
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-06-27 12:34:27 +02:00
Antonio Scandurra
9eff1c32af Merge remote-tracking branch 'origin/main' into split-agent-from-thread 2025-06-27 10:40:24 +02:00
Ben Brandt
88b1345595 variable cleanup 2025-06-27 10:09:31 +02:00
Max Brunsfeld
a02a0b9c0a Remove some methods that delegate from ZedAgent to Thread
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
2025-06-26 17:56:13 -07:00
Max Brunsfeld
f35fbbb78f Move ActionLog from ZedAgent to Thread
Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
2025-06-26 17:37:22 -07:00
Max Brunsfeld
bdeaddc59d Move checkpoints from agent to thread 2025-06-26 16:39:43 -07:00
Conrad Irwin
d5aa609bee Split thread/agent in ActiveThread
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-06-26 16:53:57 -06:00
Conrad Irwin
1f0512cd2f Move summary() into ThreadData. Split thread/agent in tests
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-06-26 16:32:51 -06:00
Conrad Irwin
438acc98d6 Move messages -> ThreadData
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-06-26 16:09:43 -06:00
Conrad Irwin
5cc016291d Factor id -> ThreadData
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-06-26 15:34:35 -06:00
Conrad Irwin
61ab3bcd8e Rename Thread -> Agent
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-06-26 15:31:53 -06:00
Max Brunsfeld
03478d5715 Inline ToolUseState
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-06-26 12:30:52 -07:00
25 changed files with 3910 additions and 3098 deletions

View File

@@ -68,6 +68,7 @@ zstd.workspace = true
[dev-dependencies] [dev-dependencies]
assistant_tools.workspace = true assistant_tools.workspace = true
assistant_tool = { workspace = true, "features" = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] }
indoc.workspace = true indoc.workspace = true
language = { workspace = true, "features" = ["test-support"] } language = { workspace = true, "features" = ["test-support"] }

View File

@@ -5,13 +5,12 @@ pub mod context_store;
pub mod history_store; pub mod history_store;
pub mod thread; pub mod thread;
pub mod thread_store; pub mod thread_store;
pub mod tool_use;
pub use context::{AgentContext, ContextId, ContextLoadResult}; pub use context::{AgentContext, ContextId, ContextLoadResult};
pub use context_store::ContextStore; pub use context_store::ContextStore;
pub use thread::{ pub use thread::{
LastRestoreCheckpoint, Message, MessageCrease, MessageId, MessageSegment, Thread, ThreadError, LastRestoreCheckpoint, Message, MessageCrease, MessageId, MessageSegment, ThreadError,
ThreadEvent, ThreadFeedback, ThreadId, ThreadSummary, TokenUsageRatio, ThreadEvent, ThreadFeedback, ThreadId, ThreadSummary, TokenUsageRatio, ZedAgentThread,
}; };
pub use thread_store::{SerializedThread, TextThreadStore, ThreadStore}; pub use thread_store::{SerializedThread, TextThreadStore, ThreadStore};

View File

@@ -1,4 +1,4 @@
use crate::thread::Thread; use crate::thread::ZedAgentThread;
use assistant_context::AssistantContext; use assistant_context::AssistantContext;
use assistant_tool::outline; use assistant_tool::outline;
use collections::HashSet; use collections::HashSet;
@@ -560,7 +560,7 @@ impl Display for FetchedUrlContext {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ThreadContextHandle { pub struct ThreadContextHandle {
pub thread: Entity<Thread>, pub agent: Entity<ZedAgentThread>,
pub context_id: ContextId, pub context_id: ContextId,
} }
@@ -573,23 +573,23 @@ pub struct ThreadContext {
impl ThreadContextHandle { impl ThreadContextHandle {
pub fn eq_for_key(&self, other: &Self) -> bool { pub fn eq_for_key(&self, other: &Self) -> bool {
self.thread == other.thread self.agent == other.agent
} }
pub fn hash_for_key<H: Hasher>(&self, state: &mut H) { pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
self.thread.hash(state) self.agent.hash(state)
} }
pub fn title(&self, cx: &App) -> SharedString { pub fn title(&self, cx: &App) -> SharedString {
self.thread.read(cx).summary().or_default() self.agent.read(cx).summary().or_default()
} }
fn load(self, cx: &App) -> Task<Option<(AgentContext, Vec<Entity<Buffer>>)>> { fn load(self, cx: &App) -> Task<Option<(AgentContext, Vec<Entity<Buffer>>)>> {
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
let text = Thread::wait_for_detailed_summary_or_text(&self.thread, cx).await?; let text = ZedAgentThread::wait_for_detailed_summary_or_text(&self.agent, cx).await?;
let title = self let title = self
.thread .agent
.read_with(cx, |thread, _cx| thread.summary().or_default()) .read_with(cx, |thread, _| thread.summary().or_default())
.ok()?; .ok()?;
let context = AgentContext::Thread(ThreadContext { let context = AgentContext::Thread(ThreadContext {
title, title,

View File

@@ -4,7 +4,7 @@ use crate::{
FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle, FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle,
SelectionContextHandle, SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle, SelectionContextHandle, SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle,
}, },
thread::{MessageId, Thread, ThreadId}, thread::{MessageId, ThreadId, ZedAgentThread},
thread_store::ThreadStore, thread_store::ThreadStore,
}; };
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
@@ -66,8 +66,9 @@ impl ContextStore {
pub fn new_context_for_thread( pub fn new_context_for_thread(
&self, &self,
thread: &Thread, thread: &ZedAgentThread,
exclude_messages_from_id: Option<MessageId>, exclude_messages_from_id: Option<MessageId>,
_cx: &App,
) -> Vec<AgentContextHandle> { ) -> Vec<AgentContextHandle> {
let existing_context = thread let existing_context = thread
.messages() .messages()
@@ -206,12 +207,15 @@ impl ContextStore {
pub fn add_thread( pub fn add_thread(
&mut self, &mut self,
thread: Entity<Thread>, thread: Entity<ZedAgentThread>,
remove_if_exists: bool, remove_if_exists: bool,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Option<AgentContextHandle> { ) -> Option<AgentContextHandle> {
let context_id = self.next_context_id.post_inc(); let context_id = self.next_context_id.post_inc();
let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }); let context = AgentContextHandle::Thread(ThreadContextHandle {
agent: thread,
context_id,
});
if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) { if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
if remove_if_exists { if remove_if_exists {
@@ -387,7 +391,10 @@ impl ContextStore {
if let Some(thread) = thread.upgrade() { if let Some(thread) = thread.upgrade() {
let context_id = self.next_context_id.post_inc(); let context_id = self.next_context_id.post_inc();
self.insert_context( self.insert_context(
AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }), AgentContextHandle::Thread(ThreadContextHandle {
agent: thread,
context_id,
}),
cx, cx,
); );
} }
@@ -411,11 +418,11 @@ impl ContextStore {
match &context { match &context {
AgentContextHandle::Thread(thread_context) => { AgentContextHandle::Thread(thread_context) => {
if let Some(thread_store) = self.thread_store.clone() { if let Some(thread_store) = self.thread_store.clone() {
thread_context.thread.update(cx, |thread, cx| { thread_context.agent.update(cx, |thread, cx| {
thread.start_generating_detailed_summary_if_needed(thread_store, cx); thread.start_generating_detailed_summary_if_needed(thread_store, cx);
}); });
self.context_thread_ids self.context_thread_ids
.insert(thread_context.thread.read(cx).id().clone()); .insert(thread_context.agent.read(cx).id().clone());
} else { } else {
return false; return false;
} }
@@ -441,7 +448,7 @@ impl ContextStore {
match context { match context {
AgentContextHandle::Thread(thread_context) => { AgentContextHandle::Thread(thread_context) => {
self.context_thread_ids self.context_thread_ids
.remove(thread_context.thread.read(cx).id()); .remove(thread_context.agent.read(cx).id());
} }
AgentContextHandle::TextThread(text_thread_context) => { AgentContextHandle::TextThread(text_thread_context) => {
if let Some(path) = text_thread_context.context.read(cx).path() { if let Some(path) = text_thread_context.context.read(cx).path() {
@@ -570,7 +577,7 @@ pub enum SuggestedContext {
}, },
Thread { Thread {
name: SharedString, name: SharedString,
thread: WeakEntity<Thread>, thread: WeakEntity<ZedAgentThread>,
}, },
TextThread { TextThread {
name: SharedString, name: SharedString,

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,7 @@
use crate::{ use crate::{
context_server_tool::ContextServerTool, context_server_tool::ContextServerTool,
thread::{ thread::{
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId, DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, ThreadId, ZedAgentThread,
}, },
}; };
use agent_settings::{AgentProfileId, CompletionMode}; use agent_settings::{AgentProfileId, CompletionMode};
@@ -400,9 +400,9 @@ impl ThreadStore {
self.threads.iter() self.threads.iter()
} }
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> { pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<ZedAgentThread> {
cx.new(|cx| { cx.new(|cx| {
Thread::new( ZedAgentThread::new(
self.project.clone(), self.project.clone(),
self.tools.clone(), self.tools.clone(),
self.prompt_builder.clone(), self.prompt_builder.clone(),
@@ -416,9 +416,9 @@ impl ThreadStore {
&mut self, &mut self,
serialized: SerializedThread, serialized: SerializedThread,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Entity<Thread> { ) -> Entity<ZedAgentThread> {
cx.new(|cx| { cx.new(|cx| {
Thread::deserialize( ZedAgentThread::deserialize(
ThreadId::new(), ThreadId::new(),
serialized, serialized,
self.project.clone(), self.project.clone(),
@@ -436,7 +436,7 @@ impl ThreadStore {
id: &ThreadId, id: &ThreadId,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Task<Result<Entity<Thread>>> { ) -> Task<Result<Entity<ZedAgentThread>>> {
let id = id.clone(); let id = id.clone();
let database_future = ThreadsDatabase::global_future(cx); let database_future = ThreadsDatabase::global_future(cx);
let this = cx.weak_entity(); let this = cx.weak_entity();
@@ -449,7 +449,7 @@ impl ThreadStore {
let thread = this.update_in(cx, |this, window, cx| { let thread = this.update_in(cx, |this, window, cx| {
cx.new(|cx| { cx.new(|cx| {
Thread::deserialize( ZedAgentThread::deserialize(
id.clone(), id.clone(),
thread, thread,
this.project.clone(), this.project.clone(),
@@ -466,9 +466,14 @@ impl ThreadStore {
}) })
} }
pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> { pub fn save_thread(
let (metadata, serialized_thread) = &self,
thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx))); thread: &Entity<ZedAgentThread>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let (metadata, serialized_thread) = thread.update(cx, |thread, cx| {
(thread.id().clone(), thread.serialize(cx))
});
let database_future = ThreadsDatabase::global_future(cx); let database_future = ThreadsDatabase::global_future(cx);
cx.spawn(async move |this, cx| { cx.spawn(async move |this, cx| {
@@ -700,7 +705,7 @@ impl SerializedThreadV0_1_0 {
} }
} }
#[derive(Debug, Serialize, Deserialize, PartialEq)] #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct SerializedMessage { pub struct SerializedMessage {
pub id: MessageId, pub id: MessageId,
pub role: Role, pub role: Role,
@@ -714,11 +719,9 @@ pub struct SerializedMessage {
pub context: String, pub context: String,
#[serde(default)] #[serde(default)]
pub creases: Vec<SerializedCrease>, pub creases: Vec<SerializedCrease>,
#[serde(default)]
pub is_hidden: bool,
} }
#[derive(Debug, Serialize, Deserialize, PartialEq)] #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum SerializedMessageSegment { pub enum SerializedMessageSegment {
#[serde(rename = "text")] #[serde(rename = "text")]
@@ -736,14 +739,14 @@ pub enum SerializedMessageSegment {
}, },
} }
#[derive(Debug, Serialize, Deserialize, PartialEq)] #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct SerializedToolUse { pub struct SerializedToolUse {
pub id: LanguageModelToolUseId, pub id: LanguageModelToolUseId,
pub name: SharedString, pub name: SharedString,
pub input: serde_json::Value, pub input: serde_json::Value,
} }
#[derive(Debug, Serialize, Deserialize, PartialEq)] #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct SerializedToolResult { pub struct SerializedToolResult {
pub tool_use_id: LanguageModelToolUseId, pub tool_use_id: LanguageModelToolUseId,
pub is_error: bool, pub is_error: bool,
@@ -801,12 +804,11 @@ impl LegacySerializedMessage {
tool_results: self.tool_results, tool_results: self.tool_results,
context: String::new(), context: String::new(),
creases: Vec::new(), creases: Vec::new(),
is_hidden: false,
} }
} }
} }
#[derive(Debug, Serialize, Deserialize, PartialEq)] #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct SerializedCrease { pub struct SerializedCrease {
pub start: usize, pub start: usize,
pub end: usize, pub end: usize,
@@ -1105,7 +1107,6 @@ mod tests {
tool_results: vec![], tool_results: vec![],
context: "".to_string(), context: "".to_string(),
creases: vec![], creases: vec![],
is_hidden: false
}], }],
version: SerializedThread::VERSION.to_string(), version: SerializedThread::VERSION.to_string(),
initial_project_snapshot: None, initial_project_snapshot: None,
@@ -1138,7 +1139,6 @@ mod tests {
tool_results: vec![], tool_results: vec![],
context: "".to_string(), context: "".to_string(),
creases: vec![], creases: vec![],
is_hidden: false,
}, },
SerializedMessage { SerializedMessage {
id: MessageId(2), id: MessageId(2),
@@ -1154,7 +1154,6 @@ mod tests {
tool_results: vec![], tool_results: vec![],
context: "".to_string(), context: "".to_string(),
creases: vec![], creases: vec![],
is_hidden: false,
}, },
SerializedMessage { SerializedMessage {
id: MessageId(1), id: MessageId(1),
@@ -1171,7 +1170,6 @@ mod tests {
}], }],
context: "".to_string(), context: "".to_string(),
creases: vec![], creases: vec![],
is_hidden: false,
}, },
], ],
version: SerializedThreadV0_1_0::VERSION.to_string(), version: SerializedThreadV0_1_0::VERSION.to_string(),
@@ -1203,7 +1201,6 @@ mod tests {
tool_results: vec![], tool_results: vec![],
context: "".to_string(), context: "".to_string(),
creases: vec![], creases: vec![],
is_hidden: false
}, },
SerializedMessage { SerializedMessage {
id: MessageId(2), id: MessageId(2),
@@ -1224,7 +1221,6 @@ mod tests {
}], }],
context: "".to_string(), context: "".to_string(),
creases: vec![], creases: vec![],
is_hidden: false,
}, },
], ],
version: SerializedThread::VERSION.to_string(), version: SerializedThread::VERSION.to_string(),

View File

@@ -1,567 +0,0 @@
use crate::{
thread::{MessageId, PromptId, ThreadId},
thread_store::SerializedMessage,
};
use anyhow::Result;
use assistant_tool::{
AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
};
use collections::HashMap;
use futures::{FutureExt as _, future::Shared};
use gpui::{App, Entity, SharedString, Task, Window};
use icons::IconName;
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
};
use project::Project;
use std::sync::Arc;
use util::truncate_lines_to_byte_limit;
#[derive(Debug)]
pub struct ToolUse {
pub id: LanguageModelToolUseId,
pub name: SharedString,
pub ui_text: SharedString,
pub status: ToolUseStatus,
pub input: serde_json::Value,
pub icon: icons::IconName,
pub needs_confirmation: bool,
}
pub struct ToolUseState {
tools: Entity<ToolWorkingSet>,
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
}
impl ToolUseState {
pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
Self {
tools,
tool_uses_by_assistant_message: HashMap::default(),
tool_results: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(),
tool_result_cards: HashMap::default(),
tool_use_metadata_by_id: HashMap::default(),
}
}
/// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
///
/// Accepts a function to filter the tools that should be used to populate the state.
///
/// If `window` is `None` (e.g., when in headless mode or when running evals),
/// tool cards won't be deserialized
pub fn from_serialized_messages(
tools: Entity<ToolWorkingSet>,
messages: &[SerializedMessage],
project: Entity<Project>,
window: Option<&mut Window>, // None in headless mode
cx: &mut App,
) -> Self {
let mut this = Self::new(tools);
let mut tool_names_by_id = HashMap::default();
let mut window = window;
for message in messages {
match message.role {
Role::Assistant => {
if !message.tool_uses.is_empty() {
let tool_uses = message
.tool_uses
.iter()
.map(|tool_use| LanguageModelToolUse {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
raw_input: tool_use.input.to_string(),
input: tool_use.input.clone(),
is_input_complete: true,
})
.collect::<Vec<_>>();
tool_names_by_id.extend(
tool_uses
.iter()
.map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
);
this.tool_uses_by_assistant_message
.insert(message.id, tool_uses);
for tool_result in &message.tool_results {
let tool_use_id = tool_result.tool_use_id.clone();
let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
log::warn!("no tool name found for tool use: {tool_use_id:?}");
continue;
};
this.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name: tool_use.clone(),
is_error: tool_result.is_error,
content: tool_result.content.clone(),
output: tool_result.output.clone(),
},
);
if let Some(window) = &mut window {
if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
if let Some(output) = tool_result.output.clone() {
if let Some(card) = tool.deserialize_card(
output,
project.clone(),
window,
cx,
) {
this.tool_result_cards.insert(tool_use_id, card);
}
}
}
}
}
}
}
Role::System | Role::User => {}
}
}
this
}
pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
let mut cancelled_tool_uses = Vec::new();
self.pending_tool_uses_by_id
.retain(|tool_use_id, tool_use| {
if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
return true;
}
let content = "Tool canceled by user".into();
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name: tool_use.name.clone(),
content,
output: None,
is_error: true,
},
);
cancelled_tool_uses.push(tool_use.clone());
false
});
cancelled_tool_uses
}
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
self.pending_tool_uses_by_id.values().collect()
}
pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
return Vec::new();
};
let mut tool_uses = Vec::new();
for tool_use in tool_uses_for_message.iter() {
let tool_result = self.tool_results.get(&tool_use.id);
let status = (|| {
if let Some(tool_result) = tool_result {
let content = tool_result
.content
.to_str()
.map(|str| str.to_owned().into())
.unwrap_or_default();
return if tool_result.is_error {
ToolUseStatus::Error(content)
} else {
ToolUseStatus::Finished(content)
};
}
if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
match pending_tool_use.status {
PendingToolUseStatus::Idle => ToolUseStatus::Pending,
PendingToolUseStatus::NeedsConfirmation { .. } => {
ToolUseStatus::NeedsConfirmation
}
PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
PendingToolUseStatus::Error(ref err) => {
ToolUseStatus::Error(err.clone().into())
}
PendingToolUseStatus::InputStillStreaming => {
ToolUseStatus::InputStillStreaming
}
}
} else {
ToolUseStatus::Pending
}
})();
let (icon, needs_confirmation) =
if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
(tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
} else {
(IconName::Cog, false)
};
tool_uses.push(ToolUse {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
ui_text: self.tool_ui_label(
&tool_use.name,
&tool_use.input,
tool_use.is_input_complete,
cx,
),
input: tool_use.input.clone(),
status,
icon,
needs_confirmation,
})
}
tool_uses
}
pub fn tool_ui_label(
&self,
tool_name: &str,
input: &serde_json::Value,
is_input_complete: bool,
cx: &App,
) -> SharedString {
if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
if is_input_complete {
tool.ui_text(input).into()
} else {
tool.still_streaming_ui_text(input).into()
}
} else {
format!("Unknown tool {tool_name:?}").into()
}
}
pub fn tool_results_for_message(
&self,
assistant_message_id: MessageId,
) -> Vec<&LanguageModelToolResult> {
let Some(tool_uses) = self
.tool_uses_by_assistant_message
.get(&assistant_message_id)
else {
return Vec::new();
};
tool_uses
.iter()
.filter_map(|tool_use| self.tool_results.get(&tool_use.id))
.collect()
}
pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
self.tool_uses_by_assistant_message
.get(&assistant_message_id)
.map_or(false, |results| !results.is_empty())
}
pub fn tool_result(
&self,
tool_use_id: &LanguageModelToolUseId,
) -> Option<&LanguageModelToolResult> {
self.tool_results.get(tool_use_id)
}
pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
self.tool_result_cards.get(tool_use_id)
}
pub fn insert_tool_result_card(
&mut self,
tool_use_id: LanguageModelToolUseId,
card: AnyToolCard,
) {
self.tool_result_cards.insert(tool_use_id, card);
}
pub fn request_tool_use(
&mut self,
assistant_message_id: MessageId,
tool_use: LanguageModelToolUse,
metadata: ToolUseMetadata,
cx: &App,
) -> Arc<str> {
let tool_uses = self
.tool_uses_by_assistant_message
.entry(assistant_message_id)
.or_default();
let mut existing_tool_use_found = false;
for existing_tool_use in tool_uses.iter_mut() {
if existing_tool_use.id == tool_use.id {
*existing_tool_use = tool_use.clone();
existing_tool_use_found = true;
}
}
if !existing_tool_use_found {
tool_uses.push(tool_use.clone());
}
let status = if tool_use.is_input_complete {
self.tool_use_metadata_by_id
.insert(tool_use.id.clone(), metadata);
PendingToolUseStatus::Idle
} else {
PendingToolUseStatus::InputStillStreaming
};
let ui_text: Arc<str> = self
.tool_ui_label(
&tool_use.name,
&tool_use.input,
tool_use.is_input_complete,
cx,
)
.into();
let may_perform_edits = self
.tools
.read(cx)
.tool(&tool_use.name, cx)
.is_some_and(|tool| tool.may_perform_edits());
self.pending_tool_uses_by_id.insert(
tool_use.id.clone(),
PendingToolUse {
assistant_message_id,
id: tool_use.id,
name: tool_use.name.clone(),
ui_text: ui_text.clone(),
input: tool_use.input,
may_perform_edits,
status,
},
);
ui_text
}
pub fn run_pending_tool(
&mut self,
tool_use_id: LanguageModelToolUseId,
ui_text: SharedString,
task: Task<()>,
) {
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
tool_use.ui_text = ui_text.into();
tool_use.status = PendingToolUseStatus::Running {
_task: task.shared(),
};
}
}
pub fn confirm_tool_use(
&mut self,
tool_use_id: LanguageModelToolUseId,
ui_text: impl Into<Arc<str>>,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
tool: Arc<dyn Tool>,
) {
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
let ui_text = ui_text.into();
tool_use.ui_text = ui_text.clone();
let confirmation = Confirmation {
tool_use_id,
input,
request,
tool,
ui_text,
};
tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
}
}
pub fn insert_tool_output(
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
output: Result<ToolResultOutput>,
configured_model: Option<&ConfiguredModel>,
) -> Option<PendingToolUse> {
let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
telemetry::event!(
"Agent Tool Finished",
model = metadata
.as_ref()
.map(|metadata| metadata.model.telemetry_id()),
model_provider = metadata
.as_ref()
.map(|metadata| metadata.model.provider_id().to_string()),
thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
tool_name,
success = output.is_ok()
);
match output {
Ok(output) => {
let tool_result = output.content;
const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
// Protect from overly large output
let tool_output_limit = configured_model
.map(|model| model.model.max_token_count() as usize * BYTES_PER_TOKEN_ESTIMATE)
.unwrap_or(usize::MAX);
let content = match tool_result {
ToolResultContent::Text(text) => {
let text = if text.len() < tool_output_limit {
text
} else {
let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
format!(
"Tool result too long. The first {} bytes:\n\n{}",
truncated.len(),
truncated
)
};
LanguageModelToolResultContent::Text(text.into())
}
ToolResultContent::Image(language_model_image) => {
if language_model_image.estimate_tokens() < tool_output_limit {
LanguageModelToolResultContent::Image(language_model_image)
} else {
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name,
content: "Tool responded with an image that would exceeded the remaining tokens".into(),
is_error: true,
output: None,
},
);
return old_use;
}
}
};
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name,
content,
is_error: false,
output: output.output,
},
);
old_use
}
Err(err) => {
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name,
content: LanguageModelToolResultContent::Text(err.to_string().into()),
is_error: true,
output: None,
},
);
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
}
self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
}
}
}
pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
self.tool_uses_by_assistant_message
.contains_key(&assistant_message_id)
}
pub fn tool_results(
&self,
assistant_message_id: MessageId,
) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
self.tool_uses_by_assistant_message
.get(&assistant_message_id)
.into_iter()
.flatten()
.map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
}
}
#[derive(Debug, Clone)]
pub struct PendingToolUse {
pub id: LanguageModelToolUseId,
/// The ID of the Assistant message in which the tool use was requested.
#[allow(unused)]
pub assistant_message_id: MessageId,
pub name: Arc<str>,
pub ui_text: Arc<str>,
pub input: serde_json::Value,
pub status: PendingToolUseStatus,
pub may_perform_edits: bool,
}
#[derive(Debug, Clone)]
pub struct Confirmation {
pub tool_use_id: LanguageModelToolUseId,
pub input: serde_json::Value,
pub ui_text: Arc<str>,
pub request: Arc<LanguageModelRequest>,
pub tool: Arc<dyn Tool>,
}
#[derive(Debug, Clone)]
pub enum PendingToolUseStatus {
InputStillStreaming,
Idle,
NeedsConfirmation(Arc<Confirmation>),
Running { _task: Shared<Task<()>> },
Error(#[allow(unused)] Arc<str>),
}
impl PendingToolUseStatus {
pub fn is_idle(&self) -> bool {
matches!(self, PendingToolUseStatus::Idle)
}
pub fn is_error(&self) -> bool {
matches!(self, PendingToolUseStatus::Error(_))
}
pub fn needs_confirmation(&self) -> bool {
matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
}
}
#[derive(Clone)]
pub struct ToolUseMetadata {
pub model: Arc<dyn LanguageModel>,
pub thread_id: ThreadId,
pub prompt_id: PromptId,
}

View File

@@ -96,6 +96,7 @@ zed_llm_client.workspace = true
[dev-dependencies] [dev-dependencies]
assistant_tools.workspace = true assistant_tools.workspace = true
assistant_tool = { workspace = true, "features" = ["test-support"] }
buffer_diff = { workspace = true, features = ["test-support"] } buffer_diff = { workspace = true, features = ["test-support"] }
editor = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] }

View File

@@ -5,16 +5,17 @@ use crate::ui::{
AddedContext, AgentNotification, AgentNotificationEvent, AnimatedLabel, ContextPill, AddedContext, AgentNotification, AgentNotificationEvent, AnimatedLabel, ContextPill,
}; };
use crate::{AgentPanel, ModelUsageContext}; use crate::{AgentPanel, ModelUsageContext};
use agent::thread::{ToolUseSegment, UserMessageParams};
use agent::{ use agent::{
ContextStore, LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, TextThreadStore, ContextStore, LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, TextThreadStore,
Thread, ThreadError, ThreadEvent, ThreadFeedback, ThreadStore, ThreadSummary, ThreadError, ThreadEvent, ThreadFeedback, ThreadStore, ThreadSummary, ZedAgentThread,
context::{self, AgentContextHandle, RULES_ICON}, context::{self, AgentContextHandle, RULES_ICON},
thread::{PendingToolUseStatus, ToolUse},
thread_store::RulesLoadingError, thread_store::RulesLoadingError,
tool_use::{PendingToolUseStatus, ToolUse},
}; };
use agent_settings::{AgentSettings, NotifyWhenAgentWaiting}; use agent_settings::{AgentSettings, NotifyWhenAgentWaiting};
use anyhow::Context as _; use anyhow::Context as _;
use assistant_tool::ToolUseStatus; use assistant_tool::{AnyToolCard, ToolUseStatus, ToolWorkingSet};
use audio::{Audio, Sound}; use audio::{Audio, Sound};
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
use editor::actions::{MoveUp, Paste}; use editor::actions::{MoveUp, Paste};
@@ -30,13 +31,14 @@ use gpui::{
}; };
use language::{Buffer, Language, LanguageRegistry}; use language::{Buffer, Language, LanguageRegistry};
use language_model::{ use language_model::{
LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, Role, StopReason, LanguageModelRequestMessage, LanguageModelToolResultContent, LanguageModelToolUseId,
MessageContent, Role, StopReason,
}; };
use markdown::parser::{CodeBlockKind, CodeBlockMetadata}; use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
use markdown::{ use markdown::{
HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown, PathWithRange, HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown, PathWithRange,
}; };
use project::{ProjectEntryId, ProjectItem as _}; use project::{Project, ProjectEntryId, ProjectItem as _};
use rope::Point; use rope::Point;
use settings::{Settings as _, SettingsStore, update_settings_file}; use settings::{Settings as _, SettingsStore, update_settings_file};
use std::ffi::OsStr; use std::ffi::OsStr;
@@ -50,11 +52,10 @@ use ui::{
Disclosure, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState, TextSize, Tooltip, Disclosure, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState, TextSize, Tooltip,
prelude::*, prelude::*,
}; };
use util::ResultExt as _;
use util::markdown::MarkdownCodeBlock; use util::markdown::MarkdownCodeBlock;
use util::{ResultExt as _, debug_panic};
use workspace::{CollaboratorId, Workspace}; use workspace::{CollaboratorId, Workspace};
use zed_actions::assistant::OpenRulesLibrary; use zed_actions::assistant::OpenRulesLibrary;
use zed_llm_client::CompletionIntent;
const CODEBLOCK_CONTAINER_GROUP: &str = "codeblock_container"; const CODEBLOCK_CONTAINER_GROUP: &str = "codeblock_container";
const EDIT_PREVIOUS_MESSAGE_MIN_LINES: usize = 1; const EDIT_PREVIOUS_MESSAGE_MIN_LINES: usize = 1;
@@ -64,8 +65,10 @@ pub struct ActiveThread {
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>, thread_store: Entity<ThreadStore>,
text_thread_store: Entity<TextThreadStore>, text_thread_store: Entity<TextThreadStore>,
thread: Entity<Thread>, agent: Entity<ZedAgentThread>,
// thread: Entity<Thread>,
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
project: Entity<Project>,
save_thread_task: Option<Task<()>>, save_thread_task: Option<Task<()>>,
messages: Vec<MessageId>, messages: Vec<MessageId>,
list_state: ListState, list_state: ListState,
@@ -92,7 +95,7 @@ struct RenderedMessage {
segments: Vec<RenderedMessageSegment>, segments: Vec<RenderedMessageSegment>,
} }
#[derive(Clone)] #[derive(Clone, Debug)]
struct RenderedToolUse { struct RenderedToolUse {
label: Entity<Markdown>, label: Entity<Markdown>,
input: Entity<Markdown>, input: Entity<Markdown>,
@@ -162,17 +165,103 @@ impl RenderedMessage {
cx, cx,
))) )))
} }
MessageSegment::RedactedThinking(_) => {} MessageSegment::ToolUse { .. } => {
todo!()
}
}; };
} }
fn update_tool_call(
&mut self,
segment_index: usize,
segment: &ToolUseSegment,
_tools: &Entity<ToolWorkingSet>,
cx: &mut App,
) {
if let Some(card) = segment.card.clone() {
if self.segments.len() < segment_index {
self.segments.push(RenderedMessageSegment::ToolUseCard(
segment.status.clone(),
card,
))
}
return;
}
if self.segments.len() <= segment_index {
self.segments
.push(RenderedMessageSegment::ToolUseMarkdown(RenderedToolUse {
label: cx.new(|cx| {
Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
}),
input: cx.new(|cx| {
Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
}),
output: cx.new(|cx| {
Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
}),
}))
}
dbg!(&self.segments);
let RenderedMessageSegment::ToolUseMarkdown(rendered) = &self.segments[segment_index]
else {
panic!()
};
// todo!()
// let ui_label = if let Some(tool) = tools.read(cx).tool(segment.name, cx) {
// if segment.is_input_complete {
// tool.ui_text(segment.input).into()
// } else {
// tool.still_streaming_ui_text(segment.input).into()
// }
// } else {
// format!("Unknown tool {:?}", segment.name).into()
// };
rendered.label.update(cx, |this, cx| {
this.replace(segment.name.clone(), cx);
});
rendered.input.update(cx, |this, cx| {
this.replace(
MarkdownCodeBlock {
tag: "json",
text: &serde_json::to_string_pretty(&segment.input).unwrap_or_default(),
}
.to_string(),
cx,
);
});
rendered.output.update(cx, |_this, _cx| {
match &segment.output {
Some(Ok(LanguageModelToolResultContent::Text(_text))) => {
// todo!
}
Some(Ok(LanguageModelToolResultContent::Image(_image))) => {
// todo!
}
Some(Err(_error)) => {
// todo!
}
None => {
// todo!
}
}
});
}
} }
#[derive(Debug)]
enum RenderedMessageSegment { enum RenderedMessageSegment {
Thinking { Thinking {
content: Entity<Markdown>, content: Entity<Markdown>,
scroll_handle: ScrollHandle, scroll_handle: ScrollHandle,
}, },
Text(Entity<Markdown>), Text(Entity<Markdown>),
ToolUseCard(ToolUseStatus, AnyToolCard),
ToolUseMarkdown(RenderedToolUse),
} }
fn parse_markdown( fn parse_markdown(
@@ -765,7 +854,7 @@ struct EditingMessageState {
impl ActiveThread { impl ActiveThread {
pub fn new( pub fn new(
thread: Entity<Thread>, agent: Entity<ZedAgentThread>,
thread_store: Entity<ThreadStore>, thread_store: Entity<ThreadStore>,
text_thread_store: Entity<TextThreadStore>, text_thread_store: Entity<TextThreadStore>,
context_store: Entity<ContextStore>, context_store: Entity<ContextStore>,
@@ -775,8 +864,8 @@ impl ActiveThread {
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let subscriptions = vec![ let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()), cx.observe(&agent, |_, _, cx| cx.notify()),
cx.subscribe_in(&thread, window, Self::handle_thread_event), cx.subscribe_in(&agent, window, Self::handle_thread_event),
cx.subscribe(&thread_store, Self::handle_rules_loading_error), cx.subscribe(&thread_store, Self::handle_rules_loading_error),
cx.observe_global::<SettingsStore>(|_, cx| cx.notify()), cx.observe_global::<SettingsStore>(|_, cx| cx.notify()),
]; ];
@@ -788,12 +877,14 @@ impl ActiveThread {
.unwrap() .unwrap()
} }
}); });
let project = agent.read(cx).project().clone();
let mut this = Self { let mut this = Self {
language_registry, language_registry,
thread_store, thread_store,
text_thread_store, text_thread_store,
context_store, context_store,
thread: thread.clone(), agent: agent.clone(),
project,
workspace, workspace,
save_thread_task: None, save_thread_task: None,
messages: Vec::new(), messages: Vec::new(),
@@ -816,7 +907,8 @@ impl ActiveThread {
_load_edited_message_context_task: None, _load_edited_message_context_task: None,
}; };
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() { // todo! hold on to thread entity and get messages directly
for message in agent.read(cx).messages().cloned().collect::<Vec<_>>() {
let rendered_message = RenderedMessage::from_segments( let rendered_message = RenderedMessage::from_segments(
&message.segments, &message.segments,
this.language_registry.clone(), this.language_registry.clone(),
@@ -824,7 +916,7 @@ impl ActiveThread {
); );
this.push_rendered_message(message.id, rendered_message); this.push_rendered_message(message.id, rendered_message);
for tool_use in thread.read(cx).tool_uses_for_message(message.id, cx) { for tool_use in agent.read(cx).tool_uses_for_message(message.id, cx) {
this.render_tool_use_markdown( this.render_tool_use_markdown(
tool_use.id.clone(), tool_use.id.clone(),
tool_use.ui_text.clone(), tool_use.ui_text.clone(),
@@ -838,8 +930,8 @@ impl ActiveThread {
this this
} }
pub fn thread(&self) -> &Entity<Thread> { pub fn agent(&self) -> &Entity<ZedAgentThread> {
&self.thread &self.agent
} }
pub fn is_empty(&self) -> bool { pub fn is_empty(&self) -> bool {
@@ -847,17 +939,17 @@ impl ActiveThread {
} }
pub fn summary<'a>(&'a self, cx: &'a App) -> &'a ThreadSummary { pub fn summary<'a>(&'a self, cx: &'a App) -> &'a ThreadSummary {
self.thread.read(cx).summary() self.agent.read(cx).summary()
} }
pub fn regenerate_summary(&self, cx: &mut App) { pub fn regenerate_summary(&self, cx: &mut App) {
self.thread.update(cx, |thread, cx| thread.summarize(cx)) self.agent.update(cx, |agent, cx| agent.summarize(cx))
} }
pub fn cancel_last_completion(&mut self, window: &mut Window, cx: &mut App) -> bool { pub fn cancel_last_completion(&mut self, window: &mut Window, cx: &mut App) -> bool {
self.last_error.take(); self.last_error.take();
self.thread.update(cx, |thread, cx| { self.agent.update(cx, |agent, cx| {
thread.cancel_last_completion(Some(window.window_handle()), cx) agent.cancel_last_completion(Some(window.window_handle()), cx)
}) })
} }
@@ -947,7 +1039,7 @@ impl ActiveThread {
fn handle_thread_event( fn handle_thread_event(
&mut self, &mut self,
_thread: &Entity<Thread>, _agent: &Entity<ZedAgentThread>,
event: &ThreadEvent, event: &ThreadEvent,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
@@ -965,10 +1057,8 @@ impl ActiveThread {
cx.notify(); cx.notify();
} }
ThreadEvent::CompletionCanceled => { ThreadEvent::CompletionCanceled => {
self.thread.update(cx, |thread, cx| { self.project.update(cx, |project, cx| {
thread.project().update(cx, |project, cx| { project.set_agent_location(None, cx);
project.set_agent_location(None, cx);
})
}); });
self.workspace self.workspace
.update(cx, |workspace, cx| { .update(cx, |workspace, cx| {
@@ -986,7 +1076,7 @@ impl ActiveThread {
} }
ThreadEvent::Stopped(reason) => match reason { ThreadEvent::Stopped(reason) => match reason {
Ok(StopReason::EndTurn | StopReason::MaxTokens) => { Ok(StopReason::EndTurn | StopReason::MaxTokens) => {
let used_tools = self.thread.read(cx).used_tools_since_last_user_message(); let used_tools = self.agent.read(cx).used_tools_since_last_user_message(cx);
self.play_notification_sound(window, cx); self.play_notification_sound(window, cx);
self.show_notification( self.show_notification(
if used_tools { if used_tools {
@@ -1024,9 +1114,28 @@ impl ActiveThread {
rendered_message.append_thinking(text, cx); rendered_message.append_thinking(text, cx);
} }
} }
ThreadEvent::StreamedToolUse2 {
message_id,
segment_index,
} => {
if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(&message_id) {
self.agent.update(cx, |agent, cx| {
if let Some(message) = agent.message(*message_id) {
let MessageSegment::ToolUse(tool_use) =
&message.segments[*segment_index]
else {
debug_panic!("segment index mismatch");
return;
};
let tools = self.agent.read(cx).tools().clone();
rendered_message.update_tool_call(*segment_index, tool_use, &tools, cx);
}
})
}
}
ThreadEvent::MessageAdded(message_id) => { ThreadEvent::MessageAdded(message_id) => {
if let Some(rendered_message) = self.thread.update(cx, |thread, cx| { if let Some(rendered_message) = self.agent.update(cx, |agent, cx| {
thread.message(*message_id).map(|message| { agent.message(*message_id).map(|message| {
RenderedMessage::from_segments( RenderedMessage::from_segments(
&message.segments, &message.segments,
self.language_registry.clone(), self.language_registry.clone(),
@@ -1042,8 +1151,8 @@ impl ActiveThread {
} }
ThreadEvent::MessageEdited(message_id) => { ThreadEvent::MessageEdited(message_id) => {
if let Some(index) = self.messages.iter().position(|id| id == message_id) { if let Some(index) = self.messages.iter().position(|id| id == message_id) {
if let Some(rendered_message) = self.thread.update(cx, |thread, cx| { if let Some(rendered_message) = self.agent.update(cx, |agent, cx| {
thread.message(*message_id).map(|message| { agent.message(*message_id).map(|message| {
let mut rendered_message = RenderedMessage { let mut rendered_message = RenderedMessage {
language_registry: self.language_registry.clone(), language_registry: self.language_registry.clone(),
segments: Vec::with_capacity(message.segments.len()), segments: Vec::with_capacity(message.segments.len()),
@@ -1100,7 +1209,7 @@ impl ActiveThread {
tool_use.id.clone(), tool_use.id.clone(),
tool_use.ui_text.clone(), tool_use.ui_text.clone(),
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(), &serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
self.thread self.agent
.read(cx) .read(cx)
.output_for_tool(&tool_use.id) .output_for_tool(&tool_use.id)
.map(|output| output.clone().into()) .map(|output| output.clone().into())
@@ -1120,7 +1229,7 @@ impl ActiveThread {
tool_use_id.clone(), tool_use_id.clone(),
ui_text, ui_text,
invalid_input_json, invalid_input_json,
self.thread self.agent
.read(cx) .read(cx)
.output_for_tool(tool_use_id) .output_for_tool(tool_use_id)
.map(|output| output.clone().into()) .map(|output| output.clone().into())
@@ -1136,7 +1245,7 @@ impl ActiveThread {
tool_use_id.clone(), tool_use_id.clone(),
ui_text, ui_text,
"", "",
self.thread self.agent
.read(cx) .read(cx)
.output_for_tool(tool_use_id) .output_for_tool(tool_use_id)
.map(|output| output.clone().into()) .map(|output| output.clone().into())
@@ -1185,7 +1294,7 @@ impl ActiveThread {
return; return;
} }
let title = self.thread.read(cx).summary().unwrap_or("Agent Panel"); let title = self.agent.read(cx).summary().unwrap_or("Agent Panel");
match AgentSettings::get_global(cx).notify_when_agent_waiting { match AgentSettings::get_global(cx).notify_when_agent_waiting {
NotifyWhenAgentWaiting::PrimaryScreen => { NotifyWhenAgentWaiting::PrimaryScreen => {
@@ -1296,12 +1405,12 @@ impl ActiveThread {
/// ///
/// Only one task to save the thread will be in flight at a time. /// Only one task to save the thread will be in flight at a time.
fn save_thread(&mut self, cx: &mut Context<Self>) { fn save_thread(&mut self, cx: &mut Context<Self>) {
let thread = self.thread.clone(); let agent = self.agent.clone();
self.save_thread_task = Some(cx.spawn(async move |this, cx| { self.save_thread_task = Some(cx.spawn(async move |this, cx| {
let task = this let task = this
.update(cx, |this, cx| { .update(cx, |this, cx| {
this.thread_store this.thread_store
.update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx)) .update(cx, |thread_store, cx| thread_store.save_thread(&agent, cx))
}) })
.ok(); .ok();
@@ -1351,7 +1460,7 @@ impl ActiveThread {
Some(self.text_thread_store.downgrade()), Some(self.text_thread_store.downgrade()),
context_picker_menu_handle.clone(), context_picker_menu_handle.clone(),
SuggestContextKind::File, SuggestContextKind::File,
ModelUsageContext::Thread(self.thread.clone()), ModelUsageContext::Thread(self.agent.clone()),
window, window,
cx, cx,
) )
@@ -1403,13 +1512,13 @@ impl ActiveThread {
cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged); cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged);
state._update_token_count_task.take(); state._update_token_count_task.take();
let Some(configured_model) = self.thread.read(cx).configured_model() else { let Some(configured_model) = self.agent.read(cx).configured_model() else {
state.last_estimated_token_count.take(); state.last_estimated_token_count.take();
return; return;
}; };
let editor = state.editor.clone(); let editor = state.editor.clone();
let thread = self.thread.clone(); let agent = self.agent.clone();
let message_id = *message_id; let message_id = *message_id;
state._update_token_count_task = Some(cx.spawn(async move |this, cx| { state._update_token_count_task = Some(cx.spawn(async move |this, cx| {
@@ -1421,7 +1530,7 @@ impl ActiveThread {
let token_count = if let Some(task) = cx let token_count = if let Some(task) = cx
.update(|cx| { .update(|cx| {
let Some(message) = thread.read(cx).message(message_id) else { let Some(message) = agent.read(cx).message(message_id) else {
log::error!("Message that was being edited no longer exists"); log::error!("Message that was being edited no longer exists");
return None; return None;
}; };
@@ -1553,8 +1662,8 @@ impl ActiveThread {
}; };
let Some(model) = self let Some(model) = self
.thread .agent
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx)) .update(cx, |agent, cx| agent.get_or_init_configured_model(cx))
else { else {
return; return;
}; };
@@ -1568,12 +1677,13 @@ impl ActiveThread {
let creases = state.editor.update(cx, extract_message_creases); let creases = state.editor.update(cx, extract_message_creases);
let new_context = self let new_context = self.context_store.read(cx).new_context_for_thread(
.context_store self.agent.read(cx),
.read(cx) Some(message_id),
.new_context_for_thread(self.thread.read(cx), Some(message_id)); cx,
);
let project = self.thread.read(cx).project().clone(); let project = self.project.clone();
let prompt_store = self.thread_store.read(cx).prompt_store().clone(); let prompt_store = self.thread_store.read(cx).prompt_store().clone();
let git_store = project.read(cx).git_store().clone(); let git_store = project.read(cx).git_store().clone();
@@ -1586,32 +1696,24 @@ impl ActiveThread {
futures::future::join(load_context_task, checkpoint).await; futures::future::join(load_context_task, checkpoint).await;
let _ = this let _ = this
.update_in(cx, |this, window, cx| { .update_in(cx, |this, window, cx| {
this.thread.update(cx, |thread, cx| { this.agent.update(cx, |agent, cx| {
thread.edit_message( agent.truncate(message_id, cx);
message_id, agent.send_message(
Role::User, UserMessageParams {
vec![MessageSegment::Text(edited_text)], text: edited_text,
creases, creases,
Some(context.loaded_context), checkpoint: checkpoint.ok(),
checkpoint.ok(), context,
cx, },
);
for message_id in this.messages_after(message_id) {
thread.delete_message(*message_id, cx);
}
});
this.thread.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.cancel_last_completion(Some(window.window_handle()), cx);
thread.send_to_model(
model.model, model.model,
CompletionIntent::UserPrompt,
Some(window.window_handle()), Some(window.window_handle()),
cx, cx,
); );
}); });
// todo! do we need this?
this._load_edited_message_context_task = None; this._load_edited_message_context_task = None;
cx.notify(); cx.notify();
}) })
.log_err(); .log_err();
@@ -1626,14 +1728,6 @@ impl ActiveThread {
} }
} }
fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
self.messages
.iter()
.position(|id| *id == message_id)
.map(|index| &self.messages[index + 1..])
.unwrap_or(&[])
}
fn handle_cancel_click(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) { fn handle_cancel_click(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
self.cancel_editing_message(&menu::Cancel, window, cx); self.cancel_editing_message(&menu::Cancel, window, cx);
} }
@@ -1654,7 +1748,7 @@ impl ActiveThread {
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
let report = self.thread.update(cx, |thread, cx| { let report = self.agent.update(cx, |thread, cx| {
thread.report_message_feedback(message_id, feedback, cx) thread.report_message_feedback(message_id, feedback, cx)
}); });
@@ -1713,17 +1807,17 @@ impl ActiveThread {
return; return;
}; };
let report_task = self.thread.update(cx, |thread, cx| { let report_task = self.agent.update(cx, |thread, cx| {
thread.report_message_feedback(message_id, ThreadFeedback::Negative, cx) thread.report_message_feedback(message_id, ThreadFeedback::Negative, cx)
}); });
let comments = editor.read(cx).text(cx); let comments = editor.read(cx).text(cx);
if !comments.is_empty() { if !comments.is_empty() {
let thread_id = self.thread.read(cx).id().clone(); let thread_id = self.agent.read(cx).id().clone();
let comments_value = String::from(comments.as_str()); let comments_value = String::from(comments.as_str());
let message_content = self let message_content = self
.thread .agent
.read(cx) .read(cx)
.message(message_id) .message(message_id)
.map(|msg| msg.to_string()) .map(|msg| msg.to_string())
@@ -1799,45 +1893,42 @@ impl ActiveThread {
fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement { fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
let message_id = self.messages[ix]; let message_id = self.messages[ix];
let workspace = self.workspace.clone(); let workspace = self.workspace.clone();
let thread = self.thread.read(cx); let agent = self.agent.read(cx);
let is_first_message = ix == 0; let is_first_message = ix == 0;
let is_last_message = ix == self.messages.len() - 1; let is_last_message = ix == self.messages.len() - 1;
let Some(message) = thread.message(message_id) else { let Some(message) = agent.message(message_id) else {
return Empty.into_any(); return Empty.into_any();
}; };
let is_generating = thread.is_generating(); let is_generating = agent.is_generating();
let is_generating_stale = thread.is_generation_stale().unwrap_or(false); let is_generating_stale = agent.is_generation_stale().unwrap_or(false);
let loading_dots = (is_generating && is_last_message).then(|| { let loading_dots = (is_generating && is_last_message).then(|| {
h_flex() h_flex()
.h_8() .h_8()
.my_3() .my_3()
.mx_5() .mx_5()
.when(is_generating_stale || message.is_hidden, |this| { .when(is_generating_stale, |this| {
this.child(AnimatedLabel::new("").size(LabelSize::Small)) this.child(AnimatedLabel::new("").size(LabelSize::Small))
}) })
}); });
if message.is_hidden {
return div().children(loading_dots).into_any();
}
let Some(rendered_message) = self.rendered_messages_by_id.get(&message_id) else { let Some(rendered_message) = self.rendered_messages_by_id.get(&message_id) else {
return Empty.into_any(); return Empty.into_any();
}; };
// Get all the data we need from thread before we start using it in closures // Get all the data we need from thread before we start using it in closures
let checkpoint = thread.checkpoint_for_message(message_id); let checkpoint = agent.checkpoint_for_message(message_id);
let configured_model = thread.configured_model().map(|m| m.model); let configured_model = agent.configured_model().map(|m| m.model);
let added_context = thread let added_context = agent
.context_for_message(message_id) .context_for_message(message_id)
.map(|context| AddedContext::new_attached(context, configured_model.as_ref(), cx)) .map(|context| AddedContext::new_attached(context, configured_model.as_ref(), cx))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let tool_uses = thread.tool_uses_for_message(message_id, cx); // let tool_uses = message.segments
let tool_uses = agent.tool_uses_for_message(message_id, cx);
let has_tool_uses = !tool_uses.is_empty(); let has_tool_uses = !tool_uses.is_empty();
let editing_message_state = self let editing_message_state = self
@@ -1856,11 +1947,11 @@ impl ActiveThread {
.icon_color(Color::Ignored) .icon_color(Color::Ignored)
.tooltip(Tooltip::text("Open Thread as Markdown")) .tooltip(Tooltip::text("Open Thread as Markdown"))
.on_click({ .on_click({
let thread = self.thread.clone(); let agent = self.agent.clone();
let workspace = self.workspace.clone(); let workspace = self.workspace.clone();
move |_, window, cx| { move |_, window, cx| {
if let Some(workspace) = workspace.upgrade() { if let Some(workspace) = workspace.upgrade() {
open_active_thread_as_markdown(thread.clone(), workspace, window, cx) open_active_thread_as_markdown(agent.clone(), workspace, window, cx)
.detach_and_log_err(cx); .detach_and_log_err(cx);
} }
} }
@@ -1877,7 +1968,7 @@ impl ActiveThread {
// For all items that should be aligned with the LLM's response. // For all items that should be aligned with the LLM's response.
const RESPONSE_PADDING_X: Pixels = px(19.); const RESPONSE_PADDING_X: Pixels = px(19.);
let show_feedback = thread.is_turn_end(ix); let show_feedback = self.agent.read(cx).is_turn_end(ix);
let feedback_container = h_flex() let feedback_container = h_flex()
.group("feedback_container") .group("feedback_container")
.mt_1() .mt_1()
@@ -1889,7 +1980,7 @@ impl ActiveThread {
.gap_1p5() .gap_1p5()
.flex_wrap() .flex_wrap()
.justify_end(); .justify_end();
let feedback_items = match self.thread.read(cx).message_feedback(message_id) { let feedback_items = match self.agent.read(cx).message_feedback(message_id) {
Some(feedback) => feedback_container Some(feedback) => feedback_container
.child( .child(
div().visible_on_hover("feedback_container").child( div().visible_on_hover("feedback_container").child(
@@ -1995,6 +2086,9 @@ impl ActiveThread {
}; };
let message_is_empty = message.should_display_content(); let message_is_empty = message.should_display_content();
let message_is_ui_only = message.ui_only;
let message_creases = message.creases.clone();
let role = message.role;
let has_content = !message_is_empty || !added_context.is_empty(); let has_content = !message_is_empty || !added_context.is_empty();
let message_content = has_content.then(|| { let message_content = has_content.then(|| {
@@ -2037,10 +2131,10 @@ impl ActiveThread {
} }
}); });
let styled_message = if message.ui_only { let styled_message = if message_is_ui_only {
self.render_ui_notification(message_content, ix, cx) self.render_ui_notification(message_content, ix, cx)
} else { } else {
match message.role { match role {
Role::User => { Role::User => {
let colors = cx.theme().colors(); let colors = cx.theme().colors();
v_flex() v_flex()
@@ -2145,10 +2239,9 @@ impl ActiveThread {
}), }),
) )
.on_click(cx.listener({ .on_click(cx.listener({
let message_creases = message.creases.clone();
move |this, _, window, cx| { move |this, _, window, cx| {
if let Some(message_text) = if let Some(message_text) =
this.thread.read(cx).message(message_id).and_then(|message| { this.agent.read(cx).message(message_id).and_then(|message| {
message.segments.first().and_then(|segment| { message.segments.first().and_then(|segment| {
match segment { match segment {
MessageSegment::Text(message_text) => { MessageSegment::Text(message_text) => {
@@ -2219,7 +2312,7 @@ impl ActiveThread {
let mut is_pending = false; let mut is_pending = false;
let mut error = None; let mut error = None;
if let Some(last_restore_checkpoint) = if let Some(last_restore_checkpoint) =
self.thread.read(cx).last_restore_checkpoint() self.agent.read(cx).last_restore_checkpoint()
{ {
if last_restore_checkpoint.message_id() == message_id { if last_restore_checkpoint.message_id() == message_id {
match last_restore_checkpoint { match last_restore_checkpoint {
@@ -2248,7 +2341,7 @@ impl ActiveThread {
.label_size(LabelSize::XSmall) .label_size(LabelSize::XSmall)
.disabled(is_pending) .disabled(is_pending)
.on_click(cx.listener(move |this, _, _window, cx| { .on_click(cx.listener(move |this, _, _window, cx| {
this.thread.update(cx, |thread, cx| { this.agent.update(cx, |thread, cx| {
thread thread
.restore_checkpoint(checkpoint.clone(), cx) .restore_checkpoint(checkpoint.clone(), cx)
.detach_and_log_err(cx); .detach_and_log_err(cx);
@@ -2382,11 +2475,11 @@ impl ActiveThread {
rendered_message: &RenderedMessage, rendered_message: &RenderedMessage,
has_tool_uses: bool, has_tool_uses: bool,
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
window: &Window, window: &mut Window,
cx: &Context<Self>, cx: &mut Context<Self>,
) -> impl IntoElement { ) -> impl IntoElement {
let is_last_message = self.messages.last() == Some(&message_id); let is_last_message = self.messages.last() == Some(&message_id);
let is_generating = self.thread.read(cx).is_generating(); let is_generating = self.agent.read(cx).is_generating();
let pending_thinking_segment_index = if is_generating && is_last_message && !has_tool_uses { let pending_thinking_segment_index = if is_generating && is_last_message && !has_tool_uses {
rendered_message rendered_message
.segments .segments
@@ -2400,7 +2493,7 @@ impl ActiveThread {
}; };
let message_role = self let message_role = self
.thread .agent
.read(cx) .read(cx)
.message(message_id) .message(message_id)
.map(|m| m.role) .map(|m| m.role)
@@ -2515,6 +2608,23 @@ impl ActiveThread {
})) }))
.into_any_element() .into_any_element()
} }
RenderedMessageSegment::ToolUseCard(status, card) => {
card.render(status, window, workspace.clone(), cx)
}
RenderedMessageSegment::ToolUseMarkdown(rendered) => v_flex()
.child(MarkdownElement::new(
rendered.label.clone(),
default_markdown_style(window, cx),
))
.child(MarkdownElement::new(
rendered.input.clone(),
default_markdown_style(window, cx),
))
.child(MarkdownElement::new(
rendered.output.clone(),
default_markdown_style(window, cx),
))
.into_any(), // todo!()
}, },
), ),
) )
@@ -2784,7 +2894,7 @@ impl ActiveThread {
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> impl IntoElement + use<> { ) -> impl IntoElement + use<> {
if let Some(card) = self.thread.read(cx).card_for_tool(&tool_use.id) { if let Some(card) = self.agent.read(cx).card_for_tool(&tool_use.id) {
return card.render(&tool_use.status, window, workspace, cx); return card.render(&tool_use.status, window, workspace, cx);
} }
@@ -3265,7 +3375,7 @@ impl ActiveThread {
} }
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement { fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
let project_context = self.thread.read(cx).project_context(); let project_context = self.agent.read(cx).project_context();
let project_context = project_context.borrow(); let project_context = project_context.borrow();
let Some(project_context) = project_context.as_ref() else { let Some(project_context) = project_context.as_ref() else {
return div().into_any(); return div().into_any();
@@ -3389,12 +3499,12 @@ impl ActiveThread {
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
if let Some(PendingToolUseStatus::NeedsConfirmation(c)) = self if let Some(PendingToolUseStatus::NeedsConfirmation(c)) = self
.thread .agent
.read(cx) .read(cx)
.pending_tool(&tool_use_id) .pending_tool(&tool_use_id)
.map(|tool_use| tool_use.status.clone()) .map(|tool_use| tool_use.status.clone())
{ {
self.thread.update(cx, |thread, cx| { self.agent.update(cx, |thread, cx| {
if let Some(configured) = thread.get_or_init_configured_model(cx) { if let Some(configured) = thread.get_or_init_configured_model(cx) {
thread.run_tool( thread.run_tool(
c.tool_use_id.clone(), c.tool_use_id.clone(),
@@ -3420,13 +3530,13 @@ impl ActiveThread {
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
let window_handle = window.window_handle(); let window_handle = window.window_handle();
self.thread.update(cx, |thread, cx| { self.agent.update(cx, |thread, cx| {
thread.deny_tool_use(tool_use_id, tool_name, Some(window_handle), cx); thread.deny_tool_use(tool_use_id, tool_name, Some(window_handle), cx);
}); });
} }
fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) { fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
let project_context = self.thread.read(cx).project_context(); let project_context = self.agent.read(cx).project_context();
let project_context = project_context.borrow(); let project_context = project_context.borrow();
let Some(project_context) = project_context.as_ref() else { let Some(project_context) = project_context.as_ref() else {
return; return;
@@ -3588,7 +3698,7 @@ impl Render for ActiveThread {
} }
pub(crate) fn open_active_thread_as_markdown( pub(crate) fn open_active_thread_as_markdown(
thread: Entity<Thread>, agent: Entity<ZedAgentThread>,
workspace: Entity<Workspace>, workspace: Entity<Workspace>,
window: &mut Window, window: &mut Window,
cx: &mut App, cx: &mut App,
@@ -3603,7 +3713,7 @@ pub(crate) fn open_active_thread_as_markdown(
let markdown_language = markdown_language_task.await?; let markdown_language = markdown_language_task.await?;
workspace.update_in(cx, |workspace, window, cx| { workspace.update_in(cx, |workspace, window, cx| {
let thread = thread.read(cx); let thread = agent.read(cx);
let markdown = thread.to_markdown(cx)?; let markdown = thread.to_markdown(cx)?;
let thread_summary = thread.summary().or_default().to_string(); let thread_summary = thread.summary().or_default().to_string();
@@ -3692,7 +3802,7 @@ pub(crate) fn open_context(
AgentContextHandle::Thread(thread_context) => workspace.update(cx, |workspace, cx| { AgentContextHandle::Thread(thread_context) => workspace.update(cx, |workspace, cx| {
if let Some(panel) = workspace.panel::<AgentPanel>(cx) { if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
panel.update(cx, |panel, cx| { panel.update(cx, |panel, cx| {
panel.open_thread(thread_context.thread.clone(), window, cx); panel.open_thread(thread_context.agent.clone(), window, cx);
}); });
} }
}), }),
@@ -3779,7 +3889,9 @@ fn open_editor_at_position(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use agent::{MessageSegment, context::ContextLoadResult, thread_store}; use agent::{
MessageSegment, context::ContextLoadResult, thread::UserMessageParams, thread_store,
};
use assistant_tool::{ToolRegistry, ToolWorkingSet}; use assistant_tool::{ToolRegistry, ToolWorkingSet};
use editor::EditorSettings; use editor::EditorSettings;
use fs::FakeFs; use fs::FakeFs;
@@ -3794,6 +3906,7 @@ mod tests {
use settings::SettingsStore; use settings::SettingsStore;
use util::path; use util::path;
use workspace::CollaboratorId; use workspace::CollaboratorId;
use zed_llm_client::CompletionIntent;
#[gpui::test] #[gpui::test]
async fn test_agent_is_unfollowed_after_cancelling_completion(cx: &mut TestAppContext) { async fn test_agent_is_unfollowed_after_cancelling_completion(cx: &mut TestAppContext) {
@@ -3810,13 +3923,12 @@ mod tests {
// Insert user message without any context (empty context vector) // Insert user message without any context (empty context vector)
thread.update(cx, |thread, cx| { thread.update(cx, |thread, cx| {
thread.insert_user_message( thread.send_message(
"What is the best way to learn Rust?", "What is the best way to learn Rust?",
ContextLoadResult::default(), model.clone(),
None, None,
vec![],
cx, cx,
); )
}); });
// Stream response to user message // Stream response to user message
@@ -3857,7 +3969,7 @@ mod tests {
registry.set_default_model( registry.set_default_model(
Some(ConfiguredModel { Some(ConfiguredModel {
provider: Arc::new(FakeLanguageModelProvider), provider: Arc::new(FakeLanguageModelProvider),
model, model: model.clone(),
}), }),
cx, cx,
); );
@@ -3871,15 +3983,19 @@ mod tests {
context: None, context: None,
}]; }];
let message = thread.update(cx, |thread, cx| { let message = thread.update(cx, |agent, cx| {
let message_id = thread.insert_user_message( let message_id = agent.send_message(
"Tell me about @foo.txt", UserMessageParams {
ContextLoadResult::default(), text: "Tell me about @foo.txt".to_string(),
creases,
checkpoint: None,
context: ContextLoadResult::default(),
},
model.clone(),
None, None,
creases,
cx, cx,
); );
thread.message(message_id).cloned().unwrap() agent.message(message_id).cloned().unwrap()
}); });
active_thread.update_in(cx, |active_thread, window, cx| { active_thread.update_in(cx, |active_thread, window, cx| {
@@ -3971,20 +4087,8 @@ mod tests {
// Insert a user message and start streaming a response // Insert a user message and start streaming a response
let message = thread.update(cx, |thread, cx| { let message = thread.update(cx, |thread, cx| {
let message_id = thread.insert_user_message( let message_id =
"Hello, how are you?", thread.send_message("Hello, how are you?", model.clone(), cx.active_window(), cx);
ContextLoadResult::default(),
None,
vec![],
cx,
);
thread.advance_prompt_id();
thread.send_to_model(
model.clone(),
CompletionIntent::UserPrompt,
cx.active_window(),
cx,
);
thread.message(message_id).cloned().unwrap() thread.message(message_id).cloned().unwrap()
}); });
@@ -4071,7 +4175,7 @@ mod tests {
&mut VisualTestContext, &mut VisualTestContext,
Entity<ActiveThread>, Entity<ActiveThread>,
Entity<Workspace>, Entity<Workspace>,
Entity<Thread>, Entity<ZedAgentThread>,
Arc<dyn LanguageModel>, Arc<dyn LanguageModel>,
) { ) {
let (workspace, cx) = let (workspace, cx) =

View File

@@ -1,7 +1,8 @@
use crate::{Keep, KeepAll, OpenAgentDiff, Reject, RejectAll}; use crate::{Keep, KeepAll, OpenAgentDiff, Reject, RejectAll};
use agent::{Thread, ThreadEvent}; use agent::{ThreadEvent, ZedAgentThread};
use agent_settings::AgentSettings; use agent_settings::AgentSettings;
use anyhow::Result; use anyhow::Result;
use assistant_tool::ActionLog;
use buffer_diff::DiffHunkStatus; use buffer_diff::DiffHunkStatus;
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
use editor::{ use editor::{
@@ -41,7 +42,8 @@ use zed_actions::assistant::ToggleFocus;
pub struct AgentDiffPane { pub struct AgentDiffPane {
multibuffer: Entity<MultiBuffer>, multibuffer: Entity<MultiBuffer>,
editor: Entity<Editor>, editor: Entity<Editor>,
thread: Entity<Thread>, agent: Entity<ZedAgentThread>,
action_log: Entity<ActionLog>,
focus_handle: FocusHandle, focus_handle: FocusHandle,
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
title: SharedString, title: SharedString,
@@ -50,70 +52,71 @@ pub struct AgentDiffPane {
impl AgentDiffPane { impl AgentDiffPane {
pub fn deploy( pub fn deploy(
thread: Entity<Thread>, agent: Entity<ZedAgentThread>,
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
window: &mut Window, window: &mut Window,
cx: &mut App, cx: &mut App,
) -> Result<Entity<Self>> { ) -> Result<Entity<Self>> {
workspace.update(cx, |workspace, cx| { workspace.update(cx, |workspace, cx| {
Self::deploy_in_workspace(thread, workspace, window, cx) Self::deploy_in_workspace(agent, workspace, window, cx)
}) })
} }
pub fn deploy_in_workspace( pub fn deploy_in_workspace(
thread: Entity<Thread>, agent: Entity<ZedAgentThread>,
workspace: &mut Workspace, workspace: &mut Workspace,
window: &mut Window, window: &mut Window,
cx: &mut Context<Workspace>, cx: &mut Context<Workspace>,
) -> Entity<Self> { ) -> Entity<Self> {
let existing_diff = workspace let existing_diff = workspace
.items_of_type::<AgentDiffPane>(cx) .items_of_type::<AgentDiffPane>(cx)
.find(|diff| diff.read(cx).thread == thread); .find(|diff| diff.read(cx).agent == agent);
if let Some(existing_diff) = existing_diff { if let Some(existing_diff) = existing_diff {
workspace.activate_item(&existing_diff, true, true, window, cx); workspace.activate_item(&existing_diff, true, true, window, cx);
existing_diff existing_diff
} else { } else {
let agent_diff = cx let agent_diff =
.new(|cx| AgentDiffPane::new(thread.clone(), workspace.weak_handle(), window, cx)); cx.new(|cx| AgentDiffPane::new(agent.clone(), workspace.weak_handle(), window, cx));
workspace.add_item_to_center(Box::new(agent_diff.clone()), window, cx); workspace.add_item_to_center(Box::new(agent_diff.clone()), window, cx);
agent_diff agent_diff
} }
} }
pub fn new( pub fn new(
thread: Entity<Thread>, agent: Entity<ZedAgentThread>,
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let focus_handle = cx.focus_handle(); let focus_handle = cx.focus_handle();
let multibuffer = cx.new(|_| MultiBuffer::new(Capability::ReadWrite)); let multibuffer = cx.new(|_| MultiBuffer::new(Capability::ReadWrite));
let action_log = agent.read(cx).action_log();
let project = agent.read(cx).project().clone();
let project = thread.read(cx).project().clone();
let editor = cx.new(|cx| { let editor = cx.new(|cx| {
let mut editor = let mut editor =
Editor::for_multibuffer(multibuffer.clone(), Some(project.clone()), window, cx); Editor::for_multibuffer(multibuffer.clone(), Some(project.clone()), window, cx);
editor.disable_inline_diagnostics(); editor.disable_inline_diagnostics();
editor.set_expand_all_diff_hunks(cx); editor.set_expand_all_diff_hunks(cx);
editor.set_render_diff_hunk_controls(diff_hunk_controls(&thread), cx); editor.set_render_diff_hunk_controls(diff_hunk_controls(&action_log), cx);
editor.register_addon(AgentDiffAddon); editor.register_addon(AgentDiffAddon);
editor editor
}); });
let action_log = thread.read(cx).action_log().clone();
let mut this = Self { let mut this = Self {
_subscriptions: vec![ _subscriptions: vec![
cx.observe_in(&action_log, window, |this, _action_log, window, cx| { cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
this.update_excerpts(window, cx) this.update_excerpts(window, cx)
}), }),
cx.subscribe(&thread, |this, _thread, event, cx| { cx.subscribe(&agent, |this, _thread, event, cx| {
this.handle_thread_event(event, cx) this.handle_thread_event(event, cx)
}), }),
], ],
title: SharedString::default(), title: SharedString::default(),
action_log,
multibuffer, multibuffer,
editor, editor,
thread, agent,
focus_handle, focus_handle,
workspace, workspace,
}; };
@@ -123,8 +126,8 @@ impl AgentDiffPane {
} }
fn update_excerpts(&mut self, window: &mut Window, cx: &mut Context<Self>) { fn update_excerpts(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let thread = self.thread.read(cx); let agent = self.agent.read(cx);
let changed_buffers = thread.action_log().read(cx).changed_buffers(cx); let changed_buffers = agent.action_log().read(cx).changed_buffers(cx);
let mut paths_to_delete = self.multibuffer.read(cx).paths().collect::<HashSet<_>>(); let mut paths_to_delete = self.multibuffer.read(cx).paths().collect::<HashSet<_>>();
for (buffer, diff_handle) in changed_buffers { for (buffer, diff_handle) in changed_buffers {
@@ -211,7 +214,7 @@ impl AgentDiffPane {
} }
fn update_title(&mut self, cx: &mut Context<Self>) { fn update_title(&mut self, cx: &mut Context<Self>) {
let new_title = self.thread.read(cx).summary().unwrap_or("Agent Changes"); let new_title = self.agent.read(cx).summary().unwrap_or("Agent Changes");
if new_title != self.title { if new_title != self.title {
self.title = new_title; self.title = new_title;
cx.emit(EditorEvent::TitleChanged); cx.emit(EditorEvent::TitleChanged);
@@ -248,14 +251,14 @@ impl AgentDiffPane {
fn keep(&mut self, _: &Keep, window: &mut Window, cx: &mut Context<Self>) { fn keep(&mut self, _: &Keep, window: &mut Window, cx: &mut Context<Self>) {
self.editor.update(cx, |editor, cx| { self.editor.update(cx, |editor, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx); let snapshot = editor.buffer().read(cx).snapshot(cx);
keep_edits_in_selection(editor, &snapshot, &self.thread, window, cx); keep_edits_in_selection(editor, &snapshot, &self.action_log, window, cx);
}); });
} }
fn reject(&mut self, _: &Reject, window: &mut Window, cx: &mut Context<Self>) { fn reject(&mut self, _: &Reject, window: &mut Window, cx: &mut Context<Self>) {
self.editor.update(cx, |editor, cx| { self.editor.update(cx, |editor, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx); let snapshot = editor.buffer().read(cx).snapshot(cx);
reject_edits_in_selection(editor, &snapshot, &self.thread, window, cx); reject_edits_in_selection(editor, &snapshot, &self.action_log, window, cx);
}); });
} }
@@ -265,7 +268,7 @@ impl AgentDiffPane {
reject_edits_in_ranges( reject_edits_in_ranges(
editor, editor,
&snapshot, &snapshot,
&self.thread, &self.action_log,
vec![editor::Anchor::min()..editor::Anchor::max()], vec![editor::Anchor::min()..editor::Anchor::max()],
window, window,
cx, cx,
@@ -274,15 +277,15 @@ impl AgentDiffPane {
} }
fn keep_all(&mut self, _: &KeepAll, _window: &mut Window, cx: &mut Context<Self>) { fn keep_all(&mut self, _: &KeepAll, _window: &mut Window, cx: &mut Context<Self>) {
self.thread self.action_log
.update(cx, |thread, cx| thread.keep_all_edits(cx)); .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
} }
} }
fn keep_edits_in_selection( fn keep_edits_in_selection(
editor: &mut Editor, editor: &mut Editor,
buffer_snapshot: &MultiBufferSnapshot, buffer_snapshot: &MultiBufferSnapshot,
thread: &Entity<Thread>, action_log: &Entity<ActionLog>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Editor>, cx: &mut Context<Editor>,
) { ) {
@@ -291,13 +294,13 @@ fn keep_edits_in_selection(
.disjoint_anchor_ranges() .disjoint_anchor_ranges()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
keep_edits_in_ranges(editor, buffer_snapshot, &thread, ranges, window, cx) keep_edits_in_ranges(editor, buffer_snapshot, &action_log, ranges, window, cx)
} }
fn reject_edits_in_selection( fn reject_edits_in_selection(
editor: &mut Editor, editor: &mut Editor,
buffer_snapshot: &MultiBufferSnapshot, buffer_snapshot: &MultiBufferSnapshot,
thread: &Entity<Thread>, action_log: &Entity<ActionLog>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Editor>, cx: &mut Context<Editor>,
) { ) {
@@ -305,13 +308,13 @@ fn reject_edits_in_selection(
.selections .selections
.disjoint_anchor_ranges() .disjoint_anchor_ranges()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
reject_edits_in_ranges(editor, buffer_snapshot, &thread, ranges, window, cx) reject_edits_in_ranges(editor, buffer_snapshot, &action_log, ranges, window, cx)
} }
fn keep_edits_in_ranges( fn keep_edits_in_ranges(
editor: &mut Editor, editor: &mut Editor,
buffer_snapshot: &MultiBufferSnapshot, buffer_snapshot: &MultiBufferSnapshot,
thread: &Entity<Thread>, action_log: &Entity<ActionLog>,
ranges: Vec<Range<editor::Anchor>>, ranges: Vec<Range<editor::Anchor>>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Editor>, cx: &mut Context<Editor>,
@@ -326,8 +329,8 @@ fn keep_edits_in_ranges(
for hunk in &diff_hunks_in_ranges { for hunk in &diff_hunks_in_ranges {
let buffer = multibuffer.read(cx).buffer(hunk.buffer_id); let buffer = multibuffer.read(cx).buffer(hunk.buffer_id);
if let Some(buffer) = buffer { if let Some(buffer) = buffer {
thread.update(cx, |thread, cx| { action_log.update(cx, |action_log, cx| {
thread.keep_edits_in_range(buffer, hunk.buffer_range.clone(), cx) action_log.keep_edits_in_range(buffer, hunk.buffer_range.clone(), cx)
}); });
} }
} }
@@ -336,7 +339,7 @@ fn keep_edits_in_ranges(
fn reject_edits_in_ranges( fn reject_edits_in_ranges(
editor: &mut Editor, editor: &mut Editor,
buffer_snapshot: &MultiBufferSnapshot, buffer_snapshot: &MultiBufferSnapshot,
thread: &Entity<Thread>, action_log: &Entity<ActionLog>,
ranges: Vec<Range<editor::Anchor>>, ranges: Vec<Range<editor::Anchor>>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Editor>, cx: &mut Context<Editor>,
@@ -361,9 +364,9 @@ fn reject_edits_in_ranges(
} }
for (buffer, ranges) in ranges_by_buffer { for (buffer, ranges) in ranges_by_buffer {
thread action_log
.update(cx, |thread, cx| { .update(cx, |action_log, cx| {
thread.reject_edits_in_ranges(buffer, ranges, cx) action_log.reject_edits_in_ranges(buffer, ranges, cx)
}) })
.detach_and_log_err(cx); .detach_and_log_err(cx);
} }
@@ -461,7 +464,7 @@ impl Item for AgentDiffPane {
} }
fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement { fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement {
let summary = self.thread.read(cx).summary().unwrap_or("Agent Changes"); let summary = self.agent.read(cx).summary().or_default();
Label::new(format!("Review: {}", summary)) Label::new(format!("Review: {}", summary))
.color(if params.selected { .color(if params.selected {
Color::Default Color::Default
@@ -511,7 +514,7 @@ impl Item for AgentDiffPane {
where where
Self: Sized, Self: Sized,
{ {
Some(cx.new(|cx| Self::new(self.thread.clone(), self.workspace.clone(), window, cx))) Some(cx.new(|cx| Self::new(self.agent.clone(), self.workspace.clone(), window, cx)))
} }
fn is_dirty(&self, cx: &App) -> bool { fn is_dirty(&self, cx: &App) -> bool {
@@ -641,8 +644,8 @@ impl Render for AgentDiffPane {
} }
} }
fn diff_hunk_controls(thread: &Entity<Thread>) -> editor::RenderDiffHunkControlsFn { fn diff_hunk_controls(action_log: &Entity<ActionLog>) -> editor::RenderDiffHunkControlsFn {
let thread = thread.clone(); let action_log = action_log.clone();
Arc::new( Arc::new(
move |row, move |row,
@@ -660,7 +663,7 @@ fn diff_hunk_controls(thread: &Entity<Thread>) -> editor::RenderDiffHunkControls
hunk_range, hunk_range,
is_created_file, is_created_file,
line_height, line_height,
&thread, &action_log,
editor, editor,
window, window,
cx, cx,
@@ -676,7 +679,7 @@ fn render_diff_hunk_controls(
hunk_range: Range<editor::Anchor>, hunk_range: Range<editor::Anchor>,
is_created_file: bool, is_created_file: bool,
line_height: Pixels, line_height: Pixels,
thread: &Entity<Thread>, action_log: &Entity<ActionLog>,
editor: &Entity<Editor>, editor: &Entity<Editor>,
window: &mut Window, window: &mut Window,
cx: &mut App, cx: &mut App,
@@ -711,14 +714,14 @@ fn render_diff_hunk_controls(
) )
.on_click({ .on_click({
let editor = editor.clone(); let editor = editor.clone();
let thread = thread.clone(); let action_log = action_log.clone();
move |_event, window, cx| { move |_event, window, cx| {
editor.update(cx, |editor, cx| { editor.update(cx, |editor, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx); let snapshot = editor.buffer().read(cx).snapshot(cx);
reject_edits_in_ranges( reject_edits_in_ranges(
editor, editor,
&snapshot, &snapshot,
&thread, &action_log,
vec![hunk_range.start..hunk_range.start], vec![hunk_range.start..hunk_range.start],
window, window,
cx, cx,
@@ -733,14 +736,14 @@ fn render_diff_hunk_controls(
) )
.on_click({ .on_click({
let editor = editor.clone(); let editor = editor.clone();
let thread = thread.clone(); let action_log = action_log.clone();
move |_event, window, cx| { move |_event, window, cx| {
editor.update(cx, |editor, cx| { editor.update(cx, |editor, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx); let snapshot = editor.buffer().read(cx).snapshot(cx);
keep_edits_in_ranges( keep_edits_in_ranges(
editor, editor,
&snapshot, &snapshot,
&thread, &action_log,
vec![hunk_range.start..hunk_range.start], vec![hunk_range.start..hunk_range.start],
window, window,
cx, cx,
@@ -1114,7 +1117,7 @@ impl Render for AgentDiffToolbar {
let has_pending_edit_tool_use = agent_diff let has_pending_edit_tool_use = agent_diff
.read(cx) .read(cx)
.thread .agent
.read(cx) .read(cx)
.has_pending_edit_tool_uses(); .has_pending_edit_tool_uses();
@@ -1187,7 +1190,7 @@ pub enum EditorState {
} }
struct WorkspaceThread { struct WorkspaceThread {
thread: WeakEntity<Thread>, agent: WeakEntity<ZedAgentThread>,
_thread_subscriptions: [Subscription; 2], _thread_subscriptions: [Subscription; 2],
singleton_editors: HashMap<WeakEntity<Buffer>, HashMap<WeakEntity<Editor>, Subscription>>, singleton_editors: HashMap<WeakEntity<Buffer>, HashMap<WeakEntity<Editor>, Subscription>>,
_settings_subscription: Subscription, _settings_subscription: Subscription,
@@ -1212,7 +1215,7 @@ impl AgentDiff {
pub fn set_active_thread( pub fn set_active_thread(
workspace: &WeakEntity<Workspace>, workspace: &WeakEntity<Workspace>,
thread: &Entity<Thread>, thread: &Entity<ZedAgentThread>,
window: &mut Window, window: &mut Window,
cx: &mut App, cx: &mut App,
) { ) {
@@ -1224,11 +1227,11 @@ impl AgentDiff {
fn register_active_thread_impl( fn register_active_thread_impl(
&mut self, &mut self,
workspace: &WeakEntity<Workspace>, workspace: &WeakEntity<Workspace>,
thread: &Entity<Thread>, agent: &Entity<ZedAgentThread>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
let action_log = thread.read(cx).action_log().clone(); let action_log = agent.read(cx).action_log().clone();
let action_log_subscription = cx.observe_in(&action_log, window, { let action_log_subscription = cx.observe_in(&action_log, window, {
let workspace = workspace.clone(); let workspace = workspace.clone();
@@ -1237,7 +1240,7 @@ impl AgentDiff {
} }
}); });
let thread_subscription = cx.subscribe_in(&thread, window, { let thread_subscription = cx.subscribe_in(&agent, window, {
let workspace = workspace.clone(); let workspace = workspace.clone();
move |this, _thread, event, window, cx| { move |this, _thread, event, window, cx| {
this.handle_thread_event(&workspace, event, window, cx) this.handle_thread_event(&workspace, event, window, cx)
@@ -1246,7 +1249,7 @@ impl AgentDiff {
if let Some(workspace_thread) = self.workspace_threads.get_mut(&workspace) { if let Some(workspace_thread) = self.workspace_threads.get_mut(&workspace) {
// replace thread and action log subscription, but keep editors // replace thread and action log subscription, but keep editors
workspace_thread.thread = thread.downgrade(); workspace_thread.agent = agent.downgrade();
workspace_thread._thread_subscriptions = [action_log_subscription, thread_subscription]; workspace_thread._thread_subscriptions = [action_log_subscription, thread_subscription];
self.update_reviewing_editors(&workspace, window, cx); self.update_reviewing_editors(&workspace, window, cx);
return; return;
@@ -1271,7 +1274,7 @@ impl AgentDiff {
self.workspace_threads.insert( self.workspace_threads.insert(
workspace.clone(), workspace.clone(),
WorkspaceThread { WorkspaceThread {
thread: thread.downgrade(), agent: agent.downgrade(),
_thread_subscriptions: [action_log_subscription, thread_subscription], _thread_subscriptions: [action_log_subscription, thread_subscription],
singleton_editors: HashMap::default(), singleton_editors: HashMap::default(),
_settings_subscription: settings_subscription, _settings_subscription: settings_subscription,
@@ -1319,7 +1322,7 @@ impl AgentDiff {
fn register_review_action<T: Action>( fn register_review_action<T: Action>(
workspace: &mut Workspace, workspace: &mut Workspace,
review: impl Fn(&Entity<Editor>, &Entity<Thread>, &mut Window, &mut App) -> PostReviewState review: impl Fn(&Entity<Editor>, &Entity<ZedAgentThread>, &mut Window, &mut App) -> PostReviewState
+ 'static, + 'static,
this: &Entity<AgentDiff>, this: &Entity<AgentDiff>,
) { ) {
@@ -1362,6 +1365,7 @@ impl AgentDiff {
| ThreadEvent::StreamedAssistantText(_, _) | ThreadEvent::StreamedAssistantText(_, _)
| ThreadEvent::StreamedAssistantThinking(_, _) | ThreadEvent::StreamedAssistantThinking(_, _)
| ThreadEvent::StreamedToolUse { .. } | ThreadEvent::StreamedToolUse { .. }
| ThreadEvent::StreamedToolUse2 { .. }
| ThreadEvent::InvalidToolInput { .. } | ThreadEvent::InvalidToolInput { .. }
| ThreadEvent::MissingToolUse { .. } | ThreadEvent::MissingToolUse { .. }
| ThreadEvent::MessageAdded(_) | ThreadEvent::MessageAdded(_)
@@ -1481,11 +1485,11 @@ impl AgentDiff {
return; return;
}; };
let Some(thread) = workspace_thread.thread.upgrade() else { let Some(agent) = workspace_thread.agent.upgrade() else {
return; return;
}; };
let action_log = thread.read(cx).action_log(); let action_log = agent.read(cx).action_log();
let changed_buffers = action_log.read(cx).changed_buffers(cx); let changed_buffers = action_log.read(cx).changed_buffers(cx);
let mut unaffected = self.reviewing_editors.clone(); let mut unaffected = self.reviewing_editors.clone();
@@ -1510,7 +1514,7 @@ impl AgentDiff {
multibuffer.add_diff(diff_handle.clone(), cx); multibuffer.add_diff(diff_handle.clone(), cx);
}); });
let new_state = if thread.read(cx).is_generating() { let new_state = if agent.read(cx).is_generating() {
EditorState::Generating EditorState::Generating
} else { } else {
EditorState::Reviewing EditorState::Reviewing
@@ -1523,7 +1527,7 @@ impl AgentDiff {
if previous_state.is_none() { if previous_state.is_none() {
editor.update(cx, |editor, cx| { editor.update(cx, |editor, cx| {
editor.start_temporary_diff_override(); editor.start_temporary_diff_override();
editor.set_render_diff_hunk_controls(diff_hunk_controls(&thread), cx); editor.set_render_diff_hunk_controls(diff_hunk_controls(&action_log), cx);
editor.set_expand_all_diff_hunks(cx); editor.set_expand_all_diff_hunks(cx);
editor.register_addon(EditorAgentDiffAddon); editor.register_addon(EditorAgentDiffAddon);
}); });
@@ -1591,22 +1595,22 @@ impl AgentDiff {
return; return;
}; };
let Some(WorkspaceThread { thread, .. }) = let Some(WorkspaceThread { agent, .. }) =
self.workspace_threads.get(&workspace.downgrade()) self.workspace_threads.get(&workspace.downgrade())
else { else {
return; return;
}; };
let Some(thread) = thread.upgrade() else { let Some(agent) = agent.upgrade() else {
return; return;
}; };
AgentDiffPane::deploy(thread, workspace.downgrade(), window, cx).log_err(); AgentDiffPane::deploy(agent, workspace.downgrade(), window, cx).log_err();
} }
fn keep_all( fn keep_all(
editor: &Entity<Editor>, editor: &Entity<Editor>,
thread: &Entity<Thread>, agent: &Entity<ZedAgentThread>,
window: &mut Window, window: &mut Window,
cx: &mut App, cx: &mut App,
) -> PostReviewState { ) -> PostReviewState {
@@ -1615,7 +1619,7 @@ impl AgentDiff {
keep_edits_in_ranges( keep_edits_in_ranges(
editor, editor,
&snapshot, &snapshot,
thread, &agent.read(cx).action_log(),
vec![editor::Anchor::min()..editor::Anchor::max()], vec![editor::Anchor::min()..editor::Anchor::max()],
window, window,
cx, cx,
@@ -1626,7 +1630,7 @@ impl AgentDiff {
fn reject_all( fn reject_all(
editor: &Entity<Editor>, editor: &Entity<Editor>,
thread: &Entity<Thread>, thread: &Entity<ZedAgentThread>,
window: &mut Window, window: &mut Window,
cx: &mut App, cx: &mut App,
) -> PostReviewState { ) -> PostReviewState {
@@ -1635,7 +1639,7 @@ impl AgentDiff {
reject_edits_in_ranges( reject_edits_in_ranges(
editor, editor,
&snapshot, &snapshot,
thread, &thread.read(cx).action_log(),
vec![editor::Anchor::min()..editor::Anchor::max()], vec![editor::Anchor::min()..editor::Anchor::max()],
window, window,
cx, cx,
@@ -1646,26 +1650,26 @@ impl AgentDiff {
fn keep( fn keep(
editor: &Entity<Editor>, editor: &Entity<Editor>,
thread: &Entity<Thread>, agent: &Entity<ZedAgentThread>,
window: &mut Window, window: &mut Window,
cx: &mut App, cx: &mut App,
) -> PostReviewState { ) -> PostReviewState {
editor.update(cx, |editor, cx| { editor.update(cx, |editor, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx); let snapshot = editor.buffer().read(cx).snapshot(cx);
keep_edits_in_selection(editor, &snapshot, thread, window, cx); keep_edits_in_selection(editor, &snapshot, &agent.read(cx).action_log(), window, cx);
Self::post_review_state(&snapshot) Self::post_review_state(&snapshot)
}) })
} }
fn reject( fn reject(
editor: &Entity<Editor>, editor: &Entity<Editor>,
thread: &Entity<Thread>, agent: &Entity<ZedAgentThread>,
window: &mut Window, window: &mut Window,
cx: &mut App, cx: &mut App,
) -> PostReviewState { ) -> PostReviewState {
editor.update(cx, |editor, cx| { editor.update(cx, |editor, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx); let snapshot = editor.buffer().read(cx).snapshot(cx);
reject_edits_in_selection(editor, &snapshot, thread, window, cx); reject_edits_in_selection(editor, &snapshot, &agent.read(cx).action_log(), window, cx);
Self::post_review_state(&snapshot) Self::post_review_state(&snapshot)
}) })
} }
@@ -1682,7 +1686,7 @@ impl AgentDiff {
fn review_in_active_editor( fn review_in_active_editor(
&mut self, &mut self,
workspace: &mut Workspace, workspace: &mut Workspace,
review: impl Fn(&Entity<Editor>, &Entity<Thread>, &mut Window, &mut App) -> PostReviewState, review: impl Fn(&Entity<Editor>, &Entity<ZedAgentThread>, &mut Window, &mut App) -> PostReviewState,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Option<Task<Result<()>>> { ) -> Option<Task<Result<()>>> {
@@ -1696,14 +1700,13 @@ impl AgentDiff {
return None; return None;
} }
let WorkspaceThread { thread, .. } = let WorkspaceThread { agent, .. } = self.workspace_threads.get(&workspace.weak_handle())?;
self.workspace_threads.get(&workspace.weak_handle())?;
let thread = thread.upgrade()?; let agent = agent.upgrade()?;
if let PostReviewState::AllReviewed = review(&editor, &thread, window, cx) { if let PostReviewState::AllReviewed = review(&editor, &agent, window, cx) {
if let Some(curr_buffer) = editor.read(cx).buffer().read(cx).as_singleton() { if let Some(curr_buffer) = editor.read(cx).buffer().read(cx).as_singleton() {
let changed_buffers = thread.read(cx).action_log().read(cx).changed_buffers(cx); let changed_buffers = agent.read(cx).action_log().read(cx).changed_buffers(cx);
let mut keys = changed_buffers.keys().cycle(); let mut keys = changed_buffers.keys().cycle();
keys.find(|k| *k == &curr_buffer); keys.find(|k| *k == &curr_buffer);
@@ -1801,13 +1804,13 @@ mod tests {
}) })
.await .await
.unwrap(); .unwrap();
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); let agent = thread_store.update(cx, |store, cx| store.create_thread(cx));
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); let action_log = agent.read_with(cx, |agent, _| agent.action_log().clone());
let (workspace, cx) = let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let agent_diff = cx.new_window_entity(|window, cx| { let agent_diff = cx.new_window_entity(|window, cx| {
AgentDiffPane::new(thread.clone(), workspace.downgrade(), window, cx) AgentDiffPane::new(agent.clone(), workspace.downgrade(), window, cx)
}); });
let editor = agent_diff.read_with(cx, |diff, _cx| diff.editor.clone()); let editor = agent_diff.read_with(cx, |diff, _cx| diff.editor.clone());
@@ -1895,7 +1898,7 @@ mod tests {
keep_edits_in_ranges( keep_edits_in_ranges(
editor, editor,
&snapshot, &snapshot,
&thread, &agent.read(cx).action_log(),
vec![position..position], vec![position..position],
window, window,
cx, cx,
@@ -1966,8 +1969,8 @@ mod tests {
}) })
.await .await
.unwrap(); .unwrap();
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); let agent = thread_store.update(cx, |store, cx| store.create_thread(cx));
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); let action_log = agent.read_with(cx, |agent, _| agent.action_log().clone());
let (workspace, cx) = let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
@@ -1989,7 +1992,7 @@ mod tests {
// Set the active thread // Set the active thread
cx.update(|window, cx| { cx.update(|window, cx| {
AgentDiff::set_active_thread(&workspace.downgrade(), &thread, window, cx) AgentDiff::set_active_thread(&workspace.downgrade(), &agent, window, cx)
}); });
let buffer1 = project let buffer1 = project
@@ -2146,7 +2149,7 @@ mod tests {
keep_edits_in_ranges( keep_edits_in_ranges(
editor, editor,
&snapshot, &snapshot,
&thread, &agent.read(cx).action_log(),
vec![position..position], vec![position..position],
window, window,
cx, cx,

View File

@@ -26,7 +26,7 @@ use crate::{
ui::AgentOnboardingModal, ui::AgentOnboardingModal,
}; };
use agent::{ use agent::{
Thread, ThreadError, ThreadEvent, ThreadId, ThreadSummary, TokenUsageRatio, ThreadError, ThreadEvent, ThreadId, ThreadSummary, TokenUsageRatio, ZedAgentThread,
context_store::ContextStore, context_store::ContextStore,
history_store::{HistoryEntryId, HistoryStore}, history_store::{HistoryEntryId, HistoryStore},
thread_store::{TextThreadStore, ThreadStore}, thread_store::{TextThreadStore, ThreadStore},
@@ -72,7 +72,7 @@ use zed_actions::{
agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding}, agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding},
assistant::{OpenRulesLibrary, ToggleFocus}, assistant::{OpenRulesLibrary, ToggleFocus},
}; };
use zed_llm_client::{CompletionIntent, UsageLimit}; use zed_llm_client::UsageLimit;
const AGENT_PANEL_KEY: &str = "agent_panel"; const AGENT_PANEL_KEY: &str = "agent_panel";
@@ -122,8 +122,8 @@ pub fn init(cx: &mut App) {
workspace.focus_panel::<AgentPanel>(window, cx); workspace.focus_panel::<AgentPanel>(window, cx);
match &panel.read(cx).active_view { match &panel.read(cx).active_view {
ActiveView::Thread { thread, .. } => { ActiveView::Thread { thread, .. } => {
let thread = thread.read(cx).thread().clone(); let agent = thread.read(cx).agent().clone();
AgentDiffPane::deploy_in_workspace(thread, workspace, window, cx); AgentDiffPane::deploy_in_workspace(agent, workspace, window, cx);
} }
ActiveView::TextThread { .. } ActiveView::TextThread { .. }
| ActiveView::History | ActiveView::History
@@ -251,9 +251,9 @@ impl ActiveView {
let new_summary = editor.read(cx).text(cx); let new_summary = editor.read(cx).text(cx);
thread.update(cx, |thread, cx| { thread.update(cx, |thread, cx| {
thread.thread().update(cx, |thread, cx| { thread.agent().update(cx, |agent, cx| {
thread.set_summary(new_summary, cx); agent.set_summary(new_summary, cx);
}); })
}) })
} }
EditorEvent::Blurred => { EditorEvent::Blurred => {
@@ -274,11 +274,11 @@ impl ActiveView {
cx.notify(); cx.notify();
} }
}), }),
cx.subscribe_in(&active_thread.read(cx).thread().clone(), window, { cx.subscribe_in(&active_thread.read(cx).agent().clone(), window, {
let editor = editor.clone(); let editor = editor.clone();
move |_, thread, event, window, cx| match event { move |_, agent, event, window, cx| match event {
ThreadEvent::SummaryGenerated => { ThreadEvent::SummaryGenerated => {
let summary = thread.read(cx).summary().or_default(); let summary = agent.read(cx).summary().or_default();
editor.update(cx, |editor, cx| { editor.update(cx, |editor, cx| {
editor.set_text(summary, window, cx); editor.set_text(summary, window, cx);
@@ -524,7 +524,7 @@ impl AgentPanel {
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let thread = thread_store.update(cx, |this, cx| this.create_thread(cx)); let agent = thread_store.update(cx, |this, cx| this.create_thread(cx));
let fs = workspace.app_state().fs.clone(); let fs = workspace.app_state().fs.clone();
let user_store = workspace.app_state().user_store.clone(); let user_store = workspace.app_state().user_store.clone();
let project = workspace.project(); let project = workspace.project();
@@ -546,13 +546,13 @@ impl AgentPanel {
prompt_store.clone(), prompt_store.clone(),
thread_store.downgrade(), thread_store.downgrade(),
context_store.downgrade(), context_store.downgrade(),
thread.clone(), agent.clone(),
window, window,
cx, cx,
) )
}); });
let thread_id = thread.read(cx).id().clone(); let thread_id = agent.read(cx).id().clone();
let history_store = cx.new(|cx| { let history_store = cx.new(|cx| {
HistoryStore::new( HistoryStore::new(
thread_store.clone(), thread_store.clone(),
@@ -566,7 +566,7 @@ impl AgentPanel {
let active_thread = cx.new(|cx| { let active_thread = cx.new(|cx| {
ActiveThread::new( ActiveThread::new(
thread.clone(), agent.clone(),
thread_store.clone(), thread_store.clone(),
context_store.clone(), context_store.clone(),
message_editor_context_store.clone(), message_editor_context_store.clone(),
@@ -607,7 +607,7 @@ impl AgentPanel {
} }
}; };
AgentDiff::set_active_thread(&workspace, &thread, window, cx); AgentDiff::set_active_thread(&workspace, &agent, window, cx);
let weak_panel = weak_self.clone(); let weak_panel = weak_self.clone();
@@ -649,9 +649,9 @@ impl AgentPanel {
ActiveView::Thread { thread, .. } => { ActiveView::Thread { thread, .. } => {
thread thread
.read(cx) .read(cx)
.thread() .agent()
.clone() .clone()
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx)); .update(cx, |agent, cx| agent.get_or_init_configured_model(cx));
} }
ActiveView::TextThread { .. } ActiveView::TextThread { .. }
| ActiveView::History | ActiveView::History
@@ -753,7 +753,7 @@ impl AgentPanel {
None None
}; };
let thread = self let agent = self
.thread_store .thread_store
.update(cx, |this, cx| this.create_thread(cx)); .update(cx, |this, cx| this.create_thread(cx));
@@ -786,7 +786,7 @@ impl AgentPanel {
let active_thread = cx.new(|cx| { let active_thread = cx.new(|cx| {
ActiveThread::new( ActiveThread::new(
thread.clone(), agent.clone(),
self.thread_store.clone(), self.thread_store.clone(),
self.context_store.clone(), self.context_store.clone(),
context_store.clone(), context_store.clone(),
@@ -806,7 +806,7 @@ impl AgentPanel {
self.prompt_store.clone(), self.prompt_store.clone(),
self.thread_store.downgrade(), self.thread_store.downgrade(),
self.context_store.downgrade(), self.context_store.downgrade(),
thread.clone(), agent.clone(),
window, window,
cx, cx,
) )
@@ -823,7 +823,7 @@ impl AgentPanel {
let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx); let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx);
self.set_active_view(thread_view, window, cx); self.set_active_view(thread_view, window, cx);
AgentDiff::set_active_thread(&self.workspace, &thread, window, cx); AgentDiff::set_active_thread(&self.workspace, &agent, window, cx);
} }
fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context<Self>) { fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -971,7 +971,7 @@ impl AgentPanel {
pub(crate) fn open_thread( pub(crate) fn open_thread(
&mut self, &mut self,
thread: Entity<Thread>, agent: Entity<ZedAgentThread>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
@@ -984,7 +984,7 @@ impl AgentPanel {
let active_thread = cx.new(|cx| { let active_thread = cx.new(|cx| {
ActiveThread::new( ActiveThread::new(
thread.clone(), agent.clone(),
self.thread_store.clone(), self.thread_store.clone(),
self.context_store.clone(), self.context_store.clone(),
context_store.clone(), context_store.clone(),
@@ -1003,7 +1003,7 @@ impl AgentPanel {
self.prompt_store.clone(), self.prompt_store.clone(),
self.thread_store.downgrade(), self.thread_store.downgrade(),
self.context_store.downgrade(), self.context_store.downgrade(),
thread.clone(), agent.clone(),
window, window,
cx, cx,
) )
@@ -1012,7 +1012,7 @@ impl AgentPanel {
let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx); let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx);
self.set_active_view(thread_view, window, cx); self.set_active_view(thread_view, window, cx);
AgentDiff::set_active_thread(&self.workspace, &thread, window, cx); AgentDiff::set_active_thread(&self.workspace, &agent, window, cx);
} }
pub fn go_back(&mut self, _: &workspace::GoBack, window: &mut Window, cx: &mut Context<Self>) { pub fn go_back(&mut self, _: &workspace::GoBack, window: &mut Window, cx: &mut Context<Self>) {
@@ -1137,10 +1137,10 @@ impl AgentPanel {
) { ) {
match &self.active_view { match &self.active_view {
ActiveView::Thread { thread, .. } => { ActiveView::Thread { thread, .. } => {
let thread = thread.read(cx).thread().clone(); let agent = thread.read(cx).agent().clone();
self.workspace self.workspace
.update(cx, |workspace, cx| { .update(cx, |workspace, cx| {
AgentDiffPane::deploy_in_workspace(thread, workspace, window, cx) AgentDiffPane::deploy_in_workspace(agent, workspace, window, cx)
}) })
.log_err(); .log_err();
} }
@@ -1190,7 +1190,7 @@ impl AgentPanel {
match &self.active_view { match &self.active_view {
ActiveView::Thread { thread, .. } => { ActiveView::Thread { thread, .. } => {
active_thread::open_active_thread_as_markdown( active_thread::open_active_thread_as_markdown(
thread.read(cx).thread().clone(), thread.read(cx).agent().clone(),
workspace, workspace,
window, window,
cx, cx,
@@ -1228,9 +1228,9 @@ impl AgentPanel {
} }
} }
pub(crate) fn active_thread(&self, cx: &App) -> Option<Entity<Thread>> { pub(crate) fn active_thread(&self, cx: &App) -> Option<Entity<ZedAgentThread>> {
match &self.active_view { match &self.active_view {
ActiveView::Thread { thread, .. } => Some(thread.read(cx).thread().clone()), ActiveView::Thread { thread, .. } => Some(thread.read(cx).agent().clone()),
_ => None, _ => None,
} }
} }
@@ -1249,23 +1249,16 @@ impl AgentPanel {
return; return;
}; };
let thread_state = thread.read(cx).thread().read(cx); let agent_state = thread.read(cx).agent().read(cx);
if !thread_state.tool_use_limit_reached() { if !agent_state.tool_use_limit_reached() {
return; return;
} }
let model = thread_state.configured_model().map(|cm| cm.model.clone()); let model = agent_state.configured_model().map(|cm| cm.model.clone());
if let Some(model) = model { if let Some(model) = model {
thread.update(cx, |active_thread, cx| { thread.update(cx, |active_thread, cx| {
active_thread.thread().update(cx, |thread, cx| { active_thread.agent().update(cx, |agent, cx| {
thread.insert_invisible_continue_message(cx); agent.send_continue_message(model, Some(window.window_handle()), cx);
thread.advance_prompt_id();
thread.send_to_model(
model,
CompletionIntent::UserPrompt,
Some(window.window_handle()),
cx,
);
}); });
}); });
} else { } else {
@@ -1284,10 +1277,10 @@ impl AgentPanel {
}; };
thread.update(cx, |active_thread, cx| { thread.update(cx, |active_thread, cx| {
active_thread.thread().update(cx, |thread, _cx| { active_thread.agent().update(cx, |agent, _cx| {
let current_mode = thread.completion_mode(); let current_mode = agent.completion_mode();
thread.set_completion_mode(match current_mode { agent.set_completion_mode(match current_mode {
CompletionMode::Burn => CompletionMode::Normal, CompletionMode::Burn => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Burn, CompletionMode::Normal => CompletionMode::Burn,
}); });
@@ -1330,7 +1323,7 @@ impl AgentPanel {
ActiveView::Thread { thread, .. } => { ActiveView::Thread { thread, .. } => {
let thread = thread.read(cx); let thread = thread.read(cx);
if thread.is_empty() { if thread.is_empty() {
let id = thread.thread().read(cx).id().clone(); let id = thread.agent().read(cx).id().clone();
self.history_store.update(cx, |store, cx| { self.history_store.update(cx, |store, cx| {
store.remove_recently_opened_thread(id, cx); store.remove_recently_opened_thread(id, cx);
}); });
@@ -1341,7 +1334,7 @@ impl AgentPanel {
match &new_view { match &new_view {
ActiveView::Thread { thread, .. } => self.history_store.update(cx, |store, cx| { ActiveView::Thread { thread, .. } => self.history_store.update(cx, |store, cx| {
let id = thread.read(cx).thread().read(cx).id().clone(); let id = thread.read(cx).agent().read(cx).id().clone();
store.push_recently_opened_entry(HistoryEntryId::Thread(id), cx); store.push_recently_opened_entry(HistoryEntryId::Thread(id), cx);
}), }),
ActiveView::TextThread { context_editor, .. } => { ActiveView::TextThread { context_editor, .. } => {
@@ -1726,7 +1719,7 @@ impl AgentPanel {
}; };
let active_thread = match &self.active_view { let active_thread = match &self.active_view {
ActiveView::Thread { thread, .. } => Some(thread.read(cx).thread().clone()), ActiveView::Thread { thread, .. } => Some(thread.clone()),
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => None, ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => None,
}; };
@@ -1761,7 +1754,7 @@ impl AgentPanel {
this.action( this.action(
"New From Summary", "New From Summary",
Box::new(NewThread { Box::new(NewThread {
from_thread_id: Some(thread.id().clone()), from_thread_id: Some(thread.agent().read(cx).id().clone()),
}), }),
) )
} else { } else {
@@ -1904,14 +1897,14 @@ impl AgentPanel {
return None; return None;
} }
let thread = active_thread.thread().read(cx); let agent = active_thread.agent().read(cx);
let is_generating = thread.is_generating(); let is_generating = agent.is_generating();
let conversation_token_usage = thread.total_token_usage()?; let conversation_token_usage = agent.total_token_usage(cx)?;
let (total_token_usage, is_estimating) = let (total_token_usage, is_estimating) =
if let Some((editing_message_id, unsent_tokens)) = active_thread.editing_message_id() { if let Some((editing_message_id, unsent_tokens)) = active_thread.editing_message_id() {
let combined = thread let combined = agent
.token_usage_up_to_message(editing_message_id) .token_usage_up_to_message(editing_message_id, cx)
.add(unsent_tokens); .add(unsent_tokens);
(combined, unsent_tokens > 0) (combined, unsent_tokens > 0)
@@ -2022,7 +2015,7 @@ impl AgentPanel {
ActiveView::Thread { thread, .. } => { ActiveView::Thread { thread, .. } => {
let is_using_zed_provider = thread let is_using_zed_provider = thread
.read(cx) .read(cx)
.thread() .agent()
.read(cx) .read(cx)
.configured_model() .configured_model()
.map_or(false, |model| { .map_or(false, |model| {
@@ -2622,14 +2615,14 @@ impl AgentPanel {
} }
}; };
let thread = active_thread.read(cx).thread().read(cx); let agent = active_thread.read(cx).agent().read(cx);
let tool_use_limit_reached = thread.tool_use_limit_reached(); let tool_use_limit_reached = agent.tool_use_limit_reached();
if !tool_use_limit_reached { if !tool_use_limit_reached {
return None; return None;
} }
let model = thread.configured_model()?.model; let model = agent.configured_model()?.model;
let focus_handle = self.focus_handle(cx); let focus_handle = self.focus_handle(cx);
@@ -2677,8 +2670,8 @@ impl AgentPanel {
let active_thread = active_thread.clone(); let active_thread = active_thread.clone();
cx.listener(move |this, _, window, cx| { cx.listener(move |this, _, window, cx| {
active_thread.update(cx, |active_thread, cx| { active_thread.update(cx, |active_thread, cx| {
active_thread.thread().update(cx, |thread, _cx| { active_thread.agent().update(cx, |agent, _cx| {
thread.set_completion_mode(CompletionMode::Burn); agent.set_completion_mode(CompletionMode::Burn);
}); });
}); });
this.continue_conversation(window, cx); this.continue_conversation(window, cx);
@@ -3062,8 +3055,8 @@ impl Render for AgentPanel {
match &this.active_view { match &this.active_view {
ActiveView::Thread { thread, .. } => { ActiveView::Thread { thread, .. } => {
thread.update(cx, |active_thread, cx| { thread.update(cx, |active_thread, cx| {
active_thread.thread().update(cx, |thread, _cx| { active_thread.agent().update(cx, |agent, _cx| {
thread.set_completion_mode(CompletionMode::Burn); agent.set_completion_mode(CompletionMode::Burn);
}); });
}); });
this.continue_conversation(window, cx); this.continue_conversation(window, cx);

View File

@@ -26,7 +26,7 @@ mod ui;
use std::sync::Arc; use std::sync::Arc;
use agent::{Thread, ThreadId}; use agent::{ThreadId, ZedAgentThread};
use agent_settings::{AgentProfileId, AgentSettings, LanguageModelSelection}; use agent_settings::{AgentProfileId, AgentSettings, LanguageModelSelection};
use assistant_slash_command::SlashCommandRegistry; use assistant_slash_command::SlashCommandRegistry;
use client::Client; use client::Client;
@@ -114,7 +114,7 @@ impl ManageProfiles {
#[derive(Clone)] #[derive(Clone)]
pub(crate) enum ModelUsageContext { pub(crate) enum ModelUsageContext {
Thread(Entity<Thread>), Thread(Entity<ZedAgentThread>),
InlineAssistant, InlineAssistant,
} }

View File

@@ -22,7 +22,7 @@ use util::ResultExt as _;
use workspace::Workspace; use workspace::Workspace;
use agent::{ use agent::{
Thread, ZedAgentThread,
context::{AgentContextHandle, AgentContextKey, RULES_ICON}, context::{AgentContextHandle, AgentContextKey, RULES_ICON},
thread_store::{TextThreadStore, ThreadStore}, thread_store::{TextThreadStore, ThreadStore},
}; };
@@ -449,7 +449,7 @@ impl ContextPickerCompletionProvider {
let context_store = context_store.clone(); let context_store = context_store.clone();
let thread_store = thread_store.clone(); let thread_store = thread_store.clone();
window.spawn::<_, Option<_>>(cx, async move |cx| { window.spawn::<_, Option<_>>(cx, async move |cx| {
let thread: Entity<Thread> = thread_store let thread: Entity<ZedAgentThread> = thread_store
.update_in(cx, |thread_store, window, cx| { .update_in(cx, |thread_store, window, cx| {
thread_store.open_thread(&thread_id, window, cx) thread_store.open_thread(&thread_id, window, cx)
}) })

View File

@@ -9,6 +9,7 @@ use crate::ui::{
MaxModeTooltip, MaxModeTooltip,
preview::{AgentPreview, UsageCallout}, preview::{AgentPreview, UsageCallout},
}; };
use agent::thread::UserMessageParams;
use agent::{ use agent::{
context::{AgentContextKey, ContextLoadResult, load_context}, context::{AgentContextKey, ContextLoadResult, load_context},
context_store::ContextStoreEvent, context_store::ContextStoreEvent,
@@ -31,7 +32,7 @@ use gpui::{
Animation, AnimationExt, App, Entity, EventEmitter, Focusable, Subscription, Task, TextStyle, Animation, AnimationExt, App, Entity, EventEmitter, Focusable, Subscription, Task, TextStyle,
WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
}; };
use language::{Buffer, Language, Point}; use language::{Buffer, Language};
use language_model::{ use language_model::{
ConfiguredModel, LanguageModelRequestMessage, MessageContent, ZED_CLOUD_PROVIDER_ID, ConfiguredModel, LanguageModelRequestMessage, MessageContent, ZED_CLOUD_PROVIDER_ID,
}; };
@@ -47,7 +48,6 @@ use ui::{
}; };
use util::ResultExt as _; use util::ResultExt as _;
use workspace::{CollaboratorId, Workspace}; use workspace::{CollaboratorId, Workspace};
use zed_llm_client::CompletionIntent;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention}; use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention};
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
@@ -58,14 +58,14 @@ use crate::{
ToggleContextPicker, ToggleProfileSelector, register_agent_preview, ToggleContextPicker, ToggleProfileSelector, register_agent_preview,
}; };
use agent::{ use agent::{
MessageCrease, Thread, TokenUsageRatio, MessageCrease, TokenUsageRatio, ZedAgentThread,
context_store::ContextStore, context_store::ContextStore,
thread_store::{TextThreadStore, ThreadStore}, thread_store::{TextThreadStore, ThreadStore},
}; };
#[derive(RegisterComponent)] #[derive(RegisterComponent)]
pub struct MessageEditor { pub struct MessageEditor {
thread: Entity<Thread>, agent: Entity<ZedAgentThread>,
incompatible_tools_state: Entity<IncompatibleToolsState>, incompatible_tools_state: Entity<IncompatibleToolsState>,
editor: Entity<Editor>, editor: Entity<Editor>,
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
@@ -156,7 +156,7 @@ impl MessageEditor {
prompt_store: Option<Entity<PromptStore>>, prompt_store: Option<Entity<PromptStore>>,
thread_store: WeakEntity<ThreadStore>, thread_store: WeakEntity<ThreadStore>,
text_thread_store: WeakEntity<TextThreadStore>, text_thread_store: WeakEntity<TextThreadStore>,
thread: Entity<Thread>, agent: Entity<ZedAgentThread>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
@@ -182,13 +182,13 @@ impl MessageEditor {
Some(text_thread_store.clone()), Some(text_thread_store.clone()),
context_picker_menu_handle.clone(), context_picker_menu_handle.clone(),
SuggestContextKind::File, SuggestContextKind::File,
ModelUsageContext::Thread(thread.clone()), ModelUsageContext::Thread(agent.clone()),
window, window,
cx, cx,
) )
}); });
let incompatible_tools = cx.new(|cx| IncompatibleToolsState::new(thread.clone(), cx)); let incompatible_tools = cx.new(|cx| IncompatibleToolsState::new(agent.clone(), cx));
let subscriptions = vec![ let subscriptions = vec![
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event), cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event),
@@ -200,9 +200,7 @@ impl MessageEditor {
// When context changes, reload it for token counting. // When context changes, reload it for token counting.
let _ = this.reload_context(cx); let _ = this.reload_context(cx);
}), }),
cx.observe(&thread.read(cx).action_log().clone(), |_, _, cx| { cx.observe(&agent.read(cx).action_log().clone(), |_, _, cx| cx.notify()),
cx.notify()
}),
]; ];
let model_selector = cx.new(|cx| { let model_selector = cx.new(|cx| {
@@ -210,20 +208,20 @@ impl MessageEditor {
fs.clone(), fs.clone(),
model_selector_menu_handle, model_selector_menu_handle,
editor.focus_handle(cx), editor.focus_handle(cx),
ModelUsageContext::Thread(thread.clone()), ModelUsageContext::Thread(agent.clone()),
window, window,
cx, cx,
) )
}); });
let profile_selector = let profile_selector =
cx.new(|cx| ProfileSelector::new(fs, thread.clone(), editor.focus_handle(cx), cx)); cx.new(|cx| ProfileSelector::new(fs, agent.clone(), editor.focus_handle(cx), cx));
Self { Self {
editor: editor.clone(), editor: editor.clone(),
project: thread.read(cx).project().clone(), project: agent.read(cx).project().clone(),
user_store, user_store,
thread, agent,
incompatible_tools_state: incompatible_tools.clone(), incompatible_tools_state: incompatible_tools.clone(),
workspace, workspace,
context_store, context_store,
@@ -313,11 +311,11 @@ impl MessageEditor {
return; return;
} }
self.thread.update(cx, |thread, cx| { self.agent.update(cx, |thread, cx| {
thread.cancel_editing(cx); thread.cancel_editing(cx);
}); });
if self.thread.read(cx).is_generating() { if self.agent.read(cx).is_generating() {
self.stop_current_and_send_new_message(window, cx); self.stop_current_and_send_new_message(window, cx);
return; return;
} }
@@ -354,7 +352,7 @@ impl MessageEditor {
fn send_to_model(&mut self, window: &mut Window, cx: &mut Context<Self>) { fn send_to_model(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let Some(ConfiguredModel { model, provider }) = self let Some(ConfiguredModel { model, provider }) = self
.thread .agent
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx)) .update(cx, |thread, cx| thread.get_or_init_configured_model(cx))
else { else {
return; return;
@@ -375,7 +373,7 @@ impl MessageEditor {
self.last_estimated_token_count.take(); self.last_estimated_token_count.take();
cx.emit(MessageEditorEvent::EstimatedTokenCount); cx.emit(MessageEditorEvent::EstimatedTokenCount);
let thread = self.thread.clone(); let agent = self.agent.clone();
let git_store = self.project.read(cx).git_store().clone(); let git_store = self.project.read(cx).git_store().clone();
let checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx)); let checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
let context_task = self.reload_context(cx); let context_task = self.reload_context(cx);
@@ -385,24 +383,16 @@ impl MessageEditor {
let (checkpoint, loaded_context) = future::join(checkpoint, context_task).await; let (checkpoint, loaded_context) = future::join(checkpoint, context_task).await;
let loaded_context = loaded_context.unwrap_or_default(); let loaded_context = loaded_context.unwrap_or_default();
thread agent
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.insert_user_message( thread.send_message(
user_message, UserMessageParams {
loaded_context, text: user_message,
checkpoint.ok(), creases: user_message_creases,
user_message_creases, checkpoint: checkpoint.ok(),
cx, context: loaded_context,
); },
})
.log_err();
thread
.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.send_to_model(
model, model,
CompletionIntent::UserPrompt,
Some(window_handle), Some(window_handle),
cx, cx,
); );
@@ -413,11 +403,11 @@ impl MessageEditor {
} }
fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context<Self>) { fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.thread.update(cx, |thread, cx| { self.agent.update(cx, |thread, cx| {
thread.cancel_editing(cx); thread.cancel_editing(cx);
}); });
let cancelled = self.thread.update(cx, |thread, cx| { let cancelled = self.agent.update(cx, |thread, cx| {
thread.cancel_last_completion(Some(window.window_handle()), cx) thread.cancel_last_completion(Some(window.window_handle()), cx)
}); });
@@ -459,7 +449,7 @@ impl MessageEditor {
fn handle_review_click(&mut self, window: &mut Window, cx: &mut Context<Self>) { fn handle_review_click(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.edits_expanded = true; self.edits_expanded = true;
AgentDiffPane::deploy(self.thread.clone(), self.workspace.clone(), window, cx).log_err(); AgentDiffPane::deploy(self.agent.clone(), self.workspace.clone(), window, cx).log_err();
cx.notify(); cx.notify();
} }
@@ -475,7 +465,7 @@ impl MessageEditor {
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
if let Ok(diff) = if let Ok(diff) =
AgentDiffPane::deploy(self.thread.clone(), self.workspace.clone(), window, cx) AgentDiffPane::deploy(self.agent.clone(), self.workspace.clone(), window, cx)
{ {
let path_key = multi_buffer::PathKey::for_buffer(&buffer, cx); let path_key = multi_buffer::PathKey::for_buffer(&buffer, cx);
diff.update(cx, |diff, cx| diff.move_to_path(path_key, window, cx)); diff.update(cx, |diff, cx| diff.move_to_path(path_key, window, cx));
@@ -488,7 +478,7 @@ impl MessageEditor {
_window: &mut Window, _window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
self.thread.update(cx, |thread, _cx| { self.agent.update(cx, |thread, _cx| {
let active_completion_mode = thread.completion_mode(); let active_completion_mode = thread.completion_mode();
thread.set_completion_mode(match active_completion_mode { thread.set_completion_mode(match active_completion_mode {
@@ -499,36 +489,22 @@ impl MessageEditor {
} }
fn handle_accept_all(&mut self, _window: &mut Window, cx: &mut Context<Self>) { fn handle_accept_all(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
if self.thread.read(cx).has_pending_edit_tool_uses() { if self.agent.read(cx).has_pending_edit_tool_uses() {
return; return;
} }
self.thread.update(cx, |thread, cx| { let action_log = self.agent.read(cx).action_log();
thread.keep_all_edits(cx); action_log.update(cx, |action_log, cx| action_log.keep_all_edits(cx));
});
cx.notify(); cx.notify();
} }
fn handle_reject_all(&mut self, _window: &mut Window, cx: &mut Context<Self>) { fn handle_reject_all(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
if self.thread.read(cx).has_pending_edit_tool_uses() { if self.agent.read(cx).has_pending_edit_tool_uses() {
return; return;
} }
// Since there's no reject_all_edits method in the thread API, let action_log = self.agent.read(cx).action_log();
// we need to iterate through all buffers and reject their edits action_log.update(cx, |action_log, cx| action_log.reject_all_edits(cx));
let action_log = self.thread.read(cx).action_log().clone();
let changed_buffers = action_log.read(cx).changed_buffers(cx);
for (buffer, _) in changed_buffers {
self.thread.update(cx, |thread, cx| {
let buffer_snapshot = buffer.read(cx);
let start = buffer_snapshot.anchor_before(Point::new(0, 0));
let end = buffer_snapshot.anchor_after(buffer_snapshot.max_point());
thread
.reject_edits_in_ranges(buffer, vec![start..end], cx)
.detach();
});
}
cx.notify(); cx.notify();
} }
@@ -538,17 +514,13 @@ impl MessageEditor {
_window: &mut Window, _window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
if self.thread.read(cx).has_pending_edit_tool_uses() { if self.agent.read(cx).has_pending_edit_tool_uses() {
return; return;
} }
self.thread.update(cx, |thread, cx| { let action_log = self.agent.read(cx).action_log();
let buffer_snapshot = buffer.read(cx); action_log.update(cx, |action_log, cx| {
let start = buffer_snapshot.anchor_before(Point::new(0, 0)); action_log.reject_buffer_edits(buffer, cx)
let end = buffer_snapshot.anchor_after(buffer_snapshot.max_point());
thread
.reject_edits_in_ranges(buffer, vec![start..end], cx)
.detach();
}); });
cx.notify(); cx.notify();
} }
@@ -559,21 +531,19 @@ impl MessageEditor {
_window: &mut Window, _window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
if self.thread.read(cx).has_pending_edit_tool_uses() { if self.agent.read(cx).has_pending_edit_tool_uses() {
return; return;
} }
self.thread.update(cx, |thread, cx| { let action_log = self.agent.read(cx).action_log();
let buffer_snapshot = buffer.read(cx); action_log.update(cx, |action_log, cx| {
let start = buffer_snapshot.anchor_before(Point::new(0, 0)); action_log.keep_buffer_edits(buffer, cx)
let end = buffer_snapshot.anchor_after(buffer_snapshot.max_point());
thread.keep_edits_in_range(buffer, start..end, cx);
}); });
cx.notify(); cx.notify();
} }
fn render_burn_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> { fn render_burn_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let thread = self.thread.read(cx); let thread = self.agent.read(cx);
let model = thread.configured_model(); let model = thread.configured_model();
if !model?.model.supports_burn_mode() { if !model?.model.supports_burn_mode() {
return None; return None;
@@ -644,7 +614,7 @@ impl MessageEditor {
} }
fn render_editor(&self, window: &mut Window, cx: &mut Context<Self>) -> Div { fn render_editor(&self, window: &mut Window, cx: &mut Context<Self>) -> Div {
let thread = self.thread.read(cx); let thread = self.agent.read(cx);
let model = thread.configured_model(); let model = thread.configured_model();
let editor_bg_color = cx.theme().colors().editor_background; let editor_bg_color = cx.theme().colors().editor_background;
@@ -945,7 +915,7 @@ impl MessageEditor {
let bg_edit_files_disclosure = editor_bg_color.blend(active_color.opacity(0.3)); let bg_edit_files_disclosure = editor_bg_color.blend(active_color.opacity(0.3));
let is_edit_changes_expanded = self.edits_expanded; let is_edit_changes_expanded = self.edits_expanded;
let thread = self.thread.read(cx); let thread = self.agent.read(cx);
let pending_edits = thread.has_pending_edit_tool_uses(); let pending_edits = thread.has_pending_edit_tool_uses();
const EDIT_NOT_READY_TOOLTIP_LABEL: &str = "Wait until file edits are complete."; const EDIT_NOT_READY_TOOLTIP_LABEL: &str = "Wait until file edits are complete.";
@@ -1247,7 +1217,7 @@ impl MessageEditor {
} }
fn is_using_zed_provider(&self, cx: &App) -> bool { fn is_using_zed_provider(&self, cx: &App) -> bool {
self.thread self.agent
.read(cx) .read(cx)
.configured_model() .configured_model()
.map_or(false, |model| { .map_or(false, |model| {
@@ -1325,7 +1295,7 @@ impl MessageEditor {
Button::new("start-new-thread", "Start New Thread") Button::new("start-new-thread", "Start New Thread")
.label_size(LabelSize::Small) .label_size(LabelSize::Small)
.on_click(cx.listener(|this, _, window, cx| { .on_click(cx.listener(|this, _, window, cx| {
let from_thread_id = Some(this.thread.read(cx).id().clone()); let from_thread_id = Some(this.agent.read(cx).id().clone());
window.dispatch_action(Box::new(NewThread { from_thread_id }), cx); window.dispatch_action(Box::new(NewThread { from_thread_id }), cx);
})), })),
); );
@@ -1359,10 +1329,11 @@ impl MessageEditor {
fn reload_context(&mut self, cx: &mut Context<Self>) -> Task<Option<ContextLoadResult>> { fn reload_context(&mut self, cx: &mut Context<Self>) -> Task<Option<ContextLoadResult>> {
let load_task = cx.spawn(async move |this, cx| { let load_task = cx.spawn(async move |this, cx| {
let Ok(load_task) = this.update(cx, |this, cx| { let Ok(load_task) = this.update(cx, |this, cx| {
let new_context = this let new_context = this.context_store.read(cx).new_context_for_thread(
.context_store this.agent.read(cx),
.read(cx) None,
.new_context_for_thread(this.thread.read(cx), None); cx,
);
load_context(new_context, &this.project, &this.prompt_store, cx) load_context(new_context, &this.project, &this.prompt_store, cx)
}) else { }) else {
return; return;
@@ -1394,7 +1365,7 @@ impl MessageEditor {
cx.emit(MessageEditorEvent::Changed); cx.emit(MessageEditorEvent::Changed);
self.update_token_count_task.take(); self.update_token_count_task.take();
let Some(model) = self.thread.read(cx).configured_model() else { let Some(model) = self.agent.read(cx).configured_model() else {
self.last_estimated_token_count.take(); self.last_estimated_token_count.take();
return; return;
}; };
@@ -1599,16 +1570,16 @@ impl Focusable for MessageEditor {
impl Render for MessageEditor { impl Render for MessageEditor {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let thread = self.thread.read(cx); let agent = self.agent.read(cx);
let token_usage_ratio = thread let token_usage_ratio = agent
.total_token_usage() .total_token_usage(cx)
.map_or(TokenUsageRatio::Normal, |total_token_usage| { .map_or(TokenUsageRatio::Normal, |total_token_usage| {
total_token_usage.ratio() total_token_usage.ratio()
}); });
let burn_mode_enabled = thread.completion_mode() == CompletionMode::Burn; let burn_mode_enabled = agent.completion_mode() == CompletionMode::Burn;
let action_log = self.thread.read(cx).action_log(); let action_log = agent.action_log();
let changed_buffers = action_log.read(cx).changed_buffers(cx); let changed_buffers = action_log.read(cx).changed_buffers(cx);
let line_height = TextSize::Small.rems(cx).to_pixels(window.rem_size()) * 1.5; let line_height = TextSize::Small.rems(cx).to_pixels(window.rem_size()) * 1.5;
@@ -1691,7 +1662,7 @@ impl AgentPreview for MessageEditor {
let weak_project = project.downgrade(); let weak_project = project.downgrade();
let context_store = cx.new(|_cx| ContextStore::new(weak_project, None)); let context_store = cx.new(|_cx| ContextStore::new(weak_project, None));
let active_thread = active_thread.read(cx); let active_thread = active_thread.read(cx);
let thread = active_thread.thread().clone(); let agent = active_thread.agent().clone();
let thread_store = active_thread.thread_store().clone(); let thread_store = active_thread.thread_store().clone();
let text_thread_store = active_thread.text_thread_store().clone(); let text_thread_store = active_thread.text_thread_store().clone();
@@ -1704,7 +1675,7 @@ impl AgentPreview for MessageEditor {
None, None,
thread_store.downgrade(), thread_store.downgrade(),
text_thread_store.downgrade(), text_thread_store.downgrade(),
thread, agent,
window, window,
cx, cx,
) )

View File

@@ -1,6 +1,6 @@
use crate::{ManageProfiles, ToggleProfileSelector}; use crate::{ManageProfiles, ToggleProfileSelector};
use agent::{ use agent::{
Thread, ZedAgentThread,
agent_profile::{AgentProfile, AvailableProfiles}, agent_profile::{AgentProfile, AvailableProfiles},
}; };
use agent_settings::{AgentDockPosition, AgentProfileId, AgentSettings, builtin_profiles}; use agent_settings::{AgentDockPosition, AgentProfileId, AgentSettings, builtin_profiles};
@@ -17,7 +17,7 @@ use ui::{
pub struct ProfileSelector { pub struct ProfileSelector {
profiles: AvailableProfiles, profiles: AvailableProfiles,
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
thread: Entity<Thread>, thread: Entity<ZedAgentThread>,
menu_handle: PopoverMenuHandle<ContextMenu>, menu_handle: PopoverMenuHandle<ContextMenu>,
focus_handle: FocusHandle, focus_handle: FocusHandle,
_subscriptions: Vec<Subscription>, _subscriptions: Vec<Subscription>,
@@ -26,7 +26,7 @@ pub struct ProfileSelector {
impl ProfileSelector { impl ProfileSelector {
pub fn new( pub fn new(
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
thread: Entity<Thread>, thread: Entity<ZedAgentThread>,
focus_handle: FocusHandle, focus_handle: FocusHandle,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {

View File

@@ -1,4 +1,4 @@
use agent::{Thread, ThreadEvent}; use agent::{ThreadEvent, ZedAgentThread};
use assistant_tool::{Tool, ToolSource}; use assistant_tool::{Tool, ToolSource};
use collections::HashMap; use collections::HashMap;
use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window}; use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window};
@@ -8,12 +8,12 @@ use ui::prelude::*;
pub struct IncompatibleToolsState { pub struct IncompatibleToolsState {
cache: HashMap<LanguageModelToolSchemaFormat, Vec<Arc<dyn Tool>>>, cache: HashMap<LanguageModelToolSchemaFormat, Vec<Arc<dyn Tool>>>,
thread: Entity<Thread>, thread: Entity<ZedAgentThread>,
_thread_subscription: Subscription, _thread_subscription: Subscription,
} }
impl IncompatibleToolsState { impl IncompatibleToolsState {
pub fn new(thread: Entity<Thread>, cx: &mut Context<Self>) -> Self { pub fn new(thread: Entity<ZedAgentThread>, cx: &mut Context<Self>) -> Self {
let _tool_working_set_subscription = let _tool_working_set_subscription =
cx.subscribe(&thread, |this, _, event, _| match event { cx.subscribe(&thread, |this, _, event, _| match event {
ThreadEvent::ProfileChanged => { ThreadEvent::ProfileChanged => {

View File

@@ -488,7 +488,7 @@ impl AddedContext {
parent: None, parent: None,
tooltip: None, tooltip: None,
icon_path: None, icon_path: None,
status: if handle.thread.read(cx).is_generating_detailed_summary() { status: if handle.agent.read(cx).is_generating_detailed_summary() {
ContextStatus::Loading { ContextStatus::Loading {
message: "Summarizing…".into(), message: "Summarizing…".into(),
} }
@@ -496,9 +496,9 @@ impl AddedContext {
ContextStatus::Ready ContextStatus::Ready
}, },
render_hover: { render_hover: {
let thread = handle.thread.clone(); let agent = handle.agent.clone();
Some(Rc::new(move |_, cx| { Some(Rc::new(move |_, cx| {
let text = thread.read(cx).latest_detailed_summary_or_text(); let text = agent.read(cx).latest_detailed_summary_or_text(cx);
ContextPillHover::new_text(text.clone(), cx).into() ContextPillHover::new_text(text.clone(), cx).into()
})) }))
}, },

View File

@@ -5,6 +5,9 @@ edition.workspace = true
publish.workspace = true publish.workspace = true
license = "GPL-3.0-or-later" license = "GPL-3.0-or-later"
[features]
test-support = []
[lints] [lints]
workspace = true workspace = true

View File

@@ -495,6 +495,10 @@ impl ActionLog {
cx.notify(); cx.notify();
} }
pub fn keep_buffer_edits(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.keep_edits_in_range(buffer, Anchor::MIN..Anchor::MAX, cx);
}
pub fn keep_edits_in_range( pub fn keep_edits_in_range(
&mut self, &mut self,
buffer: Entity<Buffer>, buffer: Entity<Buffer>,
@@ -555,6 +559,19 @@ impl ActionLog {
} }
} }
pub fn reject_all_edits(&mut self, cx: &mut Context<Self>) {
let changed_buffers = self.changed_buffers(cx);
for (buffer, _) in changed_buffers {
self.reject_edits_in_ranges(buffer, vec![Anchor::MIN..Anchor::MAX], cx)
.detach();
}
}
pub fn reject_buffer_edits(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.reject_edits_in_ranges(buffer, vec![Anchor::MIN..Anchor::MAX], cx)
.detach()
}
pub fn reject_edits_in_ranges( pub fn reject_edits_in_ranges(
&mut self, &mut self,
buffer: Entity<Buffer>, buffer: Entity<Buffer>,

View File

@@ -70,7 +70,7 @@ pub struct ToolResultOutput {
pub output: Option<serde_json::Value>, pub output: Option<serde_json::Value>,
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub enum ToolResultContent { pub enum ToolResultContent {
Text(String), Text(String),
Image(LanguageModelImage), Image(LanguageModelImage),
@@ -135,7 +135,8 @@ pub trait ToolCard: 'static + Sized {
) -> impl IntoElement; ) -> impl IntoElement;
} }
#[derive(Clone)] #[derive(Debug, Clone)]
#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq, Eq))]
pub struct AnyToolCard { pub struct AnyToolCard {
entity: gpui::AnyEntity, entity: gpui::AnyEntity,
render: fn( render: fn(

View File

@@ -10,7 +10,7 @@ use crate::{
ToolMetrics, ToolMetrics,
assertions::{AssertionsReport, RanAssertion, RanAssertionResult}, assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
}; };
use agent::{ContextLoadResult, Thread, ThreadEvent}; use agent::{ThreadEvent, ZedAgentThread};
use agent_settings::AgentProfileId; use agent_settings::AgentProfileId;
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use async_trait::async_trait; use async_trait::async_trait;
@@ -89,7 +89,7 @@ impl Error for FailedAssertion {}
pub struct ExampleContext { pub struct ExampleContext {
meta: ExampleMetadata, meta: ExampleMetadata,
log_prefix: String, log_prefix: String,
agent_thread: Entity<agent::Thread>, agent_thread: Entity<agent::ZedAgentThread>,
app: AsyncApp, app: AsyncApp,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
pub assertions: AssertionsReport, pub assertions: AssertionsReport,
@@ -100,7 +100,7 @@ impl ExampleContext {
pub fn new( pub fn new(
meta: ExampleMetadata, meta: ExampleMetadata,
log_prefix: String, log_prefix: String,
agent_thread: Entity<Thread>, agent_thread: Entity<ZedAgentThread>,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
app: AsyncApp, app: AsyncApp,
) -> Self { ) -> Self {
@@ -120,13 +120,7 @@ impl ExampleContext {
pub fn push_user_message(&mut self, text: impl ToString) { pub fn push_user_message(&mut self, text: impl ToString) {
self.app self.app
.update_entity(&self.agent_thread, |thread, cx| { .update_entity(&self.agent_thread, |thread, cx| {
thread.insert_user_message( thread.insert_user_message(text.to_string(), cx);
text.to_string(),
ContextLoadResult::default(),
None,
Vec::new(),
cx,
);
}) })
.unwrap(); .unwrap();
} }
@@ -250,6 +244,7 @@ impl ExampleContext {
| ThreadEvent::UsePendingTools { .. } | ThreadEvent::UsePendingTools { .. }
| ThreadEvent::CompletionCanceled => {} | ThreadEvent::CompletionCanceled => {}
ThreadEvent::ToolUseLimitReached => {} ThreadEvent::ToolUseLimitReached => {}
ThreadEvent::StreamedToolUse2 { .. } => {}
ThreadEvent::ToolFinished { ThreadEvent::ToolFinished {
tool_use_id, tool_use_id,
pending_tool_use, pending_tool_use,
@@ -312,10 +307,10 @@ impl ExampleContext {
let model = self.model.clone(); let model = self.model.clone();
let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| { let message_count_before = self.app.update_entity(&self.agent_thread, |agent, cx| {
thread.set_remaining_turns(iterations); agent.set_remaining_turns(iterations);
thread.send_to_model(model, CompletionIntent::UserPrompt, None, cx); agent.send_to_model(model, CompletionIntent::UserPrompt, None, cx);
thread.messages().len() agent.messages().len()
})?; })?;
loop { loop {
@@ -333,13 +328,13 @@ impl ExampleContext {
} }
} }
let messages = self.app.read_entity(&self.agent_thread, |thread, cx| { let messages = self.app.read_entity(&self.agent_thread, |agent, cx| {
let mut messages = Vec::new(); let mut messages = Vec::new();
for message in thread.messages().skip(message_count_before) { for message in agent.messages().skip(message_count_before) {
messages.push(Message { messages.push(Message {
_role: message.role, _role: message.role,
text: message.to_string(), text: message.to_string(),
tool_use: thread tool_use: agent
.tool_uses_for_message(message.id, cx) .tool_uses_for_message(message.id, cx)
.into_iter() .into_iter()
.map(|tool_use| ToolUse { .map(|tool_use| ToolUse {
@@ -387,7 +382,7 @@ impl ExampleContext {
.unwrap() .unwrap()
} }
pub fn agent_thread(&self) -> Entity<Thread> { pub fn agent_thread(&self) -> Entity<ZedAgentThread> {
self.agent_thread.clone() self.agent_thread.clone()
} }
} }

View File

@@ -32,9 +32,9 @@ impl Example for CommentTranslation {
cx.run_to_end().await?; cx.run_to_end().await?;
let mut create_or_overwrite_count = 0; let mut create_or_overwrite_count = 0;
cx.agent_thread().read_with(cx, |thread, cx| { cx.agent_thread().read_with(cx, |agent, cx| {
for message in thread.messages() { for message in agent.messages() {
for tool_use in thread.tool_uses_for_message(message.id, cx) { for tool_use in agent.tool_uses_for_message(message.id, cx) {
if tool_use.name == "edit_file" { if tool_use.name == "edit_file" {
let input: EditFileToolInput = serde_json::from_value(tool_use.input)?; let input: EditFileToolInput = serde_json::from_value(tool_use.input)?;
if !matches!(input.mode, EditFileMode::Edit) { if !matches!(input.mode, EditFileMode::Edit) {

View File

@@ -1,3 +1,4 @@
use agent::thread::ToolUseSegment;
use agent::{Message, MessageSegment, SerializedThread, ThreadStore}; use agent::{Message, MessageSegment, SerializedThread, ThreadStore};
use anyhow::{Context as _, Result, anyhow, bail}; use anyhow::{Context as _, Result, anyhow, bail};
use assistant_tool::ToolWorkingSet; use assistant_tool::ToolWorkingSet;
@@ -307,7 +308,7 @@ impl ExampleInstance {
let thread_store = thread_store.await?; let thread_store = thread_store.await?;
let thread = let agent =
thread_store.update(cx, |thread_store, cx| { thread_store.update(cx, |thread_store, cx| {
let thread = if let Some(json) = &meta.existing_thread_json { let thread = if let Some(json) = &meta.existing_thread_json {
let serialized = SerializedThread::from_json(json.as_bytes()).expect("Can't read serialized thread"); let serialized = SerializedThread::from_json(json.as_bytes()).expect("Can't read serialized thread");
@@ -322,7 +323,7 @@ impl ExampleInstance {
})?; })?;
thread.update(cx, |thread, _cx| { agent.update(cx, |thread, _cx| {
let mut request_count = 0; let mut request_count = 0;
let previous_diff = Rc::new(RefCell::new("".to_string())); let previous_diff = Rc::new(RefCell::new("".to_string()));
let example_output_dir = this.run_directory.clone(); let example_output_dir = this.run_directory.clone();
@@ -370,7 +371,7 @@ impl ExampleInstance {
let mut example_cx = ExampleContext::new( let mut example_cx = ExampleContext::new(
meta.clone(), meta.clone(),
this.log_prefix.clone(), this.log_prefix.clone(),
thread.clone(), agent.clone(),
model.clone(), model.clone(),
cx.clone(), cx.clone(),
); );
@@ -419,11 +420,12 @@ impl ExampleInstance {
fs::write(this.run_directory.join("diagnostics_after.txt"), diagnostics_after)?; fs::write(this.run_directory.join("diagnostics_after.txt"), diagnostics_after)?;
} }
thread.update(cx, |thread, _cx| { agent.update(cx, |agent, _cx| {
let response_count = thread let response_count = agent
.messages() .messages()
.filter(|message| message.role == language_model::Role::Assistant) .filter(|message| message.role == language_model::Role::Assistant)
.count(); .count();
let all_messages = messages_to_markdown(agent.messages());
RunOutput { RunOutput {
repository_diff, repository_diff,
diagnostic_summary_before, diagnostic_summary_before,
@@ -431,9 +433,9 @@ impl ExampleInstance {
diagnostics_before, diagnostics_before,
diagnostics_after, diagnostics_after,
response_count, response_count,
token_usage: thread.cumulative_token_usage(), token_usage: agent.cumulative_token_usage(),
tool_metrics: example_cx.tool_metrics.lock().unwrap().clone(), tool_metrics: example_cx.tool_metrics.lock().unwrap().clone(),
all_messages: messages_to_markdown(thread.messages()), all_messages,
programmatic_assertions: example_cx.assertions, programmatic_assertions: example_cx.assertions,
} }
}) })
@@ -848,11 +850,9 @@ fn messages_to_markdown<'a>(message_iter: impl IntoIterator<Item = &'a Message>)
messages.push_str(&text); messages.push_str(&text);
messages.push_str("\n"); messages.push_str("\n");
} }
MessageSegment::RedactedThinking(items) => { MessageSegment::ToolUse(ToolUseSegment { name, input, .. }) => {
messages.push_str(&format!( messages.push_str(&format!("**Tool Use**: {}\n\n", name));
"**Redacted Thinking**: {} item(s)\n\n", messages.push_str(&format!("Input: {:?}\n\n", input));
items.len()
));
} }
} }
} }

View File

@@ -2,7 +2,7 @@ use crate::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolChoice, LanguageModelToolUse, StopReason,
}; };
use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream}; use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window}; use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
@@ -91,7 +91,12 @@ pub struct ToolUseRequest {
#[derive(Default)] #[derive(Default)]
pub struct FakeLanguageModel { pub struct FakeLanguageModel {
current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>, current_completion_txs: Mutex<
Vec<(
LanguageModelRequest,
mpsc::UnboundedSender<LanguageModelCompletionEvent>,
)>,
>,
} }
impl FakeLanguageModel { impl FakeLanguageModel {
@@ -110,7 +115,7 @@ impl FakeLanguageModel {
pub fn stream_completion_response( pub fn stream_completion_response(
&self, &self,
request: &LanguageModelRequest, request: &LanguageModelRequest,
chunk: impl Into<String>, stream: impl Into<FakeLanguageModelStream>,
) { ) {
let current_completion_txs = self.current_completion_txs.lock(); let current_completion_txs = self.current_completion_txs.lock();
let tx = current_completion_txs let tx = current_completion_txs
@@ -118,7 +123,9 @@ impl FakeLanguageModel {
.find(|(req, _)| req == request) .find(|(req, _)| req == request)
.map(|(_, tx)| tx) .map(|(_, tx)| tx)
.unwrap(); .unwrap();
tx.unbounded_send(chunk.into()).unwrap(); for event in stream.into().events {
tx.unbounded_send(event).unwrap();
}
} }
pub fn end_completion_stream(&self, request: &LanguageModelRequest) { pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
@@ -127,7 +134,7 @@ impl FakeLanguageModel {
.retain(|(req, _)| req != request); .retain(|(req, _)| req != request);
} }
pub fn stream_last_completion_response(&self, chunk: impl Into<String>) { pub fn stream_last_completion_response(&self, chunk: impl Into<FakeLanguageModelStream>) {
self.stream_completion_response(self.pending_completions().last().unwrap(), chunk); self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
} }
@@ -136,6 +143,29 @@ impl FakeLanguageModel {
} }
} }
pub struct FakeLanguageModelStream {
events: Vec<LanguageModelCompletionEvent>,
}
impl<T: Into<String>> From<T> for FakeLanguageModelStream {
fn from(chunk: T) -> Self {
Self {
events: vec![LanguageModelCompletionEvent::Text(chunk.into())],
}
}
}
impl From<LanguageModelToolUse> for FakeLanguageModelStream {
fn from(tool_use: LanguageModelToolUse) -> Self {
Self {
events: vec![
LanguageModelCompletionEvent::ToolUse(tool_use),
LanguageModelCompletionEvent::Stop(StopReason::ToolUse),
],
}
}
}
impl LanguageModel for FakeLanguageModel { impl LanguageModel for FakeLanguageModel {
fn id(&self) -> LanguageModelId { fn id(&self) -> LanguageModelId {
language_model_id() language_model_id()
@@ -190,12 +220,7 @@ impl LanguageModel for FakeLanguageModel {
> { > {
let (tx, rx) = mpsc::unbounded(); let (tx, rx) = mpsc::unbounded();
self.current_completion_txs.lock().push((request, tx)); self.current_completion_txs.lock().push((request, tx));
async move { async move { Ok(rx.map(Ok).boxed()) }.boxed()
Ok(rx
.map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
.boxed())
}
.boxed()
} }
fn as_fake(&self) -> &Self { fn as_fake(&self) -> &Self {

View File

@@ -330,6 +330,14 @@ impl MessageContent {
| MessageContent::Image(_) => false, | MessageContent::Image(_) => false,
} }
} }
#[cfg(any(test, feature = "test-support"))]
pub fn as_tool_result(&self) -> Option<&LanguageModelToolResult> {
match self {
MessageContent::ToolResult(tool_result) => Some(tool_result),
_ => None,
}
}
} }
impl From<String> for MessageContent { impl From<String> for MessageContent {
@@ -364,6 +372,40 @@ impl LanguageModelRequestMessage {
pub fn contents_empty(&self) -> bool { pub fn contents_empty(&self) -> bool {
self.content.iter().all(|content| content.is_empty()) self.content.iter().all(|content| content.is_empty())
} }
pub fn push(&mut self, content: MessageContent) {
if let Some(last_content) = self.content.last_mut() {
match (last_content, content) {
(MessageContent::Text(last_text), MessageContent::Text(new_text)) => {
last_text.push_str(&new_text);
}
(
MessageContent::Thinking {
text: last_text,
signature,
},
MessageContent::Thinking {
text: new_text,
signature: new_signature,
},
) => {
last_text.push_str(&new_text);
if signature.is_none() {
*signature = new_signature;
}
}
(
MessageContent::RedactedThinking(last_text),
MessageContent::RedactedThinking(new_text),
) => {
last_text.push_str(&new_text);
}
(_, content) => self.content.push(content),
}
} else {
self.content.push(content);
}
}
} }
#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)] #[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]