Compare commits
8 Commits
simplify-e
...
sqlite-ove
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ef60a60e11 | ||
|
|
5d2ac968d8 | ||
|
|
e9d4b8766f | ||
|
|
6812872d1a | ||
|
|
2aebeb067c | ||
|
|
7dfd5d1963 | ||
|
|
a677b891a1 | ||
|
|
a2cb480244 |
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -117,6 +117,7 @@ dependencies = [
|
||||
"streaming_diff",
|
||||
"telemetry",
|
||||
"telemetry_events",
|
||||
"tempfile",
|
||||
"terminal",
|
||||
"terminal_view",
|
||||
"text",
|
||||
@@ -683,6 +684,7 @@ dependencies = [
|
||||
"language_model",
|
||||
"language_models",
|
||||
"log",
|
||||
"lsp",
|
||||
"markdown",
|
||||
"open",
|
||||
"paths",
|
||||
|
||||
@@ -107,3 +107,4 @@ language = { workspace = true, "features" = ["test-support"] }
|
||||
language_model = { workspace = true, "features" = ["test-support"] }
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
rand.workspace = true
|
||||
tempfile.workspace = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::borrow::Cow;
|
||||
use std::cell::{Ref, RefCell};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -10,6 +10,10 @@ use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::HashMap;
|
||||
use context_server::ContextServerId;
|
||||
use db::sqlez::bindable::Column;
|
||||
use db::sqlez::statement::Statement;
|
||||
use db::sqlez_macros::sql;
|
||||
use db::{define_connection, query};
|
||||
use futures::channel::{mpsc, oneshot};
|
||||
use futures::future::{self, BoxFuture, Shared};
|
||||
use futures::{FutureExt as _, StreamExt as _};
|
||||
@@ -17,8 +21,7 @@ use gpui::{
|
||||
App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
|
||||
Subscription, Task, prelude::*,
|
||||
};
|
||||
use heed::Database;
|
||||
use heed::types::SerdeBincode;
|
||||
use heed;
|
||||
use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
|
||||
use project::context_server_store::{ContextServerStatus, ContextServerStore};
|
||||
use project::{Project, ProjectItem, ProjectPath, Worktree};
|
||||
@@ -36,6 +39,31 @@ use crate::thread::{
|
||||
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
|
||||
};
|
||||
|
||||
// Implement Bind trait for ThreadId to use in SQL queries
|
||||
// impl db::sqlez::bindable::Bind for ThreadId {
|
||||
// fn bind(&self, statement: &Statement, start_index: i32) -> anyhow::Result<i32> {
|
||||
// self.to_string().bind(statement, start_index)
|
||||
// }
|
||||
// }
|
||||
|
||||
// Implement Column trait for SerializedThreadMetadata
|
||||
impl Column for SerializedThreadMetadata {
|
||||
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
let (id_str, next_index): (String, i32) = Column::column(statement, start_index)?;
|
||||
let (summary, next_index): (String, i32) = Column::column(statement, next_index)?;
|
||||
let (updated_at_timestamp, next_index): (i64, i32) = Column::column(statement, next_index)?;
|
||||
|
||||
Ok((
|
||||
Self {
|
||||
id: ThreadId::from(id_str.as_str()),
|
||||
summary: summary.into(),
|
||||
updated_at: DateTime::from_timestamp(updated_at_timestamp, 0).unwrap_or_default(),
|
||||
},
|
||||
next_index,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
const RULES_FILE_NAMES: [&'static str; 6] = [
|
||||
".rules",
|
||||
".cursorrules",
|
||||
@@ -657,6 +685,7 @@ pub struct SerializedThreadMetadata {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(test, derive(Clone))]
|
||||
pub struct SerializedThread {
|
||||
pub version: String,
|
||||
pub summary: SharedString,
|
||||
@@ -681,6 +710,7 @@ pub struct SerializedThread {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(test, derive(Clone))]
|
||||
pub struct SerializedLanguageModel {
|
||||
pub provider: String,
|
||||
pub model: String,
|
||||
@@ -745,7 +775,7 @@ impl SerializedThreadV0_1_0 {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct SerializedMessage {
|
||||
pub id: MessageId,
|
||||
pub role: Role,
|
||||
@@ -763,7 +793,7 @@ pub struct SerializedMessage {
|
||||
pub is_hidden: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum SerializedMessageSegment {
|
||||
#[serde(rename = "text")]
|
||||
@@ -781,14 +811,14 @@ pub enum SerializedMessageSegment {
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct SerializedToolUse {
|
||||
pub id: LanguageModelToolUseId,
|
||||
pub name: SharedString,
|
||||
pub input: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct SerializedToolResult {
|
||||
pub tool_use_id: LanguageModelToolUseId,
|
||||
pub is_error: bool,
|
||||
@@ -850,7 +880,7 @@ impl LegacySerializedMessage {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct SerializedCrease {
|
||||
pub start: usize,
|
||||
pub end: usize,
|
||||
@@ -866,26 +896,6 @@ impl Global for GlobalThreadsDatabase {}
|
||||
|
||||
pub(crate) struct ThreadsDatabase {
|
||||
executor: BackgroundExecutor,
|
||||
env: heed::Env,
|
||||
threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
|
||||
}
|
||||
|
||||
impl heed::BytesEncode<'_> for SerializedThread {
|
||||
type EItem = SerializedThread;
|
||||
|
||||
fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
|
||||
serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> heed::BytesDecode<'a> for SerializedThread {
|
||||
type DItem = SerializedThread;
|
||||
|
||||
fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
|
||||
// We implement this type manually because we want to call `SerializedThread::from_json`,
|
||||
// instead of the Deserialize trait implementation for `SerializedThread`.
|
||||
SerializedThread::from_json(bytes).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
impl ThreadsDatabase {
|
||||
@@ -900,8 +910,7 @@ impl ThreadsDatabase {
|
||||
let database_future = executor
|
||||
.spawn({
|
||||
let executor = executor.clone();
|
||||
let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
|
||||
async move { ThreadsDatabase::new(database_path, executor) }
|
||||
async move { ThreadsDatabase::new(executor).await }
|
||||
})
|
||||
.then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
|
||||
.boxed()
|
||||
@@ -910,80 +919,511 @@ impl ThreadsDatabase {
|
||||
cx.set_global(GlobalThreadsDatabase(database_future));
|
||||
}
|
||||
|
||||
pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
|
||||
std::fs::create_dir_all(&path)?;
|
||||
|
||||
const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
|
||||
let env = unsafe {
|
||||
heed::EnvOpenOptions::new()
|
||||
.map_size(ONE_GB_IN_BYTES)
|
||||
.max_dbs(1)
|
||||
.open(path)?
|
||||
};
|
||||
|
||||
let mut txn = env.write_txn()?;
|
||||
let threads = env.create_database(&mut txn, Some("threads"))?;
|
||||
txn.commit()?;
|
||||
|
||||
Ok(Self {
|
||||
executor,
|
||||
env,
|
||||
threads,
|
||||
})
|
||||
pub async fn new(executor: BackgroundExecutor) -> Result<Self> {
|
||||
Ok(Self { executor })
|
||||
}
|
||||
|
||||
pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
|
||||
let env = self.env.clone();
|
||||
let threads = self.threads;
|
||||
|
||||
self.executor.spawn(async move {
|
||||
let txn = env.read_txn()?;
|
||||
let mut iter = threads.iter(&txn)?;
|
||||
let mut threads = Vec::new();
|
||||
while let Some((key, value)) = iter.next().transpose()? {
|
||||
threads.push(SerializedThreadMetadata {
|
||||
id: key,
|
||||
summary: value.summary,
|
||||
updated_at: value.updated_at,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(threads)
|
||||
})
|
||||
self.executor
|
||||
.spawn(async move { AGENT_THREADS.all_threads().await })
|
||||
}
|
||||
|
||||
pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
|
||||
let env = self.env.clone();
|
||||
let threads = self.threads;
|
||||
|
||||
self.executor.spawn(async move {
|
||||
let txn = env.read_txn()?;
|
||||
let thread = threads.get(&txn, &id)?;
|
||||
Ok(thread)
|
||||
})
|
||||
self.executor
|
||||
.spawn(async move { AGENT_THREADS.get_thread(id).await })
|
||||
}
|
||||
|
||||
pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
|
||||
let env = self.env.clone();
|
||||
let threads = self.threads;
|
||||
|
||||
self.executor.spawn(async move {
|
||||
let mut txn = env.write_txn()?;
|
||||
threads.put(&mut txn, &id, &thread)?;
|
||||
txn.commit()?;
|
||||
Ok(())
|
||||
})
|
||||
self.executor
|
||||
.spawn(async move { AGENT_THREADS.save_thread(id, thread).await })
|
||||
}
|
||||
|
||||
pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
|
||||
let env = self.env.clone();
|
||||
let threads = self.threads;
|
||||
self.executor
|
||||
.spawn(async move { AGENT_THREADS.delete_thread_by_id(id).await })
|
||||
}
|
||||
|
||||
self.executor.spawn(async move {
|
||||
let mut txn = env.write_txn()?;
|
||||
threads.delete(&mut txn, &id)?;
|
||||
txn.commit()?;
|
||||
Ok(())
|
||||
})
|
||||
/// Migrate a legacy `heed` LMDB database to SQLite
|
||||
pub async fn migrate_from_heed(heed_path: &Path) -> Result<()> {
|
||||
Self::migrate_from_heed_to_db(heed_path, &AGENT_THREADS).await
|
||||
}
|
||||
|
||||
/// Migrate a legacy `heed` LMDB database to a specific SQLite database
|
||||
pub async fn migrate_from_heed_to_db(heed_path: &Path, db: &ThreadStoreDB) -> Result<()> {
|
||||
if !heed_path.exists() {
|
||||
return Ok(()); // No migration needed
|
||||
}
|
||||
|
||||
// Open the old heed database
|
||||
let env = unsafe {
|
||||
heed::EnvOpenOptions::new()
|
||||
.map_size(1024 * 1024 * 1024) // 1GB
|
||||
.max_dbs(1)
|
||||
.open(&heed_path)?
|
||||
};
|
||||
|
||||
let txn = env.read_txn()?;
|
||||
let old_threads: heed::Database<heed::types::SerdeBincode<ThreadId>, SerializedThread> =
|
||||
env.open_database(&txn, Some("threads"))?
|
||||
.ok_or_else(|| anyhow!("threads database not found"))?;
|
||||
|
||||
// Migrate all threads
|
||||
for result in old_threads.iter(&txn)? {
|
||||
if let Some((id, thread)) = result.log_err() {
|
||||
db.save_thread(id, thread).await.log_err();
|
||||
}
|
||||
}
|
||||
|
||||
drop(txn);
|
||||
drop(env);
|
||||
|
||||
// Rename the old heed database with .bak suffix
|
||||
let mut backup_path = heed_path.to_path_buf();
|
||||
let file_name = heed_path
|
||||
.file_name()
|
||||
.ok_or_else(|| anyhow!("invalid heed path"))?;
|
||||
let new_name = format!("{}.bak", file_name.to_string_lossy());
|
||||
backup_path.set_file_name(new_name);
|
||||
std::fs::rename(&heed_path, &backup_path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// Heed serialization helpers for migration
|
||||
impl heed::BytesEncode<'_> for SerializedThread {
|
||||
type EItem = SerializedThread;
|
||||
|
||||
fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
|
||||
serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> heed::BytesDecode<'a> for SerializedThread {
|
||||
type DItem = SerializedThread;
|
||||
|
||||
fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
|
||||
SerializedThread::from_json(bytes).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
define_connection!(pub static ref AGENT_THREADS: ThreadStoreDB<()> =
|
||||
&[sql!(
|
||||
CREATE TABLE IF NOT EXISTS agent_threads(
|
||||
id TEXT PRIMARY KEY,
|
||||
summary TEXT NOT NULL,
|
||||
updated_at INTEGER NOT NULL,
|
||||
data TEXT NOT NULL
|
||||
) STRICT;
|
||||
)];
|
||||
);
|
||||
|
||||
impl ThreadStoreDB {
|
||||
query! {
|
||||
pub async fn all_threads() -> Result<Vec<SerializedThreadMetadata>> {
|
||||
SELECT id, summary, updated_at
|
||||
FROM agent_threads
|
||||
ORDER BY updated_at DESC
|
||||
}
|
||||
}
|
||||
|
||||
query! {
|
||||
async fn get_thread_data(id: String) -> Result<Option<String>> {
|
||||
SELECT data FROM agent_threads WHERE id = (?)
|
||||
}
|
||||
}
|
||||
|
||||
query! {
|
||||
async fn save_thread_data(id: String, summary: String, updated_at: i64, data: String) -> Result<()> {
|
||||
INSERT OR REPLACE INTO agent_threads (id, summary, updated_at, data)
|
||||
VALUES ((?), (?), (?), (?))
|
||||
}
|
||||
}
|
||||
|
||||
query! {
|
||||
async fn delete_thread_data(id: String) -> Result<()> {
|
||||
DELETE FROM agent_threads WHERE id = (?)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_thread(&self, id: ThreadId) -> Result<Option<SerializedThread>> {
|
||||
let id_str = id.to_string();
|
||||
let result = self.get_thread_data(id_str).await?;
|
||||
|
||||
match result {
|
||||
Some(json_str) => {
|
||||
let thread = SerializedThread::from_json(json_str.as_bytes())?;
|
||||
Ok(Some(thread))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Result<()> {
|
||||
let thread_json = serde_json::to_string(&thread)?;
|
||||
let updated_at = thread.updated_at.timestamp();
|
||||
let id_str = id.to_string();
|
||||
let summary = thread.summary.clone();
|
||||
|
||||
self.save_thread_data(id_str, summary.to_string(), updated_at, thread_json)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn delete_thread_by_id(&self, id: ThreadId) -> Result<()> {
|
||||
let id_str = id.to_string();
|
||||
self.delete_thread_data(id_str).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::Utc;
|
||||
use gpui::TestAppContext;
|
||||
use std::sync::Arc;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_save_load_delete_threads(_cx: &mut TestAppContext) {
|
||||
let db = ThreadStoreDB::open_test_db("test_save_load_delete_threads").await;
|
||||
|
||||
// Test that no threads exist initially
|
||||
let threads = db.all_threads().await.unwrap();
|
||||
assert_eq!(threads.len(), 0);
|
||||
|
||||
// Create test thread data
|
||||
let thread_id = ThreadId::from("test-thread-1");
|
||||
let thread = SerializedThread {
|
||||
version: SerializedThread::VERSION.to_string(),
|
||||
summary: SharedString::from("Test thread summary"),
|
||||
updated_at: Utc::now(),
|
||||
messages: vec![],
|
||||
initial_project_snapshot: None,
|
||||
cumulative_token_usage: TokenUsage::default(),
|
||||
request_token_usage: vec![],
|
||||
detailed_summary_state: DetailedSummaryState::NotGenerated,
|
||||
exceeded_window_error: None,
|
||||
model: None,
|
||||
completion_mode: Some(CompletionMode::Normal),
|
||||
tool_use_limit_reached: false,
|
||||
};
|
||||
|
||||
let thread_summary = thread.summary.clone();
|
||||
let thread_version = thread.version.clone();
|
||||
|
||||
// Save thread
|
||||
db.save_thread(thread_id.clone(), thread.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Load all threads
|
||||
let threads = db.all_threads().await.unwrap();
|
||||
assert_eq!(threads.len(), 1);
|
||||
assert_eq!(threads[0].id, thread_id);
|
||||
assert_eq!(threads[0].summary, thread_summary);
|
||||
|
||||
// Load specific thread
|
||||
let loaded_thread = db.get_thread(thread_id.clone()).await.unwrap();
|
||||
assert!(loaded_thread.is_some());
|
||||
let loaded_thread = loaded_thread.unwrap();
|
||||
assert_eq!(loaded_thread.summary, thread_summary);
|
||||
assert_eq!(loaded_thread.version, thread_version);
|
||||
|
||||
// Update thread
|
||||
let updated_thread = SerializedThread {
|
||||
summary: SharedString::from("Updated summary"),
|
||||
updated_at: Utc::now(),
|
||||
..thread
|
||||
};
|
||||
db.save_thread(thread_id.clone(), updated_thread.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify update
|
||||
let loaded_thread = db.get_thread(thread_id.clone()).await.unwrap().unwrap();
|
||||
assert_eq!(loaded_thread.summary, SharedString::from("Updated summary"));
|
||||
|
||||
// Delete thread
|
||||
db.delete_thread_by_id(thread_id.clone()).await.unwrap();
|
||||
|
||||
// Verify deletion
|
||||
let loaded_thread = db.get_thread(thread_id.clone()).await.unwrap();
|
||||
assert!(loaded_thread.is_none());
|
||||
|
||||
let threads = db.all_threads().await.unwrap();
|
||||
assert_eq!(threads.len(), 0);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_multiple_threads(_cx: &mut TestAppContext) {
|
||||
let db = ThreadStoreDB::open_test_db("test_multiple_threads").await;
|
||||
|
||||
// Create multiple threads
|
||||
let thread_ids = [
|
||||
ThreadId::from("thread-1"),
|
||||
ThreadId::from("thread-2"),
|
||||
ThreadId::from("thread-3"),
|
||||
];
|
||||
|
||||
for (i, thread_id) in thread_ids.iter().enumerate() {
|
||||
let thread = SerializedThread {
|
||||
version: SerializedThread::VERSION.to_string(),
|
||||
summary: SharedString::from(format!("Thread {}", i + 1)),
|
||||
updated_at: Utc::now() - chrono::Duration::hours(i as i64),
|
||||
messages: vec![],
|
||||
initial_project_snapshot: None,
|
||||
cumulative_token_usage: TokenUsage::default(),
|
||||
request_token_usage: Vec::new(),
|
||||
detailed_summary_state: DetailedSummaryState::NotGenerated,
|
||||
exceeded_window_error: None,
|
||||
model: None,
|
||||
completion_mode: Some(CompletionMode::Normal),
|
||||
tool_use_limit_reached: false,
|
||||
};
|
||||
db.save_thread(thread_id.clone(), thread).await.unwrap();
|
||||
}
|
||||
|
||||
// Load all threads - should be ordered by updated_at DESC
|
||||
let threads = db.all_threads().await.unwrap();
|
||||
assert_eq!(threads.len(), 3);
|
||||
assert_eq!(threads[0].summary.as_ref(), "Thread 1");
|
||||
assert_eq!(threads[1].summary.as_ref(), "Thread 2");
|
||||
assert_eq!(threads[2].summary.as_ref(), "Thread 3");
|
||||
|
||||
// Delete middle thread
|
||||
db.delete_thread_by_id(thread_ids[1].clone()).await.unwrap();
|
||||
|
||||
let threads = db.all_threads().await.unwrap();
|
||||
assert_eq!(threads.len(), 2);
|
||||
assert_eq!(threads[0].summary.as_ref(), "Thread 1");
|
||||
assert_eq!(threads[1].summary.as_ref(), "Thread 3");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_heed_to_sqlite_migration(_cx: &mut TestAppContext) {
|
||||
use heed::types::SerdeBincode;
|
||||
|
||||
// Create a temporary directory for the heed database
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let heed_path = temp_dir.path().join("test-heed-db");
|
||||
|
||||
// Create and populate heed database
|
||||
{
|
||||
std::fs::create_dir_all(&heed_path).unwrap();
|
||||
let env = unsafe {
|
||||
heed::EnvOpenOptions::new()
|
||||
.map_size(1024 * 1024 * 1024)
|
||||
.max_dbs(1)
|
||||
.open(&heed_path)
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let mut txn = env.write_txn().unwrap();
|
||||
let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThread> =
|
||||
env.create_database(&mut txn, Some("threads")).unwrap();
|
||||
|
||||
// Insert test data
|
||||
let thread_ids = [
|
||||
ThreadId::from("legacy-thread-1"),
|
||||
ThreadId::from("legacy-thread-2"),
|
||||
ThreadId::from("legacy-thread-3"),
|
||||
];
|
||||
|
||||
for (i, thread_id) in thread_ids.iter().enumerate() {
|
||||
let thread = SerializedThread {
|
||||
version: SerializedThread::VERSION.to_string(),
|
||||
summary: SharedString::from(format!("Legacy Thread {}", i + 1)),
|
||||
updated_at: DateTime::from_timestamp(1700000000 - (i as i64) * 86400, 0)
|
||||
.unwrap(),
|
||||
messages: vec![SerializedMessage {
|
||||
id: MessageId(i),
|
||||
role: Role::User,
|
||||
segments: vec![SerializedMessageSegment::Text {
|
||||
text: format!("Test message {}", i),
|
||||
}],
|
||||
tool_uses: vec![],
|
||||
tool_results: vec![],
|
||||
context: String::new(),
|
||||
creases: vec![],
|
||||
is_hidden: false,
|
||||
}],
|
||||
initial_project_snapshot: None,
|
||||
cumulative_token_usage: TokenUsage {
|
||||
input_tokens: ((i + 1) * 100) as u32,
|
||||
output_tokens: ((i + 1) * 50) as u32,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
},
|
||||
request_token_usage: vec![],
|
||||
detailed_summary_state: DetailedSummaryState::NotGenerated,
|
||||
exceeded_window_error: None,
|
||||
model: None,
|
||||
completion_mode: Some(CompletionMode::Normal),
|
||||
tool_use_limit_reached: false,
|
||||
};
|
||||
threads.put(&mut txn, thread_id, &thread).unwrap();
|
||||
}
|
||||
|
||||
txn.commit().unwrap();
|
||||
}
|
||||
|
||||
// Clear any existing SQLite data
|
||||
let db = ThreadStoreDB::open_test_db("test_heed_to_sqlite_migration").await;
|
||||
|
||||
// Verify SQLite is empty
|
||||
let threads_before = db.all_threads().await.unwrap();
|
||||
assert_eq!(threads_before.len(), 0);
|
||||
|
||||
// Run migration
|
||||
ThreadsDatabase::migrate_from_heed_to_db(&heed_path, &db)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify all threads were migrated
|
||||
let threads_after = db.all_threads().await.unwrap();
|
||||
assert_eq!(threads_after.len(), 3);
|
||||
|
||||
// Verify thread metadata
|
||||
let thread_summaries: Vec<_> = threads_after.iter().map(|t| t.summary.as_ref()).collect();
|
||||
assert!(thread_summaries.contains(&"Legacy Thread 1"));
|
||||
assert!(thread_summaries.contains(&"Legacy Thread 2"));
|
||||
assert!(thread_summaries.contains(&"Legacy Thread 3"));
|
||||
|
||||
// Verify full thread data
|
||||
for i in 1..=3 {
|
||||
let thread_id = ThreadId::from(&format!("legacy-thread-{}", i) as &str);
|
||||
let thread = db.get_thread(thread_id).await.unwrap().unwrap();
|
||||
assert_eq!(thread.summary.as_ref(), format!("Legacy Thread {}", i));
|
||||
assert_eq!(thread.messages.len(), 1);
|
||||
assert_eq!(
|
||||
thread.messages[0].segments[0],
|
||||
SerializedMessageSegment::Text {
|
||||
text: format!("Test message {}", i - 1)
|
||||
}
|
||||
);
|
||||
assert_eq!(thread.cumulative_token_usage.input_tokens, (i * 100) as u32);
|
||||
assert_eq!(thread.cumulative_token_usage.output_tokens, (i * 50) as u32);
|
||||
}
|
||||
|
||||
// Verify heed database was renamed with .bak suffix
|
||||
assert!(!heed_path.exists());
|
||||
let mut backup_path = heed_path.to_path_buf();
|
||||
backup_path.set_file_name(format!("{}.bak", heed_path.file_name().unwrap().to_string_lossy()));
|
||||
assert!(backup_path.exists());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_thread_serialization_deserialization(_cx: &mut TestAppContext) {
|
||||
let db = ThreadStoreDB::open_test_db("test_thread_serialization_deserialization").await;
|
||||
|
||||
let thread_id = ThreadId::from("serialization-test");
|
||||
let original_thread = SerializedThread {
|
||||
version: SerializedThread::VERSION.to_string(),
|
||||
summary: SharedString::from("Serialization test thread"),
|
||||
updated_at: Utc::now(),
|
||||
messages: vec![
|
||||
SerializedMessage {
|
||||
id: MessageId(1),
|
||||
role: Role::User,
|
||||
segments: vec![
|
||||
SerializedMessageSegment::Text {
|
||||
text: "Hello".to_string(),
|
||||
},
|
||||
SerializedMessageSegment::Thinking {
|
||||
text: "Thinking about the response".to_string(),
|
||||
signature: Some("sig123".to_string()),
|
||||
},
|
||||
],
|
||||
tool_uses: vec![SerializedToolUse {
|
||||
id: LanguageModelToolUseId::from("tool-1"),
|
||||
name: SharedString::from("test_tool"),
|
||||
input: serde_json::json!({"key": "value"}),
|
||||
}],
|
||||
tool_results: vec![SerializedToolResult {
|
||||
tool_use_id: LanguageModelToolUseId::from("tool-1"),
|
||||
is_error: false,
|
||||
content: LanguageModelToolResultContent::Text("Result".into()),
|
||||
output: None,
|
||||
}],
|
||||
context: String::new(),
|
||||
creases: vec![SerializedCrease {
|
||||
start: 0,
|
||||
end: 5,
|
||||
icon_path: SharedString::from("icon.png"),
|
||||
label: SharedString::from("test-crease"),
|
||||
}],
|
||||
is_hidden: false,
|
||||
},
|
||||
SerializedMessage {
|
||||
id: MessageId(2),
|
||||
role: Role::Assistant,
|
||||
segments: vec![SerializedMessageSegment::RedactedThinking {
|
||||
data: vec![1, 2, 3, 4, 5],
|
||||
}],
|
||||
tool_uses: vec![],
|
||||
tool_results: vec![],
|
||||
context: String::new(),
|
||||
creases: vec![],
|
||||
is_hidden: true,
|
||||
},
|
||||
],
|
||||
initial_project_snapshot: Some(Arc::new(ProjectSnapshot {
|
||||
worktree_snapshots: vec![],
|
||||
unsaved_buffer_paths: vec![],
|
||||
timestamp: Utc::now(),
|
||||
})),
|
||||
cumulative_token_usage: TokenUsage {
|
||||
input_tokens: 1000,
|
||||
output_tokens: 500,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
},
|
||||
request_token_usage: vec![TokenUsage {
|
||||
input_tokens: 100,
|
||||
output_tokens: 50,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
}],
|
||||
detailed_summary_state: DetailedSummaryState::Generated {
|
||||
text: SharedString::from("Detailed summary"),
|
||||
message_id: MessageId(1),
|
||||
},
|
||||
exceeded_window_error: None,
|
||||
model: Some(SerializedLanguageModel {
|
||||
provider: "test-provider".to_string(),
|
||||
model: "test-model".to_string(),
|
||||
}),
|
||||
completion_mode: Some(CompletionMode::Normal),
|
||||
tool_use_limit_reached: true,
|
||||
};
|
||||
|
||||
// Save thread
|
||||
db.save_thread(thread_id.clone(), original_thread.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Load thread
|
||||
let loaded_thread = db.get_thread(thread_id).await.unwrap().unwrap();
|
||||
|
||||
// Verify all fields
|
||||
assert_eq!(loaded_thread.version, original_thread.version);
|
||||
assert_eq!(loaded_thread.summary, original_thread.summary);
|
||||
assert_eq!(loaded_thread.messages.len(), original_thread.messages.len());
|
||||
assert_eq!(loaded_thread.messages[0].segments.len(), 2);
|
||||
assert_eq!(loaded_thread.messages[0].tool_uses.len(), 1);
|
||||
assert_eq!(loaded_thread.messages[0].tool_results.len(), 1);
|
||||
assert_eq!(loaded_thread.messages[0].creases.len(), 1);
|
||||
assert_eq!(loaded_thread.messages[1].is_hidden, true);
|
||||
assert!(loaded_thread.initial_project_snapshot.is_some());
|
||||
assert_eq!(
|
||||
loaded_thread.cumulative_token_usage.input_tokens,
|
||||
original_thread.cumulative_token_usage.input_tokens
|
||||
);
|
||||
assert_eq!(
|
||||
loaded_thread.exceeded_window_error.is_none(),
|
||||
original_thread.exceeded_window_error.is_none()
|
||||
);
|
||||
assert!(loaded_thread.model.is_some());
|
||||
assert_eq!(loaded_thread.tool_use_limit_reached, true);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ itertools.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
log.workspace = true
|
||||
lsp.workspace = true
|
||||
markdown.workspace = true
|
||||
open.workspace = true
|
||||
paths.workspace = true
|
||||
@@ -64,6 +65,7 @@ workspace.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
lsp = { workspace = true, features = ["test-support"] }
|
||||
client = { workspace = true, features = ["test-support"] }
|
||||
clock = { workspace = true, features = ["test-support"] }
|
||||
collections = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -26,7 +26,6 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{cmp, iter, mem, ops::Range, path::PathBuf, sync::Arc, task::Poll};
|
||||
use streaming_diff::{CharOperation, StreamingDiff};
|
||||
use util::debug_panic;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct CreateFilePromptTemplate {
|
||||
@@ -582,8 +581,9 @@ impl EditAgent {
|
||||
conversation.messages.pop();
|
||||
}
|
||||
} else {
|
||||
debug_panic!(
|
||||
"Last message must be an Assistant tool calling! Got {:?}",
|
||||
log::error!(
|
||||
"Last message should have had a role of Assistant, but instead was {:?}. Full content of that message: {:?}",
|
||||
last_message.role,
|
||||
last_message.content
|
||||
);
|
||||
}
|
||||
|
||||
@@ -18,15 +18,19 @@ use gpui::{
|
||||
use indoc::formatdoc;
|
||||
use language::{
|
||||
Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer,
|
||||
language_settings::SoftWrap,
|
||||
language_settings::{self, FormatOnSave, SoftWrap},
|
||||
};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
|
||||
use project::{Project, ProjectPath};
|
||||
use project::{
|
||||
Project, ProjectPath,
|
||||
lsp_store::{FormatTrigger, LspFormatTarget},
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
@@ -34,6 +38,7 @@ use std::{
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Disclosure, Tooltip, prelude::*};
|
||||
use util::ResultExt;
|
||||
|
||||
use workspace::Workspace;
|
||||
|
||||
pub struct EditFileTool;
|
||||
@@ -187,8 +192,9 @@ impl Tool for EditFileTool {
|
||||
});
|
||||
|
||||
let card_clone = card.clone();
|
||||
let action_log_clone = action_log.clone();
|
||||
let task = cx.spawn(async move |cx: &mut AsyncApp| {
|
||||
let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
|
||||
let edit_agent = EditAgent::new(model, project.clone(), action_log_clone, Templates::new());
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
@@ -249,19 +255,49 @@ impl Tool for EditFileTool {
|
||||
}
|
||||
let agent_output = output.await?;
|
||||
|
||||
// Check if format_on_save is enabled
|
||||
let format_on_save_enabled = buffer
|
||||
.read_with(cx, |buffer, cx| {
|
||||
let settings = language_settings::language_settings(
|
||||
buffer.language().map(|l| l.name()),
|
||||
buffer.file(),
|
||||
cx,
|
||||
);
|
||||
!matches!(settings.format_on_save, FormatOnSave::Off)
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
// If format_on_save is enabled, format the buffer before saving
|
||||
if format_on_save_enabled {
|
||||
let format_task = project.update(cx, |project, cx| {
|
||||
project.format(
|
||||
HashSet::from_iter([buffer.clone()]),
|
||||
LspFormatTarget::Buffers,
|
||||
false, // Don't push to history since the tool did it.
|
||||
FormatTrigger::Save,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
format_task.await.log_err();
|
||||
}
|
||||
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
|
||||
.await?;
|
||||
|
||||
// Notify the action log that we've edited the buffer AFTER save completes
|
||||
// This ensures the tracked version matches the saved version
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_edited(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
let new_text = cx.background_spawn({
|
||||
let new_snapshot = new_snapshot.clone();
|
||||
async move { new_snapshot.text() }
|
||||
});
|
||||
let diff = cx.background_spawn(async move {
|
||||
language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
|
||||
});
|
||||
let (new_text, diff) = futures::join!(new_text, diff);
|
||||
})
|
||||
.await;
|
||||
let diff = language::unified_diff(&old_text, &new_text);
|
||||
|
||||
let output = EditFileToolOutput {
|
||||
original_path: project_path.path.to_path_buf(),
|
||||
@@ -272,7 +308,7 @@ impl Tool for EditFileTool {
|
||||
|
||||
if let Some(card) = card_clone {
|
||||
card.update(cx, |card, cx| {
|
||||
card.set_diff(project_path.path.clone(), old_text, new_text, cx);
|
||||
card.set_diff(project_path.path.clone(), old_text, new_text.clone(), cx);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
@@ -920,8 +956,8 @@ async fn build_buffer_diff(
|
||||
mod tests {
|
||||
use super::*;
|
||||
use client::TelemetrySettings;
|
||||
use fs::FakeFs;
|
||||
use gpui::TestAppContext;
|
||||
use fs::{FakeFs, Fs};
|
||||
use gpui::{TestAppContext, UpdateGlobal};
|
||||
use language_model::fake_provider::FakeLanguageModel;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
@@ -1131,4 +1167,483 @@ mod tests {
|
||||
Project::init_settings(cx);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_format_on_save(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/root", json!({"src": {}})).await;
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
|
||||
// Set up a Rust language with LSP formatting support
|
||||
let rust_language = Arc::new(language::Language::new(
|
||||
language::LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: language::LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
None,
|
||||
));
|
||||
|
||||
// Register the language and fake LSP
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
|
||||
language_registry.add(rust_language);
|
||||
|
||||
let mut fake_language_servers = language_registry.register_fake_lsp(
|
||||
"Rust",
|
||||
language::FakeLspAdapter {
|
||||
capabilities: lsp::ServerCapabilities {
|
||||
document_formatting_provider: Some(lsp::OneOf::Left(true)),
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
// Create the file
|
||||
fs.save(
|
||||
path!("/root/src/main.rs").as_ref(),
|
||||
&"initial content".into(),
|
||||
language::LineEnding::Unix,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Open the buffer to trigger LSP initialization
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer(path!("/root/src/main.rs"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Register the buffer with language servers
|
||||
let _handle = project.update(cx, |project, cx| {
|
||||
project.register_buffer_with_language_servers(&buffer, cx)
|
||||
});
|
||||
|
||||
const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
|
||||
const FORMATTED_CONTENT: &str =
|
||||
"This file was formatted by the fake formatter in the test.\n";
|
||||
|
||||
// Get the fake language server and set up formatting handler
|
||||
let fake_language_server = fake_language_servers.next().await.unwrap();
|
||||
fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
|
||||
|_, _| async move {
|
||||
Ok(Some(vec![lsp::TextEdit {
|
||||
range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
|
||||
new_text: FORMATTED_CONTENT.to_string(),
|
||||
}]))
|
||||
}
|
||||
});
|
||||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
|
||||
// First, test with format_on_save enabled
|
||||
cx.update(|cx| {
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
|
||||
cx,
|
||||
|settings| {
|
||||
settings.defaults.format_on_save = Some(FormatOnSave::On);
|
||||
settings.defaults.formatter =
|
||||
Some(language::language_settings::SelectedFormatter::Auto);
|
||||
},
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// Have the model stream unformatted content
|
||||
let edit_result = {
|
||||
let edit_task = cx.update(|cx| {
|
||||
let input = serde_json::to_value(EditFileToolInput {
|
||||
display_description: "Create main function".into(),
|
||||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
})
|
||||
.unwrap();
|
||||
Arc::new(EditFileTool)
|
||||
.run(
|
||||
input,
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
model.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.output
|
||||
});
|
||||
|
||||
// Stream the unformatted content
|
||||
cx.executor().run_until_parked();
|
||||
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
|
||||
model.end_last_completion_stream();
|
||||
|
||||
edit_task.await
|
||||
};
|
||||
assert!(edit_result.is_ok());
|
||||
|
||||
// Wait for any async operations (e.g. formatting) to complete
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
// Read the file to verify it was formatted automatically
|
||||
let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
|
||||
assert_eq!(
|
||||
// Ignore carriage returns on Windows
|
||||
new_content.replace("\r\n", "\n"),
|
||||
FORMATTED_CONTENT,
|
||||
"Code should be formatted when format_on_save is enabled"
|
||||
);
|
||||
|
||||
// Next, test with format_on_save disabled
|
||||
cx.update(|cx| {
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
|
||||
cx,
|
||||
|settings| {
|
||||
settings.defaults.format_on_save = Some(FormatOnSave::Off);
|
||||
},
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// Stream unformatted edits again
|
||||
let edit_result = {
|
||||
let edit_task = cx.update(|cx| {
|
||||
let input = serde_json::to_value(EditFileToolInput {
|
||||
display_description: "Update main function".into(),
|
||||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
})
|
||||
.unwrap();
|
||||
Arc::new(EditFileTool)
|
||||
.run(
|
||||
input,
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
model.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.output
|
||||
});
|
||||
|
||||
// Stream the unformatted content
|
||||
cx.executor().run_until_parked();
|
||||
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
|
||||
model.end_last_completion_stream();
|
||||
|
||||
edit_task.await
|
||||
};
|
||||
assert!(edit_result.is_ok());
|
||||
|
||||
// Wait for any async operations (e.g. formatting) to complete
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
// Verify the file was not formatted
|
||||
let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
|
||||
assert_eq!(
|
||||
// Ignore carriage returns on Windows
|
||||
new_content.replace("\r\n", "\n"),
|
||||
UNFORMATTED_CONTENT,
|
||||
"Code should not be formatted when format_on_save is disabled"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/root", json!({"src": {}})).await;
|
||||
|
||||
// Create a simple file with trailing whitespace
|
||||
fs.save(
|
||||
path!("/root/src/main.rs").as_ref(),
|
||||
&"initial content".into(),
|
||||
language::LineEnding::Unix,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
|
||||
// First, test with remove_trailing_whitespace_on_save enabled
|
||||
cx.update(|cx| {
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
|
||||
cx,
|
||||
|settings| {
|
||||
settings.defaults.remove_trailing_whitespace_on_save = Some(true);
|
||||
},
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
const CONTENT_WITH_TRAILING_WHITESPACE: &str =
|
||||
"fn main() { \n println!(\"Hello!\"); \n}\n";
|
||||
|
||||
// Have the model stream content that contains trailing whitespace
|
||||
let edit_result = {
|
||||
let edit_task = cx.update(|cx| {
|
||||
let input = serde_json::to_value(EditFileToolInput {
|
||||
display_description: "Create main function".into(),
|
||||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
})
|
||||
.unwrap();
|
||||
Arc::new(EditFileTool)
|
||||
.run(
|
||||
input,
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
model.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.output
|
||||
});
|
||||
|
||||
// Stream the content with trailing whitespace
|
||||
cx.executor().run_until_parked();
|
||||
model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
|
||||
model.end_last_completion_stream();
|
||||
|
||||
edit_task.await
|
||||
};
|
||||
assert!(edit_result.is_ok());
|
||||
|
||||
// Wait for any async operations (e.g. formatting) to complete
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
// Read the file to verify trailing whitespace was removed automatically
|
||||
assert_eq!(
|
||||
// Ignore carriage returns on Windows
|
||||
fs.load(path!("/root/src/main.rs").as_ref())
|
||||
.await
|
||||
.unwrap()
|
||||
.replace("\r\n", "\n"),
|
||||
"fn main() {\n println!(\"Hello!\");\n}\n",
|
||||
"Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
|
||||
);
|
||||
|
||||
// Next, test with remove_trailing_whitespace_on_save disabled
|
||||
cx.update(|cx| {
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
|
||||
cx,
|
||||
|settings| {
|
||||
settings.defaults.remove_trailing_whitespace_on_save = Some(false);
|
||||
},
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// Stream edits again with trailing whitespace
|
||||
let edit_result = {
|
||||
let edit_task = cx.update(|cx| {
|
||||
let input = serde_json::to_value(EditFileToolInput {
|
||||
display_description: "Update main function".into(),
|
||||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
})
|
||||
.unwrap();
|
||||
Arc::new(EditFileTool)
|
||||
.run(
|
||||
input,
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
model.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.output
|
||||
});
|
||||
|
||||
// Stream the content with trailing whitespace
|
||||
cx.executor().run_until_parked();
|
||||
model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
|
||||
model.end_last_completion_stream();
|
||||
|
||||
edit_task.await
|
||||
};
|
||||
assert!(edit_result.is_ok());
|
||||
|
||||
// Wait for any async operations (e.g. formatting) to complete
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
// Verify the file still has trailing whitespace
|
||||
// Read the file again - it should still have trailing whitespace
|
||||
let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
|
||||
assert_eq!(
|
||||
// Ignore carriage returns on Windows
|
||||
final_content.replace("\r\n", "\n"),
|
||||
CONTENT_WITH_TRAILING_WHITESPACE,
|
||||
"Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_format_on_save_marks_buffer_as_conflicted(cx: &mut TestAppContext) {
|
||||
// This test demonstrates a bug where format-on-save causes the buffer to be
|
||||
// incorrectly marked as stale, leading the agent to think the file has been
|
||||
// modified externally when it was actually just formatted as part of the save.
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/root"), json!({"src": {}})).await;
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
|
||||
// Set up a Rust language with LSP formatting support
|
||||
let rust_language = Arc::new(language::Language::new(
|
||||
language::LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: language::LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
None,
|
||||
));
|
||||
|
||||
// Register the language and fake LSP
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
|
||||
language_registry.add(rust_language);
|
||||
|
||||
let mut fake_language_servers = language_registry.register_fake_lsp(
|
||||
"Rust",
|
||||
language::FakeLspAdapter {
|
||||
capabilities: lsp::ServerCapabilities {
|
||||
document_formatting_provider: Some(lsp::OneOf::Left(true)),
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
// Create the file with some initial content
|
||||
fs.save(
|
||||
path!("/root/src/main.rs").as_ref(),
|
||||
&"fn main() { }\n".into(),
|
||||
language::LineEnding::Unix,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Open the buffer to trigger LSP initialization
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer(path!("/root/src/main.rs"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Register the buffer with language servers
|
||||
let _handle = project.update(cx, |project, cx| {
|
||||
project.register_buffer_with_language_servers(&buffer, cx)
|
||||
});
|
||||
|
||||
const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
|
||||
const FORMATTED_CONTENT: &str = "fn main() {\n println!(\"Hello!\");\n}\n";
|
||||
|
||||
// Get the fake language server and set up formatting handler
|
||||
let fake_language_server = fake_language_servers.next().await.unwrap();
|
||||
fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
|
||||
move |_, _| async move {
|
||||
Ok(Some(vec![lsp::TextEdit {
|
||||
range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
|
||||
new_text: FORMATTED_CONTENT.to_string(),
|
||||
}]))
|
||||
}
|
||||
});
|
||||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
|
||||
// Enable format_on_save
|
||||
cx.update(|cx| {
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
|
||||
cx,
|
||||
|settings| {
|
||||
settings.defaults.format_on_save = Some(FormatOnSave::On);
|
||||
settings.defaults.formatter =
|
||||
Some(language::language_settings::SelectedFormatter::Auto);
|
||||
},
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// Have the agent edit the file (this should trigger format on save)
|
||||
let edit_result = {
|
||||
let edit_task = cx.update(|cx| {
|
||||
let input = serde_json::to_value(EditFileToolInput {
|
||||
display_description: "Update main function".into(),
|
||||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
})
|
||||
.unwrap();
|
||||
Arc::new(EditFileTool)
|
||||
.run(
|
||||
input,
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
model.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.output
|
||||
});
|
||||
|
||||
// Stream the unformatted content
|
||||
cx.executor().run_until_parked();
|
||||
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
|
||||
model.end_last_completion_stream();
|
||||
|
||||
edit_task.await
|
||||
};
|
||||
assert!(edit_result.is_ok());
|
||||
|
||||
// Wait for any async operations (e.g. formatting) to complete
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
// Read the file to verify it was formatted automatically
|
||||
let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
|
||||
assert_eq!(
|
||||
new_content.replace("\r\n", "\n"),
|
||||
FORMATTED_CONTENT,
|
||||
"Code should be formatted when format_on_save is enabled"
|
||||
);
|
||||
|
||||
// Now simulate what would happen if the agent re-reads the buffer
|
||||
// (e.g., because it thinks the file changed externally)
|
||||
let stale_buffer_count = action_log.read_with(cx, |log, cx| log.stale_buffers(cx).count());
|
||||
|
||||
// This assertion demonstrates the bug: the buffer should NOT be considered stale
|
||||
// after format-on-save, but it is incorrectly marked as stale because the
|
||||
// formatting process causes the file's mtime to change, making it appear as if
|
||||
// the file was modified externally. This causes the agent to unnecessarily
|
||||
// re-read the file, thinking it has changed outside of the agent's control.
|
||||
assert_eq!(
|
||||
stale_buffer_count, 0,
|
||||
"BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
|
||||
This causes the agent to think the file was modified externally when it was just formatted.",
|
||||
stale_buffer_count
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user