Merge branch 'main' into 42286-unsafe-repository

This commit is contained in:
dino
2025-12-12 18:41:13 +00:00
9 changed files with 470 additions and 340 deletions

View File

@@ -1,14 +1,22 @@
use anyhow::{Result, anyhow};
use std::mem;
use crate::example::Example;
pub async fn run_distill(example: &mut Example) {
let [prediction]: [_; 1] = mem::take(&mut example.predictions)
.try_into()
.expect("Run predict first with a single repetition");
pub async fn run_distill(example: &mut Example) -> Result<()> {
let [prediction]: [_; 1] =
mem::take(&mut example.predictions)
.try_into()
.map_err(|preds: Vec<_>| {
anyhow!(
"Example has {} predictions, but it should have exactly one",
preds.len()
)
})?;
example.expected_patch = prediction.actual_patch;
example.prompt = None;
example.predictions = Vec::new();
example.score = Vec::new();
Ok(())
}

View File

@@ -6,6 +6,7 @@ use crate::{
progress::{Progress, Step},
retrieve_context::run_context_retrieval,
};
use anyhow::{Context as _, Result, ensure};
use edit_prediction::{
EditPredictionStore,
zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
@@ -19,8 +20,8 @@ pub async fn run_format_prompt(
prompt_format: PromptFormat,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) {
run_context_retrieval(example, app_state.clone(), cx.clone()).await;
) -> Result<()> {
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
let _step_progress = Progress::global().start(Step::FormatPrompt, &example.name);
@@ -34,29 +35,33 @@ pub async fn run_format_prompt(
});
}
PromptFormat::Zeta2 => {
run_load_project(example, app_state, cx.clone()).await;
run_load_project(example, app_state, cx.clone()).await?;
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
let ep_store = cx.update(|cx| {
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
})??;
let state = example.state.as_ref().unwrap();
let snapshot = state
.buffer
.read_with(&cx, |buffer, _| buffer.snapshot())
.unwrap();
let state = example.state.as_ref().context("state must be set")?;
let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
let project = state.project.clone();
let (_, input) = ep_store
.update(&mut cx, |ep_store, _cx| {
zeta2_prompt_input(
&snapshot,
example.context.as_ref().unwrap().files.clone(),
ep_store.edit_history_for_project(&project),
example.cursor_path.clone(),
example.buffer.as_ref().unwrap().cursor_offset,
)
})
.unwrap();
let (_, input) = ep_store.update(&mut cx, |ep_store, _cx| {
anyhow::Ok(zeta2_prompt_input(
&snapshot,
example
.context
.as_ref()
.context("context must be set")?
.files
.clone(),
ep_store.edit_history_for_project(&project),
example.cursor_path.clone(),
example
.buffer
.as_ref()
.context("buffer must be set")?
.cursor_offset,
))
})??;
let prompt = format_zeta_prompt(&input);
let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone());
example.prompt = Some(ExamplePrompt {
@@ -66,6 +71,7 @@ pub async fn run_format_prompt(
});
}
};
Ok(())
}
pub struct TeacherPrompt;
@@ -91,7 +97,7 @@ impl TeacherPrompt {
prompt
}
pub fn parse(example: &Example, response: &str) -> String {
pub fn parse(example: &Example, response: &str) -> Result<String> {
// Ideally, we should always be able to find cursor position in the retrieved context.
// In reality, sometimes we don't find it for these reasons:
// 1. `example.cursor_position` contains _more_ context than included in the retrieved context
@@ -102,7 +108,7 @@ impl TeacherPrompt {
let cursor_file = &example
.buffer
.as_ref()
.expect("`buffer` should be filled in in the context collection step")
.context("`buffer` should be filled in in the context collection step")?
.content;
// Extract updated (new) editable region from the model response
@@ -111,9 +117,10 @@ impl TeacherPrompt {
// Reconstruct old editable region we sent to the model
let old_editable_region = Self::format_editable_region(example);
let old_editable_region = Self::extract_editable_region(&old_editable_region);
if !cursor_file.contains(&old_editable_region) {
panic!("Something's wrong: editable_region is not found in the cursor file")
}
ensure!(
cursor_file.contains(&old_editable_region),
"Something's wrong: editable_region is not found in the cursor file"
);
// Apply editable region to a larger context and compute diff.
// This is needed to get a better context lines around the editable region
@@ -128,7 +135,7 @@ impl TeacherPrompt {
diff = diff,
};
diff
Ok(diff)
}
fn format_edit_history(edit_history: &str) -> String {
@@ -152,9 +159,7 @@ impl TeacherPrompt {
}
fn format_context(example: &Example) -> String {
if example.context.is_none() {
panic!("Missing context retriever step");
}
assert!(example.context.is_some(), "Missing context retriever step");
let mut prompt = String::new();
zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);

View File

@@ -4,7 +4,7 @@ use crate::{
paths::{REPOS_DIR, WORKTREES_DIR},
progress::{InfoStyle, Progress, Step, StepProgress},
};
use anyhow::{Result, anyhow};
use anyhow::{Context as _, Result};
use collections::HashMap;
use edit_prediction::EditPredictionStore;
use edit_prediction::udiff::OpenedBuffers;
@@ -25,38 +25,38 @@ use std::{
use util::{paths::PathStyle, rel_path::RelPath};
use zeta_prompt::CURSOR_MARKER;
pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>, mut cx: AsyncApp) {
pub async fn run_load_project(
example: &mut Example,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) -> Result<()> {
if example.state.is_some() {
return;
return Ok(());
}
let progress = Progress::global().start(Step::LoadProject, &example.name);
let project = setup_project(example, &app_state, &progress, &mut cx).await;
let project = setup_project(example, &app_state, &progress, &mut cx).await?;
let _open_buffers = apply_edit_history(example, &project, &mut cx)
.await
.unwrap();
let _open_buffers = apply_edit_history(example, &project, &mut cx).await?;
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
let (example_buffer, language_name) = buffer
.read_with(&cx, |buffer, _cx| {
let cursor_point = cursor_position.to_point(&buffer);
let language_name = buffer
.language()
.map(|l| l.name().to_string())
.unwrap_or_else(|| "Unknown".to_string());
(
ExampleBuffer {
content: buffer.text(),
cursor_row: cursor_point.row,
cursor_column: cursor_point.column,
cursor_offset: cursor_position.to_offset(&buffer),
},
language_name,
)
})
.unwrap();
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await?;
let (example_buffer, language_name) = buffer.read_with(&cx, |buffer, _cx| {
let cursor_point = cursor_position.to_point(&buffer);
let language_name = buffer
.language()
.map(|l| l.name().to_string())
.unwrap_or_else(|| "Unknown".to_string());
(
ExampleBuffer {
content: buffer.text(),
cursor_row: cursor_point.row,
cursor_column: cursor_point.column,
cursor_offset: cursor_position.to_offset(&buffer),
},
language_name,
)
})?;
progress.set_info(language_name, InfoStyle::Normal);
@@ -67,16 +67,15 @@ pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>,
cursor_position,
_open_buffers,
});
Ok(())
}
async fn cursor_position(
example: &Example,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> (Entity<Buffer>, Anchor) {
let language_registry = project
.read_with(cx, |project, _| project.languages().clone())
.unwrap();
) -> Result<(Entity<Buffer>, Anchor)> {
let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
let result = language_registry
.load_language_for_file_path(&example.cursor_path)
.await;
@@ -84,17 +83,18 @@ async fn cursor_position(
if let Err(error) = result
&& !error.is::<LanguageNotFound>()
{
panic!("Failed to load language for file path: {}", error);
return Err(error);
}
let worktree = project
.read_with(cx, |project, cx| {
project.visible_worktrees(cx).next().unwrap()
})
.unwrap();
let worktree = project.read_with(cx, |project, cx| {
project
.visible_worktrees(cx)
.next()
.context("No visible worktrees")
})??;
let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
.unwrap()
.context("Failed to create RelPath")?
.into_arc();
let cursor_buffer = project
.update(cx, |project, cx| {
@@ -105,15 +105,12 @@ async fn cursor_position(
},
cx,
)
})
.unwrap()
.await
.unwrap();
})?
.await?;
let cursor_offset_within_excerpt = example
.cursor_position
.find(CURSOR_MARKER)
.ok_or_else(|| anyhow!("missing cursor marker"))
.unwrap();
.context("missing cursor marker")?;
let mut cursor_excerpt = example.cursor_position.clone();
cursor_excerpt.replace_range(
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
@@ -123,22 +120,21 @@ async fn cursor_position(
let text = buffer.text();
let mut matches = text.match_indices(&cursor_excerpt);
let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
panic!(
let (excerpt_offset, _) = matches.next().with_context(|| {
format!(
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Example: {}\nCursor excerpt did not exist in buffer.",
example.name
);
});
assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
excerpt_offset
}).unwrap();
)
})?;
anyhow::ensure!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
Ok(excerpt_offset)
})??;
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
let cursor_anchor = cursor_buffer
.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
.unwrap();
let cursor_anchor =
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
(cursor_buffer, cursor_anchor)
Ok((cursor_buffer, cursor_anchor))
}
async fn setup_project(
@@ -146,67 +142,54 @@ async fn setup_project(
app_state: &Arc<EpAppState>,
step_progress: &StepProgress,
cx: &mut AsyncApp,
) -> Entity<Project> {
) -> Result<Entity<Project>> {
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
.update(|cx| EditPredictionStore::try_global(cx))?
.context("Store should be initialized at init")?;
let worktree_path = setup_worktree(example, step_progress).await;
let worktree_path = setup_worktree(example, step_progress).await?;
if let Some(project) = app_state.project_cache.get(&example.repository_url) {
ep_store
.update(cx, |ep_store, _| {
ep_store.clear_history_for_project(&project);
})
.unwrap();
let buffer_store = project
.read_with(cx, |project, _| project.buffer_store().clone())
.unwrap();
let buffers = buffer_store
.read_with(cx, |buffer_store, _| {
buffer_store.buffers().collect::<Vec<_>>()
})
.unwrap();
ep_store.update(cx, |ep_store, _| {
ep_store.clear_history_for_project(&project);
})?;
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
let buffers = buffer_store.read_with(cx, |buffer_store, _| {
buffer_store.buffers().collect::<Vec<_>>()
})?;
for buffer in buffers {
buffer
.update(cx, |buffer, cx| buffer.reload(cx))
.unwrap()
.update(cx, |buffer, cx| buffer.reload(cx))?
.await
.ok();
}
return project;
return Ok(project);
}
let project = cx
.update(|cx| {
Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
app_state.fs.clone(),
None,
cx,
)
})
.unwrap();
let project = cx.update(|cx| {
Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
app_state.fs.clone(),
None,
cx,
)
})?;
project
.update(cx, |project, cx| {
project.disable_worktree_scanner(cx);
project.create_worktree(&worktree_path, true, cx)
})
.unwrap()
.await
.unwrap();
})?
.await?;
app_state
.project_cache
.insert(example.repository_url.clone(), project.clone());
let buffer_store = project
.read_with(cx, |project, _| project.buffer_store().clone())
.unwrap();
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
cx.subscribe(&buffer_store, {
let project = project.clone();
move |_, event, cx| match event {
@@ -215,15 +198,14 @@ async fn setup_project(
}
_ => {}
}
})
.unwrap()
})?
.detach();
project
Ok(project)
}
async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> PathBuf {
let (repo_owner, repo_name) = example.repo_name().expect("failed to get repo name");
async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Result<PathBuf> {
let (repo_owner, repo_name) = example.repo_name().context("failed to get repo name")?;
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
let worktree_path = WORKTREES_DIR
.join(repo_owner.as_ref())
@@ -232,14 +214,13 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Path
if !repo_dir.is_dir() {
step_progress.set_substatus(format!("cloning {}", repo_name));
fs::create_dir_all(&repo_dir).unwrap();
run_git(&repo_dir, &["init"]).await.unwrap();
fs::create_dir_all(&repo_dir)?;
run_git(&repo_dir, &["init"]).await?;
run_git(
&repo_dir,
&["remote", "add", "origin", &example.repository_url],
)
.await
.unwrap();
.await?;
}
// Resolve the example to a revision, fetching it if needed.
@@ -259,34 +240,25 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Path
.await
.is_err()
{
run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
run_git(&repo_dir, &["fetch", "origin"]).await?;
}
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
.await
.unwrap();
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
revision
};
// Create the worktree for this example if needed.
step_progress.set_substatus("preparing worktree");
if worktree_path.is_dir() {
run_git(&worktree_path, &["clean", "--force", "-d"])
.await
.unwrap();
run_git(&worktree_path, &["reset", "--hard", "HEAD"])
.await
.unwrap();
run_git(&worktree_path, &["checkout", revision.as_str()])
.await
.unwrap();
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
} else {
let worktree_path_string = worktree_path.to_string_lossy();
run_git(
&repo_dir,
&["branch", "-f", &example.name, revision.as_str()],
)
.await
.unwrap();
.await?;
run_git(
&repo_dir,
&[
@@ -297,8 +269,7 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Path
&example.name,
],
)
.await
.unwrap();
.await?;
}
drop(repo_lock);
@@ -309,30 +280,25 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Path
.current_dir(&worktree_path)
.args(&["apply", "-"])
.stdin(std::process::Stdio::piped())
.spawn()
.unwrap();
.spawn()?;
let mut stdin = apply_process.stdin.take().unwrap();
stdin
.write_all(example.uncommitted_diff.as_bytes())
.await
.unwrap();
stdin.close().await.unwrap();
let mut stdin = apply_process.stdin.take().context("Failed to get stdin")?;
stdin.write_all(example.uncommitted_diff.as_bytes()).await?;
stdin.close().await?;
drop(stdin);
let apply_result = apply_process.output().await.unwrap();
if !apply_result.status.success() {
panic!(
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
apply_result.status,
String::from_utf8_lossy(&apply_result.stderr),
String::from_utf8_lossy(&apply_result.stdout),
);
}
let apply_result = apply_process.output().await?;
anyhow::ensure!(
apply_result.status.success(),
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
apply_result.status,
String::from_utf8_lossy(&apply_result.stderr),
String::from_utf8_lossy(&apply_result.stdout),
);
}
step_progress.clear_substatus();
worktree_path
Ok(worktree_path)
}
async fn apply_edit_history(

View File

@@ -16,12 +16,14 @@ use edit_prediction::EditPredictionStore;
use gpui::Application;
use reqwest_client::ReqwestClient;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
use std::{path::PathBuf, sync::Arc};
use crate::distill::run_distill;
use crate::example::{group_examples_by_repo, read_examples, write_examples};
use crate::format_prompt::run_format_prompt;
use crate::load_project::run_load_project;
use crate::paths::FAILED_EXAMPLES_DIR;
use crate::predict::run_prediction;
use crate::progress::Progress;
use crate::retrieve_context::run_context_retrieval;
@@ -42,6 +44,8 @@ struct EpArgs {
output: Option<PathBuf>,
#[arg(long, short, global = true)]
in_place: bool,
#[arg(long, short, global = true)]
failfast: bool,
}
#[derive(Subcommand, Debug)]
@@ -67,6 +71,58 @@ enum Command {
Clean,
}
impl Display for Command {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Command::ParseExample => write!(f, "parse-example"),
Command::LoadProject => write!(f, "load-project"),
Command::Context => write!(f, "context"),
Command::FormatPrompt(format_prompt_args) => write!(
f,
"format-prompt --prompt-format={}",
format_prompt_args
.prompt_format
.to_possible_value()
.unwrap()
.get_name()
),
Command::Predict(predict_args) => {
write!(
f,
"predict --provider={:?}",
predict_args
.provider
.to_possible_value()
.unwrap()
.get_name()
)
}
Command::Score(predict_args) => {
write!(
f,
"score --provider={:?}",
predict_args
.provider
.to_possible_value()
.unwrap()
.get_name()
)
}
Command::Distill => write!(f, "distill"),
Command::Eval(predict_args) => write!(
f,
"eval --provider={:?}",
predict_args
.provider
.to_possible_value()
.unwrap()
.get_name()
),
Command::Clean => write!(f, "clean"),
}
}
}
#[derive(Debug, Args)]
struct FormatPromptArgs {
#[clap(long)]
@@ -145,71 +201,140 @@ fn main() {
EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
cx.spawn(async move |cx| {
if let Command::Predict(args) = &command {
predict::sync_batches(&args.provider).await
};
let result = async {
if let Command::Predict(args) = &command {
predict::sync_batches(&args.provider).await?;
}
let total_examples = examples.len();
Progress::global().set_total_examples(total_examples);
let total_examples = examples.len();
Progress::global().set_total_examples(total_examples);
let mut grouped_examples = group_examples_by_repo(&mut examples);
let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
let mut grouped_examples = group_examples_by_repo(&mut examples);
let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
for example_batch in example_batches {
let futures = example_batch.into_iter().map(|repo_examples| async {
for example in repo_examples.iter_mut() {
match &command {
Command::ParseExample => {}
Command::LoadProject => {
run_load_project(example, app_state.clone(), cx.clone()).await;
for example_batch in example_batches {
let futures = example_batch.into_iter().map(|repo_examples| async {
for example in repo_examples.iter_mut() {
let result = async {
match &command {
Command::ParseExample => {}
Command::LoadProject => {
run_load_project(example, app_state.clone(), cx.clone())
.await?;
}
Command::Context => {
run_context_retrieval(
example,
app_state.clone(),
cx.clone(),
)
.await?;
}
Command::FormatPrompt(args) => {
run_format_prompt(
example,
args.prompt_format,
app_state.clone(),
cx.clone(),
)
.await?;
}
Command::Predict(args) => {
run_prediction(
example,
Some(args.provider),
args.repetitions,
app_state.clone(),
cx.clone(),
)
.await?;
}
Command::Distill => {
run_distill(example).await?;
}
Command::Score(args) | Command::Eval(args) => {
run_scoring(example, &args, app_state.clone(), cx.clone())
.await?;
}
Command::Clean => {
unreachable!()
}
}
anyhow::Ok(())
}
Command::Context => {
run_context_retrieval(example, app_state.clone(), cx.clone()).await;
}
Command::FormatPrompt(args) => {
run_format_prompt(
example,
args.prompt_format,
app_state.clone(),
cx.clone(),
)
.await;
}
Command::Predict(args) => {
run_prediction(
example,
Some(args.provider),
args.repetitions,
app_state.clone(),
cx.clone(),
)
.await;
}
Command::Distill => {
run_distill(example).await;
}
Command::Score(args) | Command::Eval(args) => {
run_scoring(example, &args, app_state.clone(), cx.clone()).await;
}
Command::Clean => {
unreachable!()
.await;
if let Err(e) = result {
Progress::global().increment_failed();
let failed_example_path =
FAILED_EXAMPLES_DIR.join(format!("{}.json", example.name));
app_state
.fs
.write(
&failed_example_path,
&serde_json::to_vec_pretty(&example).unwrap(),
)
.await
.unwrap();
let err_path =
FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example.name));
app_state
.fs
.write(&err_path, e.to_string().as_bytes())
.await
.unwrap();
let msg = format!(
indoc::indoc! {"
While processing {}:
{:?}
Written to: \x1b[36m{}\x1b[0m
Explore this example data with:
fx \x1b[36m{}\x1b[0m
Re-run this example with:
cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
"},
example.name,
e,
err_path.display(),
failed_example_path.display(),
command,
failed_example_path.display(),
);
if args.failfast || total_examples == 1 {
Progress::global().finalize();
panic!("{}", msg);
} else {
log::error!("{}", msg);
}
}
}
}
});
futures::future::join_all(futures).await;
}
Progress::global().clear();
});
futures::future::join_all(futures).await;
}
Progress::global().finalize();
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
write_examples(&examples, output.as_ref());
}
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
write_examples(&examples, output.as_ref());
}
match &command {
Command::Predict(args) => predict::sync_batches(&args.provider).await,
Command::Eval(_) => score::print_report(&examples),
_ => (),
};
match &command {
Command::Predict(args) => predict::sync_batches(&args.provider).await?,
Command::Eval(_) => score::print_report(&examples),
_ => (),
};
anyhow::Ok(())
}
.await;
if let Err(e) = result {
panic!("Fatal error: {:?}", e);
}
let _ = cx.update(|cx| cx.quit());
})

View File

@@ -18,6 +18,8 @@ pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
});
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
pub static FAILED_EXAMPLES_DIR: LazyLock<PathBuf> =
LazyLock::new(|| ensure_dir(&RUN_DIR.join("failed")));
fn ensure_dir(path: &Path) -> PathBuf {
std::fs::create_dir_all(path).expect("Failed to create directory");

View File

@@ -9,6 +9,7 @@ use crate::{
progress::{InfoStyle, Progress, Step},
retrieve_context::run_context_retrieval,
};
use anyhow::Context as _;
use edit_prediction::{DebugEvent, EditPredictionStore};
use futures::{FutureExt as _, StreamExt as _, future::Shared};
use gpui::{AppContext as _, AsyncApp, Task};
@@ -26,14 +27,14 @@ pub async fn run_prediction(
repetition_count: usize,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) {
) -> anyhow::Result<()> {
if !example.predictions.is_empty() {
return;
return Ok(());
}
let provider = provider.unwrap();
let provider = provider.context("provider is required")?;
run_context_retrieval(example, app_state.clone(), cx.clone()).await;
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
if matches!(
provider,
@@ -42,14 +43,14 @@ pub async fn run_prediction(
let _step_progress = Progress::global().start(Step::Predict, &example.name);
if example.prompt.is_none() {
run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
}
let batched = matches!(provider, PredictionProvider::Teacher);
return predict_anthropic(example, repetition_count, batched).await;
}
run_load_project(example, app_state.clone(), cx.clone()).await;
run_load_project(example, app_state.clone(), cx.clone()).await?;
let _step_progress = Progress::global().start(Step::Predict, &example.name);
@@ -62,10 +63,9 @@ pub async fn run_prediction(
.get_or_init(|| {
let client = app_state.client.clone();
cx.spawn(async move |cx| {
client
.sign_in_with_optional_connect(true, cx)
.await
.unwrap();
if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
eprintln!("Authentication failed: {}", e);
}
})
.shared()
})
@@ -73,33 +73,30 @@ pub async fn run_prediction(
.await;
}
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
let ep_store = cx.update(|cx| {
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
})??;
ep_store
.update(&mut cx, |store, _cx| {
let model = match provider {
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
unreachable!()
}
};
store.set_edit_prediction_model(model);
})
.unwrap();
let state = example.state.as_ref().unwrap();
ep_store.update(&mut cx, |store, _cx| {
let model = match provider {
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
unreachable!()
}
};
store.set_edit_prediction_model(model);
})?;
let state = example.state.as_ref().context("state must be set")?;
let run_dir = RUN_DIR.join(&example.name);
let updated_example = Arc::new(Mutex::new(example.clone()));
let current_run_ix = Arc::new(AtomicUsize::new(0));
let mut debug_rx = ep_store
.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
.unwrap();
let mut debug_rx =
ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))?;
let debug_task = cx.background_spawn({
let updated_example = updated_example.clone();
let current_run_ix = current_run_ix.clone();
@@ -153,14 +150,14 @@ pub async fn run_prediction(
run_dir.clone()
};
fs::create_dir_all(&run_dir).unwrap();
fs::create_dir_all(&run_dir)?;
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
}
#[cfg(unix)]
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
#[cfg(windows)]
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
updated_example
.lock()
@@ -181,10 +178,8 @@ pub async fn run_prediction(
cloud_llm_client::PredictEditsRequestTrigger::Cli,
cx,
)
})
.unwrap()
.await
.unwrap();
})?
.await?;
let actual_patch = prediction
.and_then(|prediction| {
@@ -213,20 +208,23 @@ pub async fn run_prediction(
}
}
ep_store
.update(&mut cx, |store, _| {
store.remove_project(&state.project);
})
.unwrap();
debug_task.await.unwrap();
ep_store.update(&mut cx, |store, _| {
store.remove_project(&state.project);
})?;
debug_task.await?;
*example = Arc::into_inner(updated_example)
.unwrap()
.ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
.into_inner()
.unwrap();
.map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
Ok(())
}
async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
async fn predict_anthropic(
example: &mut Example,
_repetition_count: usize,
batched: bool,
) -> anyhow::Result<()> {
let llm_model_name = "claude-sonnet-4-5";
let max_tokens = 16384;
let llm_client = if batched {
@@ -234,12 +232,9 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
} else {
AnthropicClient::plain()
};
let llm_client = llm_client.expect("Failed to create LLM client");
let llm_client = llm_client.context("Failed to create LLM client")?;
let prompt = example
.prompt
.as_ref()
.unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
let prompt = example.prompt.as_ref().context("Prompt is required")?;
let messages = vec![anthropic::Message {
role: anthropic::Role::User,
@@ -251,11 +246,10 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
let Some(response) = llm_client
.generate(llm_model_name, max_tokens, messages)
.await
.unwrap()
.await?
else {
// Request stashed for batched processing
return;
return Ok(());
};
let actual_output = response
@@ -268,7 +262,7 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
.collect::<Vec<String>>()
.join("\n");
let actual_patch = TeacherPrompt::parse(example, &actual_output);
let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
let prediction = ExamplePrediction {
actual_patch,
@@ -277,19 +271,21 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
};
example.predictions.push(prediction);
Ok(())
}
pub async fn sync_batches(provider: &PredictionProvider) {
pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
match provider {
PredictionProvider::Teacher => {
let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
let llm_client =
AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
AnthropicClient::batch(cache_path).context("Failed to create LLM client")?;
llm_client
.sync_batches()
.await
.expect("Failed to sync batches");
.context("Failed to sync batches")?;
}
_ => (),
}
};
Ok(())
}

View File

@@ -20,6 +20,7 @@ struct ProgressInner {
max_example_name_len: usize,
status_lines_displayed: usize,
total_examples: usize,
failed_examples: usize,
last_line_is_logging: bool,
}
@@ -78,7 +79,7 @@ impl Step {
static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
static LOGGER: ProgressLogger = ProgressLogger;
const RIGHT_MARGIN: usize = 4;
const MARGIN: usize = 4;
const MAX_STATUS_LINES: usize = 10;
impl Progress {
@@ -95,6 +96,7 @@ impl Progress {
max_example_name_len: 0,
status_lines_displayed: 0,
total_examples: 0,
failed_examples: 0,
last_line_is_logging: false,
}),
});
@@ -110,6 +112,11 @@ impl Progress {
inner.total_examples = total;
}
pub fn increment_failed(&self) {
let mut inner = self.inner.lock().unwrap();
inner.failed_examples += 1;
}
/// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
/// This should be used for any output that needs to appear above the status lines.
fn log(&self, message: &str) {
@@ -119,7 +126,7 @@ impl Progress {
if !inner.last_line_is_logging {
let reset = "\x1b[0m";
let dim = "\x1b[2m";
let divider = "".repeat(inner.terminal_width.saturating_sub(RIGHT_MARGIN));
let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
eprintln!("{dim}{divider}{reset}");
inner.last_line_is_logging = true;
}
@@ -180,7 +187,7 @@ impl Progress {
if inner.last_line_is_logging {
let reset = "\x1b[0m";
let dim = "\x1b[2m";
let divider = "".repeat(inner.terminal_width.saturating_sub(RIGHT_MARGIN));
let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
eprintln!("{dim}{divider}{reset}");
inner.last_line_is_logging = false;
}
@@ -229,7 +236,7 @@ impl Progress {
let duration_with_margin = format!("{duration} ");
let padding_needed = inner
.terminal_width
.saturating_sub(RIGHT_MARGIN)
.saturating_sub(MARGIN)
.saturating_sub(duration_with_margin.len())
.saturating_sub(strip_ansi_len(&prefix));
let padding = " ".repeat(padding_needed);
@@ -263,20 +270,33 @@ impl Progress {
// Build the done/in-progress/total label
let done_count = inner.completed.len();
let in_progress_count = inner.in_progress.len();
let failed_count = inner.failed_examples;
let failed_label = if failed_count > 0 {
format!(" {} failed ", failed_count)
} else {
String::new()
};
let range_label = format!(
" {}/{}/{} ",
done_count, in_progress_count, inner.total_examples
);
// Print a divider line with range label aligned with timestamps
// Print a divider line with failed count on left, range label on right
let failed_visible_len = strip_ansi_len(&failed_label);
let range_visible_len = range_label.len();
let left_divider_len = inner
let middle_divider_len = inner
.terminal_width
.saturating_sub(RIGHT_MARGIN)
.saturating_sub(MARGIN * 2)
.saturating_sub(failed_visible_len)
.saturating_sub(range_visible_len);
let left_divider = "".repeat(left_divider_len);
let right_divider = "".repeat(RIGHT_MARGIN);
eprintln!("{dim}{left_divider}{reset}{range_label}{dim}{right_divider}{reset}");
let left_divider = "".repeat(MARGIN);
let middle_divider = "".repeat(middle_divider_len);
let right_divider = "".repeat(MARGIN);
eprintln!(
"{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
);
let mut tasks: Vec<_> = inner.in_progress.iter().collect();
tasks.sort_by_key(|(name, _)| *name);
@@ -304,7 +324,7 @@ impl Progress {
let duration_with_margin = format!("{elapsed} ");
let padding_needed = inner
.terminal_width
.saturating_sub(RIGHT_MARGIN)
.saturating_sub(MARGIN)
.saturating_sub(duration_with_margin.len())
.saturating_sub(strip_ansi_len(&prefix));
let padding = " ".repeat(padding_needed);
@@ -324,9 +344,23 @@ impl Progress {
let _ = std::io::stderr().flush();
}
pub fn clear(&self) {
pub fn finalize(&self) {
let mut inner = self.inner.lock().unwrap();
Self::clear_status_lines(&mut inner);
// Print summary if there were failures
if inner.failed_examples > 0 {
let total_processed = inner.completed.len() + inner.failed_examples;
let percentage = if total_processed > 0 {
inner.failed_examples as f64 / total_processed as f64 * 100.0
} else {
0.0
};
eprintln!(
"\n{} of {} examples failed ({:.1}%)",
inner.failed_examples, total_processed, percentage
);
}
}
}

View File

@@ -4,6 +4,7 @@ use crate::{
load_project::run_load_project,
progress::{InfoStyle, Progress, Step, StepProgress},
};
use anyhow::Context as _;
use collections::HashSet;
use edit_prediction::{DebugEvent, EditPredictionStore};
use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
@@ -17,12 +18,12 @@ pub async fn run_context_retrieval(
example: &mut Example,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) {
) -> anyhow::Result<()> {
if example.context.is_some() {
return;
return Ok(());
}
run_load_project(example, app_state.clone(), cx.clone()).await;
run_load_project(example, app_state.clone(), cx.clone()).await?;
let step_progress: Arc<StepProgress> = Progress::global()
.start(Step::Context, &example.name)
@@ -31,25 +32,21 @@ pub async fn run_context_retrieval(
let state = example.state.as_ref().unwrap();
let project = state.project.clone();
let _lsp_handle = project
.update(&mut cx, |project, cx| {
project.register_buffer_with_language_servers(&state.buffer, cx)
})
.unwrap();
wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await;
let _lsp_handle = project.update(&mut cx, |project, cx| {
project.register_buffer_with_language_servers(&state.buffer, cx)
})?;
wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
let ep_store = cx.update(|cx| {
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
})??;
let mut events = ep_store
.update(&mut cx, |store, cx| {
store.register_buffer(&state.buffer, &project, cx);
store.set_use_context(true);
store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
store.debug_info(&project, cx)
})
.unwrap();
let mut events = ep_store.update(&mut cx, |store, cx| {
store.register_buffer(&state.buffer, &project, cx);
store.set_use_context(true);
store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
store.debug_info(&project, cx)
})?;
while let Some(event) = events.next().await {
match event {
@@ -60,9 +57,8 @@ pub async fn run_context_retrieval(
}
}
let context_files = ep_store
.update(&mut cx, |store, cx| store.context_for_project(&project, cx))
.unwrap();
let context_files =
ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx))?;
let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
@@ -70,6 +66,7 @@ pub async fn run_context_retrieval(
example.context = Some(ExampleContext {
files: context_files,
});
Ok(())
}
async fn wait_for_language_servers_to_start(
@@ -77,10 +74,8 @@ async fn wait_for_language_servers_to_start(
buffer: &Entity<Buffer>,
step_progress: &Arc<StepProgress>,
cx: &mut AsyncApp,
) {
let lsp_store = project
.read_with(cx, |project, _| project.lsp_store())
.unwrap();
) -> anyhow::Result<()> {
let lsp_store = project.read_with(cx, |project, _| project.lsp_store())?;
let (language_server_ids, mut starting_language_server_ids) = buffer
.update(cx, |buffer, cx| {
@@ -123,7 +118,7 @@ async fn wait_for_language_servers_to_start(
}
},
_ = timeout.clone().fuse() => {
panic!("LSP wait timed out after 5 minutes");
return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
}
}
}
@@ -132,8 +127,7 @@ async fn wait_for_language_servers_to_start(
if !language_server_ids.is_empty() {
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
.unwrap()
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.detach();
}
@@ -175,10 +169,8 @@ async fn wait_for_language_servers_to_start(
];
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
.unwrap()
.await
.unwrap();
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
let mut pending_language_server_ids = HashSet::from_iter(language_server_ids.into_iter());
while !pending_language_server_ids.is_empty() {
@@ -189,11 +181,12 @@ async fn wait_for_language_servers_to_start(
}
},
_ = timeout.clone().fuse() => {
panic!("LSP wait timed out after 5 minutes");
return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
}
}
}
drop(subscriptions);
step_progress.clear_substatus();
Ok(())
}

View File

@@ -15,7 +15,7 @@ pub async fn run_scoring(
args: &PredictArgs,
app_state: Arc<EpAppState>,
cx: AsyncApp,
) {
) -> anyhow::Result<()> {
run_prediction(
example,
Some(args.provider),
@@ -23,7 +23,7 @@ pub async fn run_scoring(
app_state,
cx,
)
.await;
.await?;
let _progress = Progress::global().start(Step::Score, &example.name);
@@ -43,6 +43,7 @@ pub async fn run_scoring(
}
example.score = scores;
Ok(())
}
fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {