Allow multiple expected patches, remove line-based patch scoring

This commit is contained in:
Max Brunsfeld
2025-12-23 16:46:37 -08:00
parent 7e09b59fa3
commit 0dcdc6d9a4
7 changed files with 114 additions and 239 deletions

View File

@@ -74,7 +74,7 @@ pub fn capture_example(
cursor_path: cursor_path.as_std_path().into(),
cursor_position: String::new(),
edit_history,
expected_patch: String::new(),
expected_patches: Vec::new(),
};
spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
Ok(spec)
@@ -350,7 +350,7 @@ mod tests {
seven();
"}
.to_string(),
expected_patch: "".to_string(),
expected_patches: Vec::new()
}
);
}

View File

@@ -15,7 +15,7 @@ pub struct ExampleSpec {
pub cursor_path: Arc<Path>,
pub cursor_position: String,
pub edit_history: String,
pub expected_patch: String,
pub expected_patches: Vec<String>,
}
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
@@ -95,13 +95,15 @@ impl ExampleSpec {
_ = writeln!(markdown, "## {}", EXPECTED_PATCH_HEADING);
markdown.push('\n');
_ = writeln!(markdown, "```diff");
markdown.push_str(&self.expected_patch);
if !markdown.ends_with('\n') {
for patch in &self.expected_patches {
_ = writeln!(markdown, "```diff");
markdown.push_str(patch);
if !markdown.ends_with('\n') {
markdown.push('\n');
}
_ = writeln!(markdown, "```");
markdown.push('\n');
}
_ = writeln!(markdown, "```");
markdown.push('\n');
markdown
}
@@ -118,7 +120,7 @@ impl ExampleSpec {
cursor_path: Path::new("").into(),
cursor_position: String::new(),
edit_history: String::new(),
expected_patch: String::new(),
expected_patches: Vec::new(),
};
if let Some(rest) = input.strip_prefix("+++\n")
@@ -212,7 +214,7 @@ impl ExampleSpec {
mem::take(&mut text);
}
Section::ExpectedPatch => {
spec.expected_patch = mem::take(&mut text);
spec.expected_patches.push(mem::take(&mut text));
}
Section::Start | Section::Other => {}
}
@@ -353,7 +355,7 @@ mod tests {
cursor_path: Path::new("test.rs").into(),
cursor_position: String::new(),
edit_history: String::new(),
expected_patch: String::new(),
expected_patches: Vec::new(),
};
// Cursor before `42`

View File

