Compare commits
3 Commits
tracing-ac
...
streaming-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
44501581ee | ||
|
|
ae95142cc8 | ||
|
|
b1b8d596b9 |
1411
Cargo.lock
generated
1411
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -22,3 +22,6 @@ serde.workspace = true
|
|||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
workspace.workspace = true
|
workspace.workspace = true
|
||||||
regex.workspace = true
|
regex.workspace = true
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
editor = { workspace = true, features = ["test-support"] }
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
mod streaming_json;
|
||||||
|
mod streaming_lua;
|
||||||
|
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
use assistant_tool::{Tool, ToolRegistry};
|
use assistant_tool::{Tool, ToolRegistry};
|
||||||
use gpui::{App, AppContext as _, Task, WeakEntity, Window};
|
use gpui::{App, AppContext as _, Task, WeakEntity, Window};
|
||||||
|
|||||||
152
crates/scripting_tool/src/streaming_json.rs
Normal file
152
crates/scripting_tool/src/streaming_json.rs
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
/// This module works with streaming_lua to allow us to run fragments of
|
||||||
|
/// Lua scripts that come back from LLM JSON tool calls immediately as they arrive,
|
||||||
|
/// even when the full script (and the full JSON) has not been received yet.
|
||||||
|
|
||||||
|
pub fn from_json(json_str: &str) {
|
||||||
|
// The JSON structure we're looking for is very simple:
|
||||||
|
// 1. Open curly bracket
|
||||||
|
// 2. Optional whitespace
|
||||||
|
// 3. Quoted key - either "lua_script" or "description" (if description, just parse it)
|
||||||
|
// 4. Colon
|
||||||
|
// 5. Optional whitespace
|
||||||
|
// 6. Open quote
|
||||||
|
// 7. Now we start streaming until we see a closed quote
|
||||||
|
|
||||||
|
// TODO all of this needs to be stored in state in a struct instead of in variables,
|
||||||
|
// and that includes the iterator part.
|
||||||
|
let mut chars = json_str.trim_start().chars().peekable();
|
||||||
|
|
||||||
|
// Skip the opening curly brace
|
||||||
|
if chars.next() != Some('{') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let key = parse_key(&mut chars);
|
||||||
|
|
||||||
|
if key.map(|k| k.as_str()) == Some("description") {
|
||||||
|
// TODO parse the description here
|
||||||
|
parse_comma_then_quote(&mut chars);
|
||||||
|
if parse_key(&mut chars).map(|k| k.as_str()) != Some("lua_script") {
|
||||||
|
return; // This was the only remaining valid option.
|
||||||
|
}
|
||||||
|
// TODO parse the script here, remembering to s/backslash//g to unescape everything.
|
||||||
|
} else if key.map(|k| k.as_str()) == Some("lua_script") {
|
||||||
|
// TODO parse the script here, remembering to s/backslash//g to unescape everything.
|
||||||
|
parse_comma_then_quote(&mut chars);
|
||||||
|
if parse_key(&mut chars).map(|k| k.as_str()) != Some("description") {
|
||||||
|
return; // This was the only remaining valid option.
|
||||||
|
}
|
||||||
|
// TODO parse the description here
|
||||||
|
} else {
|
||||||
|
// The key wasn't one of the two valid options.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse value
|
||||||
|
let mut value = String::new();
|
||||||
|
let mut escape_next = false;
|
||||||
|
|
||||||
|
while let Some(c) = chars.next() {
|
||||||
|
if escape_next {
|
||||||
|
value.push(match c {
|
||||||
|
'n' => '\n',
|
||||||
|
't' => '\t',
|
||||||
|
'r' => '\r',
|
||||||
|
'\\' => '\\',
|
||||||
|
'"' => '"',
|
||||||
|
_ => c,
|
||||||
|
});
|
||||||
|
escape_next = false;
|
||||||
|
} else if c == '\\' {
|
||||||
|
escape_next = true;
|
||||||
|
} else if c == '"' {
|
||||||
|
break; // End of value
|
||||||
|
} else {
|
||||||
|
value.push(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the parsed key-value pair
|
||||||
|
match key.as_str() {
|
||||||
|
"lua_script" => {
|
||||||
|
// Handle the lua script
|
||||||
|
println!("Found lua script: {}", value);
|
||||||
|
}
|
||||||
|
"description" => {
|
||||||
|
// Handle the description
|
||||||
|
println!("Found description: {}", value);
|
||||||
|
}
|
||||||
|
_ => {} // Should not reach here due to earlier check
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_key(chars: &mut impl Iterator<Item = char>) -> Option<String> {
|
||||||
|
// Skip whitespace until we reach the start of the key
|
||||||
|
while let Some(c) = chars.next() {
|
||||||
|
if c.is_whitespace() {
|
||||||
|
// Consume the whitespace and continue
|
||||||
|
} else if c == '"' {
|
||||||
|
break; // Found the start of the key
|
||||||
|
} else {
|
||||||
|
return None; // Invalid format - expected a quote to start the key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the key. We don't need to escape backslashes because the exact key
|
||||||
|
// we expect does not include backslashes or quotes.
|
||||||
|
let mut key = String::new();
|
||||||
|
|
||||||
|
while let Some(c) = chars.next() {
|
||||||
|
if c == '"' {
|
||||||
|
break; // End of key
|
||||||
|
}
|
||||||
|
key.push(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip colon and whitespace and next opening quote.
|
||||||
|
let mut found_colon = false;
|
||||||
|
while let Some(c) = chars.next() {
|
||||||
|
if c == ':' {
|
||||||
|
found_colon = true;
|
||||||
|
} else if found_colon && !c.is_whitespace() {
|
||||||
|
if c == '"' {
|
||||||
|
break; // Found the opening quote
|
||||||
|
}
|
||||||
|
return None; // Invalid format - expected a quote after colon and whitespace
|
||||||
|
} else if !c.is_whitespace() {
|
||||||
|
return None; // Invalid format - expected whitespace or colon
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_comma_then_quote(chars: &mut impl Iterator<Item = char>) -> bool {
|
||||||
|
// Skip any whitespace
|
||||||
|
while let Some(&c) = chars.peek() {
|
||||||
|
if !c.is_whitespace() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
chars.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for comma
|
||||||
|
if chars.next() != Some(',') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip any whitespace after the comma
|
||||||
|
while let Some(&c) = chars.peek() {
|
||||||
|
if !c.is_whitespace() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
chars.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for opening quote
|
||||||
|
if chars.next() != Some('"') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
268
crates/scripting_tool/src/streaming_lua.rs
Normal file
268
crates/scripting_tool/src/streaming_lua.rs
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
/// This module accepts fragments of Lua code from LLM responses, and executes
|
||||||
|
/// them as they come in (to the extent possible) rather than having to wait
|
||||||
|
/// for the entire script to arrive to execute it. (Since these are tool calls,
|
||||||
|
/// they will presumably come back in JSON; it's up to the caller to deal with
|
||||||
|
/// parsing the JSON, escaping `\\` and `\"` in the JSON-quoted Lua, etc.)
|
||||||
|
///
|
||||||
|
/// By design, Lua does not preserve top-level locals across chunks ("chunk" is a
|
||||||
|
/// Lua term for a chunk of Lua code that can be executed), and chunks are the
|
||||||
|
/// smallest unit of execution you can run in Lua. To make sure that top-level
|
||||||
|
/// locals the LLM writes are preserved across multiple silently translates
|
||||||
|
/// locals to globals. This should be harmless for our use case, because we only
|
||||||
|
/// have a single "file" and not multiple files where the distinction could matter.
|
||||||
|
///
|
||||||
|
/// Since fragments will invariably arrive that don't happen to correspond to valid
|
||||||
|
/// Lua chunks (e.g. maybe they have an opening quote for a string literal and the
|
||||||
|
/// close quote will be coming in the next fragment), we use a simple heuristic to
|
||||||
|
/// split them up: we take each fragment and split it into lines, and then whenever
|
||||||
|
/// we have a complete line, we send it to Lua to process as a chunk. If it comes back
|
||||||
|
/// with a syntax error due to it being incomplete (which mlua tells us), then we
|
||||||
|
/// know to keep waiting for more lines and try again.
|
||||||
|
///
|
||||||
|
/// Eventually we'll either succeed, or else the response will end and we'll know it
|
||||||
|
/// had an actual syntax error. (Again, it's the caller's responsibility to deal
|
||||||
|
/// with detecting when the response ends due to the JSON quote having finally closed.)
|
||||||
|
///
|
||||||
|
/// This heuristic relies on the assumption that the LLM is generating normal-looking
|
||||||
|
/// Lua code where statements are split using newlines rather than semicolons.
|
||||||
|
/// In practice, this is a safe assumption.
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
struct ChunkBuffer {
|
||||||
|
buffer: String,
|
||||||
|
incomplete_multiline_string: bool,
|
||||||
|
last_newline_index: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChunkBuffer {
|
||||||
|
pub fn receive_chunk(
|
||||||
|
&mut self,
|
||||||
|
src_chunk: &str,
|
||||||
|
exec_chunk: &mut impl FnMut(&str) -> mlua::Result<()>,
|
||||||
|
) -> mlua::Result<()> {
|
||||||
|
self.buffer.push_str(src_chunk);
|
||||||
|
|
||||||
|
// Execute each line until we hit an incomplete parse
|
||||||
|
while let Some(index) = &self.buffer[self.last_newline_index..].find('\n') {
|
||||||
|
let mut index = *index;
|
||||||
|
|
||||||
|
// LLMs can produce incredibly long multiline strings. We don't want to keep
|
||||||
|
// attempting to re-parse those every time a new line of the string comes in.
|
||||||
|
// that would be extremely wasteful! Instead, just keep waiting until it ends.
|
||||||
|
{
|
||||||
|
let line = &self.buffer[self.last_newline_index..index];
|
||||||
|
|
||||||
|
const LOCAL_PREFIX: &str = "local ";
|
||||||
|
|
||||||
|
// It's safe to assume we'll never see a line which
|
||||||
|
// includes both "]]" and "[[" other than single-line
|
||||||
|
// assignments which are just using them to escape quotes.
|
||||||
|
//
|
||||||
|
// If that assumption turns out not to hold, we can always
|
||||||
|
// make this more robust.
|
||||||
|
if line.contains("[[") && !line.contains("]]") {
|
||||||
|
self.incomplete_multiline_string = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// In practice, LLMs produce multiline strings that always end
|
||||||
|
// with the ]] at the start of the line.
|
||||||
|
if line.starts_with("]]") {
|
||||||
|
self.incomplete_multiline_string = false;
|
||||||
|
} else if line.starts_with("local ") {
|
||||||
|
// We can't have top-level locals because they don't preserve
|
||||||
|
// across chunk executions. So just turn locals into globals.
|
||||||
|
// Since this is just one script, they're the same anyway.
|
||||||
|
self.buffer
|
||||||
|
.replace_range(self.last_newline_index..LOCAL_PREFIX.len(), "");
|
||||||
|
|
||||||
|
index -= LOCAL_PREFIX.len();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.last_newline_index = index;
|
||||||
|
|
||||||
|
if self.incomplete_multiline_string {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute all lines up to (and including) this one.
|
||||||
|
match exec_chunk(&self.buffer[..index]) {
|
||||||
|
Ok(()) => {
|
||||||
|
// The chunk executed successfully. Advance the buffer
|
||||||
|
// to reflect the fact that we've executed that code.
|
||||||
|
self.buffer = self.buffer[index + 1..].to_string();
|
||||||
|
self.last_newline_index = 0;
|
||||||
|
}
|
||||||
|
Err(mlua::Error::SyntaxError {
|
||||||
|
incomplete_input: true,
|
||||||
|
message: _,
|
||||||
|
}) => {
|
||||||
|
// If it errored specifically because the input was incomplete, no problem.
|
||||||
|
// We'll keep trying with more and more lines until eventually we find a
|
||||||
|
// sequence of lines that are valid together!
|
||||||
|
}
|
||||||
|
Err(other) => {
|
||||||
|
return Err(other);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn finish(
|
||||||
|
&mut self,
|
||||||
|
exec_chunk: &mut impl FnMut(&str) -> mlua::Result<()>,
|
||||||
|
) -> mlua::Result<()> {
|
||||||
|
if !self.buffer.is_empty() {
|
||||||
|
// Execute whatever is left in the buffer
|
||||||
|
match exec_chunk(&self.buffer) {
|
||||||
|
Ok(()) => {
|
||||||
|
// Clear the buffer as everything has been executed
|
||||||
|
self.buffer.clear();
|
||||||
|
self.last_newline_index = 0;
|
||||||
|
self.incomplete_multiline_string = false;
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use mlua::Lua;
|
||||||
|
use std::cell::RefCell;
|
||||||
|
use std::rc::Rc;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_lua_runtime_receive_chunk() {
|
||||||
|
let mut chunk_buffer = ChunkBuffer::default();
|
||||||
|
let output = Rc::new(RefCell::new(String::new()));
|
||||||
|
|
||||||
|
let mut exec_chunk = |chunk: &str| -> mlua::Result<()> {
|
||||||
|
let lua = Lua::new();
|
||||||
|
|
||||||
|
// Clone the Rc to share ownership of the same RefCell
|
||||||
|
let output_ref = output.clone();
|
||||||
|
|
||||||
|
lua.globals().set(
|
||||||
|
"print",
|
||||||
|
lua.create_function(move |_, msg: String| {
|
||||||
|
let mut output = output_ref.borrow_mut();
|
||||||
|
output.push_str(&msg);
|
||||||
|
output.push('\n');
|
||||||
|
Ok(())
|
||||||
|
})?,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
lua.load(chunk).exec()
|
||||||
|
};
|
||||||
|
|
||||||
|
exec_chunk("print('Hello, World!')").unwrap();
|
||||||
|
|
||||||
|
chunk_buffer
|
||||||
|
.receive_chunk("print('Hello, World!')", &mut exec_chunk)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(*output.borrow(), "Hello, World!\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_lua_runtime_receive_chunk_shared_lua() {
|
||||||
|
let mut chunk_buffer = ChunkBuffer::default();
|
||||||
|
let output = Rc::new(RefCell::new(String::new()));
|
||||||
|
let lua = Lua::new();
|
||||||
|
|
||||||
|
// Set up the print function once for the shared Lua instance
|
||||||
|
{
|
||||||
|
let output_ref = output.clone();
|
||||||
|
lua.globals()
|
||||||
|
.set(
|
||||||
|
"print",
|
||||||
|
lua.create_function(move |_, msg: String| {
|
||||||
|
let mut output = output_ref.borrow_mut();
|
||||||
|
output.push_str(&msg);
|
||||||
|
output.push('\n');
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut exec_chunk = |chunk: &str| -> mlua::Result<()> { lua.load(chunk).exec() };
|
||||||
|
|
||||||
|
// Send first incomplete chunk
|
||||||
|
chunk_buffer
|
||||||
|
.receive_chunk("local message = 'Hello, '\n", &mut exec_chunk)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Send second chunk that completes the code
|
||||||
|
chunk_buffer
|
||||||
|
.receive_chunk(
|
||||||
|
"message = message .. 'World!'\nprint(message)",
|
||||||
|
&mut exec_chunk,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
chunk_buffer.finish(&mut exec_chunk).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(*output.borrow(), "Hello, World!\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_multiline_string_across_chunks() {
|
||||||
|
let mut chunk_buffer = ChunkBuffer::default();
|
||||||
|
let output = Rc::new(RefCell::new(String::new()));
|
||||||
|
let lua = Lua::new();
|
||||||
|
|
||||||
|
// Set up the print function for the shared Lua instance
|
||||||
|
{
|
||||||
|
let output_ref = output.clone();
|
||||||
|
lua.globals()
|
||||||
|
.set(
|
||||||
|
"print",
|
||||||
|
lua.create_function(move |_, msg: String| {
|
||||||
|
let mut output = output_ref.borrow_mut();
|
||||||
|
output.push_str(&msg);
|
||||||
|
output.push('\n');
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut exec_chunk = |chunk: &str| -> mlua::Result<()> { lua.load(chunk).exec() };
|
||||||
|
|
||||||
|
// Send first chunk with the beginning of a multiline string
|
||||||
|
chunk_buffer
|
||||||
|
.receive_chunk("local multiline = [[This is the start\n", &mut exec_chunk)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Send second chunk with more lines
|
||||||
|
chunk_buffer
|
||||||
|
.receive_chunk("of a very long\nmultiline string\n", &mut exec_chunk)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Send third chunk with more content
|
||||||
|
chunk_buffer
|
||||||
|
.receive_chunk("that spans across\n", &mut exec_chunk)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Send final chunk that completes the multiline string
|
||||||
|
chunk_buffer
|
||||||
|
.receive_chunk("multiple chunks]]\nprint(multiline)", &mut exec_chunk)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
chunk_buffer.finish(&mut exec_chunk).unwrap();
|
||||||
|
|
||||||
|
let expected = "This is the start\nof a very long\nmultiline string\nthat spans across\nmultiple chunks\n";
|
||||||
|
assert_eq!(*output.borrow(), expected);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user