Compare commits

...

4 Commits

Author SHA1 Message Date
Richard Feldman
d6a8b4cfb1 Give Tools control over how they render input/output 2025-03-05 23:32:51 -05:00
Richard Feldman
264ac61210 Initial pass at Lua syntax highlighting
Co-Authored-By: Danilo <danilo@zed.dev>
2025-03-05 20:35:51 -05:00
Michael Sloan
ad4742a5b8 Use a global for scripting tool in-memory fs
Suggested in https://github.com/zed-industries/zed/pull/26132#discussion_r1981670385
2025-03-05 17:30:07 -07:00
Richard Feldman
b0d4abb82e Persist in-memory filesystem between tool uses 2025-03-05 10:30:20 -05:00
9 changed files with 267 additions and 62 deletions

9
Cargo.lock generated
View File

@@ -644,9 +644,11 @@ dependencies = [
"collections",
"derive_more",
"gpui",
"language",
"parking_lot",
"serde",
"serde_json",
"ui",
"workspace",
]
@@ -11887,12 +11889,19 @@ version = "0.1.0"
dependencies = [
"anyhow",
"assistant_tool",
"collections",
"gpui",
"language",
"mlua",
"parking_lot",
"regex",
"rich_text",
"schemars",
"serde",
"serde_json",
"settings",
"theme",
"ui",
"workspace",
]

View File

