Compare commits
5 Commits
tool-rende
...
accept-lua
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a15c63b25 | ||
|
|
d6a8b4cfb1 | ||
|
|
264ac61210 | ||
|
|
ad4742a5b8 | ||
|
|
b0d4abb82e |
10
Cargo.lock
generated
10
Cargo.lock
generated
@@ -448,6 +448,7 @@ dependencies = [
|
|||||||
"command_palette_hooks",
|
"command_palette_hooks",
|
||||||
"context_server",
|
"context_server",
|
||||||
"db",
|
"db",
|
||||||
|
"diff",
|
||||||
"editor",
|
"editor",
|
||||||
"feature_flags",
|
"feature_flags",
|
||||||
"file_icons",
|
"file_icons",
|
||||||
@@ -644,9 +645,11 @@ dependencies = [
|
|||||||
"collections",
|
"collections",
|
||||||
"derive_more",
|
"derive_more",
|
||||||
"gpui",
|
"gpui",
|
||||||
|
"language",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"ui",
|
||||||
"workspace",
|
"workspace",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -11887,12 +11890,19 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"assistant_tool",
|
"assistant_tool",
|
||||||
|
"collections",
|
||||||
"gpui",
|
"gpui",
|
||||||
|
"language",
|
||||||
"mlua",
|
"mlua",
|
||||||
|
"parking_lot",
|
||||||
"regex",
|
"regex",
|
||||||
|
"rich_text",
|
||||||
"schemars",
|
"schemars",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"settings",
|
||||||
|
"theme",
|
||||||
|
"ui",
|
||||||
"workspace",
|
"workspace",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -415,6 +415,7 @@ core-foundation-sys = "0.8.6"
|
|||||||
ctor = "0.4.0"
|
ctor = "0.4.0"
|
||||||
dashmap = "6.0"
|
dashmap = "6.0"
|
||||||
derive_more = "0.99.17"
|
derive_more = "0.99.17"
|
||||||
|
diff = "0.1.13"
|
||||||
dirs = "4.0"
|
dirs = "4.0"
|
||||||
ec4rs = "1.1"
|
ec4rs = "1.1"
|
||||||
emojis = "0.6.1"
|
emojis = "0.6.1"
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ collections.workspace = true
|
|||||||
command_palette_hooks.workspace = true
|
command_palette_hooks.workspace = true
|
||||||
context_server.workspace = true
|
context_server.workspace = true
|
||||||
db.workspace = true
|
db.workspace = true
|
||||||
|
diff.workspace = true
|
||||||
editor.workspace = true
|
editor.workspace = true
|
||||||
feature_flags.workspace = true
|
feature_flags.workspace = true
|
||||||
file_icons.workspace = true
|
file_icons.workspace = true
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
use std::sync::Arc;
|
use assistant_tool::{ToolFileChanges, ToolRegistry, ToolWorkingSet};
|
||||||
|
|
||||||
use assistant_tool::ToolWorkingSet;
|
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use editor::{Editor, MultiBuffer};
|
use editor::{Editor, MultiBuffer};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
@@ -8,10 +6,11 @@ use gpui::{
|
|||||||
Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
|
Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
|
||||||
Task, TextStyleRefinement, UnderlineStyle, WeakEntity,
|
Task, TextStyleRefinement, UnderlineStyle, WeakEntity,
|
||||||
};
|
};
|
||||||
use language::{Buffer, LanguageRegistry};
|
use language::{Buffer, Language, LanguageRegistry};
|
||||||
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
|
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
|
||||||
use markdown::{Markdown, MarkdownStyle};
|
use markdown::{Markdown, MarkdownStyle};
|
||||||
use settings::Settings as _;
|
use settings::Settings as _;
|
||||||
|
use std::sync::Arc;
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use ui::{prelude::*, Disclosure, KeyBinding};
|
use ui::{prelude::*, Disclosure, KeyBinding};
|
||||||
use util::ResultExt as _;
|
use util::ResultExt as _;
|
||||||
@@ -35,6 +34,7 @@ pub struct ActiveThread {
|
|||||||
editing_message: Option<(MessageId, EditMessageState)>,
|
editing_message: Option<(MessageId, EditMessageState)>,
|
||||||
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
|
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
|
||||||
last_error: Option<ThreadError>,
|
last_error: Option<ThreadError>,
|
||||||
|
lua_language: Option<Arc<Language>>, // Used for syntax highlighting in the Lua script tool
|
||||||
_subscriptions: Vec<Subscription>,
|
_subscriptions: Vec<Subscription>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,9 +76,23 @@ impl ActiveThread {
|
|||||||
}),
|
}),
|
||||||
editing_message: None,
|
editing_message: None,
|
||||||
last_error: None,
|
last_error: None,
|
||||||
|
lua_language: None,
|
||||||
_subscriptions: subscriptions,
|
_subscriptions: subscriptions,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Initialize the Lua language in the background, for syntax highlighting.
|
||||||
|
let language_registry = this.language_registry.clone();
|
||||||
|
cx.spawn(|this, mut cx| async move {
|
||||||
|
if let Ok(lua_language) = language_registry.language_for_name("Lua").await {
|
||||||
|
this.update(&mut cx, |this, _| {
|
||||||
|
this.lua_language = Some(lua_language);
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok::<_, anyhow::Error>(())
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
|
||||||
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
|
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
|
||||||
this.push_message(&message.id, message.text.clone(), window, cx);
|
this.push_message(&message.id, message.text.clone(), window, cx);
|
||||||
}
|
}
|
||||||
@@ -294,9 +308,9 @@ impl ActiveThread {
|
|||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
ThreadEvent::UsePendingTools => {
|
ThreadEvent::UsePendingTools => {
|
||||||
let pending_tool_uses = self
|
let thread = self.thread.read(cx);
|
||||||
.thread
|
let thread_id = thread.id().0.clone();
|
||||||
.read(cx)
|
let pending_tool_uses = thread
|
||||||
.pending_tool_uses()
|
.pending_tool_uses()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter(|tool_use| tool_use.status.is_idle())
|
.filter(|tool_use| tool_use.status.is_idle())
|
||||||
@@ -305,7 +319,13 @@ impl ActiveThread {
|
|||||||
|
|
||||||
for tool_use in pending_tool_uses {
|
for tool_use in pending_tool_uses {
|
||||||
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
|
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
|
||||||
let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
|
let task = tool.run(
|
||||||
|
tool_use.input,
|
||||||
|
thread_id.clone(),
|
||||||
|
self.workspace.clone(),
|
||||||
|
window,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
|
||||||
self.thread.update(cx, |thread, cx| {
|
self.thread.update(cx, |thread, cx| {
|
||||||
thread.insert_tool_output(tool_use.id.clone(), task, cx);
|
thread.insert_tool_output(tool_use.id.clone(), task, cx);
|
||||||
@@ -324,6 +344,16 @@ impl ActiveThread {
|
|||||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||||
if let Some(model) = model_registry.active_model() {
|
if let Some(model) = model_registry.active_model() {
|
||||||
self.thread.update(cx, |thread, cx| {
|
self.thread.update(cx, |thread, cx| {
|
||||||
|
if let Some(global) = cx.try_global::<ToolFileChanges>() {
|
||||||
|
let thread_id = thread.id().0.clone();
|
||||||
|
|
||||||
|
if global.thread_id == thread_id
|
||||||
|
&& !global.file_changes.lock().is_empty()
|
||||||
|
{
|
||||||
|
println!("Changes:\n{}", self.handle_fs_changes(cx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Insert a user message to contain the tool results.
|
// Insert a user message to contain the tool results.
|
||||||
thread.insert_user_message(
|
thread.insert_user_message(
|
||||||
// TODO: Sending up a user message without any content results in the model sending back
|
// TODO: Sending up a user message without any content results in the model sending back
|
||||||
@@ -360,6 +390,46 @@ impl ActiveThread {
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn handle_fs_changes(&self, cx: &mut Context<Thread>) -> String {
|
||||||
|
if let Some(global) = cx.try_global::<ToolFileChanges>() {
|
||||||
|
let fs_changes = global.file_changes.lock().clone();
|
||||||
|
if !fs_changes.is_empty() {
|
||||||
|
let mut diff_output = String::new();
|
||||||
|
|
||||||
|
for (path, content) in fs_changes {
|
||||||
|
let path_str = path.to_string_lossy();
|
||||||
|
diff_output.push_str(&format!("--- {}\n+++ {}\n", path_str, path_str));
|
||||||
|
|
||||||
|
let old_content = match std::fs::read(&path) {
|
||||||
|
Ok(content) => String::from_utf8_lossy(&content).to_string(),
|
||||||
|
Err(_) => String::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let new_content = String::from_utf8_lossy(&content).to_string();
|
||||||
|
|
||||||
|
let diff = diff::lines(&old_content, &new_content);
|
||||||
|
|
||||||
|
for change in diff {
|
||||||
|
match change {
|
||||||
|
diff::Result::Left(l) => diff_output.push_str(&format!("-{}\n", l)),
|
||||||
|
diff::Result::Right(r) => diff_output.push_str(&format!("+{}\n", r)),
|
||||||
|
diff::Result::Both(b, _) => diff_output.push_str(&format!(" {}\n", b)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
diff_output.push_str("\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset fs_changes
|
||||||
|
global.file_changes.lock().clear();
|
||||||
|
|
||||||
|
return diff_output;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
|
||||||
fn start_editing_message(
|
fn start_editing_message(
|
||||||
&mut self,
|
&mut self,
|
||||||
message_id: MessageId,
|
message_id: MessageId,
|
||||||
@@ -654,6 +724,7 @@ impl ActiveThread {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
|
fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
|
||||||
|
let tool = ToolRegistry::global(cx).tool(&tool_use.name);
|
||||||
let is_open = self
|
let is_open = self
|
||||||
.expanded_tool_uses
|
.expanded_tool_uses
|
||||||
.get(&tool_use.id)
|
.get(&tool_use.id)
|
||||||
@@ -719,11 +790,17 @@ impl ActiveThread {
|
|||||||
.px_2p5()
|
.px_2p5()
|
||||||
.border_b_1()
|
.border_b_1()
|
||||||
.border_color(cx.theme().colors().border)
|
.border_color(cx.theme().colors().border)
|
||||||
.child(Label::new("Input:"))
|
.bg(cx.theme().colors().editor_background)
|
||||||
.child(Label::new(
|
.child(match tool.clone() {
|
||||||
serde_json::to_string_pretty(&tool_use.input)
|
Some(tool) => tool.render_input(
|
||||||
.unwrap_or_default(),
|
tool_use.input,
|
||||||
)),
|
self.lua_language.clone(),
|
||||||
|
cx,
|
||||||
|
),
|
||||||
|
None => {
|
||||||
|
assistant_tool::default_render_input(tool_use.input)
|
||||||
|
}
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
.map(|parent| match tool_use.status {
|
.map(|parent| match tool_use.status {
|
||||||
ToolUseStatus::Finished(output) => parent.child(
|
ToolUseStatus::Finished(output) => parent.child(
|
||||||
@@ -731,16 +808,17 @@ impl ActiveThread {
|
|||||||
.gap_0p5()
|
.gap_0p5()
|
||||||
.py_1()
|
.py_1()
|
||||||
.px_2p5()
|
.px_2p5()
|
||||||
.child(Label::new("Result:"))
|
.bg(cx.theme().colors().editor_background)
|
||||||
.child(Label::new(output)),
|
.child(match tool {
|
||||||
|
Some(tool) => tool.render_output(output, cx),
|
||||||
|
None => assistant_tool::default_render_output(output),
|
||||||
|
}),
|
||||||
),
|
),
|
||||||
ToolUseStatus::Error(err) => parent.child(
|
ToolUseStatus::Error(err) => parent.child(
|
||||||
v_flex()
|
v_flex().gap_0p5().py_1().px_2p5().child(match tool {
|
||||||
.gap_0p5()
|
Some(tool) => tool.render_error(err, cx),
|
||||||
.py_1()
|
None => assistant_tool::default_render_output(err),
|
||||||
.px_2p5()
|
}),
|
||||||
.child(Label::new("Error:"))
|
|
||||||
.child(Label::new(err)),
|
|
||||||
),
|
),
|
||||||
ToolUseStatus::Pending | ToolUseStatus::Running => parent,
|
ToolUseStatus::Pending | ToolUseStatus::Running => parent,
|
||||||
}),
|
}),
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ pub enum RequestKind {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
|
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
|
||||||
pub struct ThreadId(Arc<str>);
|
pub struct ThreadId(pub Arc<str>);
|
||||||
|
|
||||||
impl ThreadId {
|
impl ThreadId {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
|
|||||||
@@ -16,7 +16,9 @@ anyhow.workspace = true
|
|||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
derive_more.workspace = true
|
derive_more.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
|
language.workspace = true
|
||||||
parking_lot.workspace = true
|
parking_lot.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
workspace.workspace = true
|
workspace.workspace = true
|
||||||
|
ui.workspace = true
|
||||||
|
|||||||
@@ -1,12 +1,23 @@
|
|||||||
|
mod tool_file_changes;
|
||||||
mod tool_registry;
|
mod tool_registry;
|
||||||
mod tool_working_set;
|
mod tool_working_set;
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use gpui::AnyElement;
|
||||||
|
use gpui::IntoElement;
|
||||||
use gpui::{App, Task, WeakEntity, Window};
|
use gpui::{App, Task, WeakEntity, Window};
|
||||||
|
use language::Language;
|
||||||
|
use ui::div;
|
||||||
|
use ui::Label;
|
||||||
|
use ui::LabelCommon;
|
||||||
|
use ui::LabelSize;
|
||||||
|
use ui::ParentElement;
|
||||||
|
use ui::SharedString;
|
||||||
use workspace::Workspace;
|
use workspace::Workspace;
|
||||||
|
|
||||||
|
pub use crate::tool_file_changes::*;
|
||||||
pub use crate::tool_registry::*;
|
pub use crate::tool_registry::*;
|
||||||
pub use crate::tool_working_set::*;
|
pub use crate::tool_working_set::*;
|
||||||
|
|
||||||
@@ -31,8 +42,52 @@ pub trait Tool: 'static + Send + Sync {
|
|||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
|
thread_id: Arc<str>,
|
||||||
workspace: WeakEntity<Workspace>,
|
workspace: WeakEntity<Workspace>,
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<String>>;
|
) -> Task<Result<String>>;
|
||||||
|
|
||||||
|
/// Renders the tool's input when the user expands it.
|
||||||
|
fn render_input(
|
||||||
|
self: Arc<Self>,
|
||||||
|
input: serde_json::Value,
|
||||||
|
_lua_language: Option<Arc<Language>>,
|
||||||
|
_cx: &mut App,
|
||||||
|
) -> AnyElement {
|
||||||
|
default_render_input(input)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Renders the tool's output when the user expands it.
|
||||||
|
fn render_output(self: Arc<Self>, output: SharedString, _cx: &mut App) -> AnyElement {
|
||||||
|
default_render_output(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Renders the tool's error message when the user expands it.
|
||||||
|
fn render_error(self: Arc<Self>, err: SharedString, _cx: &mut App) -> AnyElement {
|
||||||
|
default_render_error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn default_render_input(input: serde_json::Value) -> AnyElement {
|
||||||
|
div()
|
||||||
|
.child(Label::new("Input:").size(LabelSize::Small))
|
||||||
|
.child(Label::new(
|
||||||
|
serde_json::to_string_pretty(&input).unwrap_or_default(),
|
||||||
|
))
|
||||||
|
.into_any_element()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn default_render_output(output: SharedString) -> AnyElement {
|
||||||
|
div()
|
||||||
|
.child(Label::new("Result:").size(LabelSize::Small))
|
||||||
|
.child(Label::new(output))
|
||||||
|
.into_any_element()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn default_render_error(err: SharedString) -> AnyElement {
|
||||||
|
div()
|
||||||
|
.child(Label::new("Error:").size(LabelSize::Small))
|
||||||
|
.child(Label::new(err))
|
||||||
|
.into_any_element()
|
||||||
}
|
}
|
||||||
|
|||||||
34
crates/assistant_tool/src/tool_file_changes.rs
Normal file
34
crates/assistant_tool/src/tool_file_changes.rs
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
use std::{path::PathBuf, sync::Arc};
|
||||||
|
|
||||||
|
use collections::HashMap;
|
||||||
|
use gpui::Global;
|
||||||
|
use parking_lot::Mutex;
|
||||||
|
use ui::App;
|
||||||
|
|
||||||
|
// Accumulates file changes made during script execution.
|
||||||
|
pub struct ToolFileChanges {
|
||||||
|
// Assistant thread ID that these files changes are associated with. Only file changes for one
|
||||||
|
// thread are supported to avoid the need for dropping these when the associated `Thread` is
|
||||||
|
// dropped.
|
||||||
|
pub thread_id: Arc<str>,
|
||||||
|
// Map from path to file contents for files changed by script execution.
|
||||||
|
pub file_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Global for ToolFileChanges {}
|
||||||
|
|
||||||
|
impl ToolFileChanges {
|
||||||
|
pub fn get(thread_id: Arc<str>, cx: &mut App) -> Arc<Mutex<HashMap<PathBuf, Vec<u8>>>> {
|
||||||
|
match cx.try_global::<ToolFileChanges>() {
|
||||||
|
Some(global) if global.thread_id == thread_id => global.file_changes.clone(),
|
||||||
|
_ => {
|
||||||
|
let file_changes = Arc::new(Mutex::new(HashMap::default()));
|
||||||
|
cx.set_global(ToolFileChanges {
|
||||||
|
thread_id,
|
||||||
|
file_changes: file_changes.clone(),
|
||||||
|
});
|
||||||
|
file_changes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,11 +1,10 @@
|
|||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use assistant_tool::Tool;
|
use assistant_tool::Tool;
|
||||||
use chrono::{Local, Utc};
|
use chrono::{Local, Utc};
|
||||||
use gpui::{App, Task, WeakEntity, Window};
|
use gpui::{App, Task, WeakEntity, Window};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
@@ -41,6 +40,7 @@ impl Tool for NowTool {
|
|||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
|
_thread_id: Arc<str>,
|
||||||
_workspace: WeakEntity<workspace::Workspace>,
|
_workspace: WeakEntity<workspace::Workspace>,
|
||||||
_window: &mut Window,
|
_window: &mut Window,
|
||||||
_cx: &mut App,
|
_cx: &mut App,
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ impl Tool for ContextServerTool {
|
|||||||
fn run(
|
fn run(
|
||||||
self: std::sync::Arc<Self>,
|
self: std::sync::Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
|
_thread_id: Arc<str>,
|
||||||
_workspace: gpui::WeakEntity<workspace::Workspace>,
|
_workspace: gpui::WeakEntity<workspace::Workspace>,
|
||||||
_: &mut Window,
|
_: &mut Window,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
|
|||||||
@@ -15,10 +15,17 @@ doctest = false
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
assistant_tool.workspace = true
|
assistant_tool.workspace = true
|
||||||
|
collections.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
|
language.workspace = true
|
||||||
mlua.workspace = true
|
mlua.workspace = true
|
||||||
|
parking_lot.workspace = true
|
||||||
schemars.workspace = true
|
schemars.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
workspace.workspace = true
|
workspace.workspace = true
|
||||||
regex.workspace = true
|
regex.workspace = true
|
||||||
|
rich_text.workspace = true
|
||||||
|
settings.workspace = true
|
||||||
|
theme.workspace = true
|
||||||
|
ui.workspace = true
|
||||||
|
|||||||
@@ -1,16 +1,25 @@
|
|||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
use assistant_tool::{Tool, ToolRegistry};
|
use assistant_tool::{Tool, ToolFileChanges, ToolRegistry};
|
||||||
use gpui::{App, AppContext as _, Task, WeakEntity, Window};
|
use collections::HashMap;
|
||||||
use mlua::{Function, Lua, MultiValue, Result, UserData, UserDataMethods};
|
use gpui::{
|
||||||
|
AnyElement, App, AppContext as _, HighlightStyle, StyledText, Task, WeakEntity, Window,
|
||||||
|
};
|
||||||
|
use language::Language;
|
||||||
|
use mlua::{Function, HookTriggers, Lua, MultiValue, Result, UserData, UserDataMethods, VmState};
|
||||||
|
use parking_lot::Mutex;
|
||||||
|
use rich_text::{self, Highlight};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
use settings::Settings;
|
||||||
|
use std::ops::Range;
|
||||||
use std::{
|
use std::{
|
||||||
cell::RefCell,
|
cell::RefCell,
|
||||||
collections::HashMap,
|
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
rc::Rc,
|
rc::Rc,
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
use theme::ThemeSettings;
|
||||||
|
use ui::prelude::*;
|
||||||
use workspace::Workspace;
|
use workspace::Workspace;
|
||||||
|
|
||||||
pub fn init(cx: &App) {
|
pub fn init(cx: &App) {
|
||||||
@@ -23,11 +32,11 @@ struct ScriptingToolInput {
|
|||||||
lua_script: String,
|
lua_script: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ScriptingTool;
|
pub struct ScriptingTool;
|
||||||
|
|
||||||
impl Tool for ScriptingTool {
|
impl Tool for ScriptingTool {
|
||||||
fn name(&self) -> String {
|
fn name(&self) -> String {
|
||||||
"lua-interpreter".into()
|
"lua-interpreter".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn description(&self) -> String {
|
fn description(&self) -> String {
|
||||||
@@ -38,13 +47,18 @@ The lua script will have access to `io` and it will run with the current working
|
|||||||
the root of the code base, so you can use it to explore, search, make changes, etc. You can also have
|
the root of the code base, so you can use it to explore, search, make changes, etc. You can also have
|
||||||
the script print things, and I'll tell you what the output was. Note that `io` only has `open`, and
|
the script print things, and I'll tell you what the output was. Note that `io` only has `open`, and
|
||||||
then the file it returns only has the methods read, write, and close - it doesn't have popen or
|
then the file it returns only has the methods read, write, and close - it doesn't have popen or
|
||||||
anything else. Also, I'm going to be putting this Lua script into JSON, so please don't use Lua's
|
anything else. `os` is not available, so don't try to use it. There will be a global called
|
||||||
double quote syntax for string literals - use one of Lua's other syntaxes for string literals, so I
|
`search` which accepts a regex (it's implemented using Rust's regex crate,
|
||||||
don't have to escape the double quotes. There will be a global called `search` which accepts a regex
|
so use that regex syntax) and runs that regex on the contents of every file in the code base
|
||||||
(it's implemented using Rust's regex crate, so use that regex syntax) and runs that regex on the contents
|
(aside from gitignored files), then returns an array of tables with two fields: "path"
|
||||||
of every file in the code base (aside from gitignored files), then returns an array of tables with two
|
(the path to the file that had the matches) and "matches" (an array of strings, with each
|
||||||
fields: "path" (the path to the file that had the matches) and "matches" (an array of strings, with each
|
string being a match that was found within the file).
|
||||||
string being a match that was found within the file)."#.into()
|
|
||||||
|
If I reference something in the code base that you aren't familiar with, I suggest writing
|
||||||
|
multiple `search` calls to find what you're looking for in one script invocation, as opposed
|
||||||
|
to trying multiple times. When you're done, if you have any proposed changes to files,
|
||||||
|
you should write them directly to disk in the script. I'll review them later.
|
||||||
|
"#.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn input_schema(&self) -> serde_json::Value {
|
fn input_schema(&self) -> serde_json::Value {
|
||||||
@@ -55,6 +69,7 @@ string being a match that was found within the file)."#.into()
|
|||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
|
thread_id: Arc<str>,
|
||||||
workspace: WeakEntity<Workspace>,
|
workspace: WeakEntity<Workspace>,
|
||||||
_window: &mut Window,
|
_window: &mut Window,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
@@ -81,15 +96,82 @@ string being a match that was found within the file)."#.into()
|
|||||||
Ok(input) => input,
|
Ok(input) => input,
|
||||||
};
|
};
|
||||||
let lua_script = input.lua_script;
|
let lua_script = input.lua_script;
|
||||||
|
let fs_changes = ToolFileChanges::get(thread_id, cx);
|
||||||
cx.background_spawn(async move {
|
cx.background_spawn(async move {
|
||||||
let fs_changes = HashMap::new();
|
|
||||||
let output = run_sandboxed_lua(&lua_script, fs_changes, root_dir)
|
let output = run_sandboxed_lua(&lua_script, fs_changes, root_dir)
|
||||||
.map_err(|err| anyhow!(format!("{err}")))?;
|
.map_err(|err| anyhow!(format!("{err}")))?
|
||||||
let output = output.printed_lines.join("\n");
|
.printed_lines
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
Ok(format!("The script output the following:\n{output}"))
|
if output.is_empty() {
|
||||||
|
Ok("(The script had no output.)".to_string())
|
||||||
|
} else {
|
||||||
|
Ok(output.to_string())
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn render_input(
|
||||||
|
self: Arc<Self>,
|
||||||
|
input: serde_json::Value,
|
||||||
|
lua_language: Option<Arc<Language>>,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> AnyElement {
|
||||||
|
let theme_settings = ThemeSettings::get_global(cx);
|
||||||
|
let theme = cx.theme();
|
||||||
|
|
||||||
|
if let Ok(input) = serde_json::from_value::<ScriptingToolInput>(input.clone()) {
|
||||||
|
let label;
|
||||||
|
|
||||||
|
// Use Lua syntax highlighting, if available
|
||||||
|
if let Some(lua_language) = &lua_language.clone() {
|
||||||
|
let mut highlights = Vec::new();
|
||||||
|
let mut buf = String::new();
|
||||||
|
|
||||||
|
rich_text::render_code(&mut buf, &mut highlights, &input.lua_script, lua_language);
|
||||||
|
|
||||||
|
let gpui_highlights: Vec<(Range<usize>, HighlightStyle)> = highlights
|
||||||
|
.iter()
|
||||||
|
.map(|(range, highlight)| {
|
||||||
|
let style = match highlight {
|
||||||
|
Highlight::Code => Default::default(),
|
||||||
|
Highlight::Id(id) => id.style(theme.syntax()).unwrap_or_default(),
|
||||||
|
Highlight::InlineCode(_link) => Default::default(), // Links won't come up
|
||||||
|
Highlight::Highlight(highlight) => *highlight,
|
||||||
|
|
||||||
|
_ => HighlightStyle::default(),
|
||||||
|
};
|
||||||
|
(range.clone(), style)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
label = StyledText::new(buf)
|
||||||
|
.with_highlights(gpui_highlights)
|
||||||
|
.into_any_element();
|
||||||
|
} else {
|
||||||
|
label = Label::new(&input.lua_script).into_any_element();
|
||||||
|
};
|
||||||
|
|
||||||
|
div()
|
||||||
|
.when(lua_language.is_some(), |this| {
|
||||||
|
this.bg(theme.colors().editor_background)
|
||||||
|
})
|
||||||
|
.font_family(theme_settings.buffer_font.family.clone())
|
||||||
|
.child(label)
|
||||||
|
.into_any_element()
|
||||||
|
} else {
|
||||||
|
// Fallback to JSON if lua_language is unavailable
|
||||||
|
Label::new(serde_json::to_string_pretty(&input).unwrap_or_default()).into_any_element()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_output(self: Arc<Self>, output: SharedString, cx: &mut App) -> AnyElement {
|
||||||
|
let theme_settings = ThemeSettings::get_global(cx);
|
||||||
|
|
||||||
|
div()
|
||||||
|
.font_family(theme_settings.buffer_font.family.clone())
|
||||||
|
.child(output)
|
||||||
|
.into_any_element()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
|
const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
|
||||||
@@ -125,7 +207,7 @@ fn print(lua: &Lua, printed_lines: Rc<RefCell<Vec<String>>>) -> Result<Function>
|
|||||||
|
|
||||||
fn search(
|
fn search(
|
||||||
lua: &Lua,
|
lua: &Lua,
|
||||||
_fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>,
|
_fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
|
||||||
root_dir: PathBuf,
|
root_dir: PathBuf,
|
||||||
) -> Result<Function> {
|
) -> Result<Function> {
|
||||||
lua.create_function(move |lua, regex: String| {
|
lua.create_function(move |lua, regex: String| {
|
||||||
@@ -221,7 +303,7 @@ fn search(
|
|||||||
/// Sandboxed io.open() function in Lua.
|
/// Sandboxed io.open() function in Lua.
|
||||||
fn io_open(
|
fn io_open(
|
||||||
lua: &Lua,
|
lua: &Lua,
|
||||||
fs_changes: Rc<RefCell<HashMap<PathBuf, Vec<u8>>>>,
|
fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
|
||||||
root_dir: PathBuf,
|
root_dir: PathBuf,
|
||||||
) -> Result<Function> {
|
) -> Result<Function> {
|
||||||
lua.create_function(move |lua, (path_str, mode): (String, Option<String>)| {
|
lua.create_function(move |lua, (path_str, mode): (String, Option<String>)| {
|
||||||
@@ -281,7 +363,7 @@ fn io_open(
|
|||||||
// Don't actually write to disk; instead, just update fs_changes.
|
// Don't actually write to disk; instead, just update fs_changes.
|
||||||
let path_buf = PathBuf::from(&path);
|
let path_buf = PathBuf::from(&path);
|
||||||
fs_changes
|
fs_changes
|
||||||
.borrow_mut()
|
.lock()
|
||||||
.insert(path_buf.clone(), content_vec.clone());
|
.insert(path_buf.clone(), content_vec.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -333,21 +415,28 @@ fn io_open(
|
|||||||
return Ok((Some(file), String::new()));
|
return Ok((Some(file), String::new()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let is_in_changes = fs_changes.borrow().contains_key(&path);
|
let is_in_changes;
|
||||||
let file_exists = is_in_changes || path.exists();
|
let file_exists;
|
||||||
let mut file_content = Vec::new();
|
let file_content = {
|
||||||
|
let fs_changes = fs_changes.lock();
|
||||||
|
is_in_changes = fs_changes.contains_key(&path);
|
||||||
|
file_exists = is_in_changes || path.exists();
|
||||||
|
|
||||||
if file_exists && !truncate {
|
if file_exists && !truncate {
|
||||||
if is_in_changes {
|
if is_in_changes {
|
||||||
file_content = fs_changes.borrow().get(&path).unwrap().clone();
|
fs_changes.get(&path).unwrap().clone()
|
||||||
} else {
|
} else {
|
||||||
// Try to read existing content if file exists and we're not truncating
|
drop(fs_changes); // Release the lock before starting the read.
|
||||||
match std::fs::read(&path) {
|
|
||||||
Ok(content) => file_content = content,
|
match std::fs::read(&path) {
|
||||||
Err(e) => return Ok((None, format!("Error reading file: {}", e))),
|
Ok(content) => content,
|
||||||
|
Err(e) => return Ok((None, format!("Error reading file: {}", e))),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
Vec::new()
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
// If in append mode, position should be at the end
|
// If in append mode, position should be at the end
|
||||||
let position = if append && file_exists {
|
let position = if append && file_exists {
|
||||||
@@ -355,6 +444,7 @@ fn io_open(
|
|||||||
} else {
|
} else {
|
||||||
0
|
0
|
||||||
};
|
};
|
||||||
|
|
||||||
file.set("__position", position)?;
|
file.set("__position", position)?;
|
||||||
file.set(
|
file.set(
|
||||||
"__content",
|
"__content",
|
||||||
@@ -582,9 +672,7 @@ fn io_open(
|
|||||||
// Update fs_changes
|
// Update fs_changes
|
||||||
let path = file_userdata.get::<String>("__path")?;
|
let path = file_userdata.get::<String>("__path")?;
|
||||||
let path_buf = PathBuf::from(path);
|
let path_buf = PathBuf::from(path);
|
||||||
fs_changes
|
fs_changes.lock().insert(path_buf, content_vec.clone());
|
||||||
.borrow_mut()
|
|
||||||
.insert(path_buf, content_vec.clone());
|
|
||||||
|
|
||||||
Ok(true)
|
Ok(true)
|
||||||
})?
|
})?
|
||||||
@@ -599,16 +687,44 @@ fn io_open(
|
|||||||
/// Runs a Lua script in a sandboxed environment and returns the printed lines
|
/// Runs a Lua script in a sandboxed environment and returns the printed lines
|
||||||
pub fn run_sandboxed_lua(
|
pub fn run_sandboxed_lua(
|
||||||
script: &str,
|
script: &str,
|
||||||
fs_changes: HashMap<PathBuf, Vec<u8>>,
|
fs: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
|
||||||
root_dir: PathBuf,
|
root_dir: PathBuf,
|
||||||
) -> Result<ScriptOutput> {
|
) -> Result<ScriptOutput> {
|
||||||
let lua = Lua::new();
|
let lua = Lua::new();
|
||||||
lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
|
lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
|
||||||
|
lua.set_hook(
|
||||||
|
HookTriggers::new().every_nth_instruction(2048),
|
||||||
|
|_lua, _| {
|
||||||
|
// Check if we need to yield to prevent long-running scripts
|
||||||
|
static EXECUTION_START: std::sync::atomic::AtomicU64 =
|
||||||
|
std::sync::atomic::AtomicU64::new(0);
|
||||||
|
let now = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_millis() as u64;
|
||||||
|
|
||||||
|
// Initialize start time on first call
|
||||||
|
let start = EXECUTION_START.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
|
if start == 0 {
|
||||||
|
EXECUTION_START.store(now, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
return Ok(VmState::Continue);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if execution time exceeds 5 seconds
|
||||||
|
if now - start > 5000 {
|
||||||
|
EXECUTION_START.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
return Err(mlua::Error::runtime(
|
||||||
|
"Script execution timed out after 5 seconds",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(VmState::Continue)
|
||||||
|
},
|
||||||
|
); // 2 GB
|
||||||
let globals = lua.globals();
|
let globals = lua.globals();
|
||||||
|
|
||||||
// Track the lines the Lua script prints out.
|
// Track the lines the Lua script prints out.
|
||||||
let printed_lines = Rc::new(RefCell::new(Vec::new()));
|
let printed_lines = Rc::new(RefCell::new(Vec::new()));
|
||||||
let fs = Rc::new(RefCell::new(fs_changes));
|
|
||||||
|
|
||||||
globals.set("sb_print", print(&lua, printed_lines.clone())?)?;
|
globals.set("sb_print", print(&lua, printed_lines.clone())?)?;
|
||||||
globals.set("search", search(&lua, fs.clone(), root_dir.clone())?)?;
|
globals.set("search", search(&lua, fs.clone(), root_dir.clone())?)?;
|
||||||
@@ -623,23 +739,18 @@ pub fn run_sandboxed_lua(
|
|||||||
printed_lines: Rc::try_unwrap(printed_lines)
|
printed_lines: Rc::try_unwrap(printed_lines)
|
||||||
.expect("There are still other references to printed_lines")
|
.expect("There are still other references to printed_lines")
|
||||||
.into_inner(),
|
.into_inner(),
|
||||||
fs_changes: Rc::try_unwrap(fs)
|
|
||||||
.expect("There are still other references to fs_changes")
|
|
||||||
.into_inner(),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ScriptOutput {
|
pub struct ScriptOutput {
|
||||||
printed_lines: Vec<String>,
|
printed_lines: Vec<String>,
|
||||||
#[allow(dead_code)]
|
|
||||||
fs_changes: HashMap<PathBuf, Vec<u8>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
impl ScriptOutput {
|
impl ScriptOutput {
|
||||||
fn fs_diff(&self) -> HashMap<PathBuf, String> {
|
fn fs_diff(&self, fs_changes: &HashMap<PathBuf, Vec<u8>>) -> HashMap<PathBuf, String> {
|
||||||
let mut diff_map = HashMap::new();
|
let mut diff_map = HashMap::default();
|
||||||
for (path, content) in &self.fs_changes {
|
for (path, content) in fs_changes {
|
||||||
let diff = if path.exists() {
|
let diff = if path.exists() {
|
||||||
// Read the current file content
|
// Read the current file content
|
||||||
match std::fs::read(path) {
|
match std::fs::read(path) {
|
||||||
@@ -758,9 +869,9 @@ impl ScriptOutput {
|
|||||||
diff_map
|
diff_map
|
||||||
}
|
}
|
||||||
|
|
||||||
fn diff_to_string(&self) -> String {
|
fn diff_to_string(&self, fs_changes: &HashMap<PathBuf, Vec<u8>>) -> String {
|
||||||
let mut answer = String::new();
|
let mut answer = String::new();
|
||||||
let diff_map = self.fs_diff();
|
let diff_map = self.fs_diff(fs_changes);
|
||||||
|
|
||||||
if diff_map.is_empty() {
|
if diff_map.is_empty() {
|
||||||
return "No changes to files".to_string();
|
return "No changes to files".to_string();
|
||||||
|
|||||||
Reference in New Issue
Block a user