Compare commits

...

8 Commits

Author SHA1 Message Date
Richard Feldman
ef60a60e11 After migration, rename the heed db to delete in a way we can get back
Co-authored-by: Agus Zubiaga <hi@aguz.me>
2025-05-28 15:09:06 -04:00
Richard Feldman
5d2ac968d8 Add a bunch of tests
Co-authored-by: Agus Zubiaga <hi@aguz.me>
2025-05-28 15:05:32 -04:00
Richard Feldman
e9d4b8766f Initial LMDB -> SQLite changes
Co-authored-by: Agus Zubiaga <hi@aguz.me>
2025-05-28 14:11:05 -04:00
Richard Feldman
6812872d1a wip 2025-05-28 13:43:15 -04:00
Richard Feldman
2aebeb067c Don't save buffer until *after* autoformatting completes.
Co-authored-by: Agus Zubiaga <hi@aguz.me>
2025-05-28 13:22:55 -04:00
Richard Feldman
7dfd5d1963 Reproduce format_on_save incorrectly marking buffers as stale for LLM
Co-authored-by: Agus Zubiaga <hi@aguz.me>
2025-05-28 13:10:34 -04:00
Richard Feldman
a677b891a1 Do a log::error on an unexpected role, not debug_panic! 2025-05-28 12:16:56 -04:00
Richard Feldman
a2cb480244 Apply format_on_save to edits made with the edit tool 2025-05-28 10:58:02 -04:00
6 changed files with 1066 additions and 106 deletions

2
Cargo.lock generated
View File

@@ -117,6 +117,7 @@ dependencies = [
"streaming_diff",
"telemetry",
"telemetry_events",
"tempfile",
"terminal",
"terminal_view",
"text",
@@ -683,6 +684,7 @@ dependencies = [
"language_model",
"language_models",
"log",
"lsp",
"markdown",
"open",
"paths",

View File

@@ -107,3 +107,4 @@ language = { workspace = true, "features" = ["test-support"] }
language_model = { workspace = true, "features" = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
tempfile.workspace = true

View File

@@ -1,6 +1,6 @@
use std::borrow::Cow;
use std::cell::{Ref, RefCell};
use std::path::{Path, PathBuf};
use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
@@ -10,6 +10,10 @@ use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
use context_server::ContextServerId;
use db::sqlez::bindable::Column;
use db::sqlez::statement::Statement;
use db::sqlez_macros::sql;
use db::{define_connection, query};
use futures::channel::{mpsc, oneshot};
use futures::future::{self, BoxFuture, Shared};
use futures::{FutureExt as _, StreamExt as _};
@@ -17,8 +21,7 @@ use gpui::{
App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
Subscription, Task, prelude::*,
};
use heed::Database;
use heed::types::SerdeBincode;
use heed;
use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
use project::context_server_store::{ContextServerStatus, ContextServerStore};
use project::{Project, ProjectItem, ProjectPath, Worktree};
@@ -36,6 +39,31 @@ use crate::thread::{
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
};
// Implement Bind trait for ThreadId to use in SQL queries
// impl db::sqlez::bindable::Bind for ThreadId {
// fn bind(&self, statement: &Statement, start_index: i32) -> anyhow::Result<i32> {
// self.to_string().bind(statement, start_index)
// }
// }
// Implement Column trait for SerializedThreadMetadata
impl Column for SerializedThreadMetadata {
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
let (id_str, next_index): (String, i32) = Column::column(statement, start_index)?;
let (summary, next_index): (String, i32) = Column::column(statement, next_index)?;
let (updated_at_timestamp, next_index): (i64, i32) = Column::column(statement, next_index)?;
Ok((
Self {
id: ThreadId::from(id_str.as_str()),
summary: summary.into(),
updated_at: DateTime::from_timestamp(updated_at_timestamp, 0).unwrap_or_default(),
},
next_index,
))
}
}
const RULES_FILE_NAMES: [&'static str; 6] = [
".rules",
".cursorrules",
@@ -657,6 +685,7 @@ pub struct SerializedThreadMetadata {
}
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(test, derive(Clone))]
pub struct SerializedThread {
pub version: String,
pub summary: SharedString,
@@ -681,6 +710,7 @@ pub struct SerializedThread {
}
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(test, derive(Clone))]
pub struct SerializedLanguageModel {
pub provider: String,
pub model: String,
@@ -745,7 +775,7 @@ impl SerializedThreadV0_1_0 {
}
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SerializedMessage {
pub id: MessageId,
pub role: Role,
@@ -763,7 +793,7 @@ pub struct SerializedMessage {
pub is_hidden: bool,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(tag = "type")]
pub enum SerializedMessageSegment {
#[serde(rename = "text")]
@@ -781,14 +811,14 @@ pub enum SerializedMessageSegment {
},
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SerializedToolUse {
pub id: LanguageModelToolUseId,
pub name: SharedString,
pub input: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SerializedToolResult {
pub tool_use_id: LanguageModelToolUseId,
pub is_error: bool,
@@ -850,7 +880,7 @@ impl LegacySerializedMessage {
}
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SerializedCrease {
pub start: usize,
pub end: usize,
@@ -866,26 +896,6 @@ impl Global for GlobalThreadsDatabase {}
pub(crate) struct ThreadsDatabase {
executor: BackgroundExecutor,
env: heed::Env,
threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
}
impl heed::BytesEncode<'_> for SerializedThread {
type EItem = SerializedThread;
fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
}
}
impl<'a> heed::BytesDecode<'a> for SerializedThread {
type DItem = SerializedThread;
fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
// We implement this type manually because we want to call `SerializedThread::from_json`,
// instead of the Deserialize trait implementation for `SerializedThread`.
SerializedThread::from_json(bytes).map_err(Into::into)
}
}
impl ThreadsDatabase {
@@ -900,8 +910,7 @@ impl ThreadsDatabase {
let database_future = executor
.spawn({
let executor = executor.clone();
let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
async move { ThreadsDatabase::new(database_path, executor) }
async move { ThreadsDatabase::new(executor).await }
})
.then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
.boxed()
@@ -910,80 +919,511 @@ impl ThreadsDatabase {
cx.set_global(GlobalThreadsDatabase(database_future));
}
pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
std::fs::create_dir_all(&path)?;
const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
let env = unsafe {
heed::EnvOpenOptions::new()
.map_size(ONE_GB_IN_BYTES)
.max_dbs(1)
.open(path)?
};
let mut txn = env.write_txn()?;
let threads = env.create_database(&mut txn, Some("threads"))?;
txn.commit()?;
Ok(Self {
executor,
env,
threads,
})
pub async fn new(executor: BackgroundExecutor) -> Result<Self> {
Ok(Self { executor })
}
pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
let env = self.env.clone();
let threads = self.threads;
self.executor.spawn(async move {
let txn = env.read_txn()?;
let mut iter = threads.iter(&txn)?;
let mut threads = Vec::new();
while let Some((key, value)) = iter.next().transpose()? {
threads.push(SerializedThreadMetadata {
id: key,
summary: value.summary,
updated_at: value.updated_at,
});
}
Ok(threads)
})
self.executor
.spawn(async move { AGENT_THREADS.all_threads().await })
}
pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
let env = self.env.clone();
let threads = self.threads;
self.executor.spawn(async move {
let txn = env.read_txn()?;
let thread = threads.get(&txn, &id)?;
Ok(thread)
})
self.executor
.spawn(async move { AGENT_THREADS.get_thread(id).await })
}
pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
let env = self.env.clone();
let threads = self.threads;
self.executor.spawn(async move {
let mut txn = env.write_txn()?;
threads.put(&mut txn, &id, &thread)?;
txn.commit()?;
Ok(())
})
self.executor
.spawn(async move { AGENT_THREADS.save_thread(id, thread).await })
}
pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
let env = self.env.clone();
let threads = self.threads;
self.executor
.spawn(async move { AGENT_THREADS.delete_thread_by_id(id).await })
}
self.executor.spawn(async move {
let mut txn = env.write_txn()?;
threads.delete(&mut txn, &id)?;
txn.commit()?;
Ok(())
})
/// Migrate a legacy `heed` LMDB database to SQLite
pub async fn migrate_from_heed(heed_path: &Path) -> Result<()> {
Self::migrate_from_heed_to_db(heed_path, &AGENT_THREADS).await
}
/// Migrate a legacy `heed` LMDB database to a specific SQLite database
pub async fn migrate_from_heed_to_db(heed_path: &Path, db: &ThreadStoreDB) -> Result<()> {
if !heed_path.exists() {
return Ok(()); // No migration needed
}
// Open the old heed database
let env = unsafe {
heed::EnvOpenOptions::new()
.map_size(1024 * 1024 * 1024) // 1GB
.max_dbs(1)
.open(&heed_path)?
};
let txn = env.read_txn()?;
let old_threads: heed::Database<heed::types::SerdeBincode<ThreadId>, SerializedThread> =
env.open_database(&txn, Some("threads"))?
.ok_or_else(|| anyhow!("threads database not found"))?;
// Migrate all threads
for result in old_threads.iter(&txn)? {
if let Some((id, thread)) = result.log_err() {
db.save_thread(id, thread).await.log_err();
}
}
drop(txn);
drop(env);
// Rename the old heed database with .bak suffix
let mut backup_path = heed_path.to_path_buf();
let file_name = heed_path
.file_name()
.ok_or_else(|| anyhow!("invalid heed path"))?;
let new_name = format!("{}.bak", file_name.to_string_lossy());
backup_path.set_file_name(new_name);
std::fs::rename(&heed_path, &backup_path)?;
Ok(())
}
}
// Heed serialization helpers for migration
impl heed::BytesEncode<'_> for SerializedThread {
type EItem = SerializedThread;
fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
}
}
impl<'a> heed::BytesDecode<'a> for SerializedThread {
type DItem = SerializedThread;
fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
SerializedThread::from_json(bytes).map_err(Into::into)
}
}
define_connection!(pub static ref AGENT_THREADS: ThreadStoreDB<()> =
&[sql!(
CREATE TABLE IF NOT EXISTS agent_threads(
id TEXT PRIMARY KEY,
summary TEXT NOT NULL,
updated_at INTEGER NOT NULL,
data TEXT NOT NULL
) STRICT;
)];
);
impl ThreadStoreDB {
query! {
pub async fn all_threads() -> Result<Vec<SerializedThreadMetadata>> {
SELECT id, summary, updated_at
FROM agent_threads
ORDER BY updated_at DESC
}
}
query! {
async fn get_thread_data(id: String) -> Result<Option<String>> {
SELECT data FROM agent_threads WHERE id = (?)
}
}
query! {
async fn save_thread_data(id: String, summary: String, updated_at: i64, data: String) -> Result<()> {
INSERT OR REPLACE INTO agent_threads (id, summary, updated_at, data)
VALUES ((?), (?), (?), (?))
}
}
query! {
async fn delete_thread_data(id: String) -> Result<()> {
DELETE FROM agent_threads WHERE id = (?)
}
}
pub async fn get_thread(&self, id: ThreadId) -> Result<Option<SerializedThread>> {
let id_str = id.to_string();
let result = self.get_thread_data(id_str).await?;
match result {
Some(json_str) => {
let thread = SerializedThread::from_json(json_str.as_bytes())?;
Ok(Some(thread))
}
None => Ok(None),
}
}
pub async fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Result<()> {
let thread_json = serde_json::to_string(&thread)?;
let updated_at = thread.updated_at.timestamp();
let id_str = id.to_string();
let summary = thread.summary.clone();
self.save_thread_data(id_str, summary.to_string(), updated_at, thread_json)
.await
}
pub async fn delete_thread_by_id(&self, id: ThreadId) -> Result<()> {
let id_str = id.to_string();
self.delete_thread_data(id_str).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
use gpui::TestAppContext;
use std::sync::Arc;
use tempfile::TempDir;
#[gpui::test]
async fn test_save_load_delete_threads(_cx: &mut TestAppContext) {
let db = ThreadStoreDB::open_test_db("test_save_load_delete_threads").await;
// Test that no threads exist initially
let threads = db.all_threads().await.unwrap();
assert_eq!(threads.len(), 0);
// Create test thread data
let thread_id = ThreadId::from("test-thread-1");
let thread = SerializedThread {
version: SerializedThread::VERSION.to_string(),
summary: SharedString::from("Test thread summary"),
updated_at: Utc::now(),
messages: vec![],
initial_project_snapshot: None,
cumulative_token_usage: TokenUsage::default(),
request_token_usage: vec![],
detailed_summary_state: DetailedSummaryState::NotGenerated,
exceeded_window_error: None,
model: None,
completion_mode: Some(CompletionMode::Normal),
tool_use_limit_reached: false,
};
let thread_summary = thread.summary.clone();
let thread_version = thread.version.clone();
// Save thread
db.save_thread(thread_id.clone(), thread.clone())
.await
.unwrap();
// Load all threads
let threads = db.all_threads().await.unwrap();
assert_eq!(threads.len(), 1);
assert_eq!(threads[0].id, thread_id);
assert_eq!(threads[0].summary, thread_summary);
// Load specific thread
let loaded_thread = db.get_thread(thread_id.clone()).await.unwrap();
assert!(loaded_thread.is_some());
let loaded_thread = loaded_thread.unwrap();
assert_eq!(loaded_thread.summary, thread_summary);
assert_eq!(loaded_thread.version, thread_version);
// Update thread
let updated_thread = SerializedThread {
summary: SharedString::from("Updated summary"),
updated_at: Utc::now(),
..thread
};
db.save_thread(thread_id.clone(), updated_thread.clone())
.await
.unwrap();
// Verify update
let loaded_thread = db.get_thread(thread_id.clone()).await.unwrap().unwrap();
assert_eq!(loaded_thread.summary, SharedString::from("Updated summary"));
// Delete thread
db.delete_thread_by_id(thread_id.clone()).await.unwrap();
// Verify deletion
let loaded_thread = db.get_thread(thread_id.clone()).await.unwrap();
assert!(loaded_thread.is_none());
let threads = db.all_threads().await.unwrap();
assert_eq!(threads.len(), 0);
}
#[gpui::test]
async fn test_multiple_threads(_cx: &mut TestAppContext) {
let db = ThreadStoreDB::open_test_db("test_multiple_threads").await;
// Create multiple threads
let thread_ids = [
ThreadId::from("thread-1"),
ThreadId::from("thread-2"),
ThreadId::from("thread-3"),
];
for (i, thread_id) in thread_ids.iter().enumerate() {
let thread = SerializedThread {
version: SerializedThread::VERSION.to_string(),
summary: SharedString::from(format!("Thread {}", i + 1)),
updated_at: Utc::now() - chrono::Duration::hours(i as i64),
messages: vec![],
initial_project_snapshot: None,
cumulative_token_usage: TokenUsage::default(),
request_token_usage: Vec::new(),
detailed_summary_state: DetailedSummaryState::NotGenerated,
exceeded_window_error: None,
model: None,
completion_mode: Some(CompletionMode::Normal),
tool_use_limit_reached: false,
};
db.save_thread(thread_id.clone(), thread).await.unwrap();
}
// Load all threads - should be ordered by updated_at DESC
let threads = db.all_threads().await.unwrap();
assert_eq!(threads.len(), 3);
assert_eq!(threads[0].summary.as_ref(), "Thread 1");
assert_eq!(threads[1].summary.as_ref(), "Thread 2");
assert_eq!(threads[2].summary.as_ref(), "Thread 3");
// Delete middle thread
db.delete_thread_by_id(thread_ids[1].clone()).await.unwrap();
let threads = db.all_threads().await.unwrap();
assert_eq!(threads.len(), 2);
assert_eq!(threads[0].summary.as_ref(), "Thread 1");
assert_eq!(threads[1].summary.as_ref(), "Thread 3");
}
#[gpui::test]
async fn test_heed_to_sqlite_migration(_cx: &mut TestAppContext) {
use heed::types::SerdeBincode;
// Create a temporary directory for the heed database
let temp_dir = TempDir::new().unwrap();
let heed_path = temp_dir.path().join("test-heed-db");
// Create and populate heed database
{
std::fs::create_dir_all(&heed_path).unwrap();
let env = unsafe {
heed::EnvOpenOptions::new()
.map_size(1024 * 1024 * 1024)
.max_dbs(1)
.open(&heed_path)
.unwrap()
};
let mut txn = env.write_txn().unwrap();
let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThread> =
env.create_database(&mut txn, Some("threads")).unwrap();
// Insert test data
let thread_ids = [
ThreadId::from("legacy-thread-1"),
ThreadId::from("legacy-thread-2"),
ThreadId::from("legacy-thread-3"),
];
for (i, thread_id) in thread_ids.iter().enumerate() {
let thread = SerializedThread {
version: SerializedThread::VERSION.to_string(),
summary: SharedString::from(format!("Legacy Thread {}", i + 1)),
updated_at: DateTime::from_timestamp(1700000000 - (i as i64) * 86400, 0)
.unwrap(),
messages: vec![SerializedMessage {
id: MessageId(i),
role: Role::User,
segments: vec![SerializedMessageSegment::Text {
text: format!("Test message {}", i),
}],
tool_uses: vec![],
tool_results: vec![],
context: String::new(),
creases: vec![],
is_hidden: false,
}],
initial_project_snapshot: None,
cumulative_token_usage: TokenUsage {
input_tokens: ((i + 1) * 100) as u32,
output_tokens: ((i + 1) * 50) as u32,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
request_token_usage: vec![],
detailed_summary_state: DetailedSummaryState::NotGenerated,
exceeded_window_error: None,
model: None,
completion_mode: Some(CompletionMode::Normal),
tool_use_limit_reached: false,
};
threads.put(&mut txn, thread_id, &thread).unwrap();
}
txn.commit().unwrap();
}
// Clear any existing SQLite data
let db = ThreadStoreDB::open_test_db("test_heed_to_sqlite_migration").await;
// Verify SQLite is empty
let threads_before = db.all_threads().await.unwrap();
assert_eq!(threads_before.len(), 0);
// Run migration
ThreadsDatabase::migrate_from_heed_to_db(&heed_path, &db)
.await
.unwrap();
// Verify all threads were migrated
let threads_after = db.all_threads().await.unwrap();
assert_eq!(threads_after.len(), 3);
// Verify thread metadata
let thread_summaries: Vec<_> = threads_after.iter().map(|t| t.summary.as_ref()).collect();
assert!(thread_summaries.contains(&"Legacy Thread 1"));
assert!(thread_summaries.contains(&"Legacy Thread 2"));
assert!(thread_summaries.contains(&"Legacy Thread 3"));
// Verify full thread data
for i in 1..=3 {
let thread_id = ThreadId::from(&format!("legacy-thread-{}", i) as &str);
let thread = db.get_thread(thread_id).await.unwrap().unwrap();
assert_eq!(thread.summary.as_ref(), format!("Legacy Thread {}", i));
assert_eq!(thread.messages.len(), 1);
assert_eq!(
thread.messages[0].segments[0],
SerializedMessageSegment::Text {
text: format!("Test message {}", i - 1)
}
);
assert_eq!(thread.cumulative_token_usage.input_tokens, (i * 100) as u32);
assert_eq!(thread.cumulative_token_usage.output_tokens, (i * 50) as u32);
}
// Verify heed database was renamed with .bak suffix
assert!(!heed_path.exists());
let mut backup_path = heed_path.to_path_buf();
backup_path.set_file_name(format!("{}.bak", heed_path.file_name().unwrap().to_string_lossy()));
assert!(backup_path.exists());
}
#[gpui::test]
async fn test_thread_serialization_deserialization(_cx: &mut TestAppContext) {
let db = ThreadStoreDB::open_test_db("test_thread_serialization_deserialization").await;
let thread_id = ThreadId::from("serialization-test");
let original_thread = SerializedThread {
version: SerializedThread::VERSION.to_string(),
summary: SharedString::from("Serialization test thread"),
updated_at: Utc::now(),
messages: vec![
SerializedMessage {
id: MessageId(1),
role: Role::User,
segments: vec![
SerializedMessageSegment::Text {
text: "Hello".to_string(),
},
SerializedMessageSegment::Thinking {
text: "Thinking about the response".to_string(),
signature: Some("sig123".to_string()),
},
],
tool_uses: vec![SerializedToolUse {
id: LanguageModelToolUseId::from("tool-1"),
name: SharedString::from("test_tool"),
input: serde_json::json!({"key": "value"}),
}],
tool_results: vec![SerializedToolResult {
tool_use_id: LanguageModelToolUseId::from("tool-1"),
is_error: false,
content: LanguageModelToolResultContent::Text("Result".into()),
output: None,
}],
context: String::new(),
creases: vec![SerializedCrease {
start: 0,
end: 5,
icon_path: SharedString::from("icon.png"),
label: SharedString::from("test-crease"),
}],
is_hidden: false,
},
SerializedMessage {
id: MessageId(2),
role: Role::Assistant,
segments: vec![SerializedMessageSegment::RedactedThinking {
data: vec![1, 2, 3, 4, 5],
}],
tool_uses: vec![],
tool_results: vec![],
context: String::new(),
creases: vec![],
is_hidden: true,
},
],
initial_project_snapshot: Some(Arc::new(ProjectSnapshot {
worktree_snapshots: vec![],
unsaved_buffer_paths: vec![],
timestamp: Utc::now(),
})),
cumulative_token_usage: TokenUsage {
input_tokens: 1000,
output_tokens: 500,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
request_token_usage: vec![TokenUsage {
input_tokens: 100,
output_tokens: 50,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
}],
detailed_summary_state: DetailedSummaryState::Generated {
text: SharedString::from("Detailed summary"),
message_id: MessageId(1),
},
exceeded_window_error: None,
model: Some(SerializedLanguageModel {
provider: "test-provider".to_string(),
model: "test-model".to_string(),
}),
completion_mode: Some(CompletionMode::Normal),
tool_use_limit_reached: true,
};
// Save thread
db.save_thread(thread_id.clone(), original_thread.clone())
.await
.unwrap();
// Load thread
let loaded_thread = db.get_thread(thread_id).await.unwrap().unwrap();
// Verify all fields
assert_eq!(loaded_thread.version, original_thread.version);
assert_eq!(loaded_thread.summary, original_thread.summary);
assert_eq!(loaded_thread.messages.len(), original_thread.messages.len());
assert_eq!(loaded_thread.messages[0].segments.len(), 2);
assert_eq!(loaded_thread.messages[0].tool_uses.len(), 1);
assert_eq!(loaded_thread.messages[0].tool_results.len(), 1);
assert_eq!(loaded_thread.messages[0].creases.len(), 1);
assert_eq!(loaded_thread.messages[1].is_hidden, true);
assert!(loaded_thread.initial_project_snapshot.is_some());
assert_eq!(
loaded_thread.cumulative_token_usage.input_tokens,
original_thread.cumulative_token_usage.input_tokens
);
assert_eq!(
loaded_thread.exceeded_window_error.is_none(),
original_thread.exceeded_window_error.is_none()
);
assert!(loaded_thread.model.is_some());
assert_eq!(loaded_thread.tool_use_limit_reached, true);
}
}

View File

@@ -36,6 +36,7 @@ itertools.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
lsp.workspace = true
markdown.workspace = true
open.workspace = true
paths.workspace = true
@@ -64,6 +65,7 @@ workspace.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
lsp = { workspace = true, features = ["test-support"] }
client = { workspace = true, features = ["test-support"] }
clock = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }

View File

@@ -26,7 +26,6 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{cmp, iter, mem, ops::Range, path::PathBuf, sync::Arc, task::Poll};
use streaming_diff::{CharOperation, StreamingDiff};
use util::debug_panic;
#[derive(Serialize)]
struct CreateFilePromptTemplate {
@@ -582,8 +581,9 @@ impl EditAgent {
conversation.messages.pop();
}
} else {
debug_panic!(
"Last message must be an Assistant tool calling! Got {:?}",
log::error!(
"Last message should have had a role of Assistant, but instead was {:?}. Full content of that message: {:?}",
last_message.role,
last_message.content
);
}

View File

@@ -18,15 +18,19 @@ use gpui::{
use indoc::formatdoc;
use language::{
Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer,
language_settings::SoftWrap,
language_settings::{self, FormatOnSave, SoftWrap},
};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
use project::{Project, ProjectPath};
use project::{
Project, ProjectPath,
lsp_store::{FormatTrigger, LspFormatTarget},
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::{
collections::HashSet,
path::{Path, PathBuf},
sync::Arc,
time::Duration,
@@ -34,6 +38,7 @@ use std::{
use theme::ThemeSettings;
use ui::{Disclosure, Tooltip, prelude::*};
use util::ResultExt;
use workspace::Workspace;
pub struct EditFileTool;
@@ -187,8 +192,9 @@ impl Tool for EditFileTool {
});
let card_clone = card.clone();
let action_log_clone = action_log.clone();
let task = cx.spawn(async move |cx: &mut AsyncApp| {
let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
let edit_agent = EditAgent::new(model, project.clone(), action_log_clone, Templates::new());
let buffer = project
.update(cx, |project, cx| {
@@ -249,19 +255,49 @@ impl Tool for EditFileTool {
}
let agent_output = output.await?;
// Check if format_on_save is enabled
let format_on_save_enabled = buffer
.read_with(cx, |buffer, cx| {
let settings = language_settings::language_settings(
buffer.language().map(|l| l.name()),
buffer.file(),
cx,
);
!matches!(settings.format_on_save, FormatOnSave::Off)
})
.unwrap_or(false);
// If format_on_save is enabled, format the buffer before saving
if format_on_save_enabled {
let format_task = project.update(cx, |project, cx| {
project.format(
HashSet::from_iter([buffer.clone()]),
LspFormatTarget::Buffers,
false, // Don't push to history since the tool did it.
FormatTrigger::Save,
cx,
)
})?;
format_task.await.log_err();
}
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
// Notify the action log that we've edited the buffer AFTER save completes
// This ensures the tracked version matches the saved version
action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx);
})?;
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let new_text = cx.background_spawn({
let new_snapshot = new_snapshot.clone();
async move { new_snapshot.text() }
});
let diff = cx.background_spawn(async move {
language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
});
let (new_text, diff) = futures::join!(new_text, diff);
})
.await;
let diff = language::unified_diff(&old_text, &new_text);
let output = EditFileToolOutput {
original_path: project_path.path.to_path_buf(),
@@ -272,7 +308,7 @@ impl Tool for EditFileTool {
if let Some(card) = card_clone {
card.update(cx, |card, cx| {
card.set_diff(project_path.path.clone(), old_text, new_text, cx);
card.set_diff(project_path.path.clone(), old_text, new_text.clone(), cx);
})
.log_err();
}
@@ -920,8 +956,8 @@ async fn build_buffer_diff(
mod tests {
use super::*;
use client::TelemetrySettings;
use fs::FakeFs;
use gpui::TestAppContext;
use fs::{FakeFs, Fs};
use gpui::{TestAppContext, UpdateGlobal};
use language_model::fake_provider::FakeLanguageModel;
use serde_json::json;
use settings::SettingsStore;
@@ -1131,4 +1167,483 @@ mod tests {
Project::init_settings(cx);
});
}
#[gpui::test]
async fn test_format_on_save(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({"src": {}})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
// Set up a Rust language with LSP formatting support
let rust_language = Arc::new(language::Language::new(
language::LanguageConfig {
name: "Rust".into(),
matcher: language::LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
None,
));
// Register the language and fake LSP
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
language_registry.add(rust_language);
let mut fake_language_servers = language_registry.register_fake_lsp(
"Rust",
language::FakeLspAdapter {
capabilities: lsp::ServerCapabilities {
document_formatting_provider: Some(lsp::OneOf::Left(true)),
..Default::default()
},
..Default::default()
},
);
// Create the file
fs.save(
path!("/root/src/main.rs").as_ref(),
&"initial content".into(),
language::LineEnding::Unix,
)
.await
.unwrap();
// Open the buffer to trigger LSP initialization
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(path!("/root/src/main.rs"), cx)
})
.await
.unwrap();
// Register the buffer with language servers
let _handle = project.update(cx, |project, cx| {
project.register_buffer_with_language_servers(&buffer, cx)
});
const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
const FORMATTED_CONTENT: &str =
"This file was formatted by the fake formatter in the test.\n";
// Get the fake language server and set up formatting handler
let fake_language_server = fake_language_servers.next().await.unwrap();
fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
|_, _| async move {
Ok(Some(vec![lsp::TextEdit {
range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
new_text: FORMATTED_CONTENT.to_string(),
}]))
}
});
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
// First, test with format_on_save enabled
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
cx,
|settings| {
settings.defaults.format_on_save = Some(FormatOnSave::On);
settings.defaults.formatter =
Some(language::language_settings::SelectedFormatter::Auto);
},
);
});
});
// Have the model stream unformatted content
let edit_result = {
let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Create main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
});
// Stream the unformatted content
cx.executor().run_until_parked();
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
model.end_last_completion_stream();
edit_task.await
};
assert!(edit_result.is_ok());
// Wait for any async operations (e.g. formatting) to complete
cx.executor().run_until_parked();
// Read the file to verify it was formatted automatically
let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
assert_eq!(
// Ignore carriage returns on Windows
new_content.replace("\r\n", "\n"),
FORMATTED_CONTENT,
"Code should be formatted when format_on_save is enabled"
);
// Next, test with format_on_save disabled
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
cx,
|settings| {
settings.defaults.format_on_save = Some(FormatOnSave::Off);
},
);
});
});
// Stream unformatted edits again
let edit_result = {
let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Update main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
});
// Stream the unformatted content
cx.executor().run_until_parked();
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
model.end_last_completion_stream();
edit_task.await
};
assert!(edit_result.is_ok());
// Wait for any async operations (e.g. formatting) to complete
cx.executor().run_until_parked();
// Verify the file was not formatted
let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
assert_eq!(
// Ignore carriage returns on Windows
new_content.replace("\r\n", "\n"),
UNFORMATTED_CONTENT,
"Code should not be formatted when format_on_save is disabled"
);
}
#[gpui::test]
async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({"src": {}})).await;
// Create a simple file with trailing whitespace
fs.save(
path!("/root/src/main.rs").as_ref(),
&"initial content".into(),
language::LineEnding::Unix,
)
.await
.unwrap();
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
// First, test with remove_trailing_whitespace_on_save enabled
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
cx,
|settings| {
settings.defaults.remove_trailing_whitespace_on_save = Some(true);
},
);
});
});
const CONTENT_WITH_TRAILING_WHITESPACE: &str =
"fn main() { \n println!(\"Hello!\"); \n}\n";
// Have the model stream content that contains trailing whitespace
let edit_result = {
let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Create main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
});
// Stream the content with trailing whitespace
cx.executor().run_until_parked();
model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
model.end_last_completion_stream();
edit_task.await
};
assert!(edit_result.is_ok());
// Wait for any async operations (e.g. formatting) to complete
cx.executor().run_until_parked();
// Read the file to verify trailing whitespace was removed automatically
assert_eq!(
// Ignore carriage returns on Windows
fs.load(path!("/root/src/main.rs").as_ref())
.await
.unwrap()
.replace("\r\n", "\n"),
"fn main() {\n println!(\"Hello!\");\n}\n",
"Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
);
// Next, test with remove_trailing_whitespace_on_save disabled
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
cx,
|settings| {
settings.defaults.remove_trailing_whitespace_on_save = Some(false);
},
);
});
});
// Stream edits again with trailing whitespace
let edit_result = {
let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Update main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
});
// Stream the content with trailing whitespace
cx.executor().run_until_parked();
model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
model.end_last_completion_stream();
edit_task.await
};
assert!(edit_result.is_ok());
// Wait for any async operations (e.g. formatting) to complete
cx.executor().run_until_parked();
// Verify the file still has trailing whitespace
// Read the file again - it should still have trailing whitespace
let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
assert_eq!(
// Ignore carriage returns on Windows
final_content.replace("\r\n", "\n"),
CONTENT_WITH_TRAILING_WHITESPACE,
"Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
);
}
#[gpui::test]
async fn test_format_on_save_marks_buffer_as_conflicted(cx: &mut TestAppContext) {
// This test demonstrates a bug where format-on-save causes the buffer to be
// incorrectly marked as stale, leading the agent to think the file has been
// modified externally when it was actually just formatted as part of the save.
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/root"), json!({"src": {}})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
// Set up a Rust language with LSP formatting support
let rust_language = Arc::new(language::Language::new(
language::LanguageConfig {
name: "Rust".into(),
matcher: language::LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
None,
));
// Register the language and fake LSP
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
language_registry.add(rust_language);
let mut fake_language_servers = language_registry.register_fake_lsp(
"Rust",
language::FakeLspAdapter {
capabilities: lsp::ServerCapabilities {
document_formatting_provider: Some(lsp::OneOf::Left(true)),
..Default::default()
},
..Default::default()
},
);
// Create the file with some initial content
fs.save(
path!("/root/src/main.rs").as_ref(),
&"fn main() { }\n".into(),
language::LineEnding::Unix,
)
.await
.unwrap();
// Open the buffer to trigger LSP initialization
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(path!("/root/src/main.rs"), cx)
})
.await
.unwrap();
// Register the buffer with language servers
let _handle = project.update(cx, |project, cx| {
project.register_buffer_with_language_servers(&buffer, cx)
});
const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
const FORMATTED_CONTENT: &str = "fn main() {\n println!(\"Hello!\");\n}\n";
// Get the fake language server and set up formatting handler
let fake_language_server = fake_language_servers.next().await.unwrap();
fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
move |_, _| async move {
Ok(Some(vec![lsp::TextEdit {
range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
new_text: FORMATTED_CONTENT.to_string(),
}]))
}
});
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
// Enable format_on_save
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
cx,
|settings| {
settings.defaults.format_on_save = Some(FormatOnSave::On);
settings.defaults.formatter =
Some(language::language_settings::SelectedFormatter::Auto);
},
);
});
});
// Have the agent edit the file (this should trigger format on save)
let edit_result = {
let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Update main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
});
// Stream the unformatted content
cx.executor().run_until_parked();
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
model.end_last_completion_stream();
edit_task.await
};
assert!(edit_result.is_ok());
// Wait for any async operations (e.g. formatting) to complete
cx.executor().run_until_parked();
// Read the file to verify it was formatted automatically
let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
assert_eq!(
new_content.replace("\r\n", "\n"),
FORMATTED_CONTENT,
"Code should be formatted when format_on_save is enabled"
);
// Now simulate what would happen if the agent re-reads the buffer
// (e.g., because it thinks the file changed externally)
let stale_buffer_count = action_log.read_with(cx, |log, cx| log.stale_buffers(cx).count());
// This assertion demonstrates the bug: the buffer should NOT be considered stale
// after format-on-save, but it is incorrectly marked as stale because the
// formatting process causes the file's mtime to change, making it appear as if
// the file was modified externally. This causes the agent to unnecessarily
// re-read the file, thinking it has changed outside of the agent's control.
assert_eq!(
stale_buffer_count, 0,
"BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
This causes the agent to think the file was modified externally when it was just formatted.",
stale_buffer_count
);
}
}