Compare commits

...

3 Commits

Author SHA1 Message Date
Richard Feldman
44501581ee Start on streaming JSON 2025-03-07 10:14:16 -05:00
Richard Feldman
ae95142cc8 Got basic chunk streaming working 2025-03-07 00:23:37 -05:00
Richard Feldman
b1b8d596b9 Use full_moon for lexing 2025-03-06 21:53:18 -05:00
5 changed files with 1166 additions and 671 deletions

1411
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -22,3 +22,6 @@ serde.workspace = true
serde_json.workspace = true
workspace.workspace = true
regex.workspace = true
[dev-dependencies]
editor = { workspace = true, features = ["test-support"] }

View File

@@ -1,3 +1,6 @@
mod streaming_json;
mod streaming_lua;
use anyhow::anyhow;
use assistant_tool::{Tool, ToolRegistry};
use gpui::{App, AppContext as _, Task, WeakEntity, Window};

View File

@@ -0,0 +1,152 @@
/// This module works with streaming_lua to allow us to run fragments of
/// Lua scripts that come back from LLM JSON tool calls immediately as they arrive,
/// even when the full script (and the full JSON) has not been received yet.
pub fn from_json(json_str: &str) {
// The JSON structure we're looking for is very simple:
// 1. Open curly bracket
// 2. Optional whitespace
// 3. Quoted key - either "lua_script" or "description" (if description, just parse it)
// 4. Colon
// 5. Optional whitespace
// 6. Open quote
// 7. Now we start streaming until we see a closed quote
// TODO all of this needs to be stored in state in a struct instead of in variables,
// and that includes the iterator part.
let mut chars = json_str.trim_start().chars().peekable();
// Skip the opening curly brace
if chars.next() != Some('{') {
return;
}
let key = parse_key(&mut chars);
if key.map(|k| k.as_str()) == Some("description") {
// TODO parse the description here
parse_comma_then_quote(&mut chars);
if parse_key(&mut chars).map(|k| k.as_str()) != Some("lua_script") {
return; // This was the only remaining valid option.
}
// TODO parse the script here, remembering to s/backslash//g to unescape everything.
} else if key.map(|k| k.as_str()) == Some("lua_script") {
// TODO parse the script here, remembering to s/backslash//g to unescape everything.
parse_comma_then_quote(&mut chars);
if parse_key(&mut chars).map(|k| k.as_str()) != Some("description") {
return; // This was the only remaining valid option.
}
// TODO parse the description here
} else {
// The key wasn't one of the two valid options.
return;
}
// Parse value
let mut value = String::new();
let mut escape_next = false;
while let Some(c) = chars.next() {
if escape_next {
value.push(match c {
'n' => '\n',
't' => '\t',
'r' => '\r',
'\\' => '\\',
'"' => '"',
_ => c,
});
escape_next = false;
} else if c == '\\' {
escape_next = true;
} else if c == '"' {
break; // End of value
} else {
value.push(c);
}
}
// Process the parsed key-value pair
match key.as_str() {
"lua_script" => {
// Handle the lua script
println!("Found lua script: {}", value);
}
"description" => {
// Handle the description
println!("Found description: {}", value);
}
_ => {} // Should not reach here due to earlier check
}
}
fn parse_key(chars: &mut impl Iterator<Item = char>) -> Option<String> {
// Skip whitespace until we reach the start of the key
while let Some(c) = chars.next() {
if c.is_whitespace() {
// Consume the whitespace and continue
} else if c == '"' {
break; // Found the start of the key
} else {
return None; // Invalid format - expected a quote to start the key
}
}
// Parse the key. We don't need to escape backslashes because the exact key
// we expect does not include backslashes or quotes.
let mut key = String::new();
while let Some(c) = chars.next() {
if c == '"' {
break; // End of key
}
key.push(c);
}
// Skip colon and whitespace and next opening quote.
let mut found_colon = false;
while let Some(c) = chars.next() {
if c == ':' {
found_colon = true;
} else if found_colon && !c.is_whitespace() {
if c == '"' {
break; // Found the opening quote
}
return None; // Invalid format - expected a quote after colon and whitespace
} else if !c.is_whitespace() {
return None; // Invalid format - expected whitespace or colon
}
}
Some(key)
}
fn parse_comma_then_quote(chars: &mut impl Iterator<Item = char>) -> bool {
// Skip any whitespace
while let Some(&c) = chars.peek() {
if !c.is_whitespace() {
break;
}
chars.next();
}
// Check for comma
if chars.next() != Some(',') {
return false;
}
// Skip any whitespace after the comma
while let Some(&c) = chars.peek() {
if !c.is_whitespace() {
break;
}
chars.next();
}
// Check for opening quote
if chars.next() != Some('"') {
return false;
}
true
}

View File

@@ -0,0 +1,268 @@
/// This module accepts fragments of Lua code from LLM responses, and executes
/// them as they come in (to the extent possible) rather than having to wait
/// for the entire script to arrive to execute it. (Since these are tool calls,
/// they will presumably come back in JSON; it's up to the caller to deal with
/// parsing the JSON, escaping `\\` and `\"` in the JSON-quoted Lua, etc.)
///
/// By design, Lua does not preserve top-level locals across chunks ("chunk" is a
/// Lua term for a chunk of Lua code that can be executed), and chunks are the
/// smallest unit of execution you can run in Lua. To make sure that top-level
/// locals the LLM writes are preserved across multiple silently translates
/// locals to globals. This should be harmless for our use case, because we only
/// have a single "file" and not multiple files where the distinction could matter.
///
/// Since fragments will invariably arrive that don't happen to correspond to valid
/// Lua chunks (e.g. maybe they have an opening quote for a string literal and the
/// close quote will be coming in the next fragment), we use a simple heuristic to
/// split them up: we take each fragment and split it into lines, and then whenever
/// we have a complete line, we send it to Lua to process as a chunk. If it comes back
/// with a syntax error due to it being incomplete (which mlua tells us), then we
/// know to keep waiting for more lines and try again.
///
/// Eventually we'll either succeed, or else the response will end and we'll know it
/// had an actual syntax error. (Again, it's the caller's responsibility to deal
/// with detecting when the response ends due to the JSON quote having finally closed.)
///
/// This heuristic relies on the assumption that the LLM is generating normal-looking
/// Lua code where statements are split using newlines rather than semicolons.
/// In practice, this is a safe assumption.
#[derive(Default)]
struct ChunkBuffer {
buffer: String,
incomplete_multiline_string: bool,
last_newline_index: usize,
}
impl ChunkBuffer {
pub fn receive_chunk(
&mut self,
src_chunk: &str,
exec_chunk: &mut impl FnMut(&str) -> mlua::Result<()>,
) -> mlua::Result<()> {
self.buffer.push_str(src_chunk);
// Execute each line until we hit an incomplete parse
while let Some(index) = &self.buffer[self.last_newline_index..].find('\n') {
let mut index = *index;
// LLMs can produce incredibly long multiline strings. We don't want to keep
// attempting to re-parse those every time a new line of the string comes in.
// that would be extremely wasteful! Instead, just keep waiting until it ends.
{
let line = &self.buffer[self.last_newline_index..index];
const LOCAL_PREFIX: &str = "local ";
// It's safe to assume we'll never see a line which
// includes both "]]" and "[[" other than single-line
// assignments which are just using them to escape quotes.
//
// If that assumption turns out not to hold, we can always
// make this more robust.
if line.contains("[[") && !line.contains("]]") {
self.incomplete_multiline_string = true;
}
// In practice, LLMs produce multiline strings that always end
// with the ]] at the start of the line.
if line.starts_with("]]") {
self.incomplete_multiline_string = false;
} else if line.starts_with("local ") {
// We can't have top-level locals because they don't preserve
// across chunk executions. So just turn locals into globals.
// Since this is just one script, they're the same anyway.
self.buffer
.replace_range(self.last_newline_index..LOCAL_PREFIX.len(), "");
index -= LOCAL_PREFIX.len();
}
}
self.last_newline_index = index;
if self.incomplete_multiline_string {
continue;
}
// Execute all lines up to (and including) this one.
match exec_chunk(&self.buffer[..index]) {
Ok(()) => {
// The chunk executed successfully. Advance the buffer
// to reflect the fact that we've executed that code.
self.buffer = self.buffer[index + 1..].to_string();
self.last_newline_index = 0;
}
Err(mlua::Error::SyntaxError {
incomplete_input: true,
message: _,
}) => {
// If it errored specifically because the input was incomplete, no problem.
// We'll keep trying with more and more lines until eventually we find a
// sequence of lines that are valid together!
}
Err(other) => {
return Err(other);
}
}
}
Ok(())
}
pub fn finish(
&mut self,
exec_chunk: &mut impl FnMut(&str) -> mlua::Result<()>,
) -> mlua::Result<()> {
if !self.buffer.is_empty() {
// Execute whatever is left in the buffer
match exec_chunk(&self.buffer) {
Ok(()) => {
// Clear the buffer as everything has been executed
self.buffer.clear();
self.last_newline_index = 0;
self.incomplete_multiline_string = false;
}
Err(err) => {
return Err(err);
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use mlua::Lua;
use std::cell::RefCell;
use std::rc::Rc;
#[test]
fn test_lua_runtime_receive_chunk() {
let mut chunk_buffer = ChunkBuffer::default();
let output = Rc::new(RefCell::new(String::new()));
let mut exec_chunk = |chunk: &str| -> mlua::Result<()> {
let lua = Lua::new();
// Clone the Rc to share ownership of the same RefCell
let output_ref = output.clone();
lua.globals().set(
"print",
lua.create_function(move |_, msg: String| {
let mut output = output_ref.borrow_mut();
output.push_str(&msg);
output.push('\n');
Ok(())
})?,
)?;
lua.load(chunk).exec()
};
exec_chunk("print('Hello, World!')").unwrap();
chunk_buffer
.receive_chunk("print('Hello, World!')", &mut exec_chunk)
.unwrap();
assert_eq!(*output.borrow(), "Hello, World!\n");
}
#[test]
fn test_lua_runtime_receive_chunk_shared_lua() {
let mut chunk_buffer = ChunkBuffer::default();
let output = Rc::new(RefCell::new(String::new()));
let lua = Lua::new();
// Set up the print function once for the shared Lua instance
{
let output_ref = output.clone();
lua.globals()
.set(
"print",
lua.create_function(move |_, msg: String| {
let mut output = output_ref.borrow_mut();
output.push_str(&msg);
output.push('\n');
Ok(())
})
.unwrap(),
)
.unwrap();
}
let mut exec_chunk = |chunk: &str| -> mlua::Result<()> { lua.load(chunk).exec() };
// Send first incomplete chunk
chunk_buffer
.receive_chunk("local message = 'Hello, '\n", &mut exec_chunk)
.unwrap();
// Send second chunk that completes the code
chunk_buffer
.receive_chunk(
"message = message .. 'World!'\nprint(message)",
&mut exec_chunk,
)
.unwrap();
chunk_buffer.finish(&mut exec_chunk).unwrap();
assert_eq!(*output.borrow(), "Hello, World!\n");
}
#[test]
fn test_multiline_string_across_chunks() {
let mut chunk_buffer = ChunkBuffer::default();
let output = Rc::new(RefCell::new(String::new()));
let lua = Lua::new();
// Set up the print function for the shared Lua instance
{
let output_ref = output.clone();
lua.globals()
.set(
"print",
lua.create_function(move |_, msg: String| {
let mut output = output_ref.borrow_mut();
output.push_str(&msg);
output.push('\n');
Ok(())
})
.unwrap(),
)
.unwrap();
}
let mut exec_chunk = |chunk: &str| -> mlua::Result<()> { lua.load(chunk).exec() };
// Send first chunk with the beginning of a multiline string
chunk_buffer
.receive_chunk("local multiline = [[This is the start\n", &mut exec_chunk)
.unwrap();
// Send second chunk with more lines
chunk_buffer
.receive_chunk("of a very long\nmultiline string\n", &mut exec_chunk)
.unwrap();
// Send third chunk with more content
chunk_buffer
.receive_chunk("that spans across\n", &mut exec_chunk)
.unwrap();
// Send final chunk that completes the multiline string
chunk_buffer
.receive_chunk("multiple chunks]]\nprint(multiline)", &mut exec_chunk)
.unwrap();
chunk_buffer.finish(&mut exec_chunk).unwrap();
let expected = "This is the start\nof a very long\nmultiline string\nthat spans across\nmultiple chunks\n";
assert_eq!(*output.borrow(), expected);
}
}