From f35fbbb78ff53a7a871e446af8a499a37084c0cd Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 26 Jun 2025 17:37:22 -0700 Subject: [PATCH] Move ActionLog from ZedAgent to Thread Co-authored-by: Mikayla Maki --- crates/agent/src/thread.rs | 69 +++---- crates/agent_ui/src/active_thread.rs | 4 +- crates/agent_ui/src/agent_diff.rs | 171 ++++++++++-------- crates/agent_ui/src/agent_panel.rs | 27 +-- crates/agent_ui/src/agent_ui.rs | 2 +- crates/agent_ui/src/message_editor.rs | 69 +++---- crates/agent_ui/src/tool_compatibility.rs | 2 +- crates/assistant_tool/src/action_log.rs | 17 ++ crates/eval/src/example.rs | 2 +- .../src/examples/file_change_notification.rs | 2 +- 10 files changed, 190 insertions(+), 175 deletions(-) diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index c0d04c225f..b90d9f7f69 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -432,7 +432,6 @@ pub struct ZedAgent { tool_result_cards: HashMap, tool_results: HashMap, - action_log: Entity, initial_project_snapshot: Shared>>>, request_token_usage: Vec, cumulative_token_usage: TokenUsage, @@ -460,6 +459,7 @@ pub struct Thread { checkpoints_by_message: HashMap, last_restore_checkpoint: Option, project: Entity, + action_log: Entity, updated_at: DateTime, } @@ -710,6 +710,14 @@ impl Thread { }) } + pub fn action_log(&self) -> Entity { + self.action_log.clone() + } + + pub fn project(&self) -> &Entity { + &self.project + } + fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context) { self.checkpoints_by_message .insert(checkpoint.message_id, checkpoint); @@ -782,7 +790,7 @@ impl ZedAgent { let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel(); let configured_model = LanguageModelRegistry::read_global(cx).default_model(); let profile_id = AgentSettings::get_global(cx).default_profile.clone(); - let thread = cx.new(|_| Thread { + let thread = cx.new(|cx| Thread { summary: ThreadSummary::Pending, pending_checkpoint: None, next_message_id: MessageId(0), @@ -791,6 +799,7 @@ impl ZedAgent { checkpoints_by_message: HashMap::default(), last_restore_checkpoint: None, updated_at: Utc::now(), + action_log: cx.new(|_| ActionLog::new(project.clone())), project: project.clone(), }); @@ -819,7 +828,6 @@ impl ZedAgent { pending_tool_uses_by_id: HashMap::default(), tool_result_cards: HashMap::default(), tool_use_metadata_by_id: HashMap::default(), - action_log: cx.new(|_| ActionLog::new(project.clone())), initial_project_snapshot: { let project_snapshot = Self::project_snapshot(project, cx); cx.foreground_executor() @@ -842,6 +850,10 @@ impl ZedAgent { } } + pub fn action_log(&self, cx: &App) -> Entity { + self.thread().read(cx).action_log().clone() + } + pub fn thread(&self) -> &Entity { &self.thread } @@ -955,7 +967,7 @@ impl ZedAgent { }) .collect(); - let thread = cx.new(|_| Thread { + let thread = cx.new(|cx| Thread { id, next_message_id, messages, @@ -964,6 +976,7 @@ impl ZedAgent { last_restore_checkpoint: None, project: project.clone(), updated_at: serialized.updated_at, + action_log: cx.new(|_| ActionLog::new(project.clone())), summary: ThreadSummary::Ready(serialized.summary), }); @@ -988,7 +1001,6 @@ impl ZedAgent { project: project.clone(), prompt_builder, tools: tools.clone(), - action_log: cx.new(|_| ActionLog::new(project.clone())), initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(), request_token_usage: serialized.request_token_usage, cumulative_token_usage: serialized.cumulative_token_usage, @@ -1441,11 +1453,15 @@ impl ZedAgent { cx: &mut Context, ) -> MessageId { if !loaded_context.referenced_buffers.is_empty() { - self.action_log.update(cx, |log, cx| { - for buffer in loaded_context.referenced_buffers { - log.buffer_read(buffer, cx); - } - }); + self.thread + .read(cx) + .action_log + .clone() + .update(cx, |log, cx| { + for buffer in loaded_context.referenced_buffers { + log.buffer_read(buffer, cx); + } + }); } let message_id = self.thread.update(cx, |thread, cx| { @@ -3007,7 +3023,7 @@ impl ZedAgent { input, request, self.project.clone(), - self.action_log.clone(), + self.thread.read(cx).action_log(), model, window, cx, @@ -3439,37 +3455,6 @@ impl ZedAgent { Ok(String::from_utf8_lossy(&markdown).to_string()) } - pub fn keep_edits_in_range( - &mut self, - buffer: Entity, - buffer_range: Range, - cx: &mut Context, - ) { - self.action_log.update(cx, |action_log, cx| { - action_log.keep_edits_in_range(buffer, buffer_range, cx) - }); - } - - pub fn keep_all_edits(&mut self, cx: &mut Context) { - self.action_log - .update(cx, |action_log, cx| action_log.keep_all_edits(cx)); - } - - pub fn reject_edits_in_ranges( - &mut self, - buffer: Entity, - buffer_ranges: Vec>, - cx: &mut Context, - ) -> Task> { - self.action_log.update(cx, |action_log, cx| { - action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx) - }) - } - - pub fn action_log(&self) -> &Entity { - &self.action_log - } - pub fn project(&self) -> &Entity { &self.project } diff --git a/crates/agent_ui/src/active_thread.rs b/crates/agent_ui/src/active_thread.rs index d705ab0c54..ef27ad747f 100644 --- a/crates/agent_ui/src/active_thread.rs +++ b/crates/agent_ui/src/active_thread.rs @@ -791,7 +791,7 @@ impl ActiveThread { text_thread_store, context_store, agent: agent.clone(), - thread, + thread: thread.clone(), project, workspace, save_thread_task: None, @@ -816,7 +816,7 @@ impl ActiveThread { }; // todo! hold on to thread entity and get messages directly - for message in agent.read(cx).messages(cx).cloned().collect::>() { + for message in thread.read(cx).messages().cloned().collect::>() { let rendered_message = RenderedMessage::from_segments( &message.segments, this.language_registry.clone(), diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index 3953c9de09..ced989c051 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -2,6 +2,7 @@ use crate::{Keep, KeepAll, OpenAgentDiff, Reject, RejectAll}; use agent::{ThreadEvent, ZedAgent}; use agent_settings::AgentSettings; use anyhow::Result; +use assistant_tool::ActionLog; use buffer_diff::DiffHunkStatus; use collections::{HashMap, HashSet}; use editor::{ @@ -40,7 +41,8 @@ use zed_actions::assistant::ToggleFocus; pub struct AgentDiffPane { multibuffer: Entity, editor: Entity, - thread: Entity, + agent: Entity, + action_log: Entity, focus_handle: FocusHandle, workspace: WeakEntity, title: SharedString, @@ -49,70 +51,71 @@ pub struct AgentDiffPane { impl AgentDiffPane { pub fn deploy( - thread: Entity, + agent: Entity, workspace: WeakEntity, window: &mut Window, cx: &mut App, ) -> Result> { workspace.update(cx, |workspace, cx| { - Self::deploy_in_workspace(thread, workspace, window, cx) + Self::deploy_in_workspace(agent, workspace, window, cx) }) } pub fn deploy_in_workspace( - thread: Entity, + agent: Entity, workspace: &mut Workspace, window: &mut Window, cx: &mut Context, ) -> Entity { let existing_diff = workspace .items_of_type::(cx) - .find(|diff| diff.read(cx).thread == thread); + .find(|diff| diff.read(cx).agent == agent); if let Some(existing_diff) = existing_diff { workspace.activate_item(&existing_diff, true, true, window, cx); existing_diff } else { - let agent_diff = cx - .new(|cx| AgentDiffPane::new(thread.clone(), workspace.weak_handle(), window, cx)); + let agent_diff = + cx.new(|cx| AgentDiffPane::new(agent.clone(), workspace.weak_handle(), window, cx)); workspace.add_item_to_center(Box::new(agent_diff.clone()), window, cx); agent_diff } } pub fn new( - thread: Entity, + agent: Entity, workspace: WeakEntity, window: &mut Window, cx: &mut Context, ) -> Self { let focus_handle = cx.focus_handle(); let multibuffer = cx.new(|_| MultiBuffer::new(Capability::ReadWrite)); + let action_log = agent.read(cx).thread().read(cx).action_log(); + let project = agent.read(cx).project().clone(); - let project = thread.read(cx).project().clone(); let editor = cx.new(|cx| { let mut editor = Editor::for_multibuffer(multibuffer.clone(), Some(project.clone()), window, cx); editor.disable_inline_diagnostics(); editor.set_expand_all_diff_hunks(cx); - editor.set_render_diff_hunk_controls(diff_hunk_controls(&thread), cx); + editor.set_render_diff_hunk_controls(diff_hunk_controls(&action_log), cx); editor.register_addon(AgentDiffAddon); editor }); - let action_log = thread.read(cx).action_log().clone(); let mut this = Self { _subscriptions: vec![ cx.observe_in(&action_log, window, |this, _action_log, window, cx| { this.update_excerpts(window, cx) }), - cx.subscribe(&thread, |this, _thread, event, cx| { + cx.subscribe(&agent, |this, _thread, event, cx| { this.handle_thread_event(event, cx) }), ], title: SharedString::default(), + action_log, multibuffer, editor, - thread, + agent, focus_handle, workspace, }; @@ -122,8 +125,13 @@ impl AgentDiffPane { } fn update_excerpts(&mut self, window: &mut Window, cx: &mut Context) { - let thread = self.thread.read(cx); - let changed_buffers = thread.action_log().read(cx).changed_buffers(cx); + let agent = self.agent.read(cx); + let changed_buffers = agent + .thread() + .read(cx) + .action_log() + .read(cx) + .changed_buffers(cx); let mut paths_to_delete = self.multibuffer.read(cx).paths().collect::>(); for (buffer, diff_handle) in changed_buffers { @@ -216,7 +224,13 @@ impl AgentDiffPane { } fn update_title(&mut self, cx: &mut Context) { - let new_title = self.thread.read(cx).summary(cx).unwrap_or("Agent Changes"); + let new_title = self + .agent + .read(cx) + .thread() + .read(cx) + .summary() + .unwrap_or("Agent Changes"); if new_title != self.title { self.title = new_title; cx.emit(EditorEvent::TitleChanged); @@ -253,14 +267,14 @@ impl AgentDiffPane { fn keep(&mut self, _: &Keep, window: &mut Window, cx: &mut Context) { self.editor.update(cx, |editor, cx| { let snapshot = editor.buffer().read(cx).snapshot(cx); - keep_edits_in_selection(editor, &snapshot, &self.thread, window, cx); + keep_edits_in_selection(editor, &snapshot, &self.action_log, window, cx); }); } fn reject(&mut self, _: &Reject, window: &mut Window, cx: &mut Context) { self.editor.update(cx, |editor, cx| { let snapshot = editor.buffer().read(cx).snapshot(cx); - reject_edits_in_selection(editor, &snapshot, &self.thread, window, cx); + reject_edits_in_selection(editor, &snapshot, &self.action_log, window, cx); }); } @@ -270,7 +284,7 @@ impl AgentDiffPane { reject_edits_in_ranges( editor, &snapshot, - &self.thread, + &self.action_log, vec![editor::Anchor::min()..editor::Anchor::max()], window, cx, @@ -279,15 +293,15 @@ impl AgentDiffPane { } fn keep_all(&mut self, _: &KeepAll, _window: &mut Window, cx: &mut Context) { - self.thread - .update(cx, |thread, cx| thread.keep_all_edits(cx)); + self.action_log + .update(cx, |action_log, cx| action_log.keep_all_edits(cx)); } } fn keep_edits_in_selection( editor: &mut Editor, buffer_snapshot: &MultiBufferSnapshot, - thread: &Entity, + action_log: &Entity, window: &mut Window, cx: &mut Context, ) { @@ -296,13 +310,13 @@ fn keep_edits_in_selection( .disjoint_anchor_ranges() .collect::>(); - keep_edits_in_ranges(editor, buffer_snapshot, &thread, ranges, window, cx) + keep_edits_in_ranges(editor, buffer_snapshot, &action_log, ranges, window, cx) } fn reject_edits_in_selection( editor: &mut Editor, buffer_snapshot: &MultiBufferSnapshot, - thread: &Entity, + action_log: &Entity, window: &mut Window, cx: &mut Context, ) { @@ -310,13 +324,13 @@ fn reject_edits_in_selection( .selections .disjoint_anchor_ranges() .collect::>(); - reject_edits_in_ranges(editor, buffer_snapshot, &thread, ranges, window, cx) + reject_edits_in_ranges(editor, buffer_snapshot, &action_log, ranges, window, cx) } fn keep_edits_in_ranges( editor: &mut Editor, buffer_snapshot: &MultiBufferSnapshot, - thread: &Entity, + action_log: &Entity, ranges: Vec>, window: &mut Window, cx: &mut Context, @@ -331,8 +345,8 @@ fn keep_edits_in_ranges( for hunk in &diff_hunks_in_ranges { let buffer = multibuffer.read(cx).buffer(hunk.buffer_id); if let Some(buffer) = buffer { - thread.update(cx, |thread, cx| { - thread.keep_edits_in_range(buffer, hunk.buffer_range.clone(), cx) + action_log.update(cx, |action_log, cx| { + action_log.keep_edits_in_range(buffer, hunk.buffer_range.clone(), cx) }); } } @@ -341,7 +355,7 @@ fn keep_edits_in_ranges( fn reject_edits_in_ranges( editor: &mut Editor, buffer_snapshot: &MultiBufferSnapshot, - thread: &Entity, + action_log: &Entity, ranges: Vec>, window: &mut Window, cx: &mut Context, @@ -366,9 +380,9 @@ fn reject_edits_in_ranges( } for (buffer, ranges) in ranges_by_buffer { - thread - .update(cx, |thread, cx| { - thread.reject_edits_in_ranges(buffer, ranges, cx) + action_log + .update(cx, |action_log, cx| { + action_log.reject_edits_in_ranges(buffer, ranges, cx) }) .detach_and_log_err(cx); } @@ -466,7 +480,7 @@ impl Item for AgentDiffPane { } fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement { - let summary = self.thread.read(cx).summary(cx).unwrap_or("Agent Changes"); + let summary = self.agent.read(cx).summary(cx).or_default(); Label::new(format!("Review: {}", summary)) .color(if params.selected { Color::Default @@ -516,7 +530,7 @@ impl Item for AgentDiffPane { where Self: Sized, { - Some(cx.new(|cx| Self::new(self.thread.clone(), self.workspace.clone(), window, cx))) + Some(cx.new(|cx| Self::new(self.agent.clone(), self.workspace.clone(), window, cx))) } fn is_dirty(&self, cx: &App) -> bool { @@ -646,8 +660,8 @@ impl Render for AgentDiffPane { } } -fn diff_hunk_controls(thread: &Entity) -> editor::RenderDiffHunkControlsFn { - let thread = thread.clone(); +fn diff_hunk_controls(action_log: &Entity) -> editor::RenderDiffHunkControlsFn { + let action_log = action_log.clone(); Arc::new( move |row, @@ -665,7 +679,7 @@ fn diff_hunk_controls(thread: &Entity) -> editor::RenderDiffHunkContro hunk_range, is_created_file, line_height, - &thread, + &action_log, editor, window, cx, @@ -681,7 +695,7 @@ fn render_diff_hunk_controls( hunk_range: Range, is_created_file: bool, line_height: Pixels, - thread: &Entity, + action_log: &Entity, editor: &Entity, window: &mut Window, cx: &mut App, @@ -716,14 +730,14 @@ fn render_diff_hunk_controls( ) .on_click({ let editor = editor.clone(); - let thread = thread.clone(); + let action_log = action_log.clone(); move |_event, window, cx| { editor.update(cx, |editor, cx| { let snapshot = editor.buffer().read(cx).snapshot(cx); reject_edits_in_ranges( editor, &snapshot, - &thread, + &action_log, vec![hunk_range.start..hunk_range.start], window, cx, @@ -738,14 +752,14 @@ fn render_diff_hunk_controls( ) .on_click({ let editor = editor.clone(); - let thread = thread.clone(); + let action_log = action_log.clone(); move |_event, window, cx| { editor.update(cx, |editor, cx| { let snapshot = editor.buffer().read(cx).snapshot(cx); keep_edits_in_ranges( editor, &snapshot, - &thread, + &action_log, vec![hunk_range.start..hunk_range.start], window, cx, @@ -1119,7 +1133,7 @@ impl Render for AgentDiffToolbar { let has_pending_edit_tool_use = agent_diff .read(cx) - .thread + .agent .read(cx) .has_pending_edit_tool_uses(); @@ -1192,7 +1206,7 @@ pub enum EditorState { } struct WorkspaceThread { - thread: WeakEntity, + agent: WeakEntity, _thread_subscriptions: [Subscription; 2], singleton_editors: HashMap, HashMap, Subscription>>, _settings_subscription: Subscription, @@ -1229,11 +1243,11 @@ impl AgentDiff { fn register_active_thread_impl( &mut self, workspace: &WeakEntity, - thread: &Entity, + agent: &Entity, window: &mut Window, cx: &mut Context, ) { - let action_log = thread.read(cx).action_log().clone(); + let action_log = agent.read(cx).action_log(cx).clone(); let action_log_subscription = cx.observe_in(&action_log, window, { let workspace = workspace.clone(); @@ -1242,7 +1256,7 @@ impl AgentDiff { } }); - let thread_subscription = cx.subscribe_in(&thread, window, { + let thread_subscription = cx.subscribe_in(&agent, window, { let workspace = workspace.clone(); move |this, _thread, event, window, cx| { this.handle_thread_event(&workspace, event, window, cx) @@ -1251,7 +1265,7 @@ impl AgentDiff { if let Some(workspace_thread) = self.workspace_threads.get_mut(&workspace) { // replace thread and action log subscription, but keep editors - workspace_thread.thread = thread.downgrade(); + workspace_thread.agent = agent.downgrade(); workspace_thread._thread_subscriptions = [action_log_subscription, thread_subscription]; self.update_reviewing_editors(&workspace, window, cx); return; @@ -1276,7 +1290,7 @@ impl AgentDiff { self.workspace_threads.insert( workspace.clone(), WorkspaceThread { - thread: thread.downgrade(), + agent: agent.downgrade(), _thread_subscriptions: [action_log_subscription, thread_subscription], singleton_editors: HashMap::default(), _settings_subscription: settings_subscription, @@ -1486,11 +1500,11 @@ impl AgentDiff { return; }; - let Some(thread) = workspace_thread.thread.upgrade() else { + let Some(agent) = workspace_thread.agent.upgrade() else { return; }; - let action_log = thread.read(cx).action_log(); + let action_log = agent.read(cx).thread().read(cx).action_log(); let changed_buffers = action_log.read(cx).changed_buffers(cx); let mut unaffected = self.reviewing_editors.clone(); @@ -1515,7 +1529,7 @@ impl AgentDiff { multibuffer.add_diff(diff_handle.clone(), cx); }); - let new_state = if thread.read(cx).is_generating() { + let new_state = if agent.read(cx).is_generating() { EditorState::Generating } else { EditorState::Reviewing @@ -1528,7 +1542,7 @@ impl AgentDiff { if previous_state.is_none() { editor.update(cx, |editor, cx| { editor.start_temporary_diff_override(); - editor.set_render_diff_hunk_controls(diff_hunk_controls(&thread), cx); + editor.set_render_diff_hunk_controls(diff_hunk_controls(&action_log), cx); editor.set_expand_all_diff_hunks(cx); editor.register_addon(EditorAgentDiffAddon); }); @@ -1596,7 +1610,7 @@ impl AgentDiff { return; }; - let Some(WorkspaceThread { thread, .. }) = + let Some(WorkspaceThread { agent: thread, .. }) = self.workspace_threads.get(&workspace.downgrade()) else { return; @@ -1611,7 +1625,7 @@ impl AgentDiff { fn keep_all( editor: &Entity, - thread: &Entity, + agent: &Entity, window: &mut Window, cx: &mut App, ) -> PostReviewState { @@ -1620,7 +1634,7 @@ impl AgentDiff { keep_edits_in_ranges( editor, &snapshot, - thread, + &agent.read(cx).action_log(cx), vec![editor::Anchor::min()..editor::Anchor::max()], window, cx, @@ -1640,7 +1654,7 @@ impl AgentDiff { reject_edits_in_ranges( editor, &snapshot, - thread, + &thread.read(cx).action_log(cx), vec![editor::Anchor::min()..editor::Anchor::max()], window, cx, @@ -1651,26 +1665,38 @@ impl AgentDiff { fn keep( editor: &Entity, - thread: &Entity, + agent: &Entity, window: &mut Window, cx: &mut App, ) -> PostReviewState { editor.update(cx, |editor, cx| { let snapshot = editor.buffer().read(cx).snapshot(cx); - keep_edits_in_selection(editor, &snapshot, thread, window, cx); + keep_edits_in_selection( + editor, + &snapshot, + &agent.read(cx).action_log(cx), + window, + cx, + ); Self::post_review_state(&snapshot) }) } fn reject( editor: &Entity, - thread: &Entity, + agent: &Entity, window: &mut Window, cx: &mut App, ) -> PostReviewState { editor.update(cx, |editor, cx| { let snapshot = editor.buffer().read(cx).snapshot(cx); - reject_edits_in_selection(editor, &snapshot, thread, window, cx); + reject_edits_in_selection( + editor, + &snapshot, + &agent.read(cx).action_log(cx), + window, + cx, + ); Self::post_review_state(&snapshot) }) } @@ -1701,14 +1727,13 @@ impl AgentDiff { return None; } - let WorkspaceThread { thread, .. } = - self.workspace_threads.get(&workspace.weak_handle())?; + let WorkspaceThread { agent, .. } = self.workspace_threads.get(&workspace.weak_handle())?; - let thread = thread.upgrade()?; + let agent = agent.upgrade()?; - if let PostReviewState::AllReviewed = review(&editor, &thread, window, cx) { + if let PostReviewState::AllReviewed = review(&editor, &agent, window, cx) { if let Some(curr_buffer) = editor.read(cx).buffer().read(cx).as_singleton() { - let changed_buffers = thread.read(cx).action_log().read(cx).changed_buffers(cx); + let changed_buffers = agent.read(cx).action_log(cx).read(cx).changed_buffers(cx); let mut keys = changed_buffers.keys().cycle(); keys.find(|k| *k == &curr_buffer); @@ -1806,13 +1831,14 @@ mod tests { }) .await .unwrap(); - let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); + let agent = thread_store.update(cx, |store, cx| store.create_thread(cx)); + let thread = agent.read_with(cx, |agent, _| agent.thread().clone()); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); let agent_diff = cx.new_window_entity(|window, cx| { - AgentDiffPane::new(thread.clone(), workspace.downgrade(), window, cx) + AgentDiffPane::new(agent.clone(), workspace.downgrade(), window, cx) }); let editor = agent_diff.read_with(cx, |diff, _cx| diff.editor.clone()); @@ -1900,7 +1926,7 @@ mod tests { keep_edits_in_ranges( editor, &snapshot, - &thread, + &thread.read(cx).action_log(), vec![position..position], window, cx, @@ -1971,7 +1997,8 @@ mod tests { }) .await .unwrap(); - let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); + let agent = thread_store.update(cx, |store, cx| store.create_thread(cx)); + let thread = agent.read_with(cx, |agent, _| agent.thread().clone()); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); let (workspace, cx) = @@ -1994,7 +2021,7 @@ mod tests { // Set the active thread cx.update(|window, cx| { - AgentDiff::set_active_thread(&workspace.downgrade(), &thread, window, cx) + AgentDiff::set_active_thread(&workspace.downgrade(), &agent, window, cx) }); let buffer1 = project @@ -2151,7 +2178,7 @@ mod tests { keep_edits_in_ranges( editor, &snapshot, - &thread, + &thread.read(cx).action_log(), vec![position..position], window, cx, diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index cd0c863d38..0468445a7a 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -524,7 +524,8 @@ impl AgentPanel { window: &mut Window, cx: &mut Context, ) -> Self { - let thread = thread_store.update(cx, |this, cx| this.create_thread(cx)); + let agent = thread_store.update(cx, |this, cx| this.create_thread(cx)); + let thread = agent.read(cx).thread().clone(); let fs = workspace.app_state().fs.clone(); let user_store = workspace.app_state().user_store.clone(); let project = workspace.project(); @@ -546,13 +547,13 @@ impl AgentPanel { prompt_store.clone(), thread_store.downgrade(), context_store.downgrade(), - thread.clone(), + agent.clone(), window, cx, ) }); - let thread_id = thread.read(cx).id(cx).clone(); + let thread_id = thread.read(cx).id().clone(); let history_store = cx.new(|cx| { HistoryStore::new( thread_store.clone(), @@ -566,7 +567,7 @@ impl AgentPanel { let active_thread = cx.new(|cx| { ActiveThread::new( - thread.clone(), + agent.clone(), thread_store.clone(), context_store.clone(), message_editor_context_store.clone(), @@ -607,7 +608,7 @@ impl AgentPanel { } }; - AgentDiff::set_active_thread(&workspace, &thread, window, cx); + AgentDiff::set_active_thread(&workspace, &agent, window, cx); let weak_panel = weak_self.clone(); @@ -753,7 +754,7 @@ impl AgentPanel { None }; - let thread = self + let agent = self .thread_store .update(cx, |this, cx| this.create_thread(cx)); @@ -786,7 +787,7 @@ impl AgentPanel { let active_thread = cx.new(|cx| { ActiveThread::new( - thread.clone(), + agent.clone(), self.thread_store.clone(), self.context_store.clone(), context_store.clone(), @@ -806,7 +807,7 @@ impl AgentPanel { self.prompt_store.clone(), self.thread_store.downgrade(), self.context_store.downgrade(), - thread.clone(), + agent.clone(), window, cx, ) @@ -823,7 +824,7 @@ impl AgentPanel { let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx); self.set_active_view(thread_view, window, cx); - AgentDiff::set_active_thread(&self.workspace, &thread, window, cx); + AgentDiff::set_active_thread(&self.workspace, &agent, window, cx); } fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context) { @@ -971,7 +972,7 @@ impl AgentPanel { pub(crate) fn open_thread( &mut self, - thread: Entity, + agent: Entity, window: &mut Window, cx: &mut Context, ) { @@ -984,7 +985,7 @@ impl AgentPanel { let active_thread = cx.new(|cx| { ActiveThread::new( - thread.clone(), + agent.clone(), self.thread_store.clone(), self.context_store.clone(), context_store.clone(), @@ -1003,7 +1004,7 @@ impl AgentPanel { self.prompt_store.clone(), self.thread_store.downgrade(), self.context_store.downgrade(), - thread.clone(), + agent.clone(), window, cx, ) @@ -1012,7 +1013,7 @@ impl AgentPanel { let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx); self.set_active_view(thread_view, window, cx); - AgentDiff::set_active_thread(&self.workspace, &thread, window, cx); + AgentDiff::set_active_thread(&self.workspace, &agent, window, cx); } pub fn go_back(&mut self, _: &workspace::GoBack, window: &mut Window, cx: &mut Context) { diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index eef467fbe1..fd502f03e2 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -26,7 +26,7 @@ mod ui; use std::sync::Arc; -use agent::{ZedAgent, ThreadId}; +use agent::{ThreadId, ZedAgent}; use agent_settings::{AgentProfileId, AgentSettings, LanguageModelSelection}; use assistant_slash_command::SlashCommandRegistry; use client::Client; diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index 1288a74900..25eae796ab 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -12,6 +12,7 @@ use crate::ui::{ use agent::{ context::{AgentContextKey, ContextLoadResult, load_context}, context_store::ContextStoreEvent, + thread::Thread, }; use agent_settings::{AgentSettings, CompletionMode}; use buffer_diff::BufferDiff; @@ -31,7 +32,7 @@ use gpui::{ Animation, AnimationExt, App, Entity, EventEmitter, Focusable, Subscription, Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between, }; -use language::{Buffer, Language, Point}; +use language::{Buffer, Language}; use language_model::{ ConfiguredModel, LanguageModelRequestMessage, MessageContent, ZED_CLOUD_PROVIDER_ID, }; @@ -66,6 +67,7 @@ use agent::{ #[derive(RegisterComponent)] pub struct MessageEditor { agent: Entity, + thread: Entity, incompatible_tools_state: Entity, editor: Entity, workspace: WeakEntity, @@ -156,10 +158,11 @@ impl MessageEditor { prompt_store: Option>, thread_store: WeakEntity, text_thread_store: WeakEntity, - thread: Entity, + agent: Entity, window: &mut Window, cx: &mut Context, ) -> Self { + let thread = agent.read(cx).thread().clone(); let context_picker_menu_handle = PopoverMenuHandle::default(); let model_selector_menu_handle = PopoverMenuHandle::default(); @@ -182,13 +185,13 @@ impl MessageEditor { Some(text_thread_store.clone()), context_picker_menu_handle.clone(), SuggestContextKind::File, - ModelUsageContext::Thread(thread.clone()), + ModelUsageContext::Thread(agent.clone()), window, cx, ) }); - let incompatible_tools = cx.new(|cx| IncompatibleToolsState::new(thread.clone(), cx)); + let incompatible_tools = cx.new(|cx| IncompatibleToolsState::new(agent.clone(), cx)); let subscriptions = vec![ cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event), @@ -210,20 +213,21 @@ impl MessageEditor { fs.clone(), model_selector_menu_handle, editor.focus_handle(cx), - ModelUsageContext::Thread(thread.clone()), + ModelUsageContext::Thread(agent.clone()), window, cx, ) }); let profile_selector = - cx.new(|cx| ProfileSelector::new(fs, thread.clone(), editor.focus_handle(cx), cx)); + cx.new(|cx| ProfileSelector::new(fs, agent.clone(), editor.focus_handle(cx), cx)); Self { editor: editor.clone(), - project: thread.read(cx).project().clone(), + project: agent.read(cx).project().clone(), user_store, - agent: thread, + agent, + thread, incompatible_tools_state: incompatible_tools.clone(), workspace, context_store, @@ -503,9 +507,8 @@ impl MessageEditor { return; } - self.agent.update(cx, |thread, cx| { - thread.keep_all_edits(cx); - }); + let action_log = self.thread.read(cx).action_log(); + action_log.update(cx, |action_log, cx| action_log.keep_all_edits(cx)); cx.notify(); } @@ -514,21 +517,8 @@ impl MessageEditor { return; } - // Since there's no reject_all_edits method in the thread API, - // we need to iterate through all buffers and reject their edits - let action_log = self.agent.read(cx).action_log().clone(); - let changed_buffers = action_log.read(cx).changed_buffers(cx); - - for (buffer, _) in changed_buffers { - self.agent.update(cx, |thread, cx| { - let buffer_snapshot = buffer.read(cx); - let start = buffer_snapshot.anchor_before(Point::new(0, 0)); - let end = buffer_snapshot.anchor_after(buffer_snapshot.max_point()); - thread - .reject_edits_in_ranges(buffer, vec![start..end], cx) - .detach(); - }); - } + let action_log = self.thread.read(cx).action_log(); + action_log.update(cx, |action_log, cx| action_log.reject_all_edits(cx)); cx.notify(); } @@ -542,13 +532,9 @@ impl MessageEditor { return; } - self.agent.update(cx, |thread, cx| { - let buffer_snapshot = buffer.read(cx); - let start = buffer_snapshot.anchor_before(Point::new(0, 0)); - let end = buffer_snapshot.anchor_after(buffer_snapshot.max_point()); - thread - .reject_edits_in_ranges(buffer, vec![start..end], cx) - .detach(); + let action_log = self.thread.read(cx).action_log(); + action_log.update(cx, |action_log, cx| { + action_log.reject_buffer_edits(buffer, cx) }); cx.notify(); } @@ -563,11 +549,9 @@ impl MessageEditor { return; } - self.agent.update(cx, |thread, cx| { - let buffer_snapshot = buffer.read(cx); - let start = buffer_snapshot.anchor_before(Point::new(0, 0)); - let end = buffer_snapshot.anchor_after(buffer_snapshot.max_point()); - thread.keep_edits_in_range(buffer, start..end, cx); + let action_log = self.thread.read(cx).action_log(); + action_log.update(cx, |action_log, cx| { + action_log.keep_buffer_edits(buffer, cx) }); cx.notify(); } @@ -1600,16 +1584,17 @@ impl Focusable for MessageEditor { impl Render for MessageEditor { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let thread = self.agent.read(cx); - let token_usage_ratio = thread + let agent = self.agent.read(cx); + let thread = agent.thread(); + let token_usage_ratio = agent .total_token_usage(cx) .map_or(TokenUsageRatio::Normal, |total_token_usage| { total_token_usage.ratio() }); - let burn_mode_enabled = thread.completion_mode() == CompletionMode::Burn; + let burn_mode_enabled = agent.completion_mode() == CompletionMode::Burn; - let action_log = self.agent.read(cx).action_log(); + let action_log = thread.read(cx).action_log(); let changed_buffers = action_log.read(cx).changed_buffers(cx); let line_height = TextSize::Small.rems(cx).to_pixels(window.rem_size()) * 1.5; diff --git a/crates/agent_ui/src/tool_compatibility.rs b/crates/agent_ui/src/tool_compatibility.rs index aa29b01dc7..0e25cb38ac 100644 --- a/crates/agent_ui/src/tool_compatibility.rs +++ b/crates/agent_ui/src/tool_compatibility.rs @@ -1,4 +1,4 @@ -use agent::{ZedAgent, ThreadEvent}; +use agent::{ThreadEvent, ZedAgent}; use assistant_tool::{Tool, ToolSource}; use collections::HashMap; use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window}; diff --git a/crates/assistant_tool/src/action_log.rs b/crates/assistant_tool/src/action_log.rs index 0877f18060..5519a51b91 100644 --- a/crates/assistant_tool/src/action_log.rs +++ b/crates/assistant_tool/src/action_log.rs @@ -495,6 +495,10 @@ impl ActionLog { cx.notify(); } + pub fn keep_buffer_edits(&mut self, buffer: Entity, cx: &mut Context) { + self.keep_edits_in_range(buffer, Anchor::MIN..Anchor::MAX, cx); + } + pub fn keep_edits_in_range( &mut self, buffer: Entity, @@ -555,6 +559,19 @@ impl ActionLog { } } + pub fn reject_all_edits(&mut self, cx: &mut Context) { + let changed_buffers = self.changed_buffers(cx); + for (buffer, _) in changed_buffers { + self.reject_edits_in_ranges(buffer, vec![Anchor::MIN..Anchor::MAX], cx) + .detach(); + } + } + + pub fn reject_buffer_edits(&mut self, buffer: Entity, cx: &mut Context) { + self.reject_edits_in_ranges(buffer, vec![Anchor::MIN..Anchor::MAX], cx) + .detach() + } + pub fn reject_edits_in_ranges( &mut self, buffer: Entity, diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 27bce6a6f3..9c1575187a 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -360,7 +360,7 @@ impl ExampleContext { pub fn edits(&self) -> HashMap, FileEdits> { self.agent_thread .read_with(&self.app, |thread, cx| { - let action_log = thread.action_log().read(cx); + let action_log = thread.action_log(cx).read(cx); HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map( |(buffer, diff)| { let snapshot = buffer.read(cx).snapshot(); diff --git a/crates/eval/src/examples/file_change_notification.rs b/crates/eval/src/examples/file_change_notification.rs index 0e4f770a67..acc4c742bb 100644 --- a/crates/eval/src/examples/file_change_notification.rs +++ b/crates/eval/src/examples/file_change_notification.rs @@ -42,7 +42,7 @@ impl Example for FileChangeNotificationExample { }; cx.agent_thread().update(cx, |thread, cx| { - thread.action_log().update(cx, |action_log, cx| { + thread.action_log(cx).update(cx, |action_log, cx| { action_log.buffer_read(buffer.clone(), cx); }); })?;