From 7340513eee4e6590fcd6c0c2e8d372d42f2387fc Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Fri, 18 Apr 2025 19:30:08 -0300 Subject: [PATCH] Wait for diagnostics after saving a buffer from any tool --- crates/agent/src/agent_diff.rs | 6 +- crates/assistant_tool/Cargo.toml | 1 + crates/assistant_tool/src/action_log.rs | 130 +++++++++++--- .../assistant_tools/src/code_action_tool.rs | 7 +- .../assistant_tools/src/create_file_tool.rs | 13 +- crates/assistant_tools/src/edit_file_tool.rs | 168 +----------------- crates/assistant_tools/src/rename_tool.rs | 8 +- 7 files changed, 123 insertions(+), 210 deletions(-) diff --git a/crates/agent/src/agent_diff.rs b/crates/agent/src/agent_diff.rs index 05f003d03a..50aa334868 100644 --- a/crates/agent/src/agent_diff.rs +++ b/crates/agent/src/agent_diff.rs @@ -984,8 +984,10 @@ mod tests { ) .unwrap() }); - action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); - }); + action_log.update(cx, |log, cx| log.save_edited_buffer(buffer.clone(), cx)) + }) + .await + .unwrap(); cx.run_until_parked(); // When opening the assistant diff, the cursor is positioned on the first hunk. diff --git a/crates/assistant_tool/Cargo.toml b/crates/assistant_tool/Cargo.toml index 9a819d9d81..42f896d621 100644 --- a/crates/assistant_tool/Cargo.toml +++ b/crates/assistant_tool/Cargo.toml @@ -22,6 +22,7 @@ gpui.workspace = true icons.workspace = true language.workspace = true language_model.workspace = true +log.workspace = true parking_lot.workspace = true project.workspace = true serde.workspace = true diff --git a/crates/assistant_tool/src/action_log.rs b/crates/assistant_tool/src/action_log.rs index 0f301d37e5..6d9e45ad19 100644 --- a/crates/assistant_tool/src/action_log.rs +++ b/crates/assistant_tool/src/action_log.rs @@ -1,7 +1,10 @@ use anyhow::{Context as _, Result}; use buffer_diff::BufferDiff; use collections::{BTreeMap, HashMap, HashSet}; -use futures::{StreamExt, channel::mpsc}; +use futures::{ + FutureExt as _, StreamExt, + channel::{mpsc, oneshot}, +}; use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity}; use language::{Anchor, Buffer, BufferEvent, DiagnosticEntry, DiskState, Point, ToPoint}; use project::{Project, ProjectItem, ProjectPath, lsp_store::OpenLspBufferHandle}; @@ -9,6 +12,7 @@ use std::{ cmp::{self, Ordering}, ops::Range, sync::Arc, + time::Duration, }; use text::{Edit, Patch, PointUtf16, Rope, Unclipped}; use util::RangeExt; @@ -336,14 +340,22 @@ impl ActionLog { self.track_buffer(buffer, false, cx); } - /// Track a buffer as read, so we can notify the model about user edits. - pub fn will_create_buffer(&mut self, buffer: Entity, cx: &mut Context) { + /// Save and track a new buffer + pub fn save_new_buffer( + &mut self, + buffer: Entity, + cx: &mut Context, + ) -> Task> { self.track_buffer(buffer.clone(), true, cx); - self.buffer_edited(buffer, cx) + self.save_edited_buffer(buffer, cx) } - /// Mark a buffer as edited, so we can refresh it in the context - pub fn buffer_edited(&mut self, buffer: Entity, cx: &mut Context) { + /// Save and track an edited buffer + pub fn save_edited_buffer( + &mut self, + buffer: Entity, + cx: &mut Context, + ) -> Task> { self.edited_since_diagnostics_report = true; let tracked_buffer = self.track_buffer(buffer.clone(), false, cx); @@ -351,6 +363,51 @@ impl ActionLog { tracked_buffer.status = TrackedBufferStatus::Modified; } tracked_buffer.schedule_diff_update(ChangeAuthor::Agent, cx); + + let project = self.project.clone(); + + cx.spawn(async move |_this, cx| { + let (tx, mut rx) = oneshot::channel(); + let mut tx = Some(tx); + + let _subscription = cx.subscribe(&project, move |_, event, _| match event { + project::Event::DiskBasedDiagnosticsFinished { .. } => { + if let Some(tx) = tx.take() { + tx.send(()).ok(); + } + } + _ => {} + }); + + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))? + .await?; + + let has_lang_server = project.update(cx, |project, cx| { + project.lsp_store().update(cx, |lsp_store, cx| { + buffer.update(cx, |buffer, cx| { + lsp_store + .language_servers_for_local_buffer(buffer, cx) + .next() + .is_some() + }) + }) + })?; + + if has_lang_server { + let timeout = cx.background_executor().timer(Duration::from_secs(30)); + futures::select! { + _ = rx => Ok(()), + _ = timeout.fuse() => { + log::info!("Did not receive diagnostics update 30s after agent edit"); + // We don't want to fail the tool here + Ok(()) + } + } + } else { + Ok(()) + } + }) } pub fn will_delete_buffer(&mut self, buffer: Entity, cx: &mut Context) { @@ -917,8 +974,10 @@ mod tests { .edit([(Point::new(4, 2)..Point::new(4, 3), "O")], None, cx) .unwrap() }); - action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); - }); + action_log.update(cx, |log, cx| log.save_edited_buffer(buffer.clone(), cx)) + }) + .await + .unwrap(); cx.run_until_parked(); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.text()), @@ -989,8 +1048,10 @@ mod tests { .unwrap(); buffer.finalize_last_transaction(); }); - action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); - }); + action_log.update(cx, |log, cx| log.save_edited_buffer(buffer.clone(), cx)) + }) + .await + .unwrap(); cx.run_until_parked(); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.text()), @@ -1056,8 +1117,10 @@ mod tests { .edit([(Point::new(1, 2)..Point::new(2, 3), "F\nGHI")], None, cx) .unwrap() }); - action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); - }); + action_log.update(cx, |log, cx| log.save_edited_buffer(buffer.clone(), cx)) + }) + .await + .unwrap(); cx.run_until_parked(); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.text()), @@ -1152,12 +1215,10 @@ mod tests { .unwrap(); cx.update(|cx| { buffer.update(cx, |buffer, cx| buffer.set_text("lorem", cx)); - action_log.update(cx, |log, cx| log.will_create_buffer(buffer.clone(), cx)); - }); - project - .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) - .await - .unwrap(); + action_log.update(cx, |log, cx| log.save_new_buffer(buffer.clone(), cx)) + }) + .await + .unwrap(); cx.run_until_parked(); assert_eq!( unreviewed_hunks(&action_log, cx), @@ -1274,9 +1335,8 @@ mod tests { .await .unwrap(); buffer2.update(cx, |buffer, cx| buffer.set_text("IPSUM", cx)); - action_log.update(cx, |log, cx| log.will_create_buffer(buffer2.clone(), cx)); - project - .update(cx, |project, cx| project.save_buffer(buffer2.clone(), cx)) + action_log + .update(cx, |log, cx| log.save_new_buffer(buffer2.clone(), cx)) .await .unwrap(); @@ -1330,8 +1390,11 @@ mod tests { .edit([(Point::new(5, 2)..Point::new(5, 3), "O")], None, cx) .unwrap() }); - action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); - }); + action_log.update(cx, |log, cx| log.save_edited_buffer(buffer.clone(), cx)) + }) + .await + .unwrap(); + cx.run_until_parked(); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.text()), @@ -1465,8 +1528,10 @@ mod tests { .edit([(Point::new(5, 2)..Point::new(5, 3), "O")], None, cx) .unwrap() }); - action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); - }); + action_log.update(cx, |log, cx| log.save_edited_buffer(buffer.clone(), cx)) + }) + .await + .unwrap(); cx.run_until_parked(); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.text()), @@ -1588,8 +1653,10 @@ mod tests { .unwrap(); cx.update(|cx| { buffer.update(cx, |buffer, cx| buffer.set_text("content", cx)); - action_log.update(cx, |log, cx| log.will_create_buffer(buffer.clone(), cx)); - }); + action_log.update(cx, |log, cx| log.save_new_buffer(buffer.clone(), cx)) + }) + .await + .unwrap(); project .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) .await @@ -1713,9 +1780,14 @@ mod tests { cx.update(|cx| { buffer.update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx)); if is_agent_change { - action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + action_log + .update(cx, |log, cx| log.save_edited_buffer(buffer.clone(), cx)) + } else { + Task::ready(Ok(())) } - }); + }) + .await + .unwrap(); } } diff --git a/crates/assistant_tools/src/code_action_tool.rs b/crates/assistant_tools/src/code_action_tool.rs index 8c60c83b56..60a11d0e2e 100644 --- a/crates/assistant_tools/src/code_action_tool.rs +++ b/crates/assistant_tools/src/code_action_tool.rs @@ -241,13 +241,10 @@ impl Tool for CodeActionTool { format!("Completed code action: {}", title) }; - project - .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))? - .await?; action_log.update(cx, |log, cx| { - log.buffer_edited(buffer.clone(), cx) - })?; + log.save_edited_buffer(buffer.clone(), cx) + })?.await?; Ok(response) } else { diff --git a/crates/assistant_tools/src/create_file_tool.rs b/crates/assistant_tools/src/create_file_tool.rs index dc777bfb8d..6aa7d6465f 100644 --- a/crates/assistant_tools/src/create_file_tool.rs +++ b/crates/assistant_tools/src/create_file_tool.rs @@ -97,14 +97,11 @@ impl Tool for CreateFileTool { cx.update(|cx| { buffer.update(cx, |buffer, cx| buffer.set_text(contents, cx)); action_log.update(cx, |action_log, cx| { - action_log.will_create_buffer(buffer.clone(), cx) - }); - })?; - - project - .update(cx, |project, cx| project.save_buffer(buffer, cx))? - .await - .map_err(|err| anyhow!("Unable to save buffer for {destination_path}: {err}"))?; + action_log.save_new_buffer(buffer.clone(), cx) + }) + })? + .await + .map_err(|err| anyhow!("Unable to save buffer for {destination_path}: {err}"))?; Ok(format!("Created file {destination_path}")) }) diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs index 8f00873961..043645d831 100644 --- a/crates/assistant_tools/src/edit_file_tool.rs +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -1,15 +1,15 @@ use crate::{replace::replace_with_flexible_indent, schema::json_schema_for}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::{ActionLog, Tool, ToolResult}; -use futures::{FutureExt as _, channel::oneshot}; + use gpui::{App, AppContext, AsyncApp, Entity, Task}; -use language::{Anchor, Buffer, BufferSnapshot, DiagnosticEntry, DiagnosticSeverity}; + use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use std::fmt::Write as _; -use std::{path::PathBuf, sync::Arc, time::Duration}; + +use std::{path::PathBuf, sync::Arc}; use ui::IconName; use crate::replace::replace_exact; @@ -104,7 +104,6 @@ impl Tool for EditFileTool { .update(cx, |project, cx| project.open_buffer(project_path, cx))? .await?; - let old_diagnostics = save_buffer_and_get_project_diagnostics(&buffer, &project, cx).await; let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; if input.old_string.is_empty() { @@ -164,13 +163,12 @@ impl Tool for EditFileTool { buffer.finalize_last_transaction(); buffer.snapshot() }); - action_log.update(cx, |log, cx| { - log.buffer_edited(buffer.clone(), cx) - }); snapshot })?; - let mut output = String::new(); + action_log.update(cx, |log, cx| { + log.save_edited_buffer(buffer.clone(), cx) + })?.await?; let diff_str = cx.background_spawn({ let snapshot = snapshot.clone(); @@ -179,158 +177,8 @@ impl Tool for EditFileTool { language::unified_diff(&old_text, &new_text) } }).await; - writeln!(&mut output, "Edited {}:\n\n```diff\n{}\n```", input.path.display(), diff_str)?; - let new_diagnostics = save_buffer_and_get_project_diagnostics(&buffer, &project, cx).await; - - if let Some((old_diagnostics, new_diagnostics)) = old_diagnostics.ok().zip(new_diagnostics.ok()) { - let diagnostics_diff = cx.background_spawn(async move { - DiagnosticDiff::new(old_diagnostics, new_diagnostics, &snapshot) - }).await; - - writeln!(&mut output, "{}", diagnostics_diff)?; - } - - Ok(output) + Ok(format!("Edited {}:\n\n```diff\n{}\n```", input.path.display(), diff_str)) }).into() } } - -async fn save_buffer_and_get_project_diagnostics( - buffer: &Entity, - project: &Entity, - cx: &mut AsyncApp, -) -> Result>> { - let (tx, mut rx) = oneshot::channel(); - let mut tx = Some(tx); - - let _subscription = cx.subscribe(&project, move |_, event, _| match event { - project::Event::DiskBasedDiagnosticsFinished { .. } => { - if let Some(tx) = tx.take() { - tx.send(()).ok(); - } - } - _ => {} - }); - - project - .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))? - .await?; - - let has_lang_server = project.update(cx, |project, cx| { - project.lsp_store().update(cx, |lsp_store, cx| { - buffer.update(cx, |buffer, cx| { - lsp_store - .language_servers_for_local_buffer(buffer, cx) - .next() - .is_some() - }) - }) - })?; - - if has_lang_server { - let timeout = cx.background_executor().timer(Duration::from_secs(60)); - futures::select! { - _ = rx => buffer.read_with(cx, |buffer, _| buffer.snapshot().diagnostics_in_range(0..buffer.len(), false).collect()), - _ = timeout.fuse() => Err(anyhow!("LSP timeout")) - } - } else { - Ok(Vec::new()) - } -} - -struct DiagnosticDiff { - added: Vec>, - removed: Vec>, -} - -impl DiagnosticDiff { - fn new( - old: Vec>, - new: Vec>, - buffer: &BufferSnapshot, - ) -> Self { - let mut added = Vec::new(); - let mut removed = Vec::new(); - - let mut old_iter = old.into_iter().peekable(); - let mut new_iter = new.into_iter().peekable(); - - loop { - match (old_iter.peek(), new_iter.peek()) { - (Some(old_entry), Some(new_entry)) => { - match old_entry.cmp(&new_entry, buffer) { - std::cmp::Ordering::Less => { - // Old entry comes first and isn't in new - it's removed - removed.push(old_iter.next().unwrap()); - } - std::cmp::Ordering::Greater => { - // New entry comes first and isn't in old - it's added - added.push(new_iter.next().unwrap()); - } - std::cmp::Ordering::Equal => { - // They're the same - just advance both iterators - old_iter.next(); - new_iter.next(); - } - } - } - (Some(_), None) => { - // Only old entries left - they're all removed - removed.push(old_iter.next().unwrap()); - } - (None, Some(_)) => { - // Only new entries left - they're all added - added.push(new_iter.next().unwrap()); - } - (None, None) => break, - } - } - - Self { added, removed } - } -} - -impl std::fmt::Display for DiagnosticDiff { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.added.is_empty() && self.removed.is_empty() { - return Ok(()); - } - - if !self.removed.is_empty() { - writeln!(f, "Fixed diagnostics:")?; - for diag in &self.removed { - writeln!( - f, - " - {}: {}", - severity_to_str(diag.diagnostic.severity), - diag.diagnostic.message - )?; - } - } - - if !self.added.is_empty() { - writeln!(f, "Introduced diagnostics:")?; - for diag in &self.added { - writeln!( - f, - " + {}: {}", - severity_to_str(diag.diagnostic.severity), - diag.diagnostic.message - )?; - } - } - - Ok(()) - } -} - -fn severity_to_str(severity: DiagnosticSeverity) -> &'static str { - match severity { - DiagnosticSeverity::ERROR => "Error", - DiagnosticSeverity::WARNING => "Warning", - DiagnosticSeverity::INFORMATION => "Info", - DiagnosticSeverity::HINT => "Hint", - _ => "Diagnostic", - } -} diff --git a/crates/assistant_tools/src/rename_tool.rs b/crates/assistant_tools/src/rename_tool.rs index a29ea02e5f..9b3c10d9c3 100644 --- a/crates/assistant_tools/src/rename_tool.rs +++ b/crates/assistant_tools/src/rename_tool.rs @@ -129,13 +129,9 @@ impl Tool for RenameTool { })? .await?; - project - .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))? - .await?; - action_log.update(cx, |log, cx| { - log.buffer_edited(buffer.clone(), cx) - })?; + log.save_edited_buffer(buffer.clone(), cx) + })?.await?; Ok(format!("Renamed '{}' to '{}'", input.symbol, input.new_name)) }).into()