Compare commits

...

5 Commits

Author SHA1 Message Date
Richard Feldman
1a15c63b25 wip 2025-03-06 10:12:48 -05:00
Richard Feldman
d6a8b4cfb1 Give Tools control over how they render input/output 2025-03-05 23:32:51 -05:00
Richard Feldman
264ac61210 Initial pass at Lua syntax highlighting
Co-Authored-By: Danilo <danilo@zed.dev>
2025-03-05 20:35:51 -05:00
Michael Sloan
ad4742a5b8 Use a global for scripting tool in-memory fs
Suggested in https://github.com/zed-industries/zed/pull/26132#discussion_r1981670385
2025-03-05 17:30:07 -07:00
Richard Feldman
b0d4abb82e Persist in-memory filesystem between tool uses 2025-03-05 10:30:20 -05:00
12 changed files with 371 additions and 71 deletions

10
Cargo.lock generated
View File

@@ -448,6 +448,7 @@ dependencies = [
"command_palette_hooks", "command_palette_hooks",
"context_server", "context_server",
"db", "db",
"diff",
"editor", "editor",
"feature_flags", "feature_flags",
"file_icons", "file_icons",
@@ -644,9 +645,11 @@ dependencies = [
"collections", "collections",
"derive_more", "derive_more",
"gpui", "gpui",
"language",
"parking_lot", "parking_lot",
"serde", "serde",
"serde_json", "serde_json",
"ui",
"workspace", "workspace",
] ]
@@ -11887,12 +11890,19 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"assistant_tool", "assistant_tool",
"collections",
"gpui", "gpui",
"language",
"mlua", "mlua",
"parking_lot",
"regex", "regex",
"rich_text",
"schemars", "schemars",
"serde", "serde",
"serde_json", "serde_json",
"settings",
"theme",
"ui",
"workspace", "workspace",
] ]

View File

@@ -415,6 +415,7 @@ core-foundation-sys = "0.8.6"
ctor = "0.4.0" ctor = "0.4.0"
dashmap = "6.0" dashmap = "6.0"
derive_more = "0.99.17" derive_more = "0.99.17"
diff = "0.1.13"
dirs = "4.0" dirs = "4.0"
ec4rs = "1.1" ec4rs = "1.1"
emojis = "0.6.1" emojis = "0.6.1"

View File

@@ -32,6 +32,7 @@ collections.workspace = true
command_palette_hooks.workspace = true command_palette_hooks.workspace = true
context_server.workspace = true context_server.workspace = true
db.workspace = true db.workspace = true
diff.workspace = true
editor.workspace = true editor.workspace = true
feature_flags.workspace = true feature_flags.workspace = true
file_icons.workspace = true file_icons.workspace = true

View File

@@ -1,6 +1,4 @@
use std::sync::Arc; use assistant_tool::{ToolFileChanges, ToolRegistry, ToolWorkingSet};
use assistant_tool::ToolWorkingSet;
use collections::HashMap; use collections::HashMap;
use editor::{Editor, MultiBuffer}; use editor::{Editor, MultiBuffer};
use gpui::{ use gpui::{
@@ -8,10 +6,11 @@ use gpui::{
Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
Task, TextStyleRefinement, UnderlineStyle, WeakEntity, Task, TextStyleRefinement, UnderlineStyle, WeakEntity,
}; };
use language::{Buffer, LanguageRegistry}; use language::{Buffer, Language, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role}; use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
use markdown::{Markdown, MarkdownStyle}; use markdown::{Markdown, MarkdownStyle};
use settings::Settings as _; use settings::Settings as _;
use std::sync::Arc;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::{prelude::*, Disclosure, KeyBinding}; use ui::{prelude::*, Disclosure, KeyBinding};
use util::ResultExt as _; use util::ResultExt as _;
@@ -35,6 +34,7 @@ pub struct ActiveThread {
editing_message: Option<(MessageId, EditMessageState)>, editing_message: Option<(MessageId, EditMessageState)>,
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>, expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
last_error: Option<ThreadError>, last_error: Option<ThreadError>,
lua_language: Option<Arc<Language>>, // Used for syntax highlighting in the Lua script tool
_subscriptions: Vec<Subscription>, _subscriptions: Vec<Subscription>,
} }
@@ -76,9 +76,23 @@ impl ActiveThread {
}), }),
editing_message: None, editing_message: None,
last_error: None, last_error: None,
lua_language: None,
_subscriptions: subscriptions, _subscriptions: subscriptions,
}; };
// Initialize the Lua language in the background, for syntax highlighting.
let language_registry = this.language_registry.clone();
cx.spawn(|this, mut cx| async move {
if let Ok(lua_language) = language_registry.language_for_name("Lua").await {
this.update(&mut cx, |this, _| {
this.lua_language = Some(lua_language);
})?;
}
Ok::<_, anyhow::Error>(())
})
.detach();
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() { for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
this.push_message(&message.id, message.text.clone(), window, cx); this.push_message(&message.id, message.text.clone(), window, cx);
} }
@@ -294,9 +308,9 @@ impl ActiveThread {
cx.notify(); cx.notify();
} }
ThreadEvent::UsePendingTools => { ThreadEvent::UsePendingTools => {
let pending_tool_uses = self let thread = self.thread.read(cx);
.thread let thread_id = thread.id().0.clone();
.read(cx) let pending_tool_uses = thread
.pending_tool_uses() .pending_tool_uses()
.into_iter() .into_iter()
.filter(|tool_use| tool_use.status.is_idle()) .filter(|tool_use| tool_use.status.is_idle())
@@ -305,7 +319,13 @@ impl ActiveThread {
for tool_use in pending_tool_uses { for tool_use in pending_tool_uses {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) { if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run(tool_use.input, self.workspace.clone(), window, cx); let task = tool.run(
tool_use.input,
thread_id.clone(),
self.workspace.clone(),
window,
cx,
);
self.thread.update(cx, |thread, cx| { self.thread.update(cx, |thread, cx| {
thread.insert_tool_output(tool_use.id.clone(), task, cx); thread.insert_tool_output(tool_use.id.clone(), task, cx);
@@ -324,6 +344,16 @@ impl ActiveThread {
let model_registry = LanguageModelRegistry::read_global(cx); let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.active_model() { if let Some(model) = model_registry.active_model() {
self.thread.update(cx, |thread, cx| { self.thread.update(cx, |thread, cx| {
if let Some(global) = cx.try_global::<ToolFileChanges>() {
let thread_id = thread.id().0.clone();
if global.thread_id == thread_id
&& !global.file_changes.lock().is_empty()
{
println!("Changes:\n{}", self.handle_fs_changes(cx));
}
}
// Insert a user message to contain the tool results. // Insert a user message to contain the tool results.
thread.insert_user_message( thread.insert_user_message(
// TODO: Sending up a user message without any content results in the model sending back // TODO: Sending up a user message without any content results in the model sending back
@@ -360,6 +390,46 @@ impl ActiveThread {
})); }));
} }
fn handle_fs_changes(&self, cx: &mut Context<Thread>) -> String {
if let Some(global) = cx.try_global::<ToolFileChanges>() {
let fs_changes = global.file_changes.lock().clone();
if !fs_changes.is_empty() {
let mut diff_output = String::new();
for (path, content) in fs_changes {
let path_str = path.to_string_lossy();
diff_output.push_str(&format!("--- {}\n+++ {}\n", path_str, path_str));
let old_content = match std::fs::read(&path) {
Ok(content) => String::from_utf8_lossy(&content).to_string(),
Err(_) => String::new(),
};
let new_content = String::from_utf8_lossy(&content).to_string();
let diff = diff::lines(&old_content, &new_content);
for change in diff {
match change {
diff::Result::Left(l) => diff_output.push_str(&format!("-{}\n", l)),
diff::Result::Right(r) => diff_output.push_str(&format!("+{}\n", r)),
diff::Result::Both(b, _) => diff_output.push_str(&format!(" {}\n", b)),
}
}
diff_output.push_str("\n");
}
// Reset fs_changes
global.file_changes.lock().clear();
return diff_output;
}
}
String::new()
}
fn start_editing_message( fn start_editing_message(
&mut self, &mut self,
message_id: MessageId, message_id: MessageId,
@@ -654,6 +724,7 @@ impl ActiveThread {
} }
fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement { fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
let tool = ToolRegistry::global(cx).tool(&tool_use.name);
let is_open = self let is_open = self
.expanded_tool_uses .expanded_tool_uses
.get(&tool_use.id) .get(&tool_use.id)
@@ -719,11 +790,17 @@ impl ActiveThread {
.px_2p5() .px_2p5()
.border_b_1() .border_b_1()
.border_color(cx.theme().colors().border) .border_color(cx.theme().colors().border)
.child(Label::new("Input:")) .bg(cx.theme().colors().editor_background)
.child(Label::new( .child(match tool.clone() {
serde_json::to_string_pretty(&tool_use.input) Some(tool) => tool.render_input(
.unwrap_or_default(), tool_use.input,
)), self.lua_language.clone(),
cx,
),
None => {
assistant_tool::default_render_input(tool_use.input)
}
}),
) )
.map(|parent| match tool_use.status { .map(|parent| match tool_use.status {
ToolUseStatus::Finished(output) => parent.child( ToolUseStatus::Finished(output) => parent.child(
@@ -731,16 +808,17 @@ impl ActiveThread {
.gap_0p5() .gap_0p5()
.py_1() .py_1()
.px_2p5() .px_2p5()
.child(Label::new("Result:")) .bg(cx.theme().colors().editor_background)
.child(Label::new(output)), .child(match tool {
Some(tool) => tool.render_output(output, cx),
None => assistant_tool::default_render_output(output),
}),
), ),
ToolUseStatus::Error(err) => parent.child( ToolUseStatus::Error(err) => parent.child(
v_flex() v_flex().gap_0p5().py_1().px_2p5().child(match tool {
.gap_0p5() Some(tool) => tool.render_error(err, cx),
.py_1() None => assistant_tool::default_render_output(err),
.px_2p5() }),
.child(Label::new("Error:"))
.child(Label::new(err)),
), ),
ToolUseStatus::Pending | ToolUseStatus::Running => parent, ToolUseStatus::Pending | ToolUseStatus::Running => parent,
}), }),

View File

@@ -28,7 +28,7 @@ pub enum RequestKind {
} }
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
pub struct ThreadId(Arc<str>); pub struct ThreadId(pub Arc<str>);
impl ThreadId { impl ThreadId {
pub fn new() -> Self { pub fn new() -> Self {

View File

@@ -16,7 +16,9 @@ anyhow.workspace = true
collections.workspace = true collections.workspace = true
derive_more.workspace = true derive_more.workspace = true
gpui.workspace = true gpui.workspace = true
language.workspace = true
parking_lot.workspace = true parking_lot.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
workspace.workspace = true workspace.workspace = true
ui.workspace = true

View File

@@ -1,12 +1,23 @@
mod tool_file_changes;
mod tool_registry; mod tool_registry;
mod tool_working_set; mod tool_working_set;
use std::sync::Arc; use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use gpui::AnyElement;
use gpui::IntoElement;
use gpui::{App, Task, WeakEntity, Window}; use gpui::{App, Task, WeakEntity, Window};
use language::Language;
use ui::div;
use ui::Label;
use ui::LabelCommon;
use ui::LabelSize;
use ui::ParentElement;
use ui::SharedString;
use workspace::Workspace; use workspace::Workspace;
pub use crate::tool_file_changes::*;
pub use crate::tool_registry::*; pub use crate::tool_registry::*;
pub use crate::tool_working_set::*; pub use crate::tool_working_set::*;
@@ -31,8 +42,52 @@ pub trait Tool: 'static + Send + Sync {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
thread_id: Arc<str>,
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
window: &mut Window, window: &mut Window,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>>; ) -> Task<Result<String>>;
/// Renders the tool's input when the user expands it.
fn render_input(
self: Arc<Self>,
input: serde_json::Value,
_lua_language: Option<Arc<Language>>,
_cx: &mut App,
) -> AnyElement {
default_render_input(input)
}
/// Renders the tool's output when the user expands it.
fn render_output(self: Arc<Self>, output: SharedString, _cx: &mut App) -> AnyElement {
default_render_output(output)
}
/// Renders the tool's error message when the user expands it.
fn render_error(self: Arc<Self>, err: SharedString, _cx: &mut App) -> AnyElement {
default_render_error(err)
}
}
pub fn default_render_input(input: serde_json::Value) -> AnyElement {
div()
.child(Label::new("Input:").size(LabelSize::Small))
.child(Label::new(
serde_json::to_string_pretty(&input).unwrap_or_default(),
))
.into_any_element()
}
pub fn default_render_output(output: SharedString) -> AnyElement {
div()
.child(Label::new("Result:").size(LabelSize::Small))
.child(Label::new(output))
.into_any_element()
}
pub fn default_render_error(err: SharedString) -> AnyElement {
div()
.child(Label::new("Error:").size(LabelSize::Small))
.child(Label::new(err))
.into_any_element()
} }

View File

@@ -0,0 +1,34 @@
use std::{path::PathBuf, sync::Arc};
use collections::HashMap;
use gpui::Global;
use parking_lot::Mutex;
use ui::App;
// Accumulates file changes made during script execution.
pub struct ToolFileChanges {
// Assistant thread ID that these files changes are associated with. Only file changes for one
// thread are supported to avoid the need for dropping these when the associated `Thread` is
// dropped.
pub thread_id: Arc<str>,
// Map from path to file contents for files changed by script execution.
pub file_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
}
impl Global for ToolFileChanges {}
impl ToolFileChanges {
pub fn get(thread_id: Arc<str>, cx: &mut App) -> Arc<Mutex<HashMap<PathBuf, Vec<u8>>>> {
match cx.try_global::<ToolFileChanges>() {
Some(global) if global.thread_id == thread_id => global.file_changes.clone(),
_ => {
let file_changes = Arc::new(Mutex::new(HashMap::default()));
cx.set_global(ToolFileChanges {
thread_id,
file_changes: file_changes.clone(),
});
file_changes
}
}
}
}

View File

@@ -1,11 +1,10 @@
use std::sync::Arc;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use assistant_tool::Tool; use assistant_tool::Tool;
use chrono::{Local, Utc}; use chrono::{Local, Utc};
use gpui::{App, Task, WeakEntity, Window}; use gpui::{App, Task, WeakEntity, Window};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Serialize, Deserialize, JsonSchema)] #[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
@@ -41,6 +40,7 @@ impl Tool for NowTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
_thread_id: Arc<str>,
_workspace: WeakEntity<workspace::Workspace>, _workspace: WeakEntity<workspace::Workspace>,
_window: &mut Window, _window: &mut Window,
_cx: &mut App, _cx: &mut App,

View File

@@ -51,6 +51,7 @@ impl Tool for ContextServerTool {
fn run( fn run(
self: std::sync::Arc<Self>, self: std::sync::Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
_thread_id: Arc<str>,
_workspace: gpui::WeakEntity<workspace::Workspace>, _workspace: gpui::WeakEntity<workspace::Workspace>,
_: &mut Window, _: &mut Window,
cx: &mut App, cx: &mut App,

View File

@@ -15,10 +15,17 @@ doctest = false
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
assistant_tool.workspace = true assistant_tool.workspace = true
collections.workspace = true
gpui.workspace = true gpui.workspace = true
language.workspace = true
mlua.workspace = true mlua.workspace = true
parking_lot.workspace = true
schemars.workspace = true schemars.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
workspace.workspace = true workspace.workspace = true
regex.workspace = true regex.workspace = true
rich_text.workspace = true
settings.workspace = true
theme.workspace = true
ui.workspace = true

View File

@@ -1,16 +1,25 @@
use anyhow::anyhow; use anyhow::anyhow;
use assistant_tool::{Tool, ToolRegistry}; use assistant_tool::{Tool, ToolFileChanges, ToolRegistry};
use gpui::{App, AppContext as _, Task, WeakEntity, Window}; use collections::HashMap;
use mlua::{Function, Lua, MultiValue, Result, UserData, UserDataMethods}; use gpui::{
AnyElement, App, AppContext as _, HighlightStyle, StyledText, Task, WeakEntity, Window,
};
use language::Language;
use mlua::{Function, HookTriggers, Lua, MultiValue, Result, UserData, UserDataMethods, VmState};
use parking_lot::Mutex;
use rich_text::{self, Highlight};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::Deserialize; use serde::Deserialize;
use settings::Settings;
use std::ops::Range;
use std::{ use std::{
cell::RefCell, cell::RefCell,
collections::HashMap,
path::{Path, PathBuf}, path::{Path, PathBuf},
rc::Rc, rc::Rc,
sync::Arc, sync::Arc,
}; };
use theme::ThemeSettings;
use ui::prelude::*;
use workspace::Workspace; use workspace::Workspace;
pub fn init(cx: &App) { pub fn init(cx: &App) {
@@ -23,11 +32,11 @@ struct ScriptingToolInput {
lua_script: String, lua_script: String,
} }
struct ScriptingTool; pub struct ScriptingTool;
impl Tool for ScriptingTool { impl Tool for ScriptingTool {
fn name(&self) -> String { fn name(&self) -> String {
"lua-interpreter".into() "lua-interpreter".to_string()
} }
fn description(&self) -> String { fn description(&self) -> String {
@@ -38,13 +47,18 @@ The lua script will have access to `io` and it will run with the current working
the root of the code base, so you can use it to explore, search, make changes, etc. You can also have the root of the code base, so you can use it to explore, search, make changes, etc. You can also have
the script print things, and I'll tell you what the output was. Note that `io` only has `open`, and the script print things, and I'll tell you what the output was. Note that `io` only has `open`, and
then the file it returns only has the methods read, write, and close - it doesn't have popen or then the file it returns only has the methods read, write, and close - it doesn't have popen or
anything else. Also, I'm going to be putting this Lua script into JSON, so please don't use Lua's anything else. `os` is not available, so don't try to use it. There will be a global called
double quote syntax for string literals - use one of Lua's other syntaxes for string literals, so I `search` which accepts a regex (it's implemented using Rust's regex crate,
don't have to escape the double quotes. There will be a global called `search` which accepts a regex so use that regex syntax) and runs that regex on the contents of every file in the code base
(it's implemented using Rust's regex crate, so use that regex syntax) and runs that regex on the contents (aside from gitignored files), then returns an array of tables with two fields: "path"
of every file in the code base (aside from gitignored files), then returns an array of tables with two (the path to the file that had the matches) and "matches" (an array of strings, with each
fields: "path" (the path to the file that had the matches) and "matches" (an array of strings, with each string being a match that was found within the file).
string being a match that was found within the file)."#.into()
If I reference something in the code base that you aren't familiar with, I suggest writing
multiple `search` calls to find what you're looking for in one script invocation, as opposed
to trying multiple times. When you're done, if you have any proposed changes to files,
you should write them directly to disk in the script. I'll review them later.
"#.into()
} }
fn input_schema(&self) -> serde_json::Value { fn input_schema(&self) -> serde_json::Value {
@@ -55,6 +69,7 @@ string being a match that was found within the file)."#.into()
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
thread_id: Arc<str>,
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
_window: &mut Window, _window: &mut Window,
cx: &mut App, cx: &mut App,
@@ -81,15 +96,82 @@ string being a match that was found within the file)."#.into()
Ok(input) => input, Ok(input) => input,
}; };
let lua_script = input.lua_script; let lua_script = input.lua_script;
let fs_changes = ToolFileChanges::get(thread_id, cx);
cx.background_spawn(async move { cx.background_spawn(async move {
let fs_changes = HashMap::new();
let output = run_sandboxed_lua(&lua_script, fs_changes, root_dir) let output = run_sandboxed_lua(&lua_script, fs_changes, root_dir)
.map_err(|err| anyhow!(format!("{err}")))?; .map_err(|err| anyhow!(format!("{err}")))?
let output = output.printed_lines.join("\n"); .printed_lines
.join("\n");
Ok(format!("The script output the following:\n{output}")) if output.is_empty() {
Ok("(The script had no output.)".to_string())
} else {
Ok(output.to_string())
}
}) })
} }
fn render_input(
self: Arc<Self>,
input: serde_json::Value,
lua_language: Option<Arc<Language>>,
cx: &mut App,
) -> AnyElement {
let theme_settings = ThemeSettings::get_global(cx);
let theme = cx.theme();
if let Ok(input) = serde_json::from_value::<ScriptingToolInput>(input.clone()) {
let label;
// Use Lua syntax highlighting, if available
if let Some(lua_language) = &lua_language.clone() {
let mut highlights = Vec::new();
let mut buf = String::new();
rich_text::render_code(&mut buf, &mut highlights, &input.lua_script, lua_language);
let gpui_highlights: Vec<(Range<usize>, HighlightStyle)> = highlights
.iter()
.map(|(range, highlight)| {
let style = match highlight {
Highlight::Code => Default::default(),
Highlight::Id(id) => id.style(theme.syntax()).unwrap_or_default(),
Highlight::InlineCode(_link) => Default::default(), // Links won't come up
Highlight::Highlight(highlight) => *highlight,
_ => HighlightStyle::default(),
};
(range.clone(), style)
})
.collect();
label = StyledText::new(buf)
.with_highlights(gpui_highlights)
.into_any_element();
} else {
label = Label::new(&input.lua_script).into_any_element();
};
div()
.when(lua_language.is_some(), |this| {
this.bg(theme.colors().editor_background)
})
.font_family(theme_settings.buffer_font.family.clone())
.child(label)
.into_any_element()
} else {
// Fallback to JSON if lua_language is unavailable
Label::new(serde_json::to_string_pretty(&input).unwrap_or_default()).into_any_element()
}
}
fn render_output(self: Arc<Self>, output: SharedString, cx: &mut App) -> AnyElement {
let theme_settings = ThemeSettings::get_global(cx);
div()
.font_family(theme_settings.buffer_font.family.clone())
.child(output)
.into_any_element()
}
} }
const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua"); const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
@@ -125,7 +207,7 @@ fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function>
fn search( fn search(
lua: &Lua, lua: &Lua,
_fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>, _fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
root_dir: PathBuf, root_dir: PathBuf,
) -> Result<Function> { ) -> Result<Function> {
lua.create_function(move |lua, regex: String| { lua.create_function(move |lua, regex: String| {
@@ -221,7 +303,7 @@ fn search(
/// Sandboxed io.open() function in Lua. /// Sandboxed io.open() function in Lua.
fn io_open( fn io_open(
lua: &Lua, lua: &Lua,
fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>, fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
root_dir: PathBuf, root_dir: PathBuf,
) -> Result<Function> { ) -> Result<Function> {
lua.create_function(move |lua, (path_str, mode): (String, Option<String>)| { lua.create_function(move |lua, (path_str, mode): (String, Option<String>)| {
@@ -281,7 +363,7 @@ fn io_open(
// Don't actually write to disk; instead, just update fs_changes. // Don't actually write to disk; instead, just update fs_changes.
let path_buf = PathBuf::from(&path); let path_buf = PathBuf::from(&path);
fs_changes fs_changes
.borrow_mut() .lock()
.insert(path_buf.clone(), content_vec.clone()); .insert(path_buf.clone(), content_vec.clone());
} }
@@ -333,21 +415,28 @@ fn io_open(
return Ok((Some(file), String::new())); return Ok((Some(file), String::new()));
} }
let is_in_changes = fs_changes.borrow().contains_key(&path); let is_in_changes;
let file_exists = is_in_changes || path.exists(); let file_exists;
let mut file_content = Vec::new(); let file_content = {
let fs_changes = fs_changes.lock();
is_in_changes = fs_changes.contains_key(&path);
file_exists = is_in_changes || path.exists();
if file_exists && !truncate { if file_exists && !truncate {
if is_in_changes { if is_in_changes {
file_content = fs_changes.borrow().get(&path).unwrap().clone(); fs_changes.get(&path).unwrap().clone()
} else { } else {
// Try to read existing content if file exists and we're not truncating drop(fs_changes); // Release the lock before starting the read.
match std::fs::read(&path) {
Ok(content) => file_content = content, match std::fs::read(&path) {
Err(e) => return Ok((None, format!("Error reading file: {}", e))), Ok(content) => content,
Err(e) => return Ok((None, format!("Error reading file: {}", e))),
}
} }
} else {
Vec::new()
} }
} };
// If in append mode, position should be at the end // If in append mode, position should be at the end
let position = if append && file_exists { let position = if append && file_exists {
@@ -355,6 +444,7 @@ fn io_open(
} else { } else {
0 0
}; };
file.set("__position", position)?; file.set("__position", position)?;
file.set( file.set(
"__content", "__content",
@@ -582,9 +672,7 @@ fn io_open(
// Update fs_changes // Update fs_changes
let path = file_userdata.get::<String>("__path")?; let path = file_userdata.get::<String>("__path")?;
let path_buf = PathBuf::from(path); let path_buf = PathBuf::from(path);
fs_changes fs_changes.lock().insert(path_buf, content_vec.clone());
.borrow_mut()
.insert(path_buf, content_vec.clone());
Ok(true) Ok(true)
})? })?
@@ -599,16 +687,44 @@ fn io_open(
/// Runs a Lua script in a sandboxed environment and returns the printed lines /// Runs a Lua script in a sandboxed environment and returns the printed lines
pub fn run_sandboxed_lua( pub fn run_sandboxed_lua(
script: &str, script: &str,
fs_changes: HashMap<PathBuf, Vec<u8>>, fs: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
root_dir: PathBuf, root_dir: PathBuf,
) -> Result<ScriptOutput> { ) -> Result<ScriptOutput> {
let lua = Lua::new(); let lua = Lua::new();
lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
lua.set_hook(
HookTriggers::new().every_nth_instruction(2048),
|_lua, _| {
// Check if we need to yield to prevent long-running scripts
static EXECUTION_START: std::sync::atomic::AtomicU64 =
std::sync::atomic::AtomicU64::new(0);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
// Initialize start time on first call
let start = EXECUTION_START.load(std::sync::atomic::Ordering::Relaxed);
if start == 0 {
EXECUTION_START.store(now, std::sync::atomic::Ordering::Relaxed);
return Ok(VmState::Continue);
}
// Check if execution time exceeds 5 seconds
if now - start > 5000 {
EXECUTION_START.store(0, std::sync::atomic::Ordering::Relaxed);
return Err(mlua::Error::runtime(
"Script execution timed out after 5 seconds",
));
}
Ok(VmState::Continue)
},
); // 2 GB
let globals = lua.globals(); let globals = lua.globals();
// Track the lines the Lua script prints out. // Track the lines the Lua script prints out.
let printed_lines = Rc::new(RefCell::new(Vec::new())); let printed_lines = Rc::new(RefCell::new(Vec::new()));
let fs = Rc::new(RefCell::new(fs_changes));
globals.set("sb_print", print(&lua, printed_lines.clone())?)?; globals.set("sb_print", print(&lua, printed_lines.clone())?)?;
globals.set("search", search(&lua, fs.clone(), root_dir.clone())?)?; globals.set("search", search(&lua, fs.clone(), root_dir.clone())?)?;
@@ -623,23 +739,18 @@ pub fn run_sandboxed_lua(
printed_lines: Rc::try_unwrap(printed_lines) printed_lines: Rc::try_unwrap(printed_lines)
.expect("There are still other references to printed_lines") .expect("There are still other references to printed_lines")
.into_inner(), .into_inner(),
fs_changes: Rc::try_unwrap(fs)
.expect("There are still other references to fs_changes")
.into_inner(),
}) })
} }
pub struct ScriptOutput { pub struct ScriptOutput {
printed_lines: Vec<String>, printed_lines: Vec<String>,
#[allow(dead_code)]
fs_changes: HashMap<PathBuf, Vec<u8>>,
} }
#[allow(dead_code)] #[allow(dead_code)]
impl ScriptOutput { impl ScriptOutput {
fn fs_diff(&self) -> HashMap<PathBuf, String> { fn fs_diff(&self, fs_changes: &HashMap<PathBuf, Vec<u8>>) -> HashMap<PathBuf, String> {
let mut diff_map = HashMap::new(); let mut diff_map = HashMap::default();
for (path, content) in &self.fs_changes { for (path, content) in fs_changes {
let diff = if path.exists() { let diff = if path.exists() {
// Read the current file content // Read the current file content
match std::fs::read(path) { match std::fs::read(path) {
@@ -758,9 +869,9 @@ impl ScriptOutput {
diff_map diff_map
} }
fn diff_to_string(&self) -> String { fn diff_to_string(&self, fs_changes: &HashMap<PathBuf, Vec<u8>>) -> String {
let mut answer = String::new(); let mut answer = String::new();
let diff_map = self.fs_diff(); let diff_map = self.fs_diff(fs_changes);
if diff_map.is_empty() { if diff_map.is_empty() {
return "No changes to files".to_string(); return "No changes to files".to_string();