Compare commits

...

2 Commits

Author SHA1 Message Date
Max Brunsfeld
ccd2502672 Start work on allowing alternative patches 2025-11-19 17:42:35 -08:00
Max Brunsfeld
b2ba012251 eval: Make zeta2 and old-text-new-text the defaults 2025-11-19 17:41:20 -08:00
3 changed files with 157 additions and 102 deletions

View File

@@ -237,7 +237,7 @@ fn write_eval_result(
out,
"## Expected edit prediction:\n\n```diff\n{}\n```\n",
compare_diffs(
&example.example.expected_patch,
&example.example.expected_patches,
&predictions.diff,
use_color
)
@@ -247,7 +247,7 @@ fn write_eval_result(
"## Actual edit prediction:\n\n```diff\n{}\n```\n",
compare_diffs(
&predictions.diff,
&example.example.expected_patch,
&example.example.expected_patches,
use_color
)
)?;
@@ -434,101 +434,103 @@ fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> Eval
..Default::default()
};
let actual_context_lines: HashSet<_> = preds
.excerpts
.iter()
.flat_map(|excerpt| {
excerpt
.text
.lines()
.map(|line| format!("{}: {line}", excerpt.path.display()))
})
.collect();
let mut false_positive_lines = actual_context_lines.clone();
for entry in &example.expected_context {
let mut best_alternative_score: Option<Scores> = None;
for alternative in &entry.alternatives {
let expected: HashSet<_> = alternative
.excerpts
.iter()
.flat_map(|excerpt| {
excerpt
.text
.lines()
.map(|line| format!("{}: {line}", excerpt.path.display()))
})
.collect();
let scores = Scores::new(&expected, &actual_context_lines);
false_positive_lines.retain(|line| !expected.contains(line));
if best_alternative_score
.as_ref()
.is_none_or(|best| scores.recall() > best.recall())
{
best_alternative_score = Some(scores);
}
}
let best_alternative = best_alternative_score.unwrap_or_default();
eval_result.context.false_negatives += best_alternative.false_negatives;
eval_result.context.true_positives += best_alternative.true_positives;
}
eval_result.context.false_positives = false_positive_lines.len();
if predict {
// todo: alternatives for patches
let expected_patch = example
.expected_patch
.lines()
.map(DiffLine::parse)
.collect::<Vec<_>>();
let expected_patch_lines = expected_patch
.iter()
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
.map(|line| line.to_string())
.collect();
let expected_context_lines = expected_patch
.iter()
.filter_map(|line| {
if let DiffLine::Context(str) = line {
Some(String::from(*str))
} else {
None
}
})
.collect::<BTreeSet<_>>();
let actual_context_lines = preds
// Context score
{
let actual_context_lines: HashSet<_> = preds
.excerpts
.iter()
.flat_map(|excerpt| excerpt.text.lines().map(ToOwned::to_owned))
.collect::<BTreeSet<_>>();
let matched = expected_context_lines
.intersection(&actual_context_lines)
.count();
let actual_patch_lines = preds
.diff
.lines()
.map(DiffLine::parse)
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
.map(|line| line.to_string())
.flat_map(|excerpt| {
excerpt
.text
.lines()
.map(|line| format!("{}: {line}", excerpt.path.display()))
})
.collect();
eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
eval_result.context_lines_in_expected_patch = expected_context_lines.len();
eval_result.context_lines_found_in_context = matched;
let mut false_positive_context_lines = actual_context_lines.clone();
for entry in &example.expected_context {
let mut best_alternative_score: Option<Scores> = None;
for alternative in &entry.alternatives {
let expected_context_lines: HashSet<_> = alternative
.excerpts
.iter()
.flat_map(|excerpt| {
excerpt
.text
.lines()
.map(|line| format!("{}: {line}", excerpt.path.display()))
})
.collect();
let scores = Scores::new(&expected_context_lines, &actual_context_lines);
false_positive_context_lines.retain(|line| !expected_context_lines.contains(line));
if best_alternative_score
.as_ref()
.is_none_or(|best| scores.recall() > best.recall())
{
best_alternative_score = Some(scores);
}
}
let best_alternative = best_alternative_score.unwrap_or_default();
eval_result.context.false_negatives += best_alternative.false_negatives;
eval_result.context.true_positives += best_alternative.true_positives;
}
eval_result.context.false_positives = false_positive_context_lines.len();
}
// Patch score
if predict {
let mut prediction_scores = Scores::default();
let actual_patch_lines = diff_lines(&preds.diff);
let mut false_positive_patch_lines = actual_patch_lines.clone();
for entry in &example.expected_patches {
let mut best_alternative_score: Option<Scores> = None;
for alternative in &entry.alternatives {
let expected_patch_lines = diff_lines(&alternative.patch);
let scores = Scores::new(&expected_patch_lines, &actual_patch_lines);
false_positive_patch_lines.retain(|line| !expected_patch_lines.contains(line));
if best_alternative_score
.as_ref()
.is_none_or(|best| scores.recall() > best.recall())
{
best_alternative_score = Some(scores);
}
}
let best_alternative = best_alternative_score.unwrap_or_default();
prediction_scores.false_negatives += best_alternative.false_negatives;
prediction_scores.true_positives += best_alternative.true_positives;
}
prediction_scores.false_positives = false_positive_patch_lines.len();
eval_result.edit_prediction = Some(prediction_scores);
// eval_result.context_lines_in_expected_patch = expected_context_lines.len();
// eval_result.context_lines_found_in_context = matched;
}
eval_result
}
fn diff_lines(diff: &str) -> HashSet<String> {
diff.lines()
.map(DiffLine::parse)
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
.map(|line| line.to_string())
.collect()
}
/// Return annotated `patch_a` so that:
/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
/// Additions and deletions that are present in `patch_b` will be highlighted in green.

