Compare commits
22 Commits
main
...
ep-example
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8953b487ad | ||
|
|
196c488ed4 | ||
|
|
dfbbacec12 | ||
|
|
9161a23513 | ||
|
|
9a8ccb32ac | ||
|
|
5cfdfd32c6 | ||
|
|
defcc2f51b | ||
|
|
6ebe0edea0 | ||
|
|
1a83c0f5e4 | ||
|
|
27a6d54efe | ||
|
|
a168d8f50a | ||
|
|
e243a658a5 | ||
|
|
a93fd51f35 | ||
|
|
0dcdc6d9a4 | ||
|
|
7e09b59fa3 | ||
|
|
1e28bf8279 | ||
|
|
b6eec44a99 | ||
|
|
d83c985923 | ||
|
|
74c4e25b8c | ||
|
|
2021f32947 | ||
|
|
299ca2e8ac | ||
|
|
c284f9086b |
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -5254,6 +5254,7 @@ dependencies = [
|
||||
"text",
|
||||
"thiserror 2.0.17",
|
||||
"time",
|
||||
"toml 0.8.23",
|
||||
"ui",
|
||||
"util",
|
||||
"uuid",
|
||||
|
||||
@@ -56,6 +56,7 @@ telemetry_events.workspace = true
|
||||
text.workspace = true
|
||||
thiserror.workspace = true
|
||||
time.workspace = true
|
||||
toml.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
|
||||
@@ -7,15 +7,14 @@ use buffer_diff::BufferDiffSnapshot;
|
||||
use collections::HashMap;
|
||||
use gpui::{App, Entity, Task};
|
||||
use language::{Buffer, ToPoint as _};
|
||||
use project::Project;
|
||||
use project::{Project, WorktreeId};
|
||||
use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
|
||||
use text::{BufferSnapshot as TextBufferSnapshot, ToOffset as _};
|
||||
use text::BufferSnapshot as TextBufferSnapshot;
|
||||
|
||||
pub fn capture_example(
|
||||
project: Entity<Project>,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_anchor: language::Anchor,
|
||||
last_event_is_expected_patch: bool,
|
||||
cx: &mut App,
|
||||
) -> Option<Task<Result<ExampleSpec>>> {
|
||||
let ep_store = EditPredictionStore::try_global(cx)?;
|
||||
@@ -43,8 +42,26 @@ pub fn capture_example(
|
||||
let git_store = project.read(cx).git_store().clone();
|
||||
|
||||
Some(cx.spawn(async move |mut cx| {
|
||||
let snapshots_by_path = collect_snapshots(&project, &git_store, &events, &mut cx).await?;
|
||||
let cursor_excerpt = cx
|
||||
let snapshots_by_path =
|
||||
collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
|
||||
|
||||
events.retain(|stored_event| {
|
||||
match stored_event.event.as_ref() {
|
||||
zeta_prompt::Event::BufferChange { path, .. } => {
|
||||
if !snapshots_by_path.contains_key(path) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
true
|
||||
});
|
||||
|
||||
let line_comment_prefix = snapshot
|
||||
.language()
|
||||
.and_then(|lang| lang.config().line_comments.first())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_default();
|
||||
let (cursor_excerpt, cursor_offset) = cx
|
||||
.background_executor()
|
||||
.spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
|
||||
.await;
|
||||
@@ -54,13 +71,6 @@ pub fn capture_example(
|
||||
.await;
|
||||
|
||||
let mut edit_history = String::new();
|
||||
let mut expected_patch = String::new();
|
||||
if last_event_is_expected_patch {
|
||||
if let Some(stored_event) = events.pop() {
|
||||
zeta_prompt::write_event(&mut expected_patch, &stored_event.event);
|
||||
}
|
||||
}
|
||||
|
||||
for stored_event in &events {
|
||||
zeta_prompt::write_event(&mut edit_history, &stored_event.event);
|
||||
if !edit_history.ends_with('\n') {
|
||||
@@ -68,57 +78,62 @@ pub fn capture_example(
|
||||
}
|
||||
}
|
||||
|
||||
let name = generate_timestamp_name();
|
||||
|
||||
Ok(ExampleSpec {
|
||||
name,
|
||||
let mut spec = ExampleSpec {
|
||||
name: generate_timestamp_name(),
|
||||
repository_url,
|
||||
revision,
|
||||
tags: Vec::new(),
|
||||
reasoning: None,
|
||||
uncommitted_diff,
|
||||
cursor_path: cursor_path.as_std_path().into(),
|
||||
cursor_position: cursor_excerpt,
|
||||
cursor_position: String::new(),
|
||||
edit_history,
|
||||
expected_patch,
|
||||
})
|
||||
expected_patches: Vec::new(),
|
||||
};
|
||||
spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
|
||||
Ok(spec)
|
||||
}))
|
||||
}
|
||||
|
||||
fn compute_cursor_excerpt(
|
||||
snapshot: &language::BufferSnapshot,
|
||||
cursor_anchor: language::Anchor,
|
||||
) -> String {
|
||||
) -> (String, usize) {
|
||||
use text::ToOffset as _;
|
||||
|
||||
let cursor_point = cursor_anchor.to_point(snapshot);
|
||||
let (_editable_range, context_range) =
|
||||
editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
|
||||
|
||||
let context_start_offset = context_range.start.to_offset(snapshot);
|
||||
let cursor_offset = cursor_anchor.to_offset(snapshot);
|
||||
let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
|
||||
let mut excerpt = snapshot.text_for_range(context_range).collect::<String>();
|
||||
if cursor_offset_in_excerpt <= excerpt.len() {
|
||||
excerpt.insert_str(cursor_offset_in_excerpt, zeta_prompt::CURSOR_MARKER);
|
||||
}
|
||||
excerpt
|
||||
let excerpt = snapshot.text_for_range(context_range).collect::<String>();
|
||||
(excerpt, cursor_offset_in_excerpt)
|
||||
}
|
||||
|
||||
async fn collect_snapshots(
|
||||
project: &Entity<Project>,
|
||||
git_store: &Entity<project::git_store::GitStore>,
|
||||
worktree_id: WorktreeId,
|
||||
events: &[StoredEvent],
|
||||
cx: &mut gpui::AsyncApp,
|
||||
) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
|
||||
let mut snapshots_by_path = HashMap::default();
|
||||
let root_name = project.read_with(cx, |project, cx| {
|
||||
project
|
||||
.worktree_for_id(worktree_id, cx)
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.root_name()
|
||||
.to_owned()
|
||||
})?;
|
||||
for stored_event in events {
|
||||
let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
|
||||
if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
|
||||
let project_path = project.find_project_path(path, cx)?;
|
||||
let full_path = project
|
||||
.worktree_for_id(project_path.worktree_id, cx)?
|
||||
.read(cx)
|
||||
.root_name()
|
||||
.join(&project_path.path)
|
||||
.as_std_path()
|
||||
.into();
|
||||
let project_path = project
|
||||
.find_project_path(path, cx)
|
||||
.filter(|path| path.worktree_id == worktree_id)?;
|
||||
let full_path = root_name.join(&project_path.path).as_std_path().into();
|
||||
Some((project_path, full_path))
|
||||
})? {
|
||||
if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
|
||||
@@ -289,9 +304,7 @@ mod tests {
|
||||
cx.run_until_parked();
|
||||
|
||||
let mut example = cx
|
||||
.update(|cx| {
|
||||
capture_example(project.clone(), buffer.clone(), Anchor::MIN, false, cx).unwrap()
|
||||
})
|
||||
.update(|cx| capture_example(project.clone(), buffer.clone(), Anchor::MIN, cx).unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
example.name = "test".to_string();
|
||||
@@ -302,6 +315,8 @@ mod tests {
|
||||
name: "test".to_string(),
|
||||
repository_url: "https://github.com/test/repo.git".to_string(),
|
||||
revision: "abc123def456".to_string(),
|
||||
tags: Vec::new(),
|
||||
reasoning: None,
|
||||
uncommitted_diff: indoc! {"
|
||||
--- a/project/src/main.rs
|
||||
+++ b/project/src/main.rs
|
||||
@@ -322,7 +337,8 @@ mod tests {
|
||||
.to_string(),
|
||||
cursor_path: Path::new("project/src/main.rs").into(),
|
||||
cursor_position: indoc! {"
|
||||
<|user_cursor|>fn main() {
|
||||
fn main() {
|
||||
^[CURSOR_POSITION]
|
||||
// comment 1
|
||||
one();
|
||||
two();
|
||||
@@ -355,7 +371,7 @@ mod tests {
|
||||
seven();
|
||||
"}
|
||||
.to_string(),
|
||||
expected_patch: "".to_string(),
|
||||
expected_patches: Vec::new()
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@@ -688,12 +688,14 @@ impl EditPredictionStore {
|
||||
pub fn clear_history(&mut self) {
|
||||
for project_state in self.projects.values_mut() {
|
||||
project_state.events.clear();
|
||||
project_state.last_event.take();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
|
||||
if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
|
||||
project_state.events.clear();
|
||||
project_state.last_event.take();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2044,7 +2046,9 @@ impl EditPredictionStore {
|
||||
"Edit Prediction Rated",
|
||||
rating,
|
||||
inputs = prediction.inputs,
|
||||
output = prediction.edit_preview.as_unified_diff(&prediction.edits),
|
||||
output = prediction
|
||||
.edit_preview
|
||||
.as_unified_diff(prediction.snapshot.file(), &prediction.edits),
|
||||
feedback
|
||||
);
|
||||
self.client.telemetry().flush_events().detach();
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fmt::Write as _, mem, path::Path, sync::Arc};
|
||||
use std::{borrow::Cow, fmt::Write as _, mem, path::Path, sync::Arc};
|
||||
|
||||
pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]";
|
||||
pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ExampleSpec {
|
||||
@@ -7,33 +11,80 @@ pub struct ExampleSpec {
|
||||
pub name: String,
|
||||
pub repository_url: String,
|
||||
pub revision: String,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub tags: Vec<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning: Option<String>,
|
||||
#[serde(default)]
|
||||
pub uncommitted_diff: String,
|
||||
pub cursor_path: Arc<Path>,
|
||||
pub cursor_position: String,
|
||||
pub edit_history: String,
|
||||
pub expected_patch: String,
|
||||
pub expected_patches: Vec<String>,
|
||||
}
|
||||
|
||||
const REASONING_HEADING: &str = "Reasoning";
|
||||
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
|
||||
const EDIT_HISTORY_HEADING: &str = "Edit History";
|
||||
const CURSOR_POSITION_HEADING: &str = "Cursor Position";
|
||||
const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
|
||||
const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
|
||||
const REPOSITORY_URL_FIELD: &str = "repository_url";
|
||||
const REVISION_FIELD: &str = "revision";
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct FrontMatter<'a> {
|
||||
repository_url: Cow<'a, str>,
|
||||
revision: Cow<'a, str>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
tags: Vec<String>,
|
||||
}
|
||||
|
||||
impl ExampleSpec {
|
||||
/// Generate a sanitized filename for this example.
|
||||
pub fn filename(&self) -> String {
|
||||
self.name
|
||||
.chars()
|
||||
.map(|c| match c {
|
||||
' ' | ':' | '~' | '^' | '?' | '*' | '[' | '\\' | '@' | '{' | '/' | '<' | '>'
|
||||
| '|' | '"' => '-',
|
||||
c => c,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Format this example spec as markdown.
|
||||
pub fn to_markdown(&self) -> String {
|
||||
use std::fmt::Write as _;
|
||||
|
||||
let front_matter = FrontMatter {
|
||||
repository_url: Cow::Borrowed(&self.repository_url),
|
||||
revision: Cow::Borrowed(&self.revision),
|
||||
tags: self.tags.clone(),
|
||||
};
|
||||
let front_matter_toml =
|
||||
toml::to_string_pretty(&front_matter).unwrap_or_else(|_| String::new());
|
||||
|
||||
let mut markdown = String::new();
|
||||
|
||||
_ = writeln!(markdown, "+++");
|
||||
markdown.push_str(&front_matter_toml);
|
||||
if !markdown.ends_with('\n') {
|
||||
markdown.push('\n');
|
||||
}
|
||||
_ = writeln!(markdown, "+++");
|
||||
markdown.push('\n');
|
||||
|
||||
_ = writeln!(markdown, "# {}", self.name);
|
||||
markdown.push('\n');
|
||||
|
||||
_ = writeln!(markdown, "repository_url = {}", self.repository_url);
|
||||
_ = writeln!(markdown, "revision = {}", self.revision);
|
||||
markdown.push('\n');
|
||||
if let Some(reasoning) = &self.reasoning {
|
||||
_ = writeln!(markdown, "## {}", REASONING_HEADING);
|
||||
markdown.push('\n');
|
||||
markdown.push_str(reasoning);
|
||||
if !markdown.ends_with('\n') {
|
||||
markdown.push('\n');
|
||||
}
|
||||
markdown.push('\n');
|
||||
}
|
||||
|
||||
if !self.uncommitted_diff.is_empty() {
|
||||
_ = writeln!(markdown, "## {}", UNCOMMITTED_DIFF_HEADING);
|
||||
@@ -75,34 +126,48 @@ impl ExampleSpec {
|
||||
|
||||
_ = writeln!(markdown, "## {}", EXPECTED_PATCH_HEADING);
|
||||
markdown.push('\n');
|
||||
_ = writeln!(markdown, "```diff");
|
||||
markdown.push_str(&self.expected_patch);
|
||||
if !markdown.ends_with('\n') {
|
||||
for patch in &self.expected_patches {
|
||||
_ = writeln!(markdown, "```diff");
|
||||
markdown.push_str(patch);
|
||||
if !markdown.ends_with('\n') {
|
||||
markdown.push('\n');
|
||||
}
|
||||
_ = writeln!(markdown, "```");
|
||||
markdown.push('\n');
|
||||
}
|
||||
_ = writeln!(markdown, "```");
|
||||
markdown.push('\n');
|
||||
|
||||
markdown
|
||||
}
|
||||
|
||||
/// Parse an example spec from markdown.
|
||||
pub fn from_markdown(name: String, input: &str) -> anyhow::Result<Self> {
|
||||
pub fn from_markdown(mut input: &str) -> anyhow::Result<Self> {
|
||||
use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
|
||||
|
||||
let parser = Parser::new(input);
|
||||
|
||||
let mut spec = ExampleSpec {
|
||||
name,
|
||||
name: String::new(),
|
||||
repository_url: String::new(),
|
||||
revision: String::new(),
|
||||
tags: Vec::new(),
|
||||
reasoning: None,
|
||||
uncommitted_diff: String::new(),
|
||||
cursor_path: Path::new("").into(),
|
||||
cursor_position: String::new(),
|
||||
edit_history: String::new(),
|
||||
expected_patch: String::new(),
|
||||
expected_patches: Vec::new(),
|
||||
};
|
||||
|
||||
if let Some(rest) = input.strip_prefix("+++\n")
|
||||
&& let Some((front_matter, rest)) = rest.split_once("+++\n")
|
||||
{
|
||||
if let Ok(data) = toml::from_str::<FrontMatter<'_>>(front_matter) {
|
||||
spec.repository_url = data.repository_url.into_owned();
|
||||
spec.revision = data.revision.into_owned();
|
||||
spec.tags = data.tags;
|
||||
}
|
||||
input = rest.trim_start();
|
||||
}
|
||||
|
||||
let parser = Parser::new(input);
|
||||
let mut text = String::new();
|
||||
let mut block_info: CowStr = "".into();
|
||||
|
||||
@@ -123,20 +188,9 @@ impl ExampleSpec {
|
||||
match event {
|
||||
Event::Text(line) => {
|
||||
text.push_str(&line);
|
||||
|
||||
if let Section::Start = current_section
|
||||
&& let Some((field, value)) = line.split_once('=')
|
||||
{
|
||||
match field.trim() {
|
||||
REPOSITORY_URL_FIELD => {
|
||||
spec.repository_url = value.trim().to_string();
|
||||
}
|
||||
REVISION_FIELD => {
|
||||
spec.revision = value.trim().to_string();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
|
||||
spec.name = mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
|
||||
let title = mem::take(&mut text);
|
||||
@@ -194,7 +248,7 @@ impl ExampleSpec {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Section::ExpectedPatch => {
|
||||
spec.expected_patch = mem::take(&mut text);
|
||||
spec.expected_patches.push(mem::take(&mut text));
|
||||
}
|
||||
Section::Start | Section::Other => {}
|
||||
}
|
||||
@@ -209,4 +263,326 @@ impl ExampleSpec {
|
||||
|
||||
Ok(spec)
|
||||
}
|
||||
|
||||
/// Returns the excerpt of text around the cursor, and the offset of the cursor within that
|
||||
/// excerpt.
|
||||
///
|
||||
/// The cursor's position is marked with a special comment that appears
|
||||
/// below the cursor line, which contains the string `[CURSOR_POSITION]`,
|
||||
/// preceded by an arrow marking the cursor's column. The arrow can be
|
||||
/// either:
|
||||
/// - `^` - The cursor column is at the position of the `^` character (pointing up to the cursor)
|
||||
/// - `<` - The cursor column is at the first non-whitespace character on that line.
|
||||
pub fn cursor_excerpt(&self) -> Result<(String, usize)> {
|
||||
let input = &self.cursor_position;
|
||||
|
||||
// Check for inline cursor marker first
|
||||
if let Some(inline_offset) = input.find(INLINE_CURSOR_MARKER) {
|
||||
let excerpt = input[..inline_offset].to_string()
|
||||
+ &input[inline_offset + INLINE_CURSOR_MARKER.len()..];
|
||||
return Ok((excerpt, inline_offset));
|
||||
}
|
||||
|
||||
let marker_offset = input
|
||||
.find(CURSOR_POSITION_MARKER)
|
||||
.context("missing [CURSOR_POSITION] marker")?;
|
||||
let marker_line_start = input[..marker_offset]
|
||||
.rfind('\n')
|
||||
.map(|pos| pos + 1)
|
||||
.unwrap_or(0);
|
||||
let marker_line_end = input[marker_line_start..]
|
||||
.find('\n')
|
||||
.map(|pos| marker_line_start + pos + 1)
|
||||
.unwrap_or(input.len());
|
||||
let marker_line = &input[marker_line_start..marker_line_end].trim_end_matches('\n');
|
||||
|
||||
let cursor_column = if let Some(cursor_offset) = marker_line.find('^') {
|
||||
cursor_offset
|
||||
} else if let Some(less_than_pos) = marker_line.find('<') {
|
||||
marker_line
|
||||
.find(|c: char| !c.is_whitespace())
|
||||
.unwrap_or(less_than_pos)
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"cursor position marker line must contain '^' or '<' before [CURSOR_POSITION]"
|
||||
);
|
||||
};
|
||||
|
||||
let mut excerpt = input[..marker_line_start].to_string() + &input[marker_line_end..];
|
||||
excerpt.truncate(excerpt.trim_end_matches('\n').len());
|
||||
|
||||
// The cursor is on the line above the marker line.
|
||||
let cursor_line_end = marker_line_start.saturating_sub(1);
|
||||
let cursor_line_start = excerpt[..cursor_line_end]
|
||||
.rfind('\n')
|
||||
.map(|pos| pos + 1)
|
||||
.unwrap_or(0);
|
||||
let cursor_offset = cursor_line_start + cursor_column;
|
||||
|
||||
Ok((excerpt, cursor_offset))
|
||||
}
|
||||
|
||||
/// Sets the cursor position excerpt from a plain excerpt and cursor byte offset.
|
||||
///
|
||||
/// The `line_comment_prefix` is used to format the marker line as a comment.
|
||||
/// If the cursor column is less than the comment prefix length, the `<` format is used.
|
||||
/// Otherwise, the `^` format is used.
|
||||
pub fn set_cursor_excerpt(
|
||||
&mut self,
|
||||
excerpt: &str,
|
||||
cursor_offset: usize,
|
||||
line_comment_prefix: &str,
|
||||
) {
|
||||
// Find which line the cursor is on and its column
|
||||
let cursor_line_start = excerpt[..cursor_offset]
|
||||
.rfind('\n')
|
||||
.map(|pos| pos + 1)
|
||||
.unwrap_or(0);
|
||||
let cursor_line_end = excerpt[cursor_line_start..]
|
||||
.find('\n')
|
||||
.map(|pos| cursor_line_start + pos + 1)
|
||||
.unwrap_or(excerpt.len());
|
||||
let cursor_line = &excerpt[cursor_line_start..cursor_line_end];
|
||||
let cursor_line_indent = &cursor_line[..cursor_line.len() - cursor_line.trim_start().len()];
|
||||
let cursor_column = cursor_offset - cursor_line_start;
|
||||
|
||||
// Build the marker line
|
||||
let mut marker_line = String::new();
|
||||
if cursor_column < line_comment_prefix.len() {
|
||||
for _ in 0..cursor_column {
|
||||
marker_line.push(' ');
|
||||
}
|
||||
marker_line.push_str(line_comment_prefix);
|
||||
write!(marker_line, " <{}", CURSOR_POSITION_MARKER).unwrap();
|
||||
} else {
|
||||
if cursor_column >= cursor_line_indent.len() + line_comment_prefix.len() {
|
||||
marker_line.push_str(cursor_line_indent);
|
||||
}
|
||||
marker_line.push_str(line_comment_prefix);
|
||||
while marker_line.len() < cursor_column {
|
||||
marker_line.push(' ');
|
||||
}
|
||||
write!(marker_line, "^{}", CURSOR_POSITION_MARKER).unwrap();
|
||||
}
|
||||
|
||||
// Build the final cursor_position string
|
||||
let mut result = String::with_capacity(excerpt.len() + marker_line.len() + 2);
|
||||
result.push_str(&excerpt[..cursor_line_end]);
|
||||
if !result.ends_with('\n') {
|
||||
result.push('\n');
|
||||
}
|
||||
result.push_str(&marker_line);
|
||||
if cursor_line_end < excerpt.len() {
|
||||
result.push('\n');
|
||||
result.push_str(&excerpt[cursor_line_end..]);
|
||||
}
|
||||
|
||||
self.cursor_position = result;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use indoc::indoc;
|
||||
|
||||
#[test]
|
||||
fn test_cursor_excerpt_with_caret() {
|
||||
let mut spec = ExampleSpec {
|
||||
name: String::new(),
|
||||
repository_url: String::new(),
|
||||
revision: String::new(),
|
||||
tags: Vec::new(),
|
||||
reasoning: None,
|
||||
uncommitted_diff: String::new(),
|
||||
cursor_path: Path::new("test.rs").into(),
|
||||
cursor_position: String::new(),
|
||||
edit_history: String::new(),
|
||||
expected_patches: Vec::new(),
|
||||
};
|
||||
|
||||
// Cursor before `42`
|
||||
let excerpt = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;
|
||||
println!(\"{}\", x);
|
||||
}"
|
||||
};
|
||||
let offset = excerpt.find("42").unwrap();
|
||||
let position_string = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;
|
||||
// ^[CURSOR_POSITION]
|
||||
println!(\"{}\", x);
|
||||
}"
|
||||
}
|
||||
.to_string();
|
||||
|
||||
spec.set_cursor_excerpt(excerpt, offset, "//");
|
||||
assert_eq!(spec.cursor_position, position_string);
|
||||
assert_eq!(
|
||||
spec.cursor_excerpt().unwrap(),
|
||||
(excerpt.to_string(), offset)
|
||||
);
|
||||
|
||||
// Cursor after `l` in `let`
|
||||
let offset = excerpt.find("et x").unwrap();
|
||||
let position_string = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;
|
||||
// ^[CURSOR_POSITION]
|
||||
println!(\"{}\", x);
|
||||
}"
|
||||
}
|
||||
.to_string();
|
||||
|
||||
spec.set_cursor_excerpt(excerpt, offset, "//");
|
||||
assert_eq!(spec.cursor_position, position_string);
|
||||
assert_eq!(
|
||||
spec.cursor_excerpt().unwrap(),
|
||||
(excerpt.to_string(), offset)
|
||||
);
|
||||
|
||||
// Cursor before `let`
|
||||
let offset = excerpt.find("let").unwrap();
|
||||
let position_string = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;
|
||||
// ^[CURSOR_POSITION]
|
||||
println!(\"{}\", x);
|
||||
}"
|
||||
}
|
||||
.to_string();
|
||||
|
||||
spec.set_cursor_excerpt(excerpt, offset, "//");
|
||||
assert_eq!(spec.cursor_position, position_string);
|
||||
assert_eq!(
|
||||
spec.cursor_excerpt().unwrap(),
|
||||
(excerpt.to_string(), offset)
|
||||
);
|
||||
|
||||
// Cursor at beginning of the line with `let`
|
||||
let offset = excerpt.find(" let").unwrap();
|
||||
let position_string = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;
|
||||
// <[CURSOR_POSITION]
|
||||
println!(\"{}\", x);
|
||||
}"
|
||||
}
|
||||
.to_string();
|
||||
|
||||
spec.set_cursor_excerpt(excerpt, offset, "//");
|
||||
assert_eq!(spec.cursor_position, position_string);
|
||||
assert_eq!(
|
||||
spec.cursor_excerpt().unwrap(),
|
||||
(excerpt.to_string(), offset)
|
||||
);
|
||||
|
||||
// Cursor at end of line, after the semicolon
|
||||
let offset = excerpt.find(';').unwrap() + 1;
|
||||
let position_string = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;
|
||||
// ^[CURSOR_POSITION]
|
||||
println!(\"{}\", x);
|
||||
}"
|
||||
}
|
||||
.to_string();
|
||||
|
||||
spec.set_cursor_excerpt(excerpt, offset, "//");
|
||||
assert_eq!(spec.cursor_position, position_string);
|
||||
assert_eq!(
|
||||
spec.cursor_excerpt().unwrap(),
|
||||
(excerpt.to_string(), offset)
|
||||
);
|
||||
|
||||
// Caret at end of file (no trailing newline)
|
||||
let excerpt = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;"
|
||||
};
|
||||
let offset = excerpt.find(';').unwrap() + 1;
|
||||
let position_string = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;
|
||||
// ^[CURSOR_POSITION]"
|
||||
}
|
||||
.to_string();
|
||||
|
||||
spec.set_cursor_excerpt(excerpt, offset, "//");
|
||||
assert_eq!(spec.cursor_position, position_string);
|
||||
assert_eq!(
|
||||
spec.cursor_excerpt().unwrap(),
|
||||
(excerpt.to_string(), offset)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cursor_excerpt_with_inline_marker() {
|
||||
let mut spec = ExampleSpec {
|
||||
name: String::new(),
|
||||
repository_url: String::new(),
|
||||
revision: String::new(),
|
||||
tags: Vec::new(),
|
||||
reasoning: None,
|
||||
uncommitted_diff: String::new(),
|
||||
cursor_path: Path::new("test.rs").into(),
|
||||
cursor_position: String::new(),
|
||||
edit_history: String::new(),
|
||||
expected_patches: Vec::new(),
|
||||
};
|
||||
|
||||
// Cursor before `42` using inline marker
|
||||
spec.cursor_position = indoc! {"
|
||||
fn main() {
|
||||
let x = <|user_cursor|>42;
|
||||
println!(\"{}\", x);
|
||||
}"
|
||||
}
|
||||
.to_string();
|
||||
|
||||
let expected_excerpt = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;
|
||||
println!(\"{}\", x);
|
||||
}"
|
||||
};
|
||||
let expected_offset = expected_excerpt.find("42").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
spec.cursor_excerpt().unwrap(),
|
||||
(expected_excerpt.to_string(), expected_offset)
|
||||
);
|
||||
|
||||
// Cursor at beginning of line
|
||||
spec.cursor_position = indoc! {"
|
||||
fn main() {
|
||||
<|user_cursor|> let x = 42;
|
||||
}"
|
||||
}
|
||||
.to_string();
|
||||
|
||||
let expected_excerpt = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;
|
||||
}"
|
||||
};
|
||||
let expected_offset = expected_excerpt.find(" let").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
spec.cursor_excerpt().unwrap(),
|
||||
(expected_excerpt.to_string(), expected_offset)
|
||||
);
|
||||
|
||||
// Cursor at end of file
|
||||
spec.cursor_position = "fn main() {}<|user_cursor|>".to_string();
|
||||
let expected_excerpt = "fn main() {}";
|
||||
let expected_offset = expected_excerpt.len();
|
||||
|
||||
assert_eq!(
|
||||
spec.cursor_excerpt().unwrap(),
|
||||
(expected_excerpt.to_string(), expected_offset)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,10 +14,8 @@ use anyhow::anyhow;
|
||||
use collections::HashMap;
|
||||
use gpui::AsyncApp;
|
||||
use gpui::Entity;
|
||||
use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot};
|
||||
use project::{Project, ProjectPath};
|
||||
use util::paths::PathStyle;
|
||||
use util::rel_path::RelPath;
|
||||
use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot, text_diff};
|
||||
use project::Project;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OpenedBuffers(#[allow(unused)] HashMap<String, Entity<Buffer>>);
|
||||
@@ -30,54 +28,26 @@ pub async fn apply_diff(
|
||||
) -> Result<OpenedBuffers> {
|
||||
let mut included_files = HashMap::default();
|
||||
|
||||
let worktree_id = project.read_with(cx, |project, cx| {
|
||||
anyhow::Ok(
|
||||
project
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.context("no worktrees")?
|
||||
.read(cx)
|
||||
.id(),
|
||||
)
|
||||
})??;
|
||||
|
||||
for line in diff_str.lines() {
|
||||
let diff_line = DiffLine::parse(line);
|
||||
|
||||
if let DiffLine::OldPath { path } = diff_line {
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = ProjectPath {
|
||||
worktree_id,
|
||||
path: RelPath::new(Path::new(path.as_ref()), PathStyle::Posix)?.into_arc(),
|
||||
};
|
||||
anyhow::Ok(project.open_buffer(project_path, cx))
|
||||
})??
|
||||
.await?;
|
||||
|
||||
included_files.insert(path.to_string(), buffer);
|
||||
}
|
||||
}
|
||||
|
||||
let ranges = [Anchor::MIN..Anchor::MAX];
|
||||
|
||||
let mut diff = DiffParser::new(diff_str);
|
||||
let mut current_file = None;
|
||||
let mut edits = vec![];
|
||||
|
||||
while let Some(event) = diff.next()? {
|
||||
match event {
|
||||
DiffEvent::Hunk {
|
||||
path: file_path,
|
||||
hunk,
|
||||
} => {
|
||||
let (buffer, ranges) = match current_file {
|
||||
DiffEvent::Hunk { path, hunk } => {
|
||||
let buffer = match current_file {
|
||||
None => {
|
||||
let buffer = included_files
|
||||
.get_mut(file_path.as_ref())
|
||||
.expect("Opened all files in diff");
|
||||
|
||||
current_file = Some((buffer, ranges.as_slice()));
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = project
|
||||
.find_project_path(path.as_ref(), cx)
|
||||
.context("no such path")?;
|
||||
anyhow::Ok(project.open_buffer(project_path, cx))
|
||||
})??
|
||||
.await?;
|
||||
included_files.insert(path.to_string(), buffer.clone());
|
||||
current_file = Some(buffer);
|
||||
current_file.as_ref().unwrap()
|
||||
}
|
||||
Some(ref current) => current,
|
||||
@@ -85,14 +55,14 @@ pub async fn apply_diff(
|
||||
|
||||
buffer.read_with(cx, |buffer, _| {
|
||||
edits.extend(
|
||||
resolve_hunk_edits_in_buffer(hunk, buffer, ranges)
|
||||
resolve_hunk_edits_in_buffer(hunk, buffer, ranges.as_slice())
|
||||
.with_context(|| format!("Diff:\n{diff_str}"))?,
|
||||
);
|
||||
anyhow::Ok(())
|
||||
})??;
|
||||
}
|
||||
DiffEvent::FileEnd { renamed_to } => {
|
||||
let (buffer, _) = current_file
|
||||
let buffer = current_file
|
||||
.take()
|
||||
.context("Got a FileEnd event before an Hunk event")?;
|
||||
|
||||
@@ -128,10 +98,69 @@ pub async fn apply_diff(
|
||||
Ok(OpenedBuffers(included_files))
|
||||
}
|
||||
|
||||
pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
|
||||
/// Extract the diff for a specific file from a multi-file diff.
|
||||
/// Returns an error if the file is not found in the diff.
|
||||
pub fn extract_file_diff(full_diff: &str, file_path: &str) -> Result<String> {
|
||||
let mut result = String::new();
|
||||
let mut in_target_file = false;
|
||||
let mut found_file = false;
|
||||
|
||||
for line in full_diff.lines() {
|
||||
if line.starts_with("diff --git") {
|
||||
if in_target_file {
|
||||
break;
|
||||
}
|
||||
in_target_file = line.contains(&format!("a/{}", file_path))
|
||||
|| line.contains(&format!("b/{}", file_path));
|
||||
if in_target_file {
|
||||
found_file = true;
|
||||
}
|
||||
}
|
||||
|
||||
if in_target_file {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
if !found_file {
|
||||
anyhow::bail!("File '{}' not found in diff", file_path);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Strip unnecessary git metadata lines from a diff, keeping only the lines
|
||||
/// needed for patch application: path headers (--- and +++), hunk headers (@@),
|
||||
/// and content lines (+, -, space).
|
||||
pub fn strip_diff_metadata(diff: &str) -> String {
|
||||
let mut result = String::new();
|
||||
|
||||
for line in diff.lines() {
|
||||
let dominated = DiffLine::parse(line);
|
||||
match dominated {
|
||||
// Keep path headers, hunk headers, and content lines
|
||||
DiffLine::OldPath { .. }
|
||||
| DiffLine::NewPath { .. }
|
||||
| DiffLine::HunkHeader(_)
|
||||
| DiffLine::Context(_)
|
||||
| DiffLine::Deletion(_)
|
||||
| DiffLine::Addition(_) => {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
// Skip garbage lines (diff --git, index, etc.)
|
||||
DiffLine::Garbage(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn apply_diff_to_string(original: &str, diff_str: &str) -> Result<String> {
|
||||
let mut diff = DiffParser::new(diff_str);
|
||||
|
||||
let mut text = text.to_string();
|
||||
let mut text = original.to_string();
|
||||
|
||||
while let Some(event) = diff.next()? {
|
||||
match event {
|
||||
@@ -151,6 +180,51 @@ pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
|
||||
Ok(text)
|
||||
}
|
||||
|
||||
/// Returns the individual edits that would be applied by a diff to the given content.
|
||||
/// Each edit is a tuple of (byte_range_in_content, replacement_text).
|
||||
/// Uses sub-line diffing to find the precise character positions of changes.
|
||||
/// Returns an empty vec if the hunk context is not found or is ambiguous.
|
||||
pub fn edits_for_diff(content: &str, diff_str: &str) -> Result<Vec<(Range<usize>, String)>> {
|
||||
let mut diff = DiffParser::new(diff_str);
|
||||
let mut result = Vec::new();
|
||||
|
||||
while let Some(event) = diff.next()? {
|
||||
match event {
|
||||
DiffEvent::Hunk { hunk, .. } => {
|
||||
if hunk.context.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Find the context in the content
|
||||
let first_match = content.find(&hunk.context);
|
||||
let Some(context_offset) = first_match else {
|
||||
return Ok(Vec::new());
|
||||
};
|
||||
|
||||
// Check for ambiguity - if context appears more than once, reject
|
||||
if content[context_offset + 1..].contains(&hunk.context) {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Use sub-line diffing to find precise edit positions
|
||||
for edit in &hunk.edits {
|
||||
let old_text = &content
|
||||
[context_offset + edit.range.start..context_offset + edit.range.end];
|
||||
let edits_within_hunk = text_diff(old_text, &edit.text);
|
||||
for (inner_range, inner_text) in edits_within_hunk {
|
||||
let absolute_start = context_offset + edit.range.start + inner_range.start;
|
||||
let absolute_end = context_offset + edit.range.start + inner_range.end;
|
||||
result.push((absolute_start..absolute_end, inner_text.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
DiffEvent::FileEnd { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
struct PatchFile<'a> {
|
||||
old_path: Cow<'a, str>,
|
||||
new_path: Cow<'a, str>,
|
||||
@@ -873,4 +947,135 @@ mod tests {
|
||||
|
||||
FakeFs::new(cx.background_executor.clone())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_file_diff() {
|
||||
let multi_file_diff = indoc! {r#"
|
||||
diff --git a/file1.txt b/file1.txt
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/file1.txt
|
||||
+++ b/file1.txt
|
||||
@@ -1,3 +1,4 @@
|
||||
line1
|
||||
+added line
|
||||
line2
|
||||
line3
|
||||
diff --git a/file2.txt b/file2.txt
|
||||
index 2345678..bcdefgh 100644
|
||||
--- a/file2.txt
|
||||
+++ b/file2.txt
|
||||
@@ -1,2 +1,2 @@
|
||||
-old line
|
||||
+new line
|
||||
unchanged
|
||||
"#};
|
||||
|
||||
let file1_diff = extract_file_diff(multi_file_diff, "file1.txt").unwrap();
|
||||
assert_eq!(
|
||||
file1_diff,
|
||||
indoc! {r#"
|
||||
diff --git a/file1.txt b/file1.txt
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/file1.txt
|
||||
+++ b/file1.txt
|
||||
@@ -1,3 +1,4 @@
|
||||
line1
|
||||
+added line
|
||||
line2
|
||||
line3
|
||||
"#}
|
||||
);
|
||||
|
||||
let file2_diff = extract_file_diff(multi_file_diff, "file2.txt").unwrap();
|
||||
assert_eq!(
|
||||
file2_diff,
|
||||
indoc! {r#"
|
||||
diff --git a/file2.txt b/file2.txt
|
||||
index 2345678..bcdefgh 100644
|
||||
--- a/file2.txt
|
||||
+++ b/file2.txt
|
||||
@@ -1,2 +1,2 @@
|
||||
-old line
|
||||
+new line
|
||||
unchanged
|
||||
"#}
|
||||
);
|
||||
|
||||
let result = extract_file_diff(multi_file_diff, "nonexistent.txt");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edits_for_diff() {
|
||||
let content = indoc! {"
|
||||
fn main() {
|
||||
let x = 1;
|
||||
let y = 2;
|
||||
println!(\"{} {}\", x, y);
|
||||
}
|
||||
"};
|
||||
|
||||
let diff = indoc! {"
|
||||
--- a/file.rs
|
||||
+++ b/file.rs
|
||||
@@ -1,5 +1,5 @@
|
||||
fn main() {
|
||||
- let x = 1;
|
||||
+ let x = 42;
|
||||
let y = 2;
|
||||
println!(\"{} {}\", x, y);
|
||||
}
|
||||
"};
|
||||
|
||||
let edits = edits_for_diff(content, diff).unwrap();
|
||||
assert_eq!(edits.len(), 1);
|
||||
|
||||
let (range, replacement) = &edits[0];
|
||||
// With sub-line diffing, the edit should start at "1" (the actual changed character)
|
||||
let expected_start = content.find("let x = 1;").unwrap() + "let x = ".len();
|
||||
assert_eq!(range.start, expected_start);
|
||||
// The deleted text is just "1"
|
||||
assert_eq!(range.end, expected_start + "1".len());
|
||||
// The replacement text
|
||||
assert_eq!(replacement, "42");
|
||||
|
||||
// Verify the cursor would be positioned at the column of "1"
|
||||
let line_start = content[..range.start]
|
||||
.rfind('\n')
|
||||
.map(|p| p + 1)
|
||||
.unwrap_or(0);
|
||||
let cursor_column = range.start - line_start;
|
||||
// " let x = " is 12 characters, so column 12
|
||||
assert_eq!(cursor_column, " let x = ".len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_diff_metadata() {
|
||||
let diff_with_metadata = indoc! {r#"
|
||||
diff --git a/file.txt b/file.txt
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ -1,3 +1,4 @@
|
||||
context line
|
||||
-removed line
|
||||
+added line
|
||||
more context
|
||||
"#};
|
||||
|
||||
let stripped = strip_diff_metadata(diff_with_metadata);
|
||||
|
||||
assert_eq!(
|
||||
stripped,
|
||||
indoc! {r#"
|
||||
--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ -1,3 +1,4 @@
|
||||
context line
|
||||
-removed line
|
||||
+added line
|
||||
more context
|
||||
"#}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
use anthropic::{
|
||||
ANTHROPIC_API_URL, Message, Request as AnthropicRequest, RequestContent,
|
||||
Response as AnthropicResponse, Role, non_streaming_completion,
|
||||
ANTHROPIC_API_URL, Event, Message, Request as AnthropicRequest, RequestContent,
|
||||
Response as AnthropicResponse, ResponseContent, Role, non_streaming_completion,
|
||||
stream_completion,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use futures::StreamExt as _;
|
||||
use http_client::HttpClient;
|
||||
use indoc::indoc;
|
||||
use reqwest_client::ReqwestClient;
|
||||
@@ -15,12 +17,12 @@ use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct PlainLlmClient {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
api_key: String,
|
||||
pub http_client: Arc<dyn HttpClient>,
|
||||
pub api_key: String,
|
||||
}
|
||||
|
||||
impl PlainLlmClient {
|
||||
fn new() -> Result<Self> {
|
||||
pub fn new() -> Result<Self> {
|
||||
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
|
||||
let api_key = std::env::var("ANTHROPIC_API_KEY")
|
||||
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
|
||||
@@ -30,7 +32,7 @@ impl PlainLlmClient {
|
||||
})
|
||||
}
|
||||
|
||||
async fn generate(
|
||||
pub async fn generate(
|
||||
&self,
|
||||
model: &str,
|
||||
max_tokens: u64,
|
||||
@@ -63,6 +65,72 @@ impl PlainLlmClient {
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn generate_streaming<F>(
|
||||
&self,
|
||||
model: &str,
|
||||
max_tokens: u64,
|
||||
messages: Vec<Message>,
|
||||
mut on_progress: F,
|
||||
) -> Result<AnthropicResponse>
|
||||
where
|
||||
F: FnMut(usize, &str),
|
||||
{
|
||||
let request = AnthropicRequest {
|
||||
model: model.to_string(),
|
||||
max_tokens,
|
||||
messages,
|
||||
tools: Vec::new(),
|
||||
thinking: None,
|
||||
tool_choice: None,
|
||||
system: None,
|
||||
metadata: None,
|
||||
stop_sequences: Vec::new(),
|
||||
temperature: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
};
|
||||
|
||||
let mut stream = stream_completion(
|
||||
self.http_client.as_ref(),
|
||||
ANTHROPIC_API_URL,
|
||||
&self.api_key,
|
||||
request,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{:?}", e))?;
|
||||
|
||||
let mut response: Option<AnthropicResponse> = None;
|
||||
let mut text_content = String::new();
|
||||
|
||||
while let Some(event_result) = stream.next().await {
|
||||
let event = event_result.map_err(|e| anyhow::anyhow!("{:?}", e))?;
|
||||
|
||||
match event {
|
||||
Event::MessageStart { message } => {
|
||||
response = Some(message);
|
||||
}
|
||||
Event::ContentBlockDelta { delta, .. } => {
|
||||
if let anthropic::ContentDelta::TextDelta { text } = delta {
|
||||
text_content.push_str(&text);
|
||||
on_progress(text_content.len(), &text_content);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let mut response = response.ok_or_else(|| anyhow::anyhow!("No response received"))?;
|
||||
|
||||
if response.content.is_empty() && !text_content.is_empty() {
|
||||
response
|
||||
.content
|
||||
.push(ResponseContent::Text { text: text_content });
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BatchingLlmClient {
|
||||
@@ -408,6 +476,29 @@ impl AnthropicClient {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn generate_streaming<F>(
|
||||
&self,
|
||||
model: &str,
|
||||
max_tokens: u64,
|
||||
messages: Vec<Message>,
|
||||
on_progress: F,
|
||||
) -> Result<Option<AnthropicResponse>>
|
||||
where
|
||||
F: FnMut(usize, &str),
|
||||
{
|
||||
match self {
|
||||
AnthropicClient::Plain(plain_llm_client) => plain_llm_client
|
||||
.generate_streaming(model, max_tokens, messages, on_progress)
|
||||
.await
|
||||
.map(Some),
|
||||
AnthropicClient::Batch(_) => {
|
||||
anyhow::bail!("Streaming not supported with batching client")
|
||||
}
|
||||
AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn sync_batches(&self) -> Result<()> {
|
||||
match self {
|
||||
AnthropicClient::Plain(_) => Ok(()),
|
||||
|
||||
@@ -1,20 +1,15 @@
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::Result;
|
||||
use std::mem;
|
||||
|
||||
use crate::example::Example;
|
||||
|
||||
pub async fn run_distill(example: &mut Example) -> Result<()> {
|
||||
let [prediction]: [_; 1] =
|
||||
mem::take(&mut example.predictions)
|
||||
.try_into()
|
||||
.map_err(|preds: Vec<_>| {
|
||||
anyhow!(
|
||||
"Example has {} predictions, but it should have exactly one",
|
||||
preds.len()
|
||||
)
|
||||
})?;
|
||||
let predictions = mem::take(&mut example.predictions)
|
||||
.into_iter()
|
||||
.map(|p| p.actual_patch)
|
||||
.collect();
|
||||
|
||||
example.spec.expected_patch = prediction.actual_patch;
|
||||
example.spec.expected_patches = predictions;
|
||||
example.prompt = None;
|
||||
example.predictions = Vec::new();
|
||||
example.score = Vec::new();
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics};
|
||||
use crate::{PredictionProvider, PromptFormat};
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use edit_prediction::example_spec::ExampleSpec;
|
||||
@@ -87,7 +87,6 @@ pub struct ExamplePrediction {
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExampleScore {
|
||||
pub delta_chr_f: f32,
|
||||
pub line_match: ClassificationMetrics,
|
||||
}
|
||||
|
||||
impl Example {
|
||||
@@ -190,7 +189,11 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
|
||||
.collect::<Vec<Example>>(),
|
||||
),
|
||||
"md" => {
|
||||
examples.push(parse_markdown_example(filename, &content).unwrap());
|
||||
let mut example = parse_markdown_example(&content).unwrap();
|
||||
if example.spec.name.is_empty() {
|
||||
example.spec.name = filename;
|
||||
}
|
||||
examples.push(example);
|
||||
}
|
||||
ext => {
|
||||
panic!("{} has invalid example extension `{ext}`", path.display())
|
||||
@@ -236,8 +239,8 @@ pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>
|
||||
examples_by_repo.into_values().collect()
|
||||
}
|
||||
|
||||
fn parse_markdown_example(name: String, input: &str) -> Result<Example> {
|
||||
let spec = ExampleSpec::from_markdown(name, input)?;
|
||||
fn parse_markdown_example(input: &str) -> Result<Example> {
|
||||
let spec = ExampleSpec::from_markdown(input)?;
|
||||
Ok(Example {
|
||||
spec,
|
||||
buffer: None,
|
||||
|
||||
@@ -30,7 +30,12 @@ pub async fn run_format_prompt(
|
||||
let prompt = TeacherPrompt::format_prompt(example);
|
||||
example.prompt = Some(ExamplePrompt {
|
||||
input: prompt,
|
||||
expected_output: example.spec.expected_patch.clone(), // TODO
|
||||
expected_output: example
|
||||
.spec
|
||||
.expected_patches
|
||||
.first()
|
||||
.cloned()
|
||||
.unwrap_or_default(),
|
||||
format: prompt_format,
|
||||
});
|
||||
}
|
||||
@@ -68,8 +73,15 @@ pub async fn run_format_prompt(
|
||||
))
|
||||
})??;
|
||||
let prompt = format_zeta_prompt(&input);
|
||||
let expected_output =
|
||||
zeta2_output_for_patch(&input, &example.spec.expected_patch.clone())?;
|
||||
let expected_output = zeta2_output_for_patch(
|
||||
&input,
|
||||
&example
|
||||
.spec
|
||||
.expected_patches
|
||||
.first()
|
||||
.context("expected patches is empty")?
|
||||
.clone(),
|
||||
)?;
|
||||
example.prompt = Some(ExamplePrompt {
|
||||
input: prompt,
|
||||
expected_output,
|
||||
@@ -86,6 +98,7 @@ impl TeacherPrompt {
|
||||
const PROMPT: &str = include_str!("teacher.prompt.md");
|
||||
pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
|
||||
pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
|
||||
pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
|
||||
|
||||
/// Truncate edit history to this number of last lines
|
||||
const MAX_HISTORY_LINES: usize = 128;
|
||||
@@ -181,13 +194,15 @@ impl TeacherPrompt {
|
||||
result.push_str(Self::EDITABLE_REGION_START);
|
||||
|
||||
// TODO: control number of lines around cursor
|
||||
result.push_str(&example.spec.cursor_position);
|
||||
if !example.spec.cursor_position.ends_with('\n') {
|
||||
let (mut excerpt, offset) = example.spec.cursor_excerpt().unwrap();
|
||||
excerpt.insert_str(offset, Self::USER_CURSOR_MARKER);
|
||||
result.push_str(&excerpt);
|
||||
if !result.ends_with('\n') {
|
||||
result.push('\n');
|
||||
}
|
||||
|
||||
result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
|
||||
result.push_str("`````");
|
||||
result.push_str(Self::EDITABLE_REGION_END);
|
||||
result.push_str("\n`````");
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
110
crates/edit_prediction_cli/src/git.rs
Normal file
110
crates/edit_prediction_cli/src/git.rs
Normal file
@@ -0,0 +1,110 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use futures::lock::{Mutex, OwnedMutexGuard};
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::paths::REPOS_DIR;
|
||||
|
||||
thread_local! {
|
||||
static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
|
||||
REPO_LOCKS
|
||||
.with(|cell| {
|
||||
cell.borrow_mut()
|
||||
.entry(path.as_ref().to_path_buf())
|
||||
.or_default()
|
||||
.clone()
|
||||
})
|
||||
.lock_owned()
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
|
||||
let output = smol::process::Command::new("git")
|
||||
.current_dir(repo_path)
|
||||
.args(args)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
anyhow::ensure!(
|
||||
output.status.success(),
|
||||
"`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
args.join(" "),
|
||||
repo_path.display(),
|
||||
output.status,
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
);
|
||||
Ok(String::from_utf8(output.stdout)?.trim().to_string())
|
||||
}
|
||||
|
||||
pub fn parse_repo_url(url: &str) -> Result<(String, String)> {
|
||||
if url.contains('@') {
|
||||
let (_, path) = url.split_once(':').context("expected : in git url")?;
|
||||
let (owner, repo) = path.split_once('/').context("expected / in git url")?;
|
||||
Ok((owner.to_string(), repo.trim_end_matches(".git").to_string()))
|
||||
} else {
|
||||
let parsed = http_client::Url::parse(url)?;
|
||||
let mut segments = parsed.path_segments().context("empty http url")?;
|
||||
let owner = segments.next().context("expected owner")?;
|
||||
let repo = segments.next().context("expected repo")?;
|
||||
Ok((owner.to_string(), repo.trim_end_matches(".git").to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn repo_path_for_url(url: &str) -> Result<PathBuf> {
|
||||
let (owner, name) = parse_repo_url(url)?;
|
||||
Ok(REPOS_DIR.join(&owner).join(&name))
|
||||
}
|
||||
|
||||
pub async fn ensure_repo_cloned(repo_url: &str) -> Result<PathBuf> {
|
||||
let repo_path = repo_path_for_url(repo_url)?;
|
||||
let _lock = lock_repo(&repo_path).await;
|
||||
|
||||
if !repo_path.is_dir() {
|
||||
log::info!("Cloning {} into {:?}", repo_url, repo_path);
|
||||
std::fs::create_dir_all(&repo_path)?;
|
||||
run_git(&repo_path, &["init"]).await?;
|
||||
run_git(&repo_path, &["remote", "add", "origin", repo_url]).await?;
|
||||
}
|
||||
|
||||
// Always fetch to get latest commits
|
||||
run_git(&repo_path, &["fetch", "origin"]).await?;
|
||||
|
||||
// Check if we have a valid HEAD, if not checkout FETCH_HEAD
|
||||
let has_head = run_git(&repo_path, &["rev-parse", "HEAD"]).await.is_ok();
|
||||
if !has_head {
|
||||
// Use reset to set HEAD without needing a branch
|
||||
run_git(&repo_path, &["reset", "--hard", "FETCH_HEAD"]).await?;
|
||||
}
|
||||
|
||||
Ok(repo_path)
|
||||
}
|
||||
|
||||
pub async fn fetch_if_needed(repo_path: &Path, revision: &str) -> Result<String> {
|
||||
let resolved = run_git(
|
||||
repo_path,
|
||||
&["rev-parse", &format!("{}^{{commit}}", revision)],
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Ok(sha) = resolved {
|
||||
return Ok(sha);
|
||||
}
|
||||
|
||||
if run_git(repo_path, &["fetch", "--depth", "1", "origin", revision])
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
run_git(repo_path, &["fetch", "origin"]).await?;
|
||||
}
|
||||
|
||||
run_git(repo_path, &["rev-parse", "FETCH_HEAD"]).await
|
||||
}
|
||||
@@ -1,29 +1,19 @@
|
||||
use crate::{
|
||||
example::{Example, ExampleBuffer, ExampleState},
|
||||
git,
|
||||
headless::EpAppState,
|
||||
paths::{REPOS_DIR, WORKTREES_DIR},
|
||||
paths::WORKTREES_DIR,
|
||||
progress::{InfoStyle, Progress, Step, StepProgress},
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use edit_prediction::EditPredictionStore;
|
||||
use edit_prediction::udiff::OpenedBuffers;
|
||||
use futures::{
|
||||
AsyncWriteExt as _,
|
||||
lock::{Mutex, OwnedMutexGuard},
|
||||
};
|
||||
use futures::AsyncWriteExt as _;
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use language::{Anchor, Buffer, LanguageNotFound, ToOffset, ToPoint};
|
||||
use project::Project;
|
||||
use project::buffer_store::BufferStoreEvent;
|
||||
use project::{Project, ProjectPath};
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::{paths::PathStyle, rel_path::RelPath};
|
||||
use zeta_prompt::CURSOR_MARKER;
|
||||
use std::{fs, path::PathBuf, sync::Arc};
|
||||
|
||||
pub async fn run_load_project(
|
||||
example: &mut Example,
|
||||
@@ -86,37 +76,22 @@ async fn cursor_position(
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
let worktree = project.read_with(cx, |project, cx| {
|
||||
project
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.context("No visible worktrees")
|
||||
})??;
|
||||
|
||||
let cursor_path = RelPath::new(&example.spec.cursor_path, PathStyle::Posix)
|
||||
.context("Failed to create RelPath")?
|
||||
.into_arc();
|
||||
let cursor_buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(
|
||||
ProjectPath {
|
||||
worktree_id: worktree.read(cx).id(),
|
||||
path: cursor_path,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
let cursor_path = project
|
||||
.read_with(cx, |project, cx| {
|
||||
project.find_project_path(&example.spec.cursor_path, cx)
|
||||
})?
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to find cursor path {}",
|
||||
example.spec.cursor_path.display()
|
||||
)
|
||||
})?;
|
||||
let cursor_buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(cursor_path, cx))?
|
||||
.await?;
|
||||
let cursor_offset_within_excerpt = example
|
||||
.spec
|
||||
.cursor_position
|
||||
.find(CURSOR_MARKER)
|
||||
.context("missing cursor marker")?;
|
||||
let mut cursor_excerpt = example.spec.cursor_position.clone();
|
||||
cursor_excerpt.replace_range(
|
||||
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
|
||||
"",
|
||||
);
|
||||
|
||||
let (cursor_excerpt, cursor_offset_within_excerpt) = example.spec.cursor_excerpt()?;
|
||||
|
||||
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
|
||||
let text = buffer.text();
|
||||
|
||||
@@ -212,17 +187,17 @@ async fn setup_project(
|
||||
|
||||
async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Result<PathBuf> {
|
||||
let (repo_owner, repo_name) = example.repo_name().context("failed to get repo name")?;
|
||||
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
|
||||
let repo_dir = git::repo_path_for_url(&example.spec.repository_url)?;
|
||||
let worktree_path = WORKTREES_DIR
|
||||
.join(repo_owner.as_ref())
|
||||
.join(repo_name.as_ref());
|
||||
let repo_lock = lock_repo(&repo_dir).await;
|
||||
let repo_lock = git::lock_repo(&repo_dir).await;
|
||||
|
||||
if !repo_dir.is_dir() {
|
||||
step_progress.set_substatus(format!("cloning {}", repo_name));
|
||||
fs::create_dir_all(&repo_dir)?;
|
||||
run_git(&repo_dir, &["init"]).await?;
|
||||
run_git(
|
||||
git::run_git(&repo_dir, &["init"]).await?;
|
||||
git::run_git(
|
||||
&repo_dir,
|
||||
&["remote", "add", "origin", &example.spec.repository_url],
|
||||
)
|
||||
@@ -230,53 +205,26 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu
|
||||
}
|
||||
|
||||
// Resolve the example to a revision, fetching it if needed.
|
||||
let revision = run_git(
|
||||
&repo_dir,
|
||||
&[
|
||||
"rev-parse",
|
||||
&format!("{}^{{commit}}", example.spec.revision),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
let revision = if let Ok(revision) = revision {
|
||||
revision
|
||||
} else {
|
||||
step_progress.set_substatus("fetching");
|
||||
if run_git(
|
||||
&repo_dir,
|
||||
&["fetch", "--depth", "1", "origin", &example.spec.revision],
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
run_git(&repo_dir, &["fetch", "origin"]).await?;
|
||||
}
|
||||
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
|
||||
revision
|
||||
};
|
||||
step_progress.set_substatus("fetching");
|
||||
let revision = git::fetch_if_needed(&repo_dir, &example.spec.revision).await?;
|
||||
|
||||
// Create the worktree for this example if needed.
|
||||
step_progress.set_substatus("preparing worktree");
|
||||
if worktree_path.is_dir() {
|
||||
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
|
||||
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
|
||||
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
|
||||
git::run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
|
||||
git::run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
|
||||
git::run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
|
||||
} else {
|
||||
let worktree_path_string = worktree_path.to_string_lossy();
|
||||
run_git(
|
||||
let branch_name = example.spec.filename();
|
||||
git::run_git(
|
||||
&repo_dir,
|
||||
&["branch", "-f", &example.spec.name, revision.as_str()],
|
||||
&["branch", "-f", &branch_name, revision.as_str()],
|
||||
)
|
||||
.await?;
|
||||
run_git(
|
||||
git::run_git(
|
||||
&repo_dir,
|
||||
&[
|
||||
"worktree",
|
||||
"add",
|
||||
"-f",
|
||||
&worktree_path_string,
|
||||
&example.spec.name,
|
||||
],
|
||||
&["worktree", "add", "-f", &worktree_path_string, &branch_name],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
@@ -319,39 +267,3 @@ async fn apply_edit_history(
|
||||
) -> Result<OpenedBuffers> {
|
||||
edit_prediction::udiff::apply_diff(&example.spec.edit_history, project, cx).await
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
|
||||
REPO_LOCKS
|
||||
.with(|cell| {
|
||||
cell.borrow_mut()
|
||||
.entry(path.as_ref().to_path_buf())
|
||||
.or_default()
|
||||
.clone()
|
||||
})
|
||||
.lock_owned()
|
||||
.await
|
||||
}
|
||||
|
||||
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
|
||||
let output = smol::process::Command::new("git")
|
||||
.current_dir(repo_path)
|
||||
.args(args)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
anyhow::ensure!(
|
||||
output.status.success(),
|
||||
"`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
args.join(" "),
|
||||
repo_path.display(),
|
||||
output.status,
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
);
|
||||
Ok(String::from_utf8(output.stdout)?.trim().to_string())
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ mod anthropic_client;
|
||||
mod distill;
|
||||
mod example;
|
||||
mod format_prompt;
|
||||
mod git;
|
||||
mod headless;
|
||||
mod load_project;
|
||||
mod metrics;
|
||||
@@ -10,6 +11,7 @@ mod predict;
|
||||
mod progress;
|
||||
mod retrieve_context;
|
||||
mod score;
|
||||
mod synthesize;
|
||||
|
||||
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
|
||||
use edit_prediction::EditPredictionStore;
|
||||
@@ -28,6 +30,7 @@ use crate::predict::run_prediction;
|
||||
use crate::progress::Progress;
|
||||
use crate::retrieve_context::run_context_retrieval;
|
||||
use crate::score::run_scoring;
|
||||
use crate::synthesize::{SynthesizeConfig, run_synthesize};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "ep")]
|
||||
@@ -67,6 +70,8 @@ enum Command {
|
||||
Distill,
|
||||
/// Print aggregated scores
|
||||
Eval(PredictArgs),
|
||||
/// Generate eval examples by analyzing git commits from a repository
|
||||
Synthesize(SynthesizeArgs),
|
||||
/// Remove git repositories and worktrees
|
||||
Clean,
|
||||
}
|
||||
@@ -118,6 +123,9 @@ impl Display for Command {
|
||||
.unwrap()
|
||||
.get_name()
|
||||
),
|
||||
Command::Synthesize(args) => {
|
||||
write!(f, "synthesize --repo={}", args.repo)
|
||||
}
|
||||
Command::Clean => write!(f, "clean"),
|
||||
}
|
||||
}
|
||||
@@ -143,7 +151,7 @@ struct PredictArgs {
|
||||
repetitions: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
|
||||
enum PredictionProvider {
|
||||
Sweep,
|
||||
Mercury,
|
||||
@@ -153,6 +161,29 @@ enum PredictionProvider {
|
||||
TeacherNonBatching,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
struct SynthesizeArgs {
|
||||
/// Repository URL (git@github.com:owner/repo or https://...)
|
||||
#[clap(long)]
|
||||
repo: String,
|
||||
|
||||
/// Number of examples to generate
|
||||
#[clap(long, default_value_t = 5)]
|
||||
count: usize,
|
||||
|
||||
/// Maximum commits to scan before giving up
|
||||
#[clap(long, default_value_t = 100)]
|
||||
max_commits: usize,
|
||||
|
||||
/// Only generate examples that require retrieved context to make a correct prediction
|
||||
#[clap(long)]
|
||||
require_context: bool,
|
||||
|
||||
/// Ignore state file and reprocess all commits
|
||||
#[clap(long)]
|
||||
fresh: bool,
|
||||
}
|
||||
|
||||
impl EpArgs {
|
||||
fn output_path(&self) -> Option<PathBuf> {
|
||||
if self.in_place {
|
||||
@@ -189,6 +220,26 @@ fn main() {
|
||||
std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
|
||||
return;
|
||||
}
|
||||
Command::Synthesize(synth_args) => {
|
||||
let Some(output_dir) = args.output else {
|
||||
panic!("output dir is required");
|
||||
};
|
||||
let config = SynthesizeConfig {
|
||||
repo_url: synth_args.repo.clone(),
|
||||
count: synth_args.count,
|
||||
max_commits: synth_args.max_commits,
|
||||
output_dir,
|
||||
require_context: synth_args.require_context,
|
||||
fresh: synth_args.fresh,
|
||||
};
|
||||
smol::block_on(async {
|
||||
if let Err(e) = run_synthesize(config).await {
|
||||
eprintln!("Error: {:?}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
@@ -256,7 +307,7 @@ fn main() {
|
||||
run_scoring(example, &args, app_state.clone(), cx.clone())
|
||||
.await?;
|
||||
}
|
||||
Command::Clean => {
|
||||
Command::Clean | Command::Synthesize(_) => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,34 +1,17 @@
|
||||
use collections::{HashMap, HashSet};
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use collections::HashMap;
|
||||
|
||||
type Counts = HashMap<String, usize>;
|
||||
type CountsDelta = HashMap<String, isize>;
|
||||
|
||||
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClassificationMetrics {
|
||||
pub true_positives: usize,
|
||||
pub false_positives: usize,
|
||||
pub false_negatives: usize,
|
||||
#[derive(Default, Debug, Clone)]
|
||||
struct ClassificationMetrics {
|
||||
true_positives: usize,
|
||||
false_positives: usize,
|
||||
false_negatives: usize,
|
||||
}
|
||||
|
||||
impl ClassificationMetrics {
|
||||
pub fn from_sets(
|
||||
expected: &HashSet<String>,
|
||||
actual: &HashSet<String>,
|
||||
) -> ClassificationMetrics {
|
||||
let true_positives = expected.intersection(actual).count();
|
||||
let false_positives = actual.difference(expected).count();
|
||||
let false_negatives = expected.difference(actual).count();
|
||||
|
||||
ClassificationMetrics {
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
|
||||
fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
|
||||
let mut true_positives = 0;
|
||||
let mut false_positives = 0;
|
||||
let mut false_negatives = 0;
|
||||
@@ -56,27 +39,7 @@ impl ClassificationMetrics {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn aggregate<'a>(
|
||||
scores: impl Iterator<Item = &'a ClassificationMetrics>,
|
||||
) -> ClassificationMetrics {
|
||||
let mut true_positives = 0;
|
||||
let mut false_positives = 0;
|
||||
let mut false_negatives = 0;
|
||||
|
||||
for score in scores {
|
||||
true_positives += score.true_positives;
|
||||
false_positives += score.false_positives;
|
||||
false_negatives += score.false_negatives;
|
||||
}
|
||||
|
||||
ClassificationMetrics {
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn precision(&self) -> f64 {
|
||||
fn precision(&self) -> f64 {
|
||||
if self.true_positives + self.false_positives == 0 {
|
||||
0.0
|
||||
} else {
|
||||
@@ -84,42 +47,13 @@ impl ClassificationMetrics {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn recall(&self) -> f64 {
|
||||
fn recall(&self) -> f64 {
|
||||
if self.true_positives + self.false_negatives == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
|
||||
}
|
||||
}
|
||||
|
||||
pub fn f1_score(&self) -> f64 {
|
||||
let recall = self.recall();
|
||||
let precision = self.precision();
|
||||
if precision + recall == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
2.0 * precision * recall / (precision + recall)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn line_match_score(
|
||||
expected_patch: &[DiffLine],
|
||||
actual_patch: &[DiffLine],
|
||||
) -> ClassificationMetrics {
|
||||
let expected_change_lines = expected_patch
|
||||
.iter()
|
||||
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
|
||||
.map(|line| line.to_string())
|
||||
.collect();
|
||||
|
||||
let actual_change_lines = actual_patch
|
||||
.iter()
|
||||
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
|
||||
.map(|line| line.to_string())
|
||||
.collect();
|
||||
|
||||
ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines)
|
||||
}
|
||||
|
||||
enum ChrfWhitespace {
|
||||
@@ -135,55 +69,26 @@ const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Ignore;
|
||||
/// Computes a delta-chrF score that compares two sets of edits.
|
||||
///
|
||||
/// This metric works by:
|
||||
/// 1. Reconstructing original, golden (expected result), and actual texts from diffs
|
||||
/// 2. Computing n-gram count differences (deltas) between original→golden and original→actual
|
||||
/// 3. Comparing these deltas to measure how well actual edits match expected edits
|
||||
pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 {
|
||||
// Reconstruct texts from diffs
|
||||
let mut original_text = String::new(); // state of the text before any edits
|
||||
let mut golden_text = String::new(); // text after applying golden edits
|
||||
let mut actual_text = String::new(); // text after applying actual edits
|
||||
|
||||
for line in expected {
|
||||
match line {
|
||||
DiffLine::Context(s) => {
|
||||
original_text.push_str(s);
|
||||
golden_text.push_str(s);
|
||||
}
|
||||
DiffLine::Deletion(s) => {
|
||||
original_text.push_str(s);
|
||||
}
|
||||
DiffLine::Addition(s) => {
|
||||
golden_text.push_str(s);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
for line in actual {
|
||||
match line {
|
||||
DiffLine::Context(s) | DiffLine::Addition(s) => {
|
||||
actual_text.push_str(s);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Edge case
|
||||
if original_text == golden_text && golden_text == actual_text {
|
||||
/// 1. Computing n-gram count differences (deltas) between original→expected and original→actual
|
||||
/// 2. Comparing these deltas to measure how well actual edits match expected edits
|
||||
///
|
||||
/// Returns a score from 0.0 to 100.0, where 100.0 means the actual edits perfectly match
|
||||
/// the expected edits.
|
||||
pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> f64 {
|
||||
// Edge case: if all texts are identical, the edits match perfectly
|
||||
if original == expected && expected == actual {
|
||||
return 100.0;
|
||||
}
|
||||
|
||||
// Compute the metric
|
||||
let original_ngrams = chr_f_ngram_counts(&original_text);
|
||||
let golden_ngrams = chr_f_ngram_counts(&golden_text);
|
||||
let actual_ngrams = chr_f_ngram_counts(&actual_text);
|
||||
let original_ngrams = chr_f_ngram_counts(original);
|
||||
let expected_ngrams = chr_f_ngram_counts(expected);
|
||||
let actual_ngrams = chr_f_ngram_counts(actual);
|
||||
|
||||
let mut total_precision = 0.0;
|
||||
let mut total_recall = 0.0;
|
||||
|
||||
for order in 0..CHR_F_CHAR_ORDER {
|
||||
let expected_delta = compute_ngram_delta(&golden_ngrams[order], &original_ngrams[order]);
|
||||
let expected_delta = compute_ngram_delta(&expected_ngrams[order], &original_ngrams[order]);
|
||||
let actual_delta = compute_ngram_delta(&actual_ngrams[order], &original_ngrams[order]);
|
||||
|
||||
if expected_delta.is_empty() && actual_delta.is_empty() {
|
||||
@@ -255,7 +160,7 @@ fn ngram_delta_to_counts(delta: &CountsDelta) -> Counts {
|
||||
for (ngram, &delta) in delta {
|
||||
if delta > 0 {
|
||||
counts.insert(ngram.clone(), delta as usize);
|
||||
} else {
|
||||
} else if delta < 0 {
|
||||
counts.insert(format!("¬{ngram}"), delta.unsigned_abs());
|
||||
}
|
||||
}
|
||||
@@ -278,94 +183,68 @@ fn count_ngrams(text: &str, n: usize) -> Counts {
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
|
||||
#[test]
|
||||
fn test_delta_chr_f_perfect_match() {
|
||||
let diff = vec![
|
||||
DiffLine::Context("fn main() {"),
|
||||
DiffLine::Deletion(" println!(\"Hello\");"),
|
||||
DiffLine::Addition(" println!(\"Hello, World!\");"),
|
||||
DiffLine::Context("}"),
|
||||
];
|
||||
let original = "fn main() { println!(\"Hello\");}";
|
||||
let expected = "fn main() { println!(\"Hello, World!\");}";
|
||||
|
||||
let score = delta_chr_f(&diff, &diff);
|
||||
let score = delta_chr_f(original, expected, expected);
|
||||
assert!((score - 100.0).abs() < 1e-2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delta_chr_f_wrong_edit() {
|
||||
// When the edit is wrong
|
||||
let expected = vec![
|
||||
DiffLine::Context("one "),
|
||||
DiffLine::Deletion("two "),
|
||||
DiffLine::Context("three"),
|
||||
];
|
||||
|
||||
let actual = vec![
|
||||
DiffLine::Context("one "),
|
||||
DiffLine::Context("two "),
|
||||
DiffLine::Deletion("three"),
|
||||
DiffLine::Addition("four"),
|
||||
];
|
||||
let original = "one two three";
|
||||
let expected = "one three"; // deleted "two "
|
||||
let actual = "one two four"; // deleted "three", added "four"
|
||||
|
||||
// Then the score should be low
|
||||
let score = delta_chr_f(&expected, &actual);
|
||||
let score = delta_chr_f(original, expected, actual);
|
||||
assert!(score > 20.0 && score < 40.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delta_chr_f_partial_match() {
|
||||
let expected = vec![
|
||||
DiffLine::Deletion("let x = 42;"),
|
||||
DiffLine::Addition("let x = 100;"),
|
||||
];
|
||||
|
||||
let actual = vec![
|
||||
DiffLine::Deletion("let x = 42;"),
|
||||
DiffLine::Addition("let x = 99;"),
|
||||
];
|
||||
let original = "let x = 42;";
|
||||
let expected = "let x = 100;";
|
||||
let actual = "let x = 99;";
|
||||
|
||||
// We got the edit location right, but the replacement text is wrong.
|
||||
// Deleted ngrams will match, bringing the score somewhere in the middle.
|
||||
let score = delta_chr_f(&expected, &actual);
|
||||
let score = delta_chr_f(original, expected, actual);
|
||||
assert!(score > 40.0 && score < 60.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delta_chr_f_missed_edit() {
|
||||
// When predictions makes no changes
|
||||
let expected = vec![
|
||||
DiffLine::Context("prefix "),
|
||||
DiffLine::Deletion("old"),
|
||||
DiffLine::Addition("new"),
|
||||
DiffLine::Context(" suffix"),
|
||||
];
|
||||
|
||||
let actual = vec![
|
||||
DiffLine::Context("prefix "),
|
||||
DiffLine::Context("old"),
|
||||
DiffLine::Context(" suffix"),
|
||||
];
|
||||
let original = "prefix old suffix";
|
||||
let expected = "prefix new suffix";
|
||||
let actual = "prefix old suffix"; // no change
|
||||
|
||||
// Then the score should be low (all expected changes are false negatives)
|
||||
let score = delta_chr_f(&expected, &actual);
|
||||
let score = delta_chr_f(original, expected, actual);
|
||||
assert!(score < 20.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delta_chr_f_extra_edit() {
|
||||
// When adding unexpected content
|
||||
let expected = vec![DiffLine::Context("hello"), DiffLine::Context("world")];
|
||||
|
||||
let actual = vec![
|
||||
DiffLine::Context("hello"),
|
||||
DiffLine::Addition("extra"),
|
||||
DiffLine::Context("world"),
|
||||
];
|
||||
let original = "helloworld";
|
||||
let expected = "helloworld"; // no change expected
|
||||
let actual = "helloextraworld"; // added "extra"
|
||||
|
||||
// Then the score should be low (all actual changes are false positives)
|
||||
let score = delta_chr_f(&expected, &actual);
|
||||
let score = delta_chr_f(original, expected, actual);
|
||||
assert!(score < 20.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delta_chr_f_no_changes() {
|
||||
let text = "unchanged text";
|
||||
let score = delta_chr_f(text, text, text);
|
||||
assert!((score - 100.0).abs() < 1e-2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,11 @@ pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
|
||||
.join(chrono::Local::now().format("%d-%m-%y-%H_%M_%S").to_string())
|
||||
});
|
||||
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
|
||||
pub static LATEST_FAILED_EXAMPLES_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| DATA_DIR.join("latest_failed"));
|
||||
pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
|
||||
pub static SYNTHESIZE_STATE_FILE: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| DATA_DIR.join("synthesize_state.json"));
|
||||
pub static FAILED_EXAMPLES_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| ensure_dir(&RUN_DIR.join("failed")));
|
||||
|
||||
|
||||
@@ -28,12 +28,16 @@ pub async fn run_prediction(
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) -> anyhow::Result<()> {
|
||||
if !example.predictions.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let provider = provider.context("provider is required")?;
|
||||
|
||||
if let Some(existing_prediction) = example.predictions.first() {
|
||||
if existing_prediction.provider == provider {
|
||||
return Ok(());
|
||||
} else {
|
||||
example.predictions.clear();
|
||||
}
|
||||
}
|
||||
|
||||
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
|
||||
|
||||
if matches!(
|
||||
@@ -184,7 +188,9 @@ pub async fn run_prediction(
|
||||
let actual_patch = prediction
|
||||
.and_then(|prediction| {
|
||||
let prediction = prediction.prediction.ok()?;
|
||||
prediction.edit_preview.as_unified_diff(&prediction.edits)
|
||||
prediction
|
||||
.edit_preview
|
||||
.as_unified_diff(prediction.snapshot.file(), &prediction.edits)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ pub enum Step {
|
||||
FormatPrompt,
|
||||
Predict,
|
||||
Score,
|
||||
Synthesize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
@@ -62,6 +63,7 @@ impl Step {
|
||||
Step::FormatPrompt => "Format",
|
||||
Step::Predict => "Predict",
|
||||
Step::Score => "Score",
|
||||
Step::Synthesize => "Synthesize",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,6 +74,7 @@ impl Step {
|
||||
Step::FormatPrompt => "\x1b[34m",
|
||||
Step::Predict => "\x1b[32m",
|
||||
Step::Score => "\x1b[31m",
|
||||
Step::Synthesize => "\x1b[36m",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,11 +2,12 @@ use crate::{
|
||||
PredictArgs,
|
||||
example::{Example, ExampleScore},
|
||||
headless::EpAppState,
|
||||
metrics::{self, ClassificationMetrics},
|
||||
metrics,
|
||||
predict::run_prediction,
|
||||
progress::{Progress, Step},
|
||||
};
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
use anyhow::Context as _;
|
||||
use edit_prediction::udiff::apply_diff_to_string;
|
||||
use gpui::AsyncApp;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -27,18 +28,32 @@ pub async fn run_scoring(
|
||||
|
||||
let _progress = Progress::global().start(Step::Score, &example.spec.name);
|
||||
|
||||
let expected_patch = parse_patch(&example.spec.expected_patch);
|
||||
let original_text = &example.buffer.as_ref().unwrap().content;
|
||||
let expected_texts: Vec<String> = example
|
||||
.spec
|
||||
.expected_patches
|
||||
.iter()
|
||||
.map(|patch| {
|
||||
apply_diff_to_string(original_text, patch)
|
||||
.with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let mut scores = vec![];
|
||||
|
||||
for pred in &example.predictions {
|
||||
let actual_patch = parse_patch(&pred.actual_patch);
|
||||
let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
|
||||
let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
|
||||
|
||||
for prediction in &example.predictions {
|
||||
let actual_text = match apply_diff_to_string(original_text, &prediction.actual_patch) {
|
||||
Ok(text) => text,
|
||||
Err(_) => {
|
||||
scores.push(ExampleScore { delta_chr_f: 0.0 });
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let best_delta_chr_f = expected_texts
|
||||
.iter()
|
||||
.map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
|
||||
.fold(0.0, f32::max);
|
||||
scores.push(ExampleScore {
|
||||
delta_chr_f,
|
||||
line_match,
|
||||
delta_chr_f: best_delta_chr_f,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -46,42 +61,25 @@ pub async fn run_scoring(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
|
||||
patch.lines().map(DiffLine::parse).collect()
|
||||
}
|
||||
|
||||
pub fn print_report(examples: &[Example]) {
|
||||
eprintln!(
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
eprintln!(
|
||||
"{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
|
||||
"Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
|
||||
);
|
||||
eprintln!("{:<50} {:>10}", "Example name", "DeltaChrF");
|
||||
eprintln!(
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
|
||||
let mut all_line_match_scores = Vec::new();
|
||||
let mut all_delta_chr_f_scores = Vec::new();
|
||||
|
||||
for example in examples {
|
||||
for score in example.score.iter() {
|
||||
let line_match = &score.line_match;
|
||||
|
||||
eprintln!(
|
||||
"{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
|
||||
truncate_name(&example.spec.name, 30),
|
||||
line_match.true_positives,
|
||||
line_match.false_positives,
|
||||
line_match.false_negatives,
|
||||
line_match.precision() * 100.0,
|
||||
line_match.recall() * 100.0,
|
||||
line_match.f1_score() * 100.0,
|
||||
"{:<50} {:>9.2}",
|
||||
truncate_name(&example.spec.name, 50),
|
||||
score.delta_chr_f
|
||||
);
|
||||
|
||||
all_line_match_scores.push(line_match.clone());
|
||||
all_delta_chr_f_scores.push(score.delta_chr_f);
|
||||
}
|
||||
}
|
||||
@@ -90,22 +88,11 @@ pub fn print_report(examples: &[Example]) {
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
|
||||
if !all_line_match_scores.is_empty() {
|
||||
let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
|
||||
if !all_delta_chr_f_scores.is_empty() {
|
||||
let avg_delta_chr_f: f32 =
|
||||
all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
|
||||
|
||||
eprintln!(
|
||||
"{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
|
||||
"TOTAL",
|
||||
total_line_match.true_positives,
|
||||
total_line_match.false_positives,
|
||||
total_line_match.false_negatives,
|
||||
total_line_match.precision() * 100.0,
|
||||
total_line_match.recall() * 100.0,
|
||||
total_line_match.f1_score() * 100.0,
|
||||
avg_delta_chr_f
|
||||
);
|
||||
eprintln!("{:<50} {:>9.2}", "AVERAGE", avg_delta_chr_f);
|
||||
eprintln!(
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
|
||||
902
crates/edit_prediction_cli/src/synthesize.rs
Normal file
902
crates/edit_prediction_cli/src/synthesize.rs
Normal file
@@ -0,0 +1,902 @@
|
||||
use crate::{
|
||||
anthropic_client::PlainLlmClient,
|
||||
git::{ensure_repo_cloned, run_git},
|
||||
paths::{FAILED_EXAMPLES_DIR, LATEST_FAILED_EXAMPLES_DIR, SYNTHESIZE_STATE_FILE},
|
||||
progress::{InfoStyle, Progress, Step, StepProgress},
|
||||
};
|
||||
use anthropic::ResponseContent;
|
||||
use anyhow::{Context as _, Result};
|
||||
use chrono::Local;
|
||||
use collections::{HashMap, HashSet};
|
||||
use edit_prediction::{
|
||||
example_spec::ExampleSpec,
|
||||
udiff::{apply_diff_to_string, edits_for_diff},
|
||||
};
|
||||
use indoc::indoc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SynthesizeConfig {
|
||||
pub repo_url: String,
|
||||
pub count: usize,
|
||||
pub max_commits: usize,
|
||||
pub output_dir: PathBuf,
|
||||
pub require_context: bool,
|
||||
pub fresh: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
struct SynthesizeState {
|
||||
repositories: HashMap<String, RepoState>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
struct RepoState {
|
||||
processed_commits: HashSet<String>,
|
||||
examples_generated: usize,
|
||||
}
|
||||
|
||||
impl SynthesizeState {
|
||||
fn load() -> Self {
|
||||
if SYNTHESIZE_STATE_FILE.exists() {
|
||||
std::fs::read_to_string(&*SYNTHESIZE_STATE_FILE)
|
||||
.ok()
|
||||
.and_then(|s| serde_json::from_str(&s).ok())
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn save(&self) -> Result<()> {
|
||||
let content = serde_json::to_string_pretty(self)?;
|
||||
std::fs::write(&*SYNTHESIZE_STATE_FILE, content)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_processed(&self, repo_url: &str, commit_sha: &str) -> bool {
|
||||
self.repositories
|
||||
.get(repo_url)
|
||||
.is_some_and(|repo| repo.processed_commits.contains(commit_sha))
|
||||
}
|
||||
|
||||
fn mark_processed(&mut self, repo_url: &str, commit_sha: &str, examples_count: usize) {
|
||||
let repo = self.repositories.entry(repo_url.to_string()).or_default();
|
||||
repo.processed_commits.insert(commit_sha.to_string());
|
||||
repo.examples_generated += examples_count;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CommitInfo {
|
||||
sha: String,
|
||||
parent_sha: String,
|
||||
message: String,
|
||||
diff: String,
|
||||
expanded_diff: String,
|
||||
}
|
||||
|
||||
/// Claude's response parsed into structured form
|
||||
#[derive(Debug)]
|
||||
struct ClaudeResponse {
|
||||
name: String,
|
||||
reasoning: String,
|
||||
edit_history_hunks: Vec<String>,
|
||||
expected_patch_hunks: Vec<String>,
|
||||
}
|
||||
|
||||
pub async fn run_synthesize(config: SynthesizeConfig) -> Result<()> {
|
||||
let mut state = if config.fresh {
|
||||
SynthesizeState::default()
|
||||
} else {
|
||||
SynthesizeState::load()
|
||||
};
|
||||
|
||||
std::fs::create_dir_all(&config.output_dir)?;
|
||||
std::fs::create_dir_all(&*FAILED_EXAMPLES_DIR)?;
|
||||
|
||||
// Create "latest_failed" symlink pointing to this run's failed directory
|
||||
if LATEST_FAILED_EXAMPLES_DIR.is_symlink() {
|
||||
std::fs::remove_file(&*LATEST_FAILED_EXAMPLES_DIR)?;
|
||||
}
|
||||
#[cfg(unix)]
|
||||
std::os::unix::fs::symlink(&*FAILED_EXAMPLES_DIR, &*LATEST_FAILED_EXAMPLES_DIR)?;
|
||||
#[cfg(windows)]
|
||||
std::os::windows::fs::symlink_dir(&*FAILED_EXAMPLES_DIR, &*LATEST_FAILED_EXAMPLES_DIR)?;
|
||||
|
||||
let progress = Progress::global();
|
||||
progress.set_total_examples(config.count);
|
||||
|
||||
let clone_progress = progress.start(Step::Synthesize, "clone");
|
||||
let repo_path = ensure_repo_cloned(&config.repo_url).await?;
|
||||
drop(clone_progress);
|
||||
|
||||
let client = PlainLlmClient::new()?;
|
||||
let mut examples_generated = 0;
|
||||
let mut commits_skipped = 0;
|
||||
let batch_size = config.max_commits;
|
||||
|
||||
'outer: loop {
|
||||
let list_progress = progress.start(Step::Synthesize, "list-commits");
|
||||
let commits = list_commits(&repo_path, batch_size, commits_skipped).await?;
|
||||
drop(list_progress);
|
||||
|
||||
if commits.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
commits_skipped += commits.len();
|
||||
|
||||
for commit in commits {
|
||||
if examples_generated >= config.count {
|
||||
break 'outer;
|
||||
}
|
||||
|
||||
if !config.fresh && state.is_processed(&config.repo_url, &commit.sha) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if should_skip_commit(&commit) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let commit_label = format!(
|
||||
"{} {}",
|
||||
&commit.sha[..8],
|
||||
truncate_message(&commit.message, 40)
|
||||
);
|
||||
let step_progress = Arc::new(progress.start(Step::Synthesize, &commit_label));
|
||||
|
||||
// Single Claude call to identify and copy hunks
|
||||
step_progress.set_substatus("analyzing...");
|
||||
let claude_response =
|
||||
match analyze_commit(&client, &config, &commit, step_progress.clone()).await {
|
||||
Ok(Some(response)) => response,
|
||||
Ok(None) => {
|
||||
step_progress.set_info("no pattern", InfoStyle::Normal);
|
||||
state.mark_processed(&config.repo_url, &commit.sha, 0);
|
||||
state.save()?;
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
step_progress.set_info(format!("error: {:?}", e), InfoStyle::Warning);
|
||||
state.mark_processed(&config.repo_url, &commit.sha, 0);
|
||||
state.save()?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Validate and build the example
|
||||
step_progress.set_substatus("validating...");
|
||||
match build_example(&config, &commit, &repo_path, &claude_response).await {
|
||||
Ok(spec) => {
|
||||
let timestamp = Local::now().format("%Y-%m-%d--%H-%M-%S");
|
||||
let filename = format!("{}.md", timestamp);
|
||||
let path = config.output_dir.join(&filename);
|
||||
std::fs::write(&path, spec.to_markdown())?;
|
||||
examples_generated += 1;
|
||||
step_progress.set_info(filename, InfoStyle::Normal);
|
||||
}
|
||||
Err(rejection_reason) => {
|
||||
log::debug!("Example rejected: {}", rejection_reason);
|
||||
let timestamp = Local::now().format("%Y-%m-%d--%H-%M-%S%.3f");
|
||||
let filename = format!("{}.md", timestamp);
|
||||
let path = FAILED_EXAMPLES_DIR.join(&filename);
|
||||
let content = format_rejected_example(&claude_response, &rejection_reason);
|
||||
if let Err(e) = std::fs::write(&path, content) {
|
||||
log::warn!("Failed to write rejected example: {:?}", e);
|
||||
}
|
||||
step_progress.set_info(format!("rejected: {}", filename), InfoStyle::Warning);
|
||||
}
|
||||
}
|
||||
|
||||
state.mark_processed(&config.repo_url, &commit.sha, 1);
|
||||
state.save()?;
|
||||
}
|
||||
}
|
||||
|
||||
progress.finalize();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn truncate_message(msg: &str, max_len: usize) -> String {
|
||||
let first_line = msg.lines().next().unwrap_or("");
|
||||
if first_line.len() <= max_len {
|
||||
first_line.to_string()
|
||||
} else {
|
||||
format!("{}...", &first_line[..max_len - 3])
|
||||
}
|
||||
}
|
||||
|
||||
fn should_skip_commit(commit: &CommitInfo) -> bool {
|
||||
let lines_changed = commit
|
||||
.diff
|
||||
.lines()
|
||||
.filter(|l| l.starts_with('+') || l.starts_with('-'))
|
||||
.count();
|
||||
lines_changed < 10
|
||||
|| lines_changed > 1000
|
||||
|| is_non_code_commit(commit)
|
||||
|| is_rename_commit(commit)
|
||||
}
|
||||
|
||||
fn is_non_code_commit(commit: &CommitInfo) -> bool {
|
||||
let non_code_extensions = [
|
||||
".md", ".txt", ".json", ".yaml", ".yml", ".toml", ".lock", ".svg", ".png", ".jpg", ".gif",
|
||||
".ico", ".woff", ".ttf", ".eot",
|
||||
];
|
||||
|
||||
let diff_files: Vec<&str> = commit
|
||||
.diff
|
||||
.lines()
|
||||
.filter(|l| l.starts_with("+++ b/") || l.starts_with("--- a/"))
|
||||
.filter_map(|l| {
|
||||
l.strip_prefix("+++ b/")
|
||||
.or_else(|| l.strip_prefix("--- a/"))
|
||||
})
|
||||
.collect();
|
||||
|
||||
if diff_files.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
diff_files
|
||||
.iter()
|
||||
.all(|f| non_code_extensions.iter().any(|ext| f.ends_with(ext)))
|
||||
}
|
||||
|
||||
fn is_rename_commit(commit: &CommitInfo) -> bool {
|
||||
commit.diff.contains("similarity index")
|
||||
|| commit.diff.contains("rename from")
|
||||
|| commit.diff.contains("rename to")
|
||||
}
|
||||
|
||||
async fn list_commits(
|
||||
repo_path: &Path,
|
||||
max_commits: usize,
|
||||
skip: usize,
|
||||
) -> Result<Vec<CommitInfo>> {
|
||||
let output = run_git(
|
||||
repo_path,
|
||||
&[
|
||||
"log",
|
||||
"--no-merges",
|
||||
&format!("--skip={}", skip),
|
||||
&format!("-{}", max_commits),
|
||||
"--format=%H|%P|%s",
|
||||
],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut commits = Vec::new();
|
||||
for line in output.lines() {
|
||||
let parts: Vec<&str> = line.splitn(3, '|').collect();
|
||||
if parts.len() < 3 {
|
||||
continue;
|
||||
}
|
||||
let sha = parts[0].to_string();
|
||||
let parent_sha = parts[1].split_whitespace().next().unwrap_or("").to_string();
|
||||
if parent_sha.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get standard diff (for skip checks)
|
||||
let diff = run_git(repo_path, &["show", "--format=", &sha])
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
// Get expanded diff with 30 lines of context
|
||||
let expanded_diff = run_git(repo_path, &["show", "-U30", "--format=", &sha])
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
commits.push(CommitInfo {
|
||||
sha,
|
||||
parent_sha,
|
||||
message: parts[2].to_string(),
|
||||
diff,
|
||||
expanded_diff,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(commits)
|
||||
}
|
||||
|
||||
fn build_prompt(config: &SynthesizeConfig, commit: &CommitInfo) -> String {
|
||||
let context_guidance = if config.require_context {
|
||||
"IMPORTANT: Only identify patterns that REQUIRE reading context from other files to make the prediction. \
|
||||
Single-file patterns (where the edit history and expected patch are in the same file) are NOT acceptable \
|
||||
unless the pattern clearly requires understanding code from other files."
|
||||
} else {
|
||||
"Both single-file and multi-file patterns are acceptable."
|
||||
};
|
||||
|
||||
format!(
|
||||
indoc! {r#"
|
||||
You are analyzing a git commit to construct a realistic edit prediction example.
|
||||
|
||||
Your goal is to tell the story of a programmer's editing session: what sequence of changes did they make, and what change logically comes next? We use these examples to train a model to predict edits, so the quality of the EDIT HISTORY is what matters most.
|
||||
|
||||
An edit prediction example consists of:
|
||||
1. **Edit History**: 3-6 hunks showing what the programmer did BEFORE making the expected patch. This is the most important part - it must tell a coherent story of the changes leading up to the prediction.
|
||||
2. **Expected Patch**: One small hunk that logically follows from the edit history.
|
||||
|
||||
{context_guidance}
|
||||
|
||||
## What Makes a Good Example
|
||||
|
||||
The edit history should read like a story: "First the programmer changed X, then Y, then Z, and now they need to change W."
|
||||
|
||||
GOOD examples (rich sequences with 3+ steps):
|
||||
- Removing a parameter: docstring update → constructor change → field removal → (predict) usage site update
|
||||
- Adding a feature: type definition → first usage → second usage → (predict) third usage
|
||||
- Bug fix pattern: fix in file A → fix in file B → fix in file C → (predict) fix in file D
|
||||
|
||||
BAD examples (respond NO_PATTERN):
|
||||
- Commits where all changes are independent (no narrative thread)
|
||||
- Simple find-and-replace (renaming, version bumps)
|
||||
- Documentation-only or config-only changes
|
||||
- Changes where you can only find 1-2 hunks for the edit history
|
||||
|
||||
## Commit Information
|
||||
|
||||
Repository: {repo_url}
|
||||
Commit: {sha}
|
||||
Message: {message}
|
||||
|
||||
## Diff (30 lines context)
|
||||
|
||||
```diff
|
||||
{expanded_diff}
|
||||
```
|
||||
|
||||
## Your Task
|
||||
|
||||
First, THINK through whether this commit can support a good example:
|
||||
|
||||
1. What is the high-level pattern in this commit?
|
||||
2. Can you identify at least 4 related hunks (3 for edit history + 1 for expected patch)?
|
||||
3. What would be the narrative? (First... then... then... finally predict...)
|
||||
4. Which specific hunk should be the expected patch (the "punchline")?
|
||||
|
||||
If you cannot construct a coherent 3+ hunk story, respond with just:
|
||||
NO_PATTERN: <brief reason>
|
||||
|
||||
If you CAN construct a good example, respond in this format:
|
||||
|
||||
ANALYSIS:
|
||||
Pattern: <one sentence describing the pattern>
|
||||
Steps:
|
||||
1. <file:line-range> - <what this hunk does>
|
||||
2. <file:line-range> - <what this hunk does>
|
||||
3. <file:line-range> - <what this hunk does>
|
||||
4. [EXPECTED PATCH] <file:line-range> - <what this hunk does>
|
||||
|
||||
NAME: <short description, like a commit message, under 60 chars>
|
||||
|
||||
EDIT_HISTORY:
|
||||
|
||||
Hunk 1:
|
||||
```diff
|
||||
--- a/src/models/user.py
|
||||
+++ b/src/models/user.py
|
||||
@@ -15,7 +15,6 @@ class User:
|
||||
"""A user in the system.
|
||||
|
||||
Attributes:
|
||||
- email: The user's email address.
|
||||
name: The user's display name.
|
||||
"""
|
||||
```
|
||||
|
||||
Hunk 2:
|
||||
```diff
|
||||
--- a/src/models/user.py
|
||||
+++ b/src/models/user.py
|
||||
@@ -25,10 +24,9 @@ class User:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
- email: str,
|
||||
created_at: datetime,
|
||||
):
|
||||
self.name = name
|
||||
- self.email = email
|
||||
self.created_at = created_at
|
||||
```
|
||||
|
||||
Hunk 3:
|
||||
```diff
|
||||
--- a/src/api/handlers.py
|
||||
+++ b/src/api/handlers.py
|
||||
@@ -42,7 +42,6 @@ def create_user(request):
|
||||
data = request.json()
|
||||
user = User(
|
||||
name=data["name"],
|
||||
- email=data["email"],
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
return user.save()
|
||||
```
|
||||
|
||||
EXPECTED_PATCH:
|
||||
```diff
|
||||
--- a/src/api/handlers.py
|
||||
+++ b/src/api/handlers.py
|
||||
@@ -58,7 +57,6 @@ def update_user(request, user_id):
|
||||
user = User.get(user_id)
|
||||
user.name = data.get("name", user.name)
|
||||
- user.email = data.get("email", user.email)
|
||||
user.save()
|
||||
return user
|
||||
```
|
||||
|
||||
## Requirements for the diffs
|
||||
|
||||
Edit history:
|
||||
- MUST have 3-6 hunks (if you cannot find 3+, respond NO_PATTERN instead)
|
||||
- Each hunk needs file headers (--- a/path and +++ b/path)
|
||||
- Hunks must be valid unified diffs that apply to the parent commit
|
||||
- Order hunks as a programmer would naturally make the changes
|
||||
|
||||
Expected patch:
|
||||
- Must be a SINGLE hunk from a SINGLE file
|
||||
- Must be SMALL: 1-15 changed lines (not counting context)
|
||||
- Must be clearly predictable from the edit history narrative
|
||||
"#},
|
||||
context_guidance = context_guidance,
|
||||
repo_url = config.repo_url,
|
||||
sha = commit.sha,
|
||||
message = commit.message,
|
||||
expanded_diff = commit.expanded_diff,
|
||||
)
|
||||
}
|
||||
|
||||
async fn analyze_commit(
|
||||
client: &PlainLlmClient,
|
||||
config: &SynthesizeConfig,
|
||||
commit: &CommitInfo,
|
||||
step_progress: Arc<StepProgress>,
|
||||
) -> Result<Option<ClaudeResponse>> {
|
||||
use anthropic::{Message, RequestContent, Role};
|
||||
|
||||
let prompt = build_prompt(config, commit);
|
||||
let messages = vec![Message {
|
||||
role: Role::User,
|
||||
content: vec![RequestContent::Text {
|
||||
text: prompt,
|
||||
cache_control: None,
|
||||
}],
|
||||
}];
|
||||
|
||||
let response = client
|
||||
.generate_streaming("claude-sonnet-4-5", 8192, messages, |chars, _text| {
|
||||
step_progress.set_substatus(format!("analyzing: {:.1}K", chars as f64 / 1000.0));
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Extract text content from response
|
||||
let response_text: String = response
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|block| {
|
||||
if let ResponseContent::Text { text } = block {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
parse_claude_response(&response_text)
|
||||
}
|
||||
|
||||
fn parse_claude_response(response: &str) -> Result<Option<ClaudeResponse>> {
|
||||
// Check for NO_PATTERN
|
||||
if response.contains("NO_PATTERN:") {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Parse NAME
|
||||
let name = response
|
||||
.lines()
|
||||
.find(|l| l.starts_with("NAME:"))
|
||||
.map(|l| l.strip_prefix("NAME:").unwrap_or("").trim().to_string())
|
||||
.unwrap_or_else(|| "unnamed example".to_string());
|
||||
|
||||
// Parse ANALYSIS section (Claude's planning) - this is the primary reasoning
|
||||
let reasoning = extract_section(
|
||||
response,
|
||||
"ANALYSIS:",
|
||||
&["NAME:", "REASONING:", "EDIT_HISTORY:", "EXPECTED_PATCH:"],
|
||||
)
|
||||
.unwrap_or_default();
|
||||
|
||||
// Parse EDIT_HISTORY diff block
|
||||
let edit_history_hunks = extract_diff_block(response, "EDIT_HISTORY:")?;
|
||||
|
||||
// Parse EXPECTED_PATCH diff block
|
||||
let expected_patch_hunks = extract_diff_block(response, "EXPECTED_PATCH:")?;
|
||||
|
||||
if edit_history_hunks.is_empty() {
|
||||
anyhow::bail!("No edit history hunks found in response");
|
||||
}
|
||||
if expected_patch_hunks.is_empty() {
|
||||
anyhow::bail!("No expected patch hunks found in response");
|
||||
}
|
||||
|
||||
Ok(Some(ClaudeResponse {
|
||||
name,
|
||||
reasoning,
|
||||
edit_history_hunks,
|
||||
expected_patch_hunks,
|
||||
}))
|
||||
}
|
||||
|
||||
fn extract_section(text: &str, start_marker: &str, end_markers: &[&str]) -> Option<String> {
|
||||
let start_idx = text.find(start_marker)?;
|
||||
let content_start = start_idx + start_marker.len();
|
||||
|
||||
let end_idx = end_markers
|
||||
.iter()
|
||||
.filter_map(|marker| text[content_start..].find(marker))
|
||||
.min()
|
||||
.map(|idx| content_start + idx)
|
||||
.unwrap_or(text.len());
|
||||
|
||||
Some(text[content_start..end_idx].trim().to_string())
|
||||
}
|
||||
|
||||
fn extract_diff_block(text: &str, section_marker: &str) -> Result<Vec<String>> {
|
||||
let section_start = text
|
||||
.find(section_marker)
|
||||
.context(format!("Section {} not found", section_marker))?;
|
||||
|
||||
let after_marker = &text[section_start + section_marker.len()..];
|
||||
|
||||
// Find where the next major section starts (to bound our search)
|
||||
let section_end = ["EXPECTED_PATCH:", "## "]
|
||||
.iter()
|
||||
.filter(|&&m| m != section_marker)
|
||||
.filter_map(|marker| after_marker.find(marker))
|
||||
.min()
|
||||
.unwrap_or(after_marker.len());
|
||||
|
||||
let section_content = &after_marker[..section_end];
|
||||
|
||||
// Collect all ```diff blocks in this section
|
||||
let mut hunks = Vec::new();
|
||||
let mut search_start = 0;
|
||||
|
||||
while let Some(diff_start) = section_content[search_start..].find("```diff") {
|
||||
let abs_diff_start = search_start + diff_start;
|
||||
let block_content_start = section_content[abs_diff_start..]
|
||||
.find('\n')
|
||||
.map(|i| abs_diff_start + i + 1)
|
||||
.unwrap_or(abs_diff_start);
|
||||
|
||||
if let Some(block_end_rel) = section_content[block_content_start..].find("```") {
|
||||
let block_end = block_content_start + block_end_rel;
|
||||
let diff_content = section_content[block_content_start..block_end].trim();
|
||||
|
||||
// Split this block into hunks (in case multiple hunks in one block)
|
||||
hunks.extend(split_into_hunks(diff_content));
|
||||
|
||||
search_start = block_end + 3;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if hunks.is_empty() {
|
||||
anyhow::bail!("No diff blocks found in section {}", section_marker);
|
||||
}
|
||||
|
||||
Ok(hunks)
|
||||
}
|
||||
|
||||
/// Split a diff block into individual hunks, preserving file headers
|
||||
fn split_into_hunks(diff: &str) -> Vec<String> {
|
||||
let mut hunks = Vec::new();
|
||||
let mut current_file_header: Option<String> = None;
|
||||
let mut current_hunk: Vec<String> = Vec::new();
|
||||
let mut in_hunk = false;
|
||||
|
||||
for line in diff.lines() {
|
||||
if line.starts_with("--- a/") || line.starts_with("--- /") {
|
||||
// Start of file header - flush previous hunk
|
||||
if in_hunk && !current_hunk.is_empty() {
|
||||
let mut hunk_text = String::new();
|
||||
if let Some(ref header) = current_file_header {
|
||||
hunk_text.push_str(header);
|
||||
hunk_text.push('\n');
|
||||
}
|
||||
hunk_text.push_str(¤t_hunk.join("\n"));
|
||||
hunks.push(hunk_text);
|
||||
current_hunk.clear();
|
||||
}
|
||||
current_file_header = Some(line.to_string());
|
||||
in_hunk = false;
|
||||
} else if line.starts_with("+++ b/") || line.starts_with("+++ /") {
|
||||
if let Some(ref mut header) = current_file_header {
|
||||
header.push('\n');
|
||||
header.push_str(line);
|
||||
}
|
||||
} else if line.starts_with("@@ ") {
|
||||
// New hunk - flush previous
|
||||
if in_hunk && !current_hunk.is_empty() {
|
||||
let mut hunk_text = String::new();
|
||||
if let Some(ref header) = current_file_header {
|
||||
hunk_text.push_str(header);
|
||||
hunk_text.push('\n');
|
||||
}
|
||||
hunk_text.push_str(¤t_hunk.join("\n"));
|
||||
hunks.push(hunk_text);
|
||||
current_hunk.clear();
|
||||
}
|
||||
current_hunk.push(line.to_string());
|
||||
in_hunk = true;
|
||||
} else if in_hunk {
|
||||
current_hunk.push(line.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Flush final hunk
|
||||
if !current_hunk.is_empty() {
|
||||
let mut hunk_text = String::new();
|
||||
if let Some(ref header) = current_file_header {
|
||||
hunk_text.push_str(header);
|
||||
hunk_text.push('\n');
|
||||
}
|
||||
hunk_text.push_str(¤t_hunk.join("\n"));
|
||||
hunks.push(hunk_text);
|
||||
}
|
||||
|
||||
hunks
|
||||
}
|
||||
|
||||
/// Validate Claude's output by applying diffs and build the ExampleSpec
|
||||
async fn build_example(
|
||||
config: &SynthesizeConfig,
|
||||
commit: &CommitInfo,
|
||||
repo_path: &Path,
|
||||
response: &ClaudeResponse,
|
||||
) -> Result<ExampleSpec, String> {
|
||||
// Validate expected patch hunks
|
||||
if response.expected_patch_hunks.len() != 1 {
|
||||
return Err(format!(
|
||||
"Expected exactly 1 expected patch hunk, got {}",
|
||||
response.expected_patch_hunks.len()
|
||||
));
|
||||
}
|
||||
|
||||
// Parse the expected patch to determine cursor file
|
||||
let expected_patch = &response.expected_patch_hunks[0];
|
||||
let cursor_file = extract_file_from_hunk(expected_patch)
|
||||
.ok_or_else(|| "Could not determine file from expected patch".to_string())?;
|
||||
|
||||
// Get the file content before the commit
|
||||
let before_content = run_git(
|
||||
repo_path,
|
||||
&["show", &format!("{}^:{}", commit.sha, cursor_file)],
|
||||
)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to get file content for {}: {}", cursor_file, e))?;
|
||||
|
||||
// Build edit history diff from Claude's hunks
|
||||
let edit_history = response.edit_history_hunks.join("\n");
|
||||
|
||||
// Apply edit history to get intermediate state (validates edit history)
|
||||
let intermediate_state =
|
||||
apply_edit_history_to_content(&before_content, &edit_history, &cursor_file)?;
|
||||
|
||||
// Validate expected patch applies to intermediate state
|
||||
let expected_patch_with_header = ensure_diff_header(expected_patch, &cursor_file);
|
||||
apply_diff_to_string(&intermediate_state, &expected_patch_with_header)
|
||||
.map_err(|e| format!("Expected patch failed to apply: {}", e))?;
|
||||
|
||||
// Find where the expected patch edits would apply in the intermediate state
|
||||
let edits = edits_for_diff(&intermediate_state, &expected_patch_with_header)
|
||||
.map_err(|e| format!("Failed to parse expected patch: {}", e))?;
|
||||
if edits.is_empty() {
|
||||
return Err(
|
||||
"Could not locate expected patch in file (context not found or ambiguous)".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
// Use the start of the first edit for cursor positioning
|
||||
let cursor_byte_offset = edits[0].0.start;
|
||||
|
||||
// Extract excerpt around the edit location
|
||||
let (excerpt, cursor_offset) = extract_cursor_excerpt(&intermediate_state, cursor_byte_offset)?;
|
||||
|
||||
// Build the ExampleSpec and use set_cursor_excerpt to format with comment marker
|
||||
let comment_prefix = line_comment_prefix(&cursor_file);
|
||||
let reasoning_with_source = format!(
|
||||
"Source commit: {} ({})\n\n{}",
|
||||
commit.sha,
|
||||
truncate_message(&commit.message, 60),
|
||||
response.reasoning
|
||||
);
|
||||
let mut spec = ExampleSpec {
|
||||
name: response.name.clone(),
|
||||
repository_url: config.repo_url.clone(),
|
||||
revision: commit.parent_sha.clone(),
|
||||
tags: Vec::new(),
|
||||
reasoning: Some(reasoning_with_source),
|
||||
uncommitted_diff: String::new(),
|
||||
cursor_path: Arc::from(Path::new(&cursor_file)),
|
||||
cursor_position: String::new(),
|
||||
edit_history,
|
||||
expected_patches: vec![expected_patch_with_header],
|
||||
};
|
||||
spec.set_cursor_excerpt(&excerpt, cursor_offset, comment_prefix);
|
||||
|
||||
Ok(spec)
|
||||
}
|
||||
|
||||
/// Extract file path from a hunk (looks for --- a/path or +++ b/path)
|
||||
fn extract_file_from_hunk(hunk: &str) -> Option<String> {
|
||||
for line in hunk.lines() {
|
||||
if let Some(path) = line.strip_prefix("+++ b/") {
|
||||
return Some(path.to_string());
|
||||
}
|
||||
if let Some(path) = line.strip_prefix("--- a/") {
|
||||
return Some(path.to_string());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Ensure a hunk has proper file headers
|
||||
fn ensure_diff_header(hunk: &str, file_path: &str) -> String {
|
||||
if hunk.contains("--- a/") || hunk.contains("+++ b/") {
|
||||
return hunk.to_string();
|
||||
}
|
||||
format!("--- a/{}\n+++ b/{}\n{}", file_path, file_path, hunk)
|
||||
}
|
||||
|
||||
/// Apply edit history to file content, only if hunks affect this file
|
||||
fn apply_edit_history_to_content(
|
||||
content: &str,
|
||||
edit_history: &str,
|
||||
cursor_file: &str,
|
||||
) -> Result<String, String> {
|
||||
// Extract just the hunks for this file from the edit history
|
||||
let file_diff = extract_file_diff_from_combined(edit_history, cursor_file);
|
||||
|
||||
if file_diff.is_empty() {
|
||||
return Ok(content.to_string());
|
||||
}
|
||||
|
||||
apply_diff_to_string(content, &file_diff)
|
||||
.map_err(|e| format!("Failed to apply edit history: {}", e))
|
||||
}
|
||||
|
||||
/// Extract hunks for a specific file from a combined diff
|
||||
fn extract_file_diff_from_combined(combined_diff: &str, target_file: &str) -> String {
|
||||
let mut result = String::new();
|
||||
let mut in_target_file = false;
|
||||
let mut found_header = false;
|
||||
|
||||
for line in combined_diff.lines() {
|
||||
if line.starts_with("--- a/") {
|
||||
let file = line.strip_prefix("--- a/").unwrap_or("");
|
||||
in_target_file = file == target_file;
|
||||
if in_target_file {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
found_header = false;
|
||||
}
|
||||
} else if line.starts_with("+++ b/") && in_target_file {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
found_header = true;
|
||||
} else if in_target_file && found_header {
|
||||
if line.starts_with("--- a/") {
|
||||
break;
|
||||
}
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Extract a cursor position excerpt from content around a byte offset.
|
||||
/// Returns the excerpt and the cursor offset within the excerpt.
|
||||
fn extract_cursor_excerpt(
|
||||
content: &str,
|
||||
cursor_byte_offset: usize,
|
||||
) -> Result<(String, usize), String> {
|
||||
// Find the line containing the cursor
|
||||
let line_start = content[..cursor_byte_offset]
|
||||
.rfind('\n')
|
||||
.map(|pos| pos + 1)
|
||||
.unwrap_or(0);
|
||||
let line_end = content[cursor_byte_offset..]
|
||||
.find('\n')
|
||||
.map(|pos| cursor_byte_offset + pos)
|
||||
.unwrap_or(content.len());
|
||||
|
||||
// Get context lines before
|
||||
let lines_before: Vec<&str> = content[..line_start].lines().collect();
|
||||
let context_before: Vec<&str> = lines_before.iter().rev().take(3).rev().cloned().collect();
|
||||
|
||||
// Get context lines after
|
||||
let after_line_end = if line_end < content.len() {
|
||||
line_end + 1
|
||||
} else {
|
||||
line_end
|
||||
};
|
||||
let context_after: Vec<&str> = content[after_line_end..].lines().take(4).collect();
|
||||
|
||||
// The line containing the cursor
|
||||
let cursor_line = &content[line_start..line_end];
|
||||
let cursor_column = cursor_byte_offset - line_start;
|
||||
|
||||
// Build the excerpt
|
||||
let mut excerpt = String::new();
|
||||
for line in context_before {
|
||||
excerpt.push_str(line);
|
||||
excerpt.push('\n');
|
||||
}
|
||||
// Track where cursor will be in the excerpt
|
||||
let cursor_offset_in_excerpt = excerpt.len() + cursor_column;
|
||||
// Line containing cursor
|
||||
excerpt.push_str(cursor_line);
|
||||
excerpt.push('\n');
|
||||
for line in context_after {
|
||||
excerpt.push_str(line);
|
||||
excerpt.push('\n');
|
||||
}
|
||||
|
||||
// Trim trailing newline
|
||||
if excerpt.ends_with('\n') {
|
||||
excerpt.pop();
|
||||
}
|
||||
|
||||
Ok((excerpt, cursor_offset_in_excerpt))
|
||||
}
|
||||
|
||||
/// Get the line comment prefix for a file based on its extension
|
||||
fn line_comment_prefix(file_path: &str) -> &'static str {
|
||||
let extension = file_path.rsplit('.').next().unwrap_or("");
|
||||
match extension {
|
||||
"rs" | "c" | "cpp" | "cc" | "h" | "hpp" | "js" | "ts" | "tsx" | "jsx" | "go" | "java"
|
||||
| "swift" | "kt" | "kts" | "scala" | "cs" | "m" | "mm" | "zig" | "v" | "d" => "//",
|
||||
"py" | "rb" | "sh" | "bash" | "zsh" | "pl" | "pm" | "r" | "jl" | "yaml" | "yml"
|
||||
| "toml" | "coffee" | "cr" | "ex" | "exs" | "elixir" => "#",
|
||||
"lua" | "hs" | "sql" => "--",
|
||||
"lisp" | "clj" | "cljs" | "scm" | "rkt" | "el" => ";",
|
||||
"erl" | "hrl" => "%",
|
||||
_ => "//",
|
||||
}
|
||||
}
|
||||
|
||||
fn format_rejected_example(response: &ClaudeResponse, rejection_reason: &str) -> String {
|
||||
let mut content = String::new();
|
||||
content.push_str("# Rejected Example\n\n");
|
||||
content.push_str(&format!("## Name\n\n{}\n\n", response.name));
|
||||
content.push_str(&format!("## Reasoning\n\n{}\n\n", response.reasoning));
|
||||
content.push_str("## Edit History Hunks\n\n```diff\n");
|
||||
for hunk in &response.edit_history_hunks {
|
||||
content.push_str(hunk);
|
||||
content.push_str("\n\n");
|
||||
}
|
||||
content.push_str("```\n\n");
|
||||
content.push_str("## Expected Patch Hunks\n\n```diff\n");
|
||||
for hunk in &response.expected_patch_hunks {
|
||||
content.push_str(hunk);
|
||||
content.push_str("\n\n");
|
||||
}
|
||||
content.push_str("```\n\n");
|
||||
content.push_str(&format!("## Rejection Reason\n\n{}\n", rejection_reason));
|
||||
content
|
||||
}
|
||||
@@ -150,7 +150,7 @@ fn capture_example_as_markdown(
|
||||
.buffer()
|
||||
.read(cx)
|
||||
.text_anchor_for_position(editor.selections.newest_anchor().head(), cx)?;
|
||||
let example = capture_example(project.clone(), buffer, cursor_anchor, true, cx)?;
|
||||
let example = capture_example(project.clone(), buffer, cursor_anchor, cx)?;
|
||||
|
||||
let examples_dir = AllLanguageSettings::get_global(cx)
|
||||
.edit_predictions
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::{
|
||||
},
|
||||
task_context::RunnableRange,
|
||||
text_diff::text_diff,
|
||||
unified_diff,
|
||||
unified_diff_with_offsets,
|
||||
};
|
||||
pub use crate::{
|
||||
Grammar, Language, LanguageRegistry,
|
||||
@@ -773,7 +773,11 @@ pub struct EditPreview {
|
||||
}
|
||||
|
||||
impl EditPreview {
|
||||
pub fn as_unified_diff(&self, edits: &[(Range<Anchor>, impl AsRef<str>)]) -> Option<String> {
|
||||
pub fn as_unified_diff(
|
||||
&self,
|
||||
file: Option<&Arc<dyn File>>,
|
||||
edits: &[(Range<Anchor>, impl AsRef<str>)],
|
||||
) -> Option<String> {
|
||||
let (first, _) = edits.first()?;
|
||||
let (last, _) = edits.last()?;
|
||||
|
||||
@@ -788,7 +792,7 @@ impl EditPreview {
|
||||
let old_end = Point::new(old_end.row + 4, 0).min(self.old_snapshot.max_point());
|
||||
let new_end = Point::new(new_end.row + 4, 0).min(self.applied_edits_snapshot.max_point());
|
||||
|
||||
Some(unified_diff(
|
||||
let diff_body = unified_diff_with_offsets(
|
||||
&self
|
||||
.old_snapshot
|
||||
.text_for_range(start..old_end)
|
||||
@@ -797,7 +801,17 @@ impl EditPreview {
|
||||
.applied_edits_snapshot
|
||||
.text_for_range(start..new_end)
|
||||
.collect::<String>(),
|
||||
))
|
||||
start.row,
|
||||
start.row,
|
||||
);
|
||||
|
||||
let path = file.map(|f| f.path().as_unix_str());
|
||||
let header = match path {
|
||||
Some(p) => format!("--- a/{}\n+++ b/{}\n", p, p),
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
Some(format!("{}{}", header, diff_body))
|
||||
}
|
||||
|
||||
pub fn highlight_edits(
|
||||
|
||||
@@ -4538,13 +4538,19 @@ impl Project {
|
||||
|
||||
for worktree in worktree_store.visible_worktrees(cx) {
|
||||
let worktree = worktree.read(cx);
|
||||
if let Ok(path) = RelPath::new(path, path_style)
|
||||
&& let Some(entry) = worktree.entry_for_path(&path)
|
||||
{
|
||||
return Some(ProjectPath {
|
||||
worktree_id: worktree.id(),
|
||||
path: entry.path.clone(),
|
||||
});
|
||||
if let Ok(rel_path) = RelPath::new(path, path_style) {
|
||||
if let Some(entry) = worktree.entry_for_path(&rel_path) {
|
||||
return Some(ProjectPath {
|
||||
worktree_id: worktree.id(),
|
||||
path: entry.path.clone(),
|
||||
});
|
||||
}
|
||||
if worktree_store.visible_worktrees(cx).count() == 1 {
|
||||
return Some(ProjectPath {
|
||||
worktree_id: worktree.id(),
|
||||
path: rel_path.into_arc(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user