Compare commits
2 Commits
quickfix
...
persist-fs
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad4742a5b8 | ||
|
|
b0d4abb82e |
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -11887,8 +11887,10 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"collections",
|
||||
"gpui",
|
||||
"mlua",
|
||||
"parking_lot",
|
||||
"regex",
|
||||
"schemars",
|
||||
"serde",
|
||||
|
||||
@@ -294,9 +294,9 @@ impl ActiveThread {
|
||||
cx.notify();
|
||||
}
|
||||
ThreadEvent::UsePendingTools => {
|
||||
let pending_tool_uses = self
|
||||
.thread
|
||||
.read(cx)
|
||||
let thread = self.thread.read(cx);
|
||||
let thread_id = thread.id().0.clone();
|
||||
let pending_tool_uses = thread
|
||||
.pending_tool_uses()
|
||||
.into_iter()
|
||||
.filter(|tool_use| tool_use.status.is_idle())
|
||||
@@ -305,7 +305,13 @@ impl ActiveThread {
|
||||
|
||||
for tool_use in pending_tool_uses {
|
||||
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
|
||||
let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
|
||||
let task = tool.run(
|
||||
tool_use.input,
|
||||
thread_id.clone(),
|
||||
self.workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
self.thread.update(cx, |thread, cx| {
|
||||
thread.insert_tool_output(tool_use.id.clone(), task, cx);
|
||||
|
||||
@@ -28,7 +28,7 @@ pub enum RequestKind {
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
|
||||
pub struct ThreadId(Arc<str>);
|
||||
pub struct ThreadId(pub Arc<str>);
|
||||
|
||||
impl ThreadId {
|
||||
pub fn new() -> Self {
|
||||
|
||||
@@ -31,6 +31,7 @@ pub trait Tool: 'static + Send + Sync {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
thread_id: Arc<str>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use assistant_tool::Tool;
|
||||
use chrono::{Local, Utc};
|
||||
use gpui::{App, Task, WeakEntity, Window};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
@@ -41,6 +40,7 @@ impl Tool for NowTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_thread_id: Arc<str>,
|
||||
_workspace: WeakEntity<workspace::Workspace>,
|
||||
_window: &mut Window,
|
||||
_cx: &mut App,
|
||||
|
||||
@@ -51,6 +51,7 @@ impl Tool for ContextServerTool {
|
||||
fn run(
|
||||
self: std::sync::Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_thread_id: Arc<str>,
|
||||
_workspace: gpui::WeakEntity<workspace::Workspace>,
|
||||
_: &mut Window,
|
||||
cx: &mut App,
|
||||
|
||||
@@ -15,8 +15,10 @@ doctest = false
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
collections.workspace = true
|
||||
gpui.workspace = true
|
||||
mlua.workspace = true
|
||||
parking_lot.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
use anyhow::anyhow;
|
||||
use assistant_tool::{Tool, ToolRegistry};
|
||||
use gpui::{App, AppContext as _, Task, WeakEntity, Window};
|
||||
use collections::HashMap;
|
||||
use gpui::{App, AppContext as _, Global, Task, WeakEntity, Window};
|
||||
use mlua::{Function, Lua, MultiValue, Result, UserData, UserDataMethods};
|
||||
use parking_lot::Mutex;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
collections::HashMap,
|
||||
path::{Path, PathBuf},
|
||||
rc::Rc,
|
||||
sync::Arc,
|
||||
@@ -55,6 +56,7 @@ string being a match that was found within the file)."#.into()
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
thread_id: Arc<str>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
_window: &mut Window,
|
||||
cx: &mut App,
|
||||
@@ -81,17 +83,46 @@ string being a match that was found within the file)."#.into()
|
||||
Ok(input) => input,
|
||||
};
|
||||
let lua_script = input.lua_script;
|
||||
let fs_changes = ScriptingToolFileChanges::get(thread_id, cx);
|
||||
cx.background_spawn(async move {
|
||||
let fs_changes = HashMap::new();
|
||||
let output = run_sandboxed_lua(&lua_script, fs_changes, root_dir)
|
||||
.map_err(|err| anyhow!(format!("{err}")))?;
|
||||
let output = output.printed_lines.join("\n");
|
||||
.map_err(|err| anyhow!(format!("{err}")))?
|
||||
.printed_lines
|
||||
.join("\n");
|
||||
|
||||
Ok(format!("The script output the following:\n{output}"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulates file changes made during script execution.
|
||||
struct ScriptingToolFileChanges {
|
||||
// Assistant thread ID that these files changes are associated with. Only file changes for one
|
||||
// thread are supported to avoid the need for dropping these when the associated `Thread` is
|
||||
// dropped.
|
||||
thread_id: Arc<str>,
|
||||
// Map from path to file contents for files changed by script execution.
|
||||
file_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
|
||||
}
|
||||
|
||||
impl Global for ScriptingToolFileChanges {}
|
||||
|
||||
impl ScriptingToolFileChanges {
|
||||
fn get(thread_id: Arc<str>, cx: &mut App) -> Arc<Mutex<HashMap<PathBuf, Vec<u8>>>> {
|
||||
match cx.try_global::<ScriptingToolFileChanges>() {
|
||||
Some(global) if global.thread_id == thread_id => global.file_changes.clone(),
|
||||
_ => {
|
||||
let file_changes = Arc::new(Mutex::new(HashMap::default()));
|
||||
cx.set_global(ScriptingToolFileChanges {
|
||||
thread_id,
|
||||
file_changes: file_changes.clone(),
|
||||
});
|
||||
file_changes
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
|
||||
|
||||
struct FileContent(RefCell<Vec<u8>>);
|
||||
@@ -125,7 +156,7 @@ fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function>
|
||||
|
||||
fn search(
|
||||
lua: &Lua,
|
||||
_fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>,
|
||||
_fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
|
||||
root_dir: PathBuf,
|
||||
) -> Result<Function> {
|
||||
lua.create_function(move |lua, regex: String| {
|
||||
@@ -221,7 +252,7 @@ fn search(
|
||||
/// Sandboxed io.open() function in Lua.
|
||||
fn io_open(
|
||||
lua: &Lua,
|
||||
fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>,
|
||||
fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
|
||||
root_dir: PathBuf,
|
||||
) -> Result<Function> {
|
||||
lua.create_function(move |lua, (path_str, mode): (String, Option<String>)| {
|
||||
@@ -281,7 +312,7 @@ fn io_open(
|
||||
// Don't actually write to disk; instead, just update fs_changes.
|
||||
let path_buf = PathBuf::from(&path);
|
||||
fs_changes
|
||||
.borrow_mut()
|
||||
.lock()
|
||||
.insert(path_buf.clone(), content_vec.clone());
|
||||
}
|
||||
|
||||
@@ -333,21 +364,28 @@ fn io_open(
|
||||
return Ok((Some(file), String::new()));
|
||||
}
|
||||
|
||||
let is_in_changes = fs_changes.borrow().contains_key(&path);
|
||||
let file_exists = is_in_changes || path.exists();
|
||||
let mut file_content = Vec::new();
|
||||
let is_in_changes;
|
||||
let file_exists;
|
||||
let file_content = {
|
||||
let fs_changes = fs_changes.lock();
|
||||
is_in_changes = fs_changes.contains_key(&path);
|
||||
file_exists = is_in_changes || path.exists();
|
||||
|
||||
if file_exists && !truncate {
|
||||
if is_in_changes {
|
||||
file_content = fs_changes.borrow().get(&path).unwrap().clone();
|
||||
} else {
|
||||
// Try to read existing content if file exists and we're not truncating
|
||||
match std::fs::read(&path) {
|
||||
Ok(content) => file_content = content,
|
||||
Err(e) => return Ok((None, format!("Error reading file: {}", e))),
|
||||
if file_exists && !truncate {
|
||||
if is_in_changes {
|
||||
fs_changes.get(&path).unwrap().clone()
|
||||
} else {
|
||||
drop(fs_changes); // Release the lock before starting the read.
|
||||
|
||||
match std::fs::read(&path) {
|
||||
Ok(content) => content,
|
||||
Err(e) => return Ok((None, format!("Error reading file: {}", e))),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// If in append mode, position should be at the end
|
||||
let position = if append && file_exists {
|
||||
@@ -355,6 +393,7 @@ fn io_open(
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
file.set("__position", position)?;
|
||||
file.set(
|
||||
"__content",
|
||||
@@ -582,9 +621,7 @@ fn io_open(
|
||||
// Update fs_changes
|
||||
let path = file_userdata.get::<String>("__path")?;
|
||||
let path_buf = PathBuf::from(path);
|
||||
fs_changes
|
||||
.borrow_mut()
|
||||
.insert(path_buf, content_vec.clone());
|
||||
fs_changes.lock().insert(path_buf, content_vec.clone());
|
||||
|
||||
Ok(true)
|
||||
})?
|
||||
@@ -599,7 +636,7 @@ fn io_open(
|
||||
/// Runs a Lua script in a sandboxed environment and returns the printed lines
|
||||
pub fn run_sandboxed_lua(
|
||||
script: &str,
|
||||
fs_changes: HashMap<PathBuf, Vec<u8>>,
|
||||
fs: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
|
||||
root_dir: PathBuf,
|
||||
) -> Result<ScriptOutput> {
|
||||
let lua = Lua::new();
|
||||
@@ -608,7 +645,6 @@ pub fn run_sandboxed_lua(
|
||||
|
||||
// Track the lines the Lua script prints out.
|
||||
let printed_lines = Rc::new(RefCell::new(Vec::new()));
|
||||
let fs = Rc::new(RefCell::new(fs_changes));
|
||||
|
||||
globals.set("sb_print", print(&lua, printed_lines.clone())?)?;
|
||||
globals.set("search", search(&lua, fs.clone(), root_dir.clone())?)?;
|
||||
@@ -623,23 +659,18 @@ pub fn run_sandboxed_lua(
|
||||
printed_lines: Rc::try_unwrap(printed_lines)
|
||||
.expect("There are still other references to printed_lines")
|
||||
.into_inner(),
|
||||
fs_changes: Rc::try_unwrap(fs)
|
||||
.expect("There are still other references to fs_changes")
|
||||
.into_inner(),
|
||||
})
|
||||
}
|
||||
|
||||
pub struct ScriptOutput {
|
||||
printed_lines: Vec<String>,
|
||||
#[allow(dead_code)]
|
||||
fs_changes: HashMap<PathBuf, Vec<u8>>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl ScriptOutput {
|
||||
fn fs_diff(&self) -> HashMap<PathBuf, String> {
|
||||
let mut diff_map = HashMap::new();
|
||||
for (path, content) in &self.fs_changes {
|
||||
fn fs_diff(&self, fs_changes: &HashMap<PathBuf, Vec<u8>>) -> HashMap<PathBuf, String> {
|
||||
let mut diff_map = HashMap::default();
|
||||
for (path, content) in fs_changes {
|
||||
let diff = if path.exists() {
|
||||
// Read the current file content
|
||||
match std::fs::read(path) {
|
||||
@@ -758,9 +789,9 @@ impl ScriptOutput {
|
||||
diff_map
|
||||
}
|
||||
|
||||
fn diff_to_string(&self) -> String {
|
||||
fn diff_to_string(&self, fs_changes: &HashMap<PathBuf, Vec<u8>>) -> String {
|
||||
let mut answer = String::new();
|
||||
let diff_map = self.fs_diff();
|
||||
let diff_map = self.fs_diff(fs_changes);
|
||||
|
||||
if diff_map.is_empty() {
|
||||
return "No changes to files".to_string();
|
||||
|
||||
Reference in New Issue
Block a user