View File

@@ -52,7 +52,7 @@ pub struct Example {
pub cursor_path: PathBuf,
pub cursor_position: String,
pub edit_history: String,
pub expected_patch: String,
pub expected_patches: Vec<ExpectedPatchEntry>,
pub expected_context: Vec<ExpectedContextEntry>,
}
@@ -64,6 +64,18 @@ pub struct Excerpt {
pub text: String,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize)]
pub struct ExpectedPatchEntry {
pub heading: String,
pub alternatives: Vec<ExpectedPatch>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize)]
pub struct ExpectedPatch {
pub heading: String,
pub patch: String,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize)]
pub struct ExpectedContextEntry {
pub heading: String,
@@ -131,7 +143,7 @@ impl NamedExample {
cursor_path: PathBuf::new(),
cursor_position: String::new(),
edit_history: String::new(),
expected_patch: String::new(),
expected_patches: Vec::new(),
expected_context: Vec::new(),
},
};
@@ -205,6 +217,12 @@ impl NamedExample {
alternatives: Vec::new(),
});
}
Section::ExpectedPatch => {
named.example.expected_patches.push(ExpectedPatchEntry {
heading,
alternatives: Vec::new(),
});
}
_ => {}
}
}
@@ -219,6 +237,14 @@ impl NamedExample {
excerpts: Vec::new(),
})
}
Section::ExpectedPatch => {
let expected_patch = &mut named.example.expected_patches;
let last_entry = expected_patch.last_mut().unwrap();
last_entry.alternatives.push(ExpectedPatch {
heading,
patch: String::new(),
})
}
_ => {}
}
}
@@ -290,7 +316,25 @@ impl NamedExample {
}
}
Section::ExpectedPatch => {
named.example.expected_patch = mem::take(&mut text);
let patch = mem::take(&mut text);
if named.example.expected_patches.is_empty() {
named.example.expected_patches.push(Default::default());
}
let alternatives = &mut named
.example
.expected_patches
.last_mut()
.unwrap()
.alternatives;
if alternatives.is_empty() {
alternatives.push(ExpectedPatch {
heading: String::new(),
patch,
});
}
}
Section::Other => {}
}
@@ -648,14 +692,6 @@ impl Display for NamedExample {
)?;
write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
if !self.example.expected_patch.is_empty() {
write!(
f,
"\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
self.example.expected_patch
)?;
}
if !self.example.expected_context.is_empty() {
write!(f, "\n## {EXPECTED_CONTEXT_HEADING}\n\n")?;
@@ -687,6 +723,23 @@ impl Display for NamedExample {
}
}
if !self.example.expected_patches.is_empty() {
for entry in &self.example.expected_patches {
write!(f, "\n### {}\n\n", entry.heading)?;
let skip_h4 =
entry.alternatives.len() == 1 && entry.alternatives[0].heading.is_empty();
for patch in &entry.alternatives {
if !skip_h4 {
write!(f, "\n#### {}\n\n", patch.heading)?;
}
write!(f, "`````diff\n{}`````\n", patch.patch)?;
}
}
}
Ok(())
}
}

View File

@@ -72,7 +72,7 @@ struct ContextStatsArgs {
#[derive(Debug, Args)]
struct ContextArgs {
#[arg(long)]
#[arg(long, value_enum, default_value_t = Default::default())]
provider: ContextProvider,
#[arg(long)]
worktree: PathBuf,
@@ -132,7 +132,7 @@ pub struct PredictionOptions {
use_expected_context: bool,
#[clap(flatten)]
zeta2: Zeta2Args,
#[clap(long)]
#[clap(long, value_enum, default_value_t = Default::default())]
provider: PredictionProvider,
#[clap(long, value_enum, default_value_t = CacheMode::default())]
cache: CacheMode,
@@ -225,8 +225,8 @@ enum PromptFormat {
MarkedExcerpt,
LabeledSections,
OnlySnippets,
#[default]
NumberedLines,
#[default]
OldTextNewText,
Minimal,
MinimalQwen,