@@ -1,6 +1,4 @@
use std::sync::Arc;
use assistant_tool::ToolWorkingSet;
use assistant_tool::{ToolRegistry, ToolWorkingSet};
use collections::HashMap;
use editor::{Editor, MultiBuffer};
use gpui::{
@@ -8,10 +6,11 @@ use gpui::{
Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
Task, TextStyleRefinement, UnderlineStyle, WeakEntity,
};
use language::{Buffer, LanguageRegistry};
use language::{Buffer, Language, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
use markdown::{Markdown, MarkdownStyle};
use settings::Settings as _;
use std::sync::Arc;
use theme::ThemeSettings;
use ui::{prelude::*, Disclosure, KeyBinding};
use util::ResultExt as _;
@@ -35,6 +34,7 @@ pub struct ActiveThread {
editing_message: Option<(MessageId, EditMessageState)>,
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
last_error: Option<ThreadError>,
lua_language: Option<Arc<Language>>, // Used for syntax highlighting in the Lua script tool
_subscriptions: Vec<Subscription>,
}
@@ -76,9 +76,23 @@ impl ActiveThread {
}),
editing_message: None,
last_error: None,
lua_language: None,
_subscriptions: subscriptions,
};
// Initialize the Lua language in the background, for syntax highlighting.
let language_registry = this.language_registry.clone();
cx.spawn(|this, mut cx| async move {
if let Ok(lua_language) = language_registry.language_for_name("Lua").await {
this.update(&mut cx, |this, _| {
this.lua_language = Some(lua_language);
})?;
}
Ok::<_, anyhow::Error>(())
})
.detach();
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
this.push_message(&message.id, message.text.clone(), window, cx);
}
@@ -294,9 +308,9 @@ impl ActiveThread {
cx.notify();
}
ThreadEvent::UsePendingTools => {
let pending_tool_uses = self
.thread
.read(cx)
let thread = self.thread.read(cx);
let thread_id = thread.id().0.clone();
let pending_tool_uses = thread
.pending_tool_uses()
.into_iter()
.filter(|tool_use| tool_use.status.is_idle())
@@ -305,7 +319,13 @@ impl ActiveThread {
for tool_use in pending_tool_uses {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
let task = tool.run(
tool_use.input,
thread_id.clone(),
self.workspace.clone(),
window,
cx,
);
self.thread.update(cx, |thread, cx| {
thread.insert_tool_output(tool_use.id.clone(), task, cx);
@@ -654,6 +674,7 @@ impl ActiveThread {
}
fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
let tool = ToolRegistry::global(cx).tool(&tool_use.name);
let is_open = self
.expanded_tool_uses
.get(&tool_use.id)
@@ -719,11 +740,17 @@ impl ActiveThread {
.px_2p5()
.border_b_1()
.border_color(cx.theme().colors().border)
.child(Label::new("Input:"))
.child(Label::new(
serde_json::to_string_pretty(&tool_use.input)
.unwrap_or_default(),
)),
.bg(cx.theme().colors().editor_background)
.child(match tool.clone() {
Some(tool) => tool.render_input(
tool_use.input,
self.lua_language.clone(),
cx,
),
None => {
assistant_tool::default_render_input(tool_use.input)
}
}),
)
.map(|parent| match tool_use.status {
ToolUseStatus::Finished(output) => parent.child(
@@ -731,16 +758,17 @@ impl ActiveThread {
.gap_0p5()
.py_1()
.px_2p5()
.child(Label::new("Result:"))
.child(Label::new(output)),
.bg(cx.theme().colors().editor_background)
.child(match tool {
Some(tool) => tool.render_output(output, cx),
None => assistant_tool::default_render_output(output),
}),
),
ToolUseStatus::Error(err) => parent.child(
v_flex()
.gap_0p5()
.py_1()
.px_2p5()
.child(Label::new("Error:"))
.child(Label::new(err)),
v_flex().gap_0p5().py_1().px_2p5().child(match tool {
Some(tool) => tool.render_error(err, cx),
None => assistant_tool::default_render_output(err),
}),
),
ToolUseStatus::Pending | ToolUseStatus::Running => parent,
}),

View File

@@ -28,7 +28,7 @@ pub enum RequestKind {
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
pub struct ThreadId(Arc<str>);
pub struct ThreadId(pub Arc<str>);
impl ThreadId {
pub fn new() -> Self {

View File

@@ -16,7 +16,9 @@ anyhow.workspace = true
collections.workspace = true
derive_more.workspace = true
gpui.workspace = true
language.workspace = true
parking_lot.workspace = true
serde.workspace = true
serde_json.workspace = true
workspace.workspace = true
ui.workspace = true

View File

@@ -4,7 +4,16 @@ mod tool_working_set;
use std::sync::Arc;
use anyhow::Result;
use gpui::AnyElement;
use gpui::IntoElement;
use gpui::{App, Task, WeakEntity, Window};
use language::Language;
use ui::div;
use ui::Label;
use ui::LabelCommon;
use ui::LabelSize;
use ui::ParentElement;
use ui::SharedString;
use workspace::Workspace;
pub use crate::tool_registry::*;
@@ -31,8 +40,52 @@ pub trait Tool: 'static + Send + Sync {
fn run(
self: Arc<Self>,
input: serde_json::Value,
thread_id: Arc<str>,
workspace: WeakEntity<Workspace>,
window: &mut Window,
cx: &mut App,
) -> Task<Result<String>>;
/// Renders the tool's input when the user expands it.
fn render_input(
self: Arc<Self>,
input: serde_json::Value,
_lua_language: Option<Arc<Language>>,
_cx: &mut App,
) -> AnyElement {
default_render_input(input)
}
/// Renders the tool's output when the user expands it.
fn render_output(self: Arc<Self>, output: SharedString, _cx: &mut App) -> AnyElement {
default_render_output(output)
}
/// Renders the tool's error message when the user expands it.
fn render_error(self: Arc<Self>, err: SharedString, _cx: &mut App) -> AnyElement {
default_render_error(err)
}
}
pub fn default_render_input(input: serde_json::Value) -> AnyElement {
div()
.child(Label::new("Input:").size(LabelSize::Small))
.child(Label::new(
serde_json::to_string_pretty(&input).unwrap_or_default(),
))
.into_any_element()
}
pub fn default_render_output(output: SharedString) -> AnyElement {
div()
.child(Label::new("Result:").size(LabelSize::Small))
.child(Label::new(output))
.into_any_element()
}
pub fn default_render_error(err: SharedString) -> AnyElement {
div()
.child(Label::new("Error:").size(LabelSize::Small))
.child(Label::new(err))
.into_any_element()
}

View File

@@ -1,11 +1,10 @@
use std::sync::Arc;
use anyhow::{anyhow, Result};
use assistant_tool::Tool;
use chrono::{Local, Utc};
use gpui::{App, Task, WeakEntity, Window};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
@@ -41,6 +40,7 @@ impl Tool for NowTool {
fn run(
self: Arc<Self>,
input: serde_json::Value,
_thread_id: Arc<str>,
_workspace: WeakEntity<workspace::Workspace>,
_window: &mut Window,
_cx: &mut App,

View File

@@ -51,6 +51,7 @@ impl Tool for ContextServerTool {
fn run(
self: std::sync::Arc<Self>,
input: serde_json::Value,
_thread_id: Arc<str>,
_workspace: gpui::WeakEntity<workspace::Workspace>,
_: &mut Window,
cx: &mut App,

View File

@@ -15,10 +15,17 @@ doctest = false
[dependencies]
anyhow.workspace = true
assistant_tool.workspace = true
collections.workspace = true
gpui.workspace = true
language.workspace = true
mlua.workspace = true
parking_lot.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
workspace.workspace = true
regex.workspace = true
rich_text.workspace = true
settings.workspace = true
theme.workspace = true
ui.workspace = true

View File

@@ -1,16 +1,25 @@
use anyhow::anyhow;
use assistant_tool::{Tool, ToolRegistry};
use gpui::{App, AppContext as _, Task, WeakEntity, Window};
use collections::HashMap;
use gpui::{
AnyElement, App, AppContext as _, Global, HighlightStyle, StyledText, Task, WeakEntity, Window,
};
use language::Language;
use mlua::{Function, Lua, MultiValue, Result, UserData, UserDataMethods};
use parking_lot::Mutex;
use rich_text::{self, Highlight};
use schemars::JsonSchema;
use serde::Deserialize;
use settings::Settings;
use std::ops::Range;
use std::{
cell::RefCell,
collections::HashMap,
path::{Path, PathBuf},
rc::Rc,
sync::Arc,
};
use theme::ThemeSettings;
use ui::prelude::*;
use workspace::Workspace;
pub fn init(cx: &App) {
@@ -23,11 +32,11 @@ struct ScriptingToolInput {
lua_script: String,
}
struct ScriptingTool;
pub struct ScriptingTool;
impl Tool for ScriptingTool {
fn name(&self) -> String {
"lua-interpreter".into()
"lua-interpreter".to_string()
}
fn description(&self) -> String {
@@ -55,6 +64,7 @@ string being a match that was found within the file)."#.into()
fn run(
self: Arc<Self>,
input: serde_json::Value,
thread_id: Arc<str>,
workspace: WeakEntity<Workspace>,
_window: &mut Window,
cx: &mut App,
@@ -81,15 +91,110 @@ string being a match that was found within the file)."#.into()
Ok(input) => input,
};
let lua_script = input.lua_script;
let fs_changes = ScriptingToolFileChanges::get(thread_id, cx);
cx.background_spawn(async move {
let fs_changes = HashMap::new();
let output = run_sandboxed_lua(&lua_script, fs_changes, root_dir)
.map_err(|err| anyhow!(format!("{err}")))?;
let output = output.printed_lines.join("\n");
.map_err(|err| anyhow!(format!("{err}")))?
.printed_lines
.join("\n");
Ok(format!("The script output the following:\n{output}"))
if output.is_empty() {
Ok("(The script had no output.)".to_string())
} else {
Ok(output.to_string())
}
})
}
fn render_input(
self: Arc<Self>,
input: serde_json::Value,
lua_language: Option<Arc<Language>>,
cx: &mut App,
) -> AnyElement {
let theme_settings = ThemeSettings::get_global(cx);
let theme = cx.theme();
if let Ok(input) = serde_json::from_value::<ScriptingToolInput>(input.clone()) {
let label;
// Use Lua syntax highlighting, if available
if let Some(lua_language) = &lua_language.clone() {
let mut highlights = Vec::new();
let mut buf = String::new();
rich_text::render_code(&mut buf, &mut highlights, &input.lua_script, lua_language);
let gpui_highlights: Vec<(Range<usize>, HighlightStyle)> = highlights
.iter()
.map(|(range, highlight)| {
let style = match highlight {
Highlight::Code => Default::default(),
Highlight::Id(id) => id.style(theme.syntax()).unwrap_or_default(),
Highlight::InlineCode(_link) => Default::default(), // Links won't come up
Highlight::Highlight(highlight) => *highlight,
_ => HighlightStyle::default(),
};
(range.clone(), style)
})
.collect();
label = StyledText::new(buf)
.with_highlights(gpui_highlights)
.into_any_element();
} else {
label = Label::new(&input.lua_script).into_any_element();
};
div()
.when(lua_language.is_some(), |this| {
this.bg(theme.colors().editor_background)
})
.font_family(theme_settings.buffer_font.family.clone())
.child(label)
.into_any_element()
} else {
// Fallback to JSON if lua_language is unavailable
Label::new(serde_json::to_string_pretty(&input).unwrap_or_default()).into_any_element()
}
}
fn render_output(self: Arc<Self>, output: SharedString, cx: &mut App) -> AnyElement {
let theme_settings = ThemeSettings::get_global(cx);
div()
.font_family(theme_settings.buffer_font.family.clone())
.child(output)
.into_any_element()
}
}
// Accumulates file changes made during script execution.
struct ScriptingToolFileChanges {
// Assistant thread ID that these files changes are associated with. Only file changes for one
// thread are supported to avoid the need for dropping these when the associated `Thread` is
// dropped.
thread_id: Arc<str>,
// Map from path to file contents for files changed by script execution.
file_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
}
impl Global for ScriptingToolFileChanges {}
impl ScriptingToolFileChanges {
fn get(thread_id: Arc<str>, cx: &mut App) -> Arc<Mutex<HashMap<PathBuf, Vec<u8>>>> {
match cx.try_global::<ScriptingToolFileChanges>() {
Some(global) if global.thread_id == thread_id => global.file_changes.clone(),
_ => {
let file_changes = Arc::new(Mutex::new(HashMap::default()));
cx.set_global(ScriptingToolFileChanges {
thread_id,
file_changes: file_changes.clone(),
});
file_changes
}
}
}
}
const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
@@ -125,7 +230,7 @@ fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function>
fn search(
lua: &Lua,
_fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>,
_fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
root_dir: PathBuf,
) -> Result<Function> {
lua.create_function(move |lua, regex: String| {
@@ -221,7 +326,7 @@ fn search(
/// Sandboxed io.open() function in Lua.
fn io_open(
lua: &Lua,
fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>,
fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
root_dir: PathBuf,
) -> Result<Function> {
lua.create_function(move |lua, (path_str, mode): (String, Option<String>)| {
@@ -281,7 +386,7 @@ fn io_open(
// Don't actually write to disk; instead, just update fs_changes.
let path_buf = PathBuf::from(&path);
fs_changes
.borrow_mut()
.lock()
.insert(path_buf.clone(), content_vec.clone());
}
@@ -333,21 +438,28 @@ fn io_open(
return Ok((Some(file), String::new()));
}
let is_in_changes = fs_changes.borrow().contains_key(&path);
let file_exists = is_in_changes || path.exists();
let mut file_content = Vec::new();
let is_in_changes;
let file_exists;
let file_content = {
let fs_changes = fs_changes.lock();
is_in_changes = fs_changes.contains_key(&path);
file_exists = is_in_changes || path.exists();
if file_exists && !truncate {
if is_in_changes {
file_content = fs_changes.borrow().get(&path).unwrap().clone();
} else {
// Try to read existing content if file exists and we're not truncating
match std::fs::read(&path) {
Ok(content) => file_content = content,
Err(e) => return Ok((None, format!("Error reading file: {}", e))),
if file_exists && !truncate {
if is_in_changes {
fs_changes.get(&path).unwrap().clone()
} else {
drop(fs_changes); // Release the lock before starting the read.
match std::fs::read(&path) {
Ok(content) => content,
Err(e) => return Ok((None, format!("Error reading file: {}", e))),
}
}
} else {
Vec::new()
}
}
};
// If in append mode, position should be at the end
let position = if append && file_exists {
@@ -355,6 +467,7 @@ fn io_open(
} else {
0
};
file.set("__position", position)?;
file.set(
"__content",
@@ -582,9 +695,7 @@ fn io_open(
// Update fs_changes
let path = file_userdata.get::<String>("__path")?;
let path_buf = PathBuf::from(path);
fs_changes
.borrow_mut()
.insert(path_buf, content_vec.clone());
fs_changes.lock().insert(path_buf, content_vec.clone());
Ok(true)
})?
@@ -599,7 +710,7 @@ fn io_open(
/// Runs a Lua script in a sandboxed environment and returns the printed lines
pub fn run_sandboxed_lua(
script: &str,
fs_changes: HashMap<PathBuf, Vec<u8>>,
fs: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
root_dir: PathBuf,
) -> Result<ScriptOutput> {
let lua = Lua::new();
@@ -608,7 +719,6 @@ pub fn run_sandboxed_lua(
// Track the lines the Lua script prints out.
let printed_lines = Rc::new(RefCell::new(Vec::new()));
let fs = Rc::new(RefCell::new(fs_changes));
globals.set("sb_print", print(&lua, printed_lines.clone())?)?;
globals.set("search", search(&lua, fs.clone(), root_dir.clone())?)?;
@@ -623,23 +733,18 @@ pub fn run_sandboxed_lua(
printed_lines: Rc::try_unwrap(printed_lines)
.expect("There are still other references to printed_lines")
.into_inner(),
fs_changes: Rc::try_unwrap(fs)
.expect("There are still other references to fs_changes")
.into_inner(),
})
}
pub struct ScriptOutput {
printed_lines: Vec<String>,
#[allow(dead_code)]
fs_changes: HashMap<PathBuf, Vec<u8>>,
}
#[allow(dead_code)]
impl ScriptOutput {
fn fs_diff(&self) -> HashMap<PathBuf, String> {
let mut diff_map = HashMap::new();
for (path, content) in &self.fs_changes {
fn fs_diff(&self, fs_changes: &HashMap<PathBuf, Vec<u8>>) -> HashMap<PathBuf, String> {
let mut diff_map = HashMap::default();
for (path, content) in fs_changes {
let diff = if path.exists() {
// Read the current file content
match std::fs::read(path) {
@@ -758,9 +863,9 @@ impl ScriptOutput {
diff_map
}
fn diff_to_string(&self) -> String {
fn diff_to_string(&self, fs_changes: &HashMap<PathBuf, Vec<u8>>) -> String {
let mut answer = String::new();
let diff_map = self.fs_diff();
let diff_map = self.fs_diff(fs_changes);
if diff_map.is_empty() {
return "No changes to files".to_string();