@@ -1,20 +1,15 @@
use anyhow::{Result, anyhow};
use anyhow::Result;
use std::mem;
use crate::example::Example;
pub async fn run_distill(example: &mut Example) -> Result<()> {
let [prediction]: [_; 1] =
mem::take(&mut example.predictions)
.try_into()
.map_err(|preds: Vec<_>| {
anyhow!(
"Example has {} predictions, but it should have exactly one",
preds.len()
)
})?;
let predictions = mem::take(&mut example.predictions)
.into_iter()
.map(|p| p.actual_patch)
.collect();
example.spec.expected_patch = prediction.actual_patch;
example.spec.expected_patches = predictions;
example.prompt = None;
example.predictions = Vec::new();
example.score = Vec::new();

View File

@@ -1,4 +1,4 @@
use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics};
use crate::{PredictionProvider, PromptFormat};
use anyhow::{Context as _, Result};
use collections::HashMap;
use edit_prediction::example_spec::ExampleSpec;
@@ -87,7 +87,6 @@ pub struct ExamplePrediction {
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExampleScore {
pub delta_chr_f: f32,
pub line_match: ClassificationMetrics,
}
impl Example {

View File

@@ -30,7 +30,13 @@ pub async fn run_format_prompt(
let prompt = TeacherPrompt::format_prompt(example);
example.prompt = Some(ExamplePrompt {
input: prompt,
expected_output: example.spec.expected_patch.clone(), // TODO
// TODO
expected_output: example
.spec
.expected_patches
.first()
.context("no expected patches")?
.clone(),
format: prompt_format,
});
}
@@ -68,8 +74,15 @@ pub async fn run_format_prompt(
))
})??;
let prompt = format_zeta_prompt(&input);
let expected_output =
zeta2_output_for_patch(&input, &example.spec.expected_patch.clone())?;
let expected_output = zeta2_output_for_patch(
&input,
&example
.spec
.expected_patches
.first()
.context("expected patches is empty")?
.clone(),
)?;
example.prompt = Some(ExamplePrompt {
input: prompt,
expected_output,

View File

@@ -1,34 +1,17 @@
use collections::{HashMap, HashSet};
use edit_prediction::udiff::DiffLine;
use serde::{Deserialize, Serialize};
use collections::HashMap;
type Counts = HashMap<String, usize>;
type CountsDelta = HashMap<String, isize>;
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct ClassificationMetrics {
pub true_positives: usize,
pub false_positives: usize,
pub false_negatives: usize,
#[derive(Default, Debug, Clone)]
struct ClassificationMetrics {
true_positives: usize,
false_positives: usize,
false_negatives: usize,
}
impl ClassificationMetrics {
pub fn from_sets(
expected: &HashSet<String>,
actual: &HashSet<String>,
) -> ClassificationMetrics {
let true_positives = expected.intersection(actual).count();
let false_positives = actual.difference(expected).count();
let false_negatives = expected.difference(actual).count();
ClassificationMetrics {
true_positives,
false_positives,
false_negatives,
}
}
pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
let mut true_positives = 0;
let mut false_positives = 0;
let mut false_negatives = 0;
@@ -56,27 +39,7 @@ impl ClassificationMetrics {
}
}
pub fn aggregate<'a>(
scores: impl Iterator<Item = &'a ClassificationMetrics>,
) -> ClassificationMetrics {
let mut true_positives = 0;
let mut false_positives = 0;
let mut false_negatives = 0;
for score in scores {
true_positives += score.true_positives;
false_positives += score.false_positives;
false_negatives += score.false_negatives;
}
ClassificationMetrics {
true_positives,
false_positives,
false_negatives,
}
}
pub fn precision(&self) -> f64 {
fn precision(&self) -> f64 {
if self.true_positives + self.false_positives == 0 {
0.0
} else {
@@ -84,42 +47,13 @@ impl ClassificationMetrics {
}
}
pub fn recall(&self) -> f64 {
fn recall(&self) -> f64 {
if self.true_positives + self.false_negatives == 0 {
0.0
} else {
self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
}
}
pub fn f1_score(&self) -> f64 {
let recall = self.recall();
let precision = self.precision();
if precision + recall == 0.0 {
0.0
} else {
2.0 * precision * recall / (precision + recall)
}
}
}
pub fn line_match_score(
expected_patch: &[DiffLine],
actual_patch: &[DiffLine],
) -> ClassificationMetrics {
let expected_change_lines = expected_patch
.iter()
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
.map(|line| line.to_string())
.collect();
let actual_change_lines = actual_patch
.iter()
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
.map(|line| line.to_string())
.collect();
ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines)
}
enum ChrfWhitespace {
@@ -135,55 +69,26 @@ const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Ignore;
/// Computes a delta-chrF score that compares two sets of edits.
///
/// This metric works by:
/// 1. Reconstructing original, golden (expected result), and actual texts from diffs
/// 2. Computing n-gram count differences (deltas) between original→golden and original→actual
/// 3. Comparing these deltas to measure how well actual edits match expected edits
pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 {
// Reconstruct texts from diffs
let mut original_text = String::new(); // state of the text before any edits
let mut golden_text = String::new(); // text after applying golden edits
let mut actual_text = String::new(); // text after applying actual edits
for line in expected {
match line {
DiffLine::Context(s) => {
original_text.push_str(s);
golden_text.push_str(s);
}
DiffLine::Deletion(s) => {
original_text.push_str(s);
}
DiffLine::Addition(s) => {
golden_text.push_str(s);
}
_ => {}
}
}
for line in actual {
match line {
DiffLine::Context(s) | DiffLine::Addition(s) => {
actual_text.push_str(s);
}
_ => {}
}
}
// Edge case
if original_text == golden_text && golden_text == actual_text {
/// 1. Computing n-gram count differences (deltas) between original→expected and original→actual
/// 2. Comparing these deltas to measure how well actual edits match expected edits
///
/// Returns a score from 0.0 to 100.0, where 100.0 means the actual edits perfectly match
/// the expected edits.
pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> f64 {
// Edge case: if all texts are identical, the edits match perfectly
if original == expected && expected == actual {
return 100.0;
}
// Compute the metric
let original_ngrams = chr_f_ngram_counts(&original_text);
let golden_ngrams = chr_f_ngram_counts(&golden_text);
let actual_ngrams = chr_f_ngram_counts(&actual_text);
let original_ngrams = chr_f_ngram_counts(original);
let expected_ngrams = chr_f_ngram_counts(expected);
let actual_ngrams = chr_f_ngram_counts(actual);
let mut total_precision = 0.0;
let mut total_recall = 0.0;
for order in 0..CHR_F_CHAR_ORDER {
let expected_delta = compute_ngram_delta(&golden_ngrams[order], &original_ngrams[order]);
let expected_delta = compute_ngram_delta(&expected_ngrams[order], &original_ngrams[order]);
let actual_delta = compute_ngram_delta(&actual_ngrams[order], &original_ngrams[order]);
if expected_delta.is_empty() && actual_delta.is_empty() {
@@ -278,94 +183,68 @@ fn count_ngrams(text: &str, n: usize) -> Counts {
#[cfg(test)]
mod test {
use super::*;
use edit_prediction::udiff::DiffLine;
#[test]
fn test_delta_chr_f_perfect_match() {
let diff = vec![
DiffLine::Context("fn main() {"),
DiffLine::Deletion(" println!(\"Hello\");"),
DiffLine::Addition(" println!(\"Hello, World!\");"),
DiffLine::Context("}"),
];
let original = "fn main() { println!(\"Hello\");}";
let expected = "fn main() { println!(\"Hello, World!\");}";
let score = delta_chr_f(&diff, &diff);
let score = delta_chr_f(original, expected, expected);
assert!((score - 100.0).abs() < 1e-2);
}
#[test]
fn test_delta_chr_f_wrong_edit() {
// When the edit is wrong
let expected = vec![
DiffLine::Context("one "),
DiffLine::Deletion("two "),
DiffLine::Context("three"),
];
let actual = vec![
DiffLine::Context("one "),
DiffLine::Context("two "),
DiffLine::Deletion("three"),
DiffLine::Addition("four"),
];
let original = "one two three";
let expected = "one three"; // deleted "two "
let actual = "one two four"; // deleted "three", added "four"
// Then the score should be low
let score = delta_chr_f(&expected, &actual);
let score = delta_chr_f(original, expected, actual);
assert!(score > 20.0 && score < 40.0);
}
#[test]
fn test_delta_chr_f_partial_match() {
let expected = vec![
DiffLine::Deletion("let x = 42;"),
DiffLine::Addition("let x = 100;"),
];
let actual = vec![
DiffLine::Deletion("let x = 42;"),
DiffLine::Addition("let x = 99;"),
];
let original = "let x = 42;";
let expected = "let x = 100;";
let actual = "let x = 99;";
// We got the edit location right, but the replacement text is wrong.
// Deleted ngrams will match, bringing the score somewhere in the middle.
let score = delta_chr_f(&expected, &actual);
let score = delta_chr_f(original, expected, actual);
assert!(score > 40.0 && score < 60.0);
}
#[test]
fn test_delta_chr_f_missed_edit() {
// When predictions makes no changes
let expected = vec![
DiffLine::Context("prefix "),
DiffLine::Deletion("old"),
DiffLine::Addition("new"),
DiffLine::Context(" suffix"),
];
let actual = vec![
DiffLine::Context("prefix "),
DiffLine::Context("old"),
DiffLine::Context(" suffix"),
];
let original = "prefix old suffix";
let expected = "prefix new suffix";
let actual = "prefix old suffix"; // no change
// Then the score should be low (all expected changes are false negatives)
let score = delta_chr_f(&expected, &actual);
let score = delta_chr_f(original, expected, actual);
assert!(score < 20.0);
}
#[test]
fn test_delta_chr_f_extra_edit() {
// When adding unexpected content
let expected = vec![DiffLine::Context("hello"), DiffLine::Context("world")];
let actual = vec![
DiffLine::Context("hello"),
DiffLine::Addition("extra"),
DiffLine::Context("world"),
];
let original = "helloworld";
let expected = "helloworld"; // no change expected
let actual = "helloextraworld"; // added "extra"
// Then the score should be low (all actual changes are false positives)
let score = delta_chr_f(&expected, &actual);
let score = delta_chr_f(original, expected, actual);
assert!(score < 20.0);
}
#[test]
fn test_delta_chr_f_no_changes() {
let text = "unchanged text";
let score = delta_chr_f(text, text, text);
assert!((score - 100.0).abs() < 1e-2);
}
}

View File

@@ -2,11 +2,12 @@ use crate::{
PredictArgs,
example::{Example, ExampleScore},
headless::EpAppState,
metrics::{self, ClassificationMetrics},
metrics,
predict::run_prediction,
progress::{Progress, Step},
};
use edit_prediction::udiff::DiffLine;
use anyhow::Context as _;
use edit_prediction::udiff::apply_diff_to_string;
use gpui::AsyncApp;
use std::sync::Arc;
@@ -27,18 +28,32 @@ pub async fn run_scoring(
let _progress = Progress::global().start(Step::Score, &example.spec.name);
let expected_patch = parse_patch(&example.spec.expected_patch);
let original_text = &example.buffer.as_ref().unwrap().content;
let expected_texts: Vec<String> = example
.spec
.expected_patches
.iter()
.map(|patch| {
apply_diff_to_string(original_text, patch)
.with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
})
.collect::<Result<Vec<_>, _>>()?;
let mut scores = vec![];
for pred in &example.predictions {
let actual_patch = parse_patch(&pred.actual_patch);
let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
for prediction in &example.predictions {
let actual_text = match apply_diff_to_string(original_text, &prediction.actual_patch) {
Ok(text) => text,
Err(_) => {
scores.push(ExampleScore { delta_chr_f: 0.0 });
continue;
}
};
let best_delta_chr_f = expected_texts
.iter()
.map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
.fold(0.0, f32::max);
scores.push(ExampleScore {
delta_chr_f,
line_match,
delta_chr_f: best_delta_chr_f,
});
}
@@ -46,42 +61,25 @@ pub async fn run_scoring(
Ok(())
}
fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
patch.lines().map(DiffLine::parse).collect()
}
pub fn print_report(examples: &[Example]) {
eprintln!(
"──────────────────────────────────────────────────────────────────────────────────────"
);
eprintln!(
"{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
"Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
);
eprintln!("{:<50} {:>10}", "Example name", "DeltaChrF");
eprintln!(
"──────────────────────────────────────────────────────────────────────────────────────"
);
let mut all_line_match_scores = Vec::new();
let mut all_delta_chr_f_scores = Vec::new();
for example in examples {
for score in example.score.iter() {
let line_match = &score.line_match;
eprintln!(
"{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
truncate_name(&example.spec.name, 30),
line_match.true_positives,
line_match.false_positives,
line_match.false_negatives,
line_match.precision() * 100.0,
line_match.recall() * 100.0,
line_match.f1_score() * 100.0,
"{:<50} {:>9.2}",
truncate_name(&example.spec.name, 50),
score.delta_chr_f
);
all_line_match_scores.push(line_match.clone());
all_delta_chr_f_scores.push(score.delta_chr_f);
}
}
@@ -90,22 +88,11 @@ pub fn print_report(examples: &[Example]) {
"──────────────────────────────────────────────────────────────────────────────────────"
);
if !all_line_match_scores.is_empty() {
let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
if !all_delta_chr_f_scores.is_empty() {
let avg_delta_chr_f: f32 =
all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
eprintln!(
"{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
"TOTAL",
total_line_match.true_positives,
total_line_match.false_positives,
total_line_match.false_negatives,
total_line_match.precision() * 100.0,
total_line_match.recall() * 100.0,
total_line_match.f1_score() * 100.0,
avg_delta_chr_f
);
eprintln!("{:<50} {:>9.2}", "AVERAGE", avg_delta_chr_f);
eprintln!(
"──────────────────────────────────────────────────────────────────────────────────────"
);