Compare commits
3 Commits
debug-shel
...
streaming-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
44501581ee | ||
|
|
ae95142cc8 | ||
|
|
b1b8d596b9 |
1411
Cargo.lock
generated
1411
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -22,3 +22,6 @@ serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
workspace.workspace = true
|
||||
regex.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
editor = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
mod streaming_json;
|
||||
mod streaming_lua;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use assistant_tool::{Tool, ToolRegistry};
|
||||
use gpui::{App, AppContext as _, Task, WeakEntity, Window};
|
||||
|
||||
152
crates/scripting_tool/src/streaming_json.rs
Normal file
152
crates/scripting_tool/src/streaming_json.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
/// This module works with streaming_lua to allow us to run fragments of
|
||||
/// Lua scripts that come back from LLM JSON tool calls immediately as they arrive,
|
||||
/// even when the full script (and the full JSON) has not been received yet.
|
||||
|
||||
pub fn from_json(json_str: &str) {
|
||||
// The JSON structure we're looking for is very simple:
|
||||
// 1. Open curly bracket
|
||||
// 2. Optional whitespace
|
||||
// 3. Quoted key - either "lua_script" or "description" (if description, just parse it)
|
||||
// 4. Colon
|
||||
// 5. Optional whitespace
|
||||
// 6. Open quote
|
||||
// 7. Now we start streaming until we see a closed quote
|
||||
|
||||
// TODO all of this needs to be stored in state in a struct instead of in variables,
|
||||
// and that includes the iterator part.
|
||||
let mut chars = json_str.trim_start().chars().peekable();
|
||||
|
||||
// Skip the opening curly brace
|
||||
if chars.next() != Some('{') {
|
||||
return;
|
||||
}
|
||||
|
||||
let key = parse_key(&mut chars);
|
||||
|
||||
if key.map(|k| k.as_str()) == Some("description") {
|
||||
// TODO parse the description here
|
||||
parse_comma_then_quote(&mut chars);
|
||||
if parse_key(&mut chars).map(|k| k.as_str()) != Some("lua_script") {
|
||||
return; // This was the only remaining valid option.
|
||||
}
|
||||
// TODO parse the script here, remembering to s/backslash//g to unescape everything.
|
||||
} else if key.map(|k| k.as_str()) == Some("lua_script") {
|
||||
// TODO parse the script here, remembering to s/backslash//g to unescape everything.
|
||||
parse_comma_then_quote(&mut chars);
|
||||
if parse_key(&mut chars).map(|k| k.as_str()) != Some("description") {
|
||||
return; // This was the only remaining valid option.
|
||||
}
|
||||
// TODO parse the description here
|
||||
} else {
|
||||
// The key wasn't one of the two valid options.
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse value
|
||||
let mut value = String::new();
|
||||
let mut escape_next = false;
|
||||
|
||||
while let Some(c) = chars.next() {
|
||||
if escape_next {
|
||||
value.push(match c {
|
||||
'n' => '\n',
|
||||
't' => '\t',
|
||||
'r' => '\r',
|
||||
'\\' => '\\',
|
||||
'"' => '"',
|
||||
_ => c,
|
||||
});
|
||||
escape_next = false;
|
||||
} else if c == '\\' {
|
||||
escape_next = true;
|
||||
} else if c == '"' {
|
||||
break; // End of value
|
||||
} else {
|
||||
value.push(c);
|
||||
}
|
||||
}
|
||||
|
||||
// Process the parsed key-value pair
|
||||
match key.as_str() {
|
||||
"lua_script" => {
|
||||
// Handle the lua script
|
||||
println!("Found lua script: {}", value);
|
||||
}
|
||||
"description" => {
|
||||
// Handle the description
|
||||
println!("Found description: {}", value);
|
||||
}
|
||||
_ => {} // Should not reach here due to earlier check
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_key(chars: &mut impl Iterator<Item = char>) -> Option<String> {
|
||||
// Skip whitespace until we reach the start of the key
|
||||
while let Some(c) = chars.next() {
|
||||
if c.is_whitespace() {
|
||||
// Consume the whitespace and continue
|
||||
} else if c == '"' {
|
||||
break; // Found the start of the key
|
||||
} else {
|
||||
return None; // Invalid format - expected a quote to start the key
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the key. We don't need to escape backslashes because the exact key
|
||||
// we expect does not include backslashes or quotes.
|
||||
let mut key = String::new();
|
||||
|
||||
while let Some(c) = chars.next() {
|
||||
if c == '"' {
|
||||
break; // End of key
|
||||
}
|
||||
key.push(c);
|
||||
}
|
||||
|
||||
// Skip colon and whitespace and next opening quote.
|
||||
let mut found_colon = false;
|
||||
while let Some(c) = chars.next() {
|
||||
if c == ':' {
|
||||
found_colon = true;
|
||||
} else if found_colon && !c.is_whitespace() {
|
||||
if c == '"' {
|
||||
break; // Found the opening quote
|
||||
}
|
||||
return None; // Invalid format - expected a quote after colon and whitespace
|
||||
} else if !c.is_whitespace() {
|
||||
return None; // Invalid format - expected whitespace or colon
|
||||
}
|
||||
}
|
||||
|
||||
Some(key)
|
||||
}
|
||||
|
||||
fn parse_comma_then_quote(chars: &mut impl Iterator<Item = char>) -> bool {
|
||||
// Skip any whitespace
|
||||
while let Some(&c) = chars.peek() {
|
||||
if !c.is_whitespace() {
|
||||
break;
|
||||
}
|
||||
chars.next();
|
||||
}
|
||||
|
||||
// Check for comma
|
||||
if chars.next() != Some(',') {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Skip any whitespace after the comma
|
||||
while let Some(&c) = chars.peek() {
|
||||
if !c.is_whitespace() {
|
||||
break;
|
||||
}
|
||||
chars.next();
|
||||
}
|
||||
|
||||
// Check for opening quote
|
||||
if chars.next() != Some('"') {
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
268
crates/scripting_tool/src/streaming_lua.rs
Normal file
268
crates/scripting_tool/src/streaming_lua.rs
Normal file
@@ -0,0 +1,268 @@
|
||||
/// This module accepts fragments of Lua code from LLM responses, and executes
|
||||
/// them as they come in (to the extent possible) rather than having to wait
|
||||
/// for the entire script to arrive to execute it. (Since these are tool calls,
|
||||
/// they will presumably come back in JSON; it's up to the caller to deal with
|
||||
/// parsing the JSON, escaping `\\` and `\"` in the JSON-quoted Lua, etc.)
|
||||
///
|
||||
/// By design, Lua does not preserve top-level locals across chunks ("chunk" is a
|
||||
/// Lua term for a chunk of Lua code that can be executed), and chunks are the
|
||||
/// smallest unit of execution you can run in Lua. To make sure that top-level
|
||||
/// locals the LLM writes are preserved across multiple silently translates
|
||||
/// locals to globals. This should be harmless for our use case, because we only
|
||||
/// have a single "file" and not multiple files where the distinction could matter.
|
||||
///
|
||||
/// Since fragments will invariably arrive that don't happen to correspond to valid
|
||||
/// Lua chunks (e.g. maybe they have an opening quote for a string literal and the
|
||||
/// close quote will be coming in the next fragment), we use a simple heuristic to
|
||||
/// split them up: we take each fragment and split it into lines, and then whenever
|
||||
/// we have a complete line, we send it to Lua to process as a chunk. If it comes back
|
||||
/// with a syntax error due to it being incomplete (which mlua tells us), then we
|
||||
/// know to keep waiting for more lines and try again.
|
||||
///
|
||||
/// Eventually we'll either succeed, or else the response will end and we'll know it
|
||||
/// had an actual syntax error. (Again, it's the caller's responsibility to deal
|
||||
/// with detecting when the response ends due to the JSON quote having finally closed.)
|
||||
///
|
||||
/// This heuristic relies on the assumption that the LLM is generating normal-looking
|
||||
/// Lua code where statements are split using newlines rather than semicolons.
|
||||
/// In practice, this is a safe assumption.
|
||||
|
||||
#[derive(Default)]
|
||||
struct ChunkBuffer {
|
||||
buffer: String,
|
||||
incomplete_multiline_string: bool,
|
||||
last_newline_index: usize,
|
||||
}
|
||||
|
||||
impl ChunkBuffer {
|
||||
pub fn receive_chunk(
|
||||
&mut self,
|
||||
src_chunk: &str,
|
||||
exec_chunk: &mut impl FnMut(&str) -> mlua::Result<()>,
|
||||
) -> mlua::Result<()> {
|
||||
self.buffer.push_str(src_chunk);
|
||||
|
||||
// Execute each line until we hit an incomplete parse
|
||||
while let Some(index) = &self.buffer[self.last_newline_index..].find('\n') {
|
||||
let mut index = *index;
|
||||
|
||||
// LLMs can produce incredibly long multiline strings. We don't want to keep
|
||||
// attempting to re-parse those every time a new line of the string comes in.
|
||||
// that would be extremely wasteful! Instead, just keep waiting until it ends.
|
||||
{
|
||||
let line = &self.buffer[self.last_newline_index..index];
|
||||
|
||||
const LOCAL_PREFIX: &str = "local ";
|
||||
|
||||
// It's safe to assume we'll never see a line which
|
||||
// includes both "]]" and "[[" other than single-line
|
||||
// assignments which are just using them to escape quotes.
|
||||
//
|
||||
// If that assumption turns out not to hold, we can always
|
||||
// make this more robust.
|
||||
if line.contains("[[") && !line.contains("]]") {
|
||||
self.incomplete_multiline_string = true;
|
||||
}
|
||||
|
||||
// In practice, LLMs produce multiline strings that always end
|
||||
// with the ]] at the start of the line.
|
||||
if line.starts_with("]]") {
|
||||
self.incomplete_multiline_string = false;
|
||||
} else if line.starts_with("local ") {
|
||||
// We can't have top-level locals because they don't preserve
|
||||
// across chunk executions. So just turn locals into globals.
|
||||
// Since this is just one script, they're the same anyway.
|
||||
self.buffer
|
||||
.replace_range(self.last_newline_index..LOCAL_PREFIX.len(), "");
|
||||
|
||||
index -= LOCAL_PREFIX.len();
|
||||
}
|
||||
}
|
||||
|
||||
self.last_newline_index = index;
|
||||
|
||||
if self.incomplete_multiline_string {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Execute all lines up to (and including) this one.
|
||||
match exec_chunk(&self.buffer[..index]) {
|
||||
Ok(()) => {
|
||||
// The chunk executed successfully. Advance the buffer
|
||||
// to reflect the fact that we've executed that code.
|
||||
self.buffer = self.buffer[index + 1..].to_string();
|
||||
self.last_newline_index = 0;
|
||||
}
|
||||
Err(mlua::Error::SyntaxError {
|
||||
incomplete_input: true,
|
||||
message: _,
|
||||
}) => {
|
||||
// If it errored specifically because the input was incomplete, no problem.
|
||||
// We'll keep trying with more and more lines until eventually we find a
|
||||
// sequence of lines that are valid together!
|
||||
}
|
||||
Err(other) => {
|
||||
return Err(other);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn finish(
|
||||
&mut self,
|
||||
exec_chunk: &mut impl FnMut(&str) -> mlua::Result<()>,
|
||||
) -> mlua::Result<()> {
|
||||
if !self.buffer.is_empty() {
|
||||
// Execute whatever is left in the buffer
|
||||
match exec_chunk(&self.buffer) {
|
||||
Ok(()) => {
|
||||
// Clear the buffer as everything has been executed
|
||||
self.buffer.clear();
|
||||
self.last_newline_index = 0;
|
||||
self.incomplete_multiline_string = false;
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use mlua::Lua;
|
||||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
|
||||
#[test]
|
||||
fn test_lua_runtime_receive_chunk() {
|
||||
let mut chunk_buffer = ChunkBuffer::default();
|
||||
let output = Rc::new(RefCell::new(String::new()));
|
||||
|
||||
let mut exec_chunk = |chunk: &str| -> mlua::Result<()> {
|
||||
let lua = Lua::new();
|
||||
|
||||
// Clone the Rc to share ownership of the same RefCell
|
||||
let output_ref = output.clone();
|
||||
|
||||
lua.globals().set(
|
||||
"print",
|
||||
lua.create_function(move |_, msg: String| {
|
||||
let mut output = output_ref.borrow_mut();
|
||||
output.push_str(&msg);
|
||||
output.push('\n');
|
||||
Ok(())
|
||||
})?,
|
||||
)?;
|
||||
|
||||
lua.load(chunk).exec()
|
||||
};
|
||||
|
||||
exec_chunk("print('Hello, World!')").unwrap();
|
||||
|
||||
chunk_buffer
|
||||
.receive_chunk("print('Hello, World!')", &mut exec_chunk)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(*output.borrow(), "Hello, World!\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lua_runtime_receive_chunk_shared_lua() {
|
||||
let mut chunk_buffer = ChunkBuffer::default();
|
||||
let output = Rc::new(RefCell::new(String::new()));
|
||||
let lua = Lua::new();
|
||||
|
||||
// Set up the print function once for the shared Lua instance
|
||||
{
|
||||
let output_ref = output.clone();
|
||||
lua.globals()
|
||||
.set(
|
||||
"print",
|
||||
lua.create_function(move |_, msg: String| {
|
||||
let mut output = output_ref.borrow_mut();
|
||||
output.push_str(&msg);
|
||||
output.push('\n');
|
||||
Ok(())
|
||||
})
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let mut exec_chunk = |chunk: &str| -> mlua::Result<()> { lua.load(chunk).exec() };
|
||||
|
||||
// Send first incomplete chunk
|
||||
chunk_buffer
|
||||
.receive_chunk("local message = 'Hello, '\n", &mut exec_chunk)
|
||||
.unwrap();
|
||||
|
||||
// Send second chunk that completes the code
|
||||
chunk_buffer
|
||||
.receive_chunk(
|
||||
"message = message .. 'World!'\nprint(message)",
|
||||
&mut exec_chunk,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
chunk_buffer.finish(&mut exec_chunk).unwrap();
|
||||
|
||||
assert_eq!(*output.borrow(), "Hello, World!\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiline_string_across_chunks() {
|
||||
let mut chunk_buffer = ChunkBuffer::default();
|
||||
let output = Rc::new(RefCell::new(String::new()));
|
||||
let lua = Lua::new();
|
||||
|
||||
// Set up the print function for the shared Lua instance
|
||||
{
|
||||
let output_ref = output.clone();
|
||||
lua.globals()
|
||||
.set(
|
||||
"print",
|
||||
lua.create_function(move |_, msg: String| {
|
||||
let mut output = output_ref.borrow_mut();
|
||||
output.push_str(&msg);
|
||||
output.push('\n');
|
||||
Ok(())
|
||||
})
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let mut exec_chunk = |chunk: &str| -> mlua::Result<()> { lua.load(chunk).exec() };
|
||||
|
||||
// Send first chunk with the beginning of a multiline string
|
||||
chunk_buffer
|
||||
.receive_chunk("local multiline = [[This is the start\n", &mut exec_chunk)
|
||||
.unwrap();
|
||||
|
||||
// Send second chunk with more lines
|
||||
chunk_buffer
|
||||
.receive_chunk("of a very long\nmultiline string\n", &mut exec_chunk)
|
||||
.unwrap();
|
||||
|
||||
// Send third chunk with more content
|
||||
chunk_buffer
|
||||
.receive_chunk("that spans across\n", &mut exec_chunk)
|
||||
.unwrap();
|
||||
|
||||
// Send final chunk that completes the multiline string
|
||||
chunk_buffer
|
||||
.receive_chunk("multiple chunks]]\nprint(multiline)", &mut exec_chunk)
|
||||
.unwrap();
|
||||
|
||||
chunk_buffer.finish(&mut exec_chunk).unwrap();
|
||||
|
||||
let expected = "This is the start\nof a very long\nmultiline string\nthat spans across\nmultiple chunks\n";
|
||||
assert_eq!(*output.borrow(), expected);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user