Merge branch 'main' into 42286-unsafe-repository
This commit is contained in:
@@ -1,14 +1,22 @@
|
||||
use anyhow::{Result, anyhow};
|
||||
use std::mem;
|
||||
|
||||
use crate::example::Example;
|
||||
|
||||
pub async fn run_distill(example: &mut Example) {
|
||||
let [prediction]: [_; 1] = mem::take(&mut example.predictions)
|
||||
.try_into()
|
||||
.expect("Run predict first with a single repetition");
|
||||
pub async fn run_distill(example: &mut Example) -> Result<()> {
|
||||
let [prediction]: [_; 1] =
|
||||
mem::take(&mut example.predictions)
|
||||
.try_into()
|
||||
.map_err(|preds: Vec<_>| {
|
||||
anyhow!(
|
||||
"Example has {} predictions, but it should have exactly one",
|
||||
preds.len()
|
||||
)
|
||||
})?;
|
||||
|
||||
example.expected_patch = prediction.actual_patch;
|
||||
example.prompt = None;
|
||||
example.predictions = Vec::new();
|
||||
example.score = Vec::new();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::{
|
||||
progress::{Progress, Step},
|
||||
retrieve_context::run_context_retrieval,
|
||||
};
|
||||
use anyhow::{Context as _, Result, ensure};
|
||||
use edit_prediction::{
|
||||
EditPredictionStore,
|
||||
zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
|
||||
@@ -19,8 +20,8 @@ pub async fn run_format_prompt(
|
||||
prompt_format: PromptFormat,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) {
|
||||
run_context_retrieval(example, app_state.clone(), cx.clone()).await;
|
||||
) -> Result<()> {
|
||||
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
|
||||
|
||||
let _step_progress = Progress::global().start(Step::FormatPrompt, &example.name);
|
||||
|
||||
@@ -34,29 +35,33 @@ pub async fn run_format_prompt(
|
||||
});
|
||||
}
|
||||
PromptFormat::Zeta2 => {
|
||||
run_load_project(example, app_state, cx.clone()).await;
|
||||
run_load_project(example, app_state, cx.clone()).await?;
|
||||
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
let ep_store = cx.update(|cx| {
|
||||
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
|
||||
})??;
|
||||
|
||||
let state = example.state.as_ref().unwrap();
|
||||
let snapshot = state
|
||||
.buffer
|
||||
.read_with(&cx, |buffer, _| buffer.snapshot())
|
||||
.unwrap();
|
||||
let state = example.state.as_ref().context("state must be set")?;
|
||||
let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
|
||||
let project = state.project.clone();
|
||||
let (_, input) = ep_store
|
||||
.update(&mut cx, |ep_store, _cx| {
|
||||
zeta2_prompt_input(
|
||||
&snapshot,
|
||||
example.context.as_ref().unwrap().files.clone(),
|
||||
ep_store.edit_history_for_project(&project),
|
||||
example.cursor_path.clone(),
|
||||
example.buffer.as_ref().unwrap().cursor_offset,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
let (_, input) = ep_store.update(&mut cx, |ep_store, _cx| {
|
||||
anyhow::Ok(zeta2_prompt_input(
|
||||
&snapshot,
|
||||
example
|
||||
.context
|
||||
.as_ref()
|
||||
.context("context must be set")?
|
||||
.files
|
||||
.clone(),
|
||||
ep_store.edit_history_for_project(&project),
|
||||
example.cursor_path.clone(),
|
||||
example
|
||||
.buffer
|
||||
.as_ref()
|
||||
.context("buffer must be set")?
|
||||
.cursor_offset,
|
||||
))
|
||||
})??;
|
||||
let prompt = format_zeta_prompt(&input);
|
||||
let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone());
|
||||
example.prompt = Some(ExamplePrompt {
|
||||
@@ -66,6 +71,7 @@ pub async fn run_format_prompt(
|
||||
});
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct TeacherPrompt;
|
||||
@@ -91,7 +97,7 @@ impl TeacherPrompt {
|
||||
prompt
|
||||
}
|
||||
|
||||
pub fn parse(example: &Example, response: &str) -> String {
|
||||
pub fn parse(example: &Example, response: &str) -> Result<String> {
|
||||
// Ideally, we should always be able to find cursor position in the retrieved context.
|
||||
// In reality, sometimes we don't find it for these reasons:
|
||||
// 1. `example.cursor_position` contains _more_ context than included in the retrieved context
|
||||
@@ -102,7 +108,7 @@ impl TeacherPrompt {
|
||||
let cursor_file = &example
|
||||
.buffer
|
||||
.as_ref()
|
||||
.expect("`buffer` should be filled in in the context collection step")
|
||||
.context("`buffer` should be filled in in the context collection step")?
|
||||
.content;
|
||||
|
||||
// Extract updated (new) editable region from the model response
|
||||
@@ -111,9 +117,10 @@ impl TeacherPrompt {
|
||||
// Reconstruct old editable region we sent to the model
|
||||
let old_editable_region = Self::format_editable_region(example);
|
||||
let old_editable_region = Self::extract_editable_region(&old_editable_region);
|
||||
if !cursor_file.contains(&old_editable_region) {
|
||||
panic!("Something's wrong: editable_region is not found in the cursor file")
|
||||
}
|
||||
ensure!(
|
||||
cursor_file.contains(&old_editable_region),
|
||||
"Something's wrong: editable_region is not found in the cursor file"
|
||||
);
|
||||
|
||||
// Apply editable region to a larger context and compute diff.
|
||||
// This is needed to get a better context lines around the editable region
|
||||
@@ -128,7 +135,7 @@ impl TeacherPrompt {
|
||||
diff = diff,
|
||||
};
|
||||
|
||||
diff
|
||||
Ok(diff)
|
||||
}
|
||||
|
||||
fn format_edit_history(edit_history: &str) -> String {
|
||||
@@ -152,9 +159,7 @@ impl TeacherPrompt {
|
||||
}
|
||||
|
||||
fn format_context(example: &Example) -> String {
|
||||
if example.context.is_none() {
|
||||
panic!("Missing context retriever step");
|
||||
}
|
||||
assert!(example.context.is_some(), "Missing context retriever step");
|
||||
|
||||
let mut prompt = String::new();
|
||||
zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::{
|
||||
paths::{REPOS_DIR, WORKTREES_DIR},
|
||||
progress::{InfoStyle, Progress, Step, StepProgress},
|
||||
};
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use edit_prediction::EditPredictionStore;
|
||||
use edit_prediction::udiff::OpenedBuffers;
|
||||
@@ -25,38 +25,38 @@ use std::{
|
||||
use util::{paths::PathStyle, rel_path::RelPath};
|
||||
use zeta_prompt::CURSOR_MARKER;
|
||||
|
||||
pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>, mut cx: AsyncApp) {
|
||||
pub async fn run_load_project(
|
||||
example: &mut Example,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) -> Result<()> {
|
||||
if example.state.is_some() {
|
||||
return;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let progress = Progress::global().start(Step::LoadProject, &example.name);
|
||||
|
||||
let project = setup_project(example, &app_state, &progress, &mut cx).await;
|
||||
let project = setup_project(example, &app_state, &progress, &mut cx).await?;
|
||||
|
||||
let _open_buffers = apply_edit_history(example, &project, &mut cx)
|
||||
.await
|
||||
.unwrap();
|
||||
let _open_buffers = apply_edit_history(example, &project, &mut cx).await?;
|
||||
|
||||
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
|
||||
let (example_buffer, language_name) = buffer
|
||||
.read_with(&cx, |buffer, _cx| {
|
||||
let cursor_point = cursor_position.to_point(&buffer);
|
||||
let language_name = buffer
|
||||
.language()
|
||||
.map(|l| l.name().to_string())
|
||||
.unwrap_or_else(|| "Unknown".to_string());
|
||||
(
|
||||
ExampleBuffer {
|
||||
content: buffer.text(),
|
||||
cursor_row: cursor_point.row,
|
||||
cursor_column: cursor_point.column,
|
||||
cursor_offset: cursor_position.to_offset(&buffer),
|
||||
},
|
||||
language_name,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await?;
|
||||
let (example_buffer, language_name) = buffer.read_with(&cx, |buffer, _cx| {
|
||||
let cursor_point = cursor_position.to_point(&buffer);
|
||||
let language_name = buffer
|
||||
.language()
|
||||
.map(|l| l.name().to_string())
|
||||
.unwrap_or_else(|| "Unknown".to_string());
|
||||
(
|
||||
ExampleBuffer {
|
||||
content: buffer.text(),
|
||||
cursor_row: cursor_point.row,
|
||||
cursor_column: cursor_point.column,
|
||||
cursor_offset: cursor_position.to_offset(&buffer),
|
||||
},
|
||||
language_name,
|
||||
)
|
||||
})?;
|
||||
|
||||
progress.set_info(language_name, InfoStyle::Normal);
|
||||
|
||||
@@ -67,16 +67,15 @@ pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>,
|
||||
cursor_position,
|
||||
_open_buffers,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cursor_position(
|
||||
example: &Example,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> (Entity<Buffer>, Anchor) {
|
||||
let language_registry = project
|
||||
.read_with(cx, |project, _| project.languages().clone())
|
||||
.unwrap();
|
||||
) -> Result<(Entity<Buffer>, Anchor)> {
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
|
||||
let result = language_registry
|
||||
.load_language_for_file_path(&example.cursor_path)
|
||||
.await;
|
||||
@@ -84,17 +83,18 @@ async fn cursor_position(
|
||||
if let Err(error) = result
|
||||
&& !error.is::<LanguageNotFound>()
|
||||
{
|
||||
panic!("Failed to load language for file path: {}", error);
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
let worktree = project
|
||||
.read_with(cx, |project, cx| {
|
||||
project.visible_worktrees(cx).next().unwrap()
|
||||
})
|
||||
.unwrap();
|
||||
let worktree = project.read_with(cx, |project, cx| {
|
||||
project
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.context("No visible worktrees")
|
||||
})??;
|
||||
|
||||
let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
|
||||
.unwrap()
|
||||
.context("Failed to create RelPath")?
|
||||
.into_arc();
|
||||
let cursor_buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
@@ -105,15 +105,12 @@ async fn cursor_position(
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
})?
|
||||
.await?;
|
||||
let cursor_offset_within_excerpt = example
|
||||
.cursor_position
|
||||
.find(CURSOR_MARKER)
|
||||
.ok_or_else(|| anyhow!("missing cursor marker"))
|
||||
.unwrap();
|
||||
.context("missing cursor marker")?;
|
||||
let mut cursor_excerpt = example.cursor_position.clone();
|
||||
cursor_excerpt.replace_range(
|
||||
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
|
||||
@@ -123,22 +120,21 @@ async fn cursor_position(
|
||||
let text = buffer.text();
|
||||
|
||||
let mut matches = text.match_indices(&cursor_excerpt);
|
||||
let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
|
||||
panic!(
|
||||
let (excerpt_offset, _) = matches.next().with_context(|| {
|
||||
format!(
|
||||
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Example: {}\nCursor excerpt did not exist in buffer.",
|
||||
example.name
|
||||
);
|
||||
});
|
||||
assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
|
||||
excerpt_offset
|
||||
}).unwrap();
|
||||
)
|
||||
})?;
|
||||
anyhow::ensure!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
|
||||
Ok(excerpt_offset)
|
||||
})??;
|
||||
|
||||
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
|
||||
let cursor_anchor = cursor_buffer
|
||||
.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
|
||||
.unwrap();
|
||||
let cursor_anchor =
|
||||
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
|
||||
|
||||
(cursor_buffer, cursor_anchor)
|
||||
Ok((cursor_buffer, cursor_anchor))
|
||||
}
|
||||
|
||||
async fn setup_project(
|
||||
@@ -146,67 +142,54 @@ async fn setup_project(
|
||||
app_state: &Arc<EpAppState>,
|
||||
step_progress: &StepProgress,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Entity<Project> {
|
||||
) -> Result<Entity<Project>> {
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
.update(|cx| EditPredictionStore::try_global(cx))?
|
||||
.context("Store should be initialized at init")?;
|
||||
|
||||
let worktree_path = setup_worktree(example, step_progress).await;
|
||||
let worktree_path = setup_worktree(example, step_progress).await?;
|
||||
|
||||
if let Some(project) = app_state.project_cache.get(&example.repository_url) {
|
||||
ep_store
|
||||
.update(cx, |ep_store, _| {
|
||||
ep_store.clear_history_for_project(&project);
|
||||
})
|
||||
.unwrap();
|
||||
let buffer_store = project
|
||||
.read_with(cx, |project, _| project.buffer_store().clone())
|
||||
.unwrap();
|
||||
let buffers = buffer_store
|
||||
.read_with(cx, |buffer_store, _| {
|
||||
buffer_store.buffers().collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap();
|
||||
ep_store.update(cx, |ep_store, _| {
|
||||
ep_store.clear_history_for_project(&project);
|
||||
})?;
|
||||
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
|
||||
let buffers = buffer_store.read_with(cx, |buffer_store, _| {
|
||||
buffer_store.buffers().collect::<Vec<_>>()
|
||||
})?;
|
||||
for buffer in buffers {
|
||||
buffer
|
||||
.update(cx, |buffer, cx| buffer.reload(cx))
|
||||
.unwrap()
|
||||
.update(cx, |buffer, cx| buffer.reload(cx))?
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
return project;
|
||||
return Ok(project);
|
||||
}
|
||||
|
||||
let project = cx
|
||||
.update(|cx| {
|
||||
Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
app_state.fs.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
let project = cx.update(|cx| {
|
||||
Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
app_state.fs.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
project.disable_worktree_scanner(cx);
|
||||
project.create_worktree(&worktree_path, true, cx)
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
})?
|
||||
.await?;
|
||||
|
||||
app_state
|
||||
.project_cache
|
||||
.insert(example.repository_url.clone(), project.clone());
|
||||
|
||||
let buffer_store = project
|
||||
.read_with(cx, |project, _| project.buffer_store().clone())
|
||||
.unwrap();
|
||||
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
|
||||
cx.subscribe(&buffer_store, {
|
||||
let project = project.clone();
|
||||
move |_, event, cx| match event {
|
||||
@@ -215,15 +198,14 @@ async fn setup_project(
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
})
|
||||
.unwrap()
|
||||
})?
|
||||
.detach();
|
||||
|
||||
project
|
||||
Ok(project)
|
||||
}
|
||||
|
||||
async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> PathBuf {
|
||||
let (repo_owner, repo_name) = example.repo_name().expect("failed to get repo name");
|
||||
async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Result<PathBuf> {
|
||||
let (repo_owner, repo_name) = example.repo_name().context("failed to get repo name")?;
|
||||
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
|
||||
let worktree_path = WORKTREES_DIR
|
||||
.join(repo_owner.as_ref())
|
||||
@@ -232,14 +214,13 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Path
|
||||
|
||||
if !repo_dir.is_dir() {
|
||||
step_progress.set_substatus(format!("cloning {}", repo_name));
|
||||
fs::create_dir_all(&repo_dir).unwrap();
|
||||
run_git(&repo_dir, &["init"]).await.unwrap();
|
||||
fs::create_dir_all(&repo_dir)?;
|
||||
run_git(&repo_dir, &["init"]).await?;
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["remote", "add", "origin", &example.repository_url],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Resolve the example to a revision, fetching it if needed.
|
||||
@@ -259,34 +240,25 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Path
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
|
||||
run_git(&repo_dir, &["fetch", "origin"]).await?;
|
||||
}
|
||||
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
|
||||
.await
|
||||
.unwrap();
|
||||
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
|
||||
revision
|
||||
};
|
||||
|
||||
// Create the worktree for this example if needed.
|
||||
step_progress.set_substatus("preparing worktree");
|
||||
if worktree_path.is_dir() {
|
||||
run_git(&worktree_path, &["clean", "--force", "-d"])
|
||||
.await
|
||||
.unwrap();
|
||||
run_git(&worktree_path, &["reset", "--hard", "HEAD"])
|
||||
.await
|
||||
.unwrap();
|
||||
run_git(&worktree_path, &["checkout", revision.as_str()])
|
||||
.await
|
||||
.unwrap();
|
||||
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
|
||||
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
|
||||
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
|
||||
} else {
|
||||
let worktree_path_string = worktree_path.to_string_lossy();
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["branch", "-f", &example.name, revision.as_str()],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&[
|
||||
@@ -297,8 +269,7 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Path
|
||||
&example.name,
|
||||
],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
}
|
||||
drop(repo_lock);
|
||||
|
||||
@@ -309,30 +280,25 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Path
|
||||
.current_dir(&worktree_path)
|
||||
.args(&["apply", "-"])
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.spawn()
|
||||
.unwrap();
|
||||
.spawn()?;
|
||||
|
||||
let mut stdin = apply_process.stdin.take().unwrap();
|
||||
stdin
|
||||
.write_all(example.uncommitted_diff.as_bytes())
|
||||
.await
|
||||
.unwrap();
|
||||
stdin.close().await.unwrap();
|
||||
let mut stdin = apply_process.stdin.take().context("Failed to get stdin")?;
|
||||
stdin.write_all(example.uncommitted_diff.as_bytes()).await?;
|
||||
stdin.close().await?;
|
||||
drop(stdin);
|
||||
|
||||
let apply_result = apply_process.output().await.unwrap();
|
||||
if !apply_result.status.success() {
|
||||
panic!(
|
||||
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
apply_result.status,
|
||||
String::from_utf8_lossy(&apply_result.stderr),
|
||||
String::from_utf8_lossy(&apply_result.stdout),
|
||||
);
|
||||
}
|
||||
let apply_result = apply_process.output().await?;
|
||||
anyhow::ensure!(
|
||||
apply_result.status.success(),
|
||||
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
apply_result.status,
|
||||
String::from_utf8_lossy(&apply_result.stderr),
|
||||
String::from_utf8_lossy(&apply_result.stdout),
|
||||
);
|
||||
}
|
||||
|
||||
step_progress.clear_substatus();
|
||||
worktree_path
|
||||
Ok(worktree_path)
|
||||
}
|
||||
|
||||
async fn apply_edit_history(
|
||||
|
||||
@@ -16,12 +16,14 @@ use edit_prediction::EditPredictionStore;
|
||||
use gpui::Application;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Display;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
|
||||
use crate::distill::run_distill;
|
||||
use crate::example::{group_examples_by_repo, read_examples, write_examples};
|
||||
use crate::format_prompt::run_format_prompt;
|
||||
use crate::load_project::run_load_project;
|
||||
use crate::paths::FAILED_EXAMPLES_DIR;
|
||||
use crate::predict::run_prediction;
|
||||
use crate::progress::Progress;
|
||||
use crate::retrieve_context::run_context_retrieval;
|
||||
@@ -42,6 +44,8 @@ struct EpArgs {
|
||||
output: Option<PathBuf>,
|
||||
#[arg(long, short, global = true)]
|
||||
in_place: bool,
|
||||
#[arg(long, short, global = true)]
|
||||
failfast: bool,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
@@ -67,6 +71,58 @@ enum Command {
|
||||
Clean,
|
||||
}
|
||||
|
||||
impl Display for Command {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Command::ParseExample => write!(f, "parse-example"),
|
||||
Command::LoadProject => write!(f, "load-project"),
|
||||
Command::Context => write!(f, "context"),
|
||||
Command::FormatPrompt(format_prompt_args) => write!(
|
||||
f,
|
||||
"format-prompt --prompt-format={}",
|
||||
format_prompt_args
|
||||
.prompt_format
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
),
|
||||
Command::Predict(predict_args) => {
|
||||
write!(
|
||||
f,
|
||||
"predict --provider={:?}",
|
||||
predict_args
|
||||
.provider
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
)
|
||||
}
|
||||
Command::Score(predict_args) => {
|
||||
write!(
|
||||
f,
|
||||
"score --provider={:?}",
|
||||
predict_args
|
||||
.provider
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
)
|
||||
}
|
||||
Command::Distill => write!(f, "distill"),
|
||||
Command::Eval(predict_args) => write!(
|
||||
f,
|
||||
"eval --provider={:?}",
|
||||
predict_args
|
||||
.provider
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
),
|
||||
Command::Clean => write!(f, "clean"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
struct FormatPromptArgs {
|
||||
#[clap(long)]
|
||||
@@ -145,71 +201,140 @@ fn main() {
|
||||
EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
if let Command::Predict(args) = &command {
|
||||
predict::sync_batches(&args.provider).await
|
||||
};
|
||||
let result = async {
|
||||
if let Command::Predict(args) = &command {
|
||||
predict::sync_batches(&args.provider).await?;
|
||||
}
|
||||
|
||||
let total_examples = examples.len();
|
||||
Progress::global().set_total_examples(total_examples);
|
||||
let total_examples = examples.len();
|
||||
Progress::global().set_total_examples(total_examples);
|
||||
|
||||
let mut grouped_examples = group_examples_by_repo(&mut examples);
|
||||
let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
|
||||
let mut grouped_examples = group_examples_by_repo(&mut examples);
|
||||
let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
|
||||
|
||||
for example_batch in example_batches {
|
||||
let futures = example_batch.into_iter().map(|repo_examples| async {
|
||||
for example in repo_examples.iter_mut() {
|
||||
match &command {
|
||||
Command::ParseExample => {}
|
||||
Command::LoadProject => {
|
||||
run_load_project(example, app_state.clone(), cx.clone()).await;
|
||||
for example_batch in example_batches {
|
||||
let futures = example_batch.into_iter().map(|repo_examples| async {
|
||||
for example in repo_examples.iter_mut() {
|
||||
let result = async {
|
||||
match &command {
|
||||
Command::ParseExample => {}
|
||||
Command::LoadProject => {
|
||||
run_load_project(example, app_state.clone(), cx.clone())
|
||||
.await?;
|
||||
}
|
||||
Command::Context => {
|
||||
run_context_retrieval(
|
||||
example,
|
||||
app_state.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Command::FormatPrompt(args) => {
|
||||
run_format_prompt(
|
||||
example,
|
||||
args.prompt_format,
|
||||
app_state.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Command::Predict(args) => {
|
||||
run_prediction(
|
||||
example,
|
||||
Some(args.provider),
|
||||
args.repetitions,
|
||||
app_state.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Command::Distill => {
|
||||
run_distill(example).await?;
|
||||
}
|
||||
Command::Score(args) | Command::Eval(args) => {
|
||||
run_scoring(example, &args, app_state.clone(), cx.clone())
|
||||
.await?;
|
||||
}
|
||||
Command::Clean => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
anyhow::Ok(())
|
||||
}
|
||||
Command::Context => {
|
||||
run_context_retrieval(example, app_state.clone(), cx.clone()).await;
|
||||
}
|
||||
Command::FormatPrompt(args) => {
|
||||
run_format_prompt(
|
||||
example,
|
||||
args.prompt_format,
|
||||
app_state.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Command::Predict(args) => {
|
||||
run_prediction(
|
||||
example,
|
||||
Some(args.provider),
|
||||
args.repetitions,
|
||||
app_state.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Command::Distill => {
|
||||
run_distill(example).await;
|
||||
}
|
||||
Command::Score(args) | Command::Eval(args) => {
|
||||
run_scoring(example, &args, app_state.clone(), cx.clone()).await;
|
||||
}
|
||||
Command::Clean => {
|
||||
unreachable!()
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
Progress::global().increment_failed();
|
||||
let failed_example_path =
|
||||
FAILED_EXAMPLES_DIR.join(format!("{}.json", example.name));
|
||||
app_state
|
||||
.fs
|
||||
.write(
|
||||
&failed_example_path,
|
||||
&serde_json::to_vec_pretty(&example).unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let err_path =
|
||||
FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example.name));
|
||||
app_state
|
||||
.fs
|
||||
.write(&err_path, e.to_string().as_bytes())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let msg = format!(
|
||||
indoc::indoc! {"
|
||||
While processing {}:
|
||||
|
||||
{:?}
|
||||
|
||||
Written to: \x1b[36m{}\x1b[0m
|
||||
|
||||
Explore this example data with:
|
||||
fx \x1b[36m{}\x1b[0m
|
||||
|
||||
Re-run this example with:
|
||||
cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
|
||||
"},
|
||||
example.name,
|
||||
e,
|
||||
err_path.display(),
|
||||
failed_example_path.display(),
|
||||
command,
|
||||
failed_example_path.display(),
|
||||
);
|
||||
if args.failfast || total_examples == 1 {
|
||||
Progress::global().finalize();
|
||||
panic!("{}", msg);
|
||||
} else {
|
||||
log::error!("{}", msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
futures::future::join_all(futures).await;
|
||||
}
|
||||
Progress::global().clear();
|
||||
});
|
||||
futures::future::join_all(futures).await;
|
||||
}
|
||||
Progress::global().finalize();
|
||||
|
||||
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
|
||||
write_examples(&examples, output.as_ref());
|
||||
}
|
||||
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
|
||||
write_examples(&examples, output.as_ref());
|
||||
}
|
||||
|
||||
match &command {
|
||||
Command::Predict(args) => predict::sync_batches(&args.provider).await,
|
||||
Command::Eval(_) => score::print_report(&examples),
|
||||
_ => (),
|
||||
};
|
||||
match &command {
|
||||
Command::Predict(args) => predict::sync_batches(&args.provider).await?,
|
||||
Command::Eval(_) => score::print_report(&examples),
|
||||
_ => (),
|
||||
};
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
panic!("Fatal error: {:?}", e);
|
||||
}
|
||||
|
||||
let _ = cx.update(|cx| cx.quit());
|
||||
})
|
||||
|
||||
@@ -18,6 +18,8 @@ pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
|
||||
});
|
||||
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
|
||||
pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
|
||||
pub static FAILED_EXAMPLES_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| ensure_dir(&RUN_DIR.join("failed")));
|
||||
|
||||
fn ensure_dir(path: &Path) -> PathBuf {
|
||||
std::fs::create_dir_all(path).expect("Failed to create directory");
|
||||
|
||||
@@ -9,6 +9,7 @@ use crate::{
|
||||
progress::{InfoStyle, Progress, Step},
|
||||
retrieve_context::run_context_retrieval,
|
||||
};
|
||||
use anyhow::Context as _;
|
||||
use edit_prediction::{DebugEvent, EditPredictionStore};
|
||||
use futures::{FutureExt as _, StreamExt as _, future::Shared};
|
||||
use gpui::{AppContext as _, AsyncApp, Task};
|
||||
@@ -26,14 +27,14 @@ pub async fn run_prediction(
|
||||
repetition_count: usize,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) {
|
||||
) -> anyhow::Result<()> {
|
||||
if !example.predictions.is_empty() {
|
||||
return;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let provider = provider.unwrap();
|
||||
let provider = provider.context("provider is required")?;
|
||||
|
||||
run_context_retrieval(example, app_state.clone(), cx.clone()).await;
|
||||
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
|
||||
|
||||
if matches!(
|
||||
provider,
|
||||
@@ -42,14 +43,14 @@ pub async fn run_prediction(
|
||||
let _step_progress = Progress::global().start(Step::Predict, &example.name);
|
||||
|
||||
if example.prompt.is_none() {
|
||||
run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
|
||||
run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
|
||||
}
|
||||
|
||||
let batched = matches!(provider, PredictionProvider::Teacher);
|
||||
return predict_anthropic(example, repetition_count, batched).await;
|
||||
}
|
||||
|
||||
run_load_project(example, app_state.clone(), cx.clone()).await;
|
||||
run_load_project(example, app_state.clone(), cx.clone()).await?;
|
||||
|
||||
let _step_progress = Progress::global().start(Step::Predict, &example.name);
|
||||
|
||||
@@ -62,10 +63,9 @@ pub async fn run_prediction(
|
||||
.get_or_init(|| {
|
||||
let client = app_state.client.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
client
|
||||
.sign_in_with_optional_connect(true, cx)
|
||||
.await
|
||||
.unwrap();
|
||||
if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
|
||||
eprintln!("Authentication failed: {}", e);
|
||||
}
|
||||
})
|
||||
.shared()
|
||||
})
|
||||
@@ -73,33 +73,30 @@ pub async fn run_prediction(
|
||||
.await;
|
||||
}
|
||||
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
let ep_store = cx.update(|cx| {
|
||||
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
|
||||
})??;
|
||||
|
||||
ep_store
|
||||
.update(&mut cx, |store, _cx| {
|
||||
let model = match provider {
|
||||
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
|
||||
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
|
||||
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
|
||||
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
|
||||
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
|
||||
unreachable!()
|
||||
}
|
||||
};
|
||||
store.set_edit_prediction_model(model);
|
||||
})
|
||||
.unwrap();
|
||||
let state = example.state.as_ref().unwrap();
|
||||
ep_store.update(&mut cx, |store, _cx| {
|
||||
let model = match provider {
|
||||
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
|
||||
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
|
||||
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
|
||||
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
|
||||
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
|
||||
unreachable!()
|
||||
}
|
||||
};
|
||||
store.set_edit_prediction_model(model);
|
||||
})?;
|
||||
let state = example.state.as_ref().context("state must be set")?;
|
||||
let run_dir = RUN_DIR.join(&example.name);
|
||||
|
||||
let updated_example = Arc::new(Mutex::new(example.clone()));
|
||||
let current_run_ix = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let mut debug_rx = ep_store
|
||||
.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
|
||||
.unwrap();
|
||||
let mut debug_rx =
|
||||
ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))?;
|
||||
let debug_task = cx.background_spawn({
|
||||
let updated_example = updated_example.clone();
|
||||
let current_run_ix = current_run_ix.clone();
|
||||
@@ -153,14 +150,14 @@ pub async fn run_prediction(
|
||||
run_dir.clone()
|
||||
};
|
||||
|
||||
fs::create_dir_all(&run_dir).unwrap();
|
||||
fs::create_dir_all(&run_dir)?;
|
||||
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
|
||||
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
|
||||
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
|
||||
}
|
||||
#[cfg(unix)]
|
||||
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
|
||||
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
|
||||
#[cfg(windows)]
|
||||
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
|
||||
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
|
||||
|
||||
updated_example
|
||||
.lock()
|
||||
@@ -181,10 +178,8 @@ pub async fn run_prediction(
|
||||
cloud_llm_client::PredictEditsRequestTrigger::Cli,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let actual_patch = prediction
|
||||
.and_then(|prediction| {
|
||||
@@ -213,20 +208,23 @@ pub async fn run_prediction(
|
||||
}
|
||||
}
|
||||
|
||||
ep_store
|
||||
.update(&mut cx, |store, _| {
|
||||
store.remove_project(&state.project);
|
||||
})
|
||||
.unwrap();
|
||||
debug_task.await.unwrap();
|
||||
ep_store.update(&mut cx, |store, _| {
|
||||
store.remove_project(&state.project);
|
||||
})?;
|
||||
debug_task.await?;
|
||||
|
||||
*example = Arc::into_inner(updated_example)
|
||||
.unwrap()
|
||||
.ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
|
||||
.into_inner()
|
||||
.unwrap();
|
||||
.map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
|
||||
async fn predict_anthropic(
|
||||
example: &mut Example,
|
||||
_repetition_count: usize,
|
||||
batched: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let llm_model_name = "claude-sonnet-4-5";
|
||||
let max_tokens = 16384;
|
||||
let llm_client = if batched {
|
||||
@@ -234,12 +232,9 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
|
||||
} else {
|
||||
AnthropicClient::plain()
|
||||
};
|
||||
let llm_client = llm_client.expect("Failed to create LLM client");
|
||||
let llm_client = llm_client.context("Failed to create LLM client")?;
|
||||
|
||||
let prompt = example
|
||||
.prompt
|
||||
.as_ref()
|
||||
.unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
|
||||
let prompt = example.prompt.as_ref().context("Prompt is required")?;
|
||||
|
||||
let messages = vec![anthropic::Message {
|
||||
role: anthropic::Role::User,
|
||||
@@ -251,11 +246,10 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
|
||||
|
||||
let Some(response) = llm_client
|
||||
.generate(llm_model_name, max_tokens, messages)
|
||||
.await
|
||||
.unwrap()
|
||||
.await?
|
||||
else {
|
||||
// Request stashed for batched processing
|
||||
return;
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let actual_output = response
|
||||
@@ -268,7 +262,7 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
let actual_patch = TeacherPrompt::parse(example, &actual_output);
|
||||
let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
|
||||
|
||||
let prediction = ExamplePrediction {
|
||||
actual_patch,
|
||||
@@ -277,19 +271,21 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
|
||||
};
|
||||
|
||||
example.predictions.push(prediction);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn sync_batches(provider: &PredictionProvider) {
|
||||
pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
|
||||
match provider {
|
||||
PredictionProvider::Teacher => {
|
||||
let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
|
||||
let llm_client =
|
||||
AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
|
||||
AnthropicClient::batch(cache_path).context("Failed to create LLM client")?;
|
||||
llm_client
|
||||
.sync_batches()
|
||||
.await
|
||||
.expect("Failed to sync batches");
|
||||
.context("Failed to sync batches")?;
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ struct ProgressInner {
|
||||
max_example_name_len: usize,
|
||||
status_lines_displayed: usize,
|
||||
total_examples: usize,
|
||||
failed_examples: usize,
|
||||
last_line_is_logging: bool,
|
||||
}
|
||||
|
||||
@@ -78,7 +79,7 @@ impl Step {
|
||||
static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
|
||||
static LOGGER: ProgressLogger = ProgressLogger;
|
||||
|
||||
const RIGHT_MARGIN: usize = 4;
|
||||
const MARGIN: usize = 4;
|
||||
const MAX_STATUS_LINES: usize = 10;
|
||||
|
||||
impl Progress {
|
||||
@@ -95,6 +96,7 @@ impl Progress {
|
||||
max_example_name_len: 0,
|
||||
status_lines_displayed: 0,
|
||||
total_examples: 0,
|
||||
failed_examples: 0,
|
||||
last_line_is_logging: false,
|
||||
}),
|
||||
});
|
||||
@@ -110,6 +112,11 @@ impl Progress {
|
||||
inner.total_examples = total;
|
||||
}
|
||||
|
||||
pub fn increment_failed(&self) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
inner.failed_examples += 1;
|
||||
}
|
||||
|
||||
/// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
|
||||
/// This should be used for any output that needs to appear above the status lines.
|
||||
fn log(&self, message: &str) {
|
||||
@@ -119,7 +126,7 @@ impl Progress {
|
||||
if !inner.last_line_is_logging {
|
||||
let reset = "\x1b[0m";
|
||||
let dim = "\x1b[2m";
|
||||
let divider = "─".repeat(inner.terminal_width.saturating_sub(RIGHT_MARGIN));
|
||||
let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
|
||||
eprintln!("{dim}{divider}{reset}");
|
||||
inner.last_line_is_logging = true;
|
||||
}
|
||||
@@ -180,7 +187,7 @@ impl Progress {
|
||||
if inner.last_line_is_logging {
|
||||
let reset = "\x1b[0m";
|
||||
let dim = "\x1b[2m";
|
||||
let divider = "─".repeat(inner.terminal_width.saturating_sub(RIGHT_MARGIN));
|
||||
let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
|
||||
eprintln!("{dim}{divider}{reset}");
|
||||
inner.last_line_is_logging = false;
|
||||
}
|
||||
@@ -229,7 +236,7 @@ impl Progress {
|
||||
let duration_with_margin = format!("{duration} ");
|
||||
let padding_needed = inner
|
||||
.terminal_width
|
||||
.saturating_sub(RIGHT_MARGIN)
|
||||
.saturating_sub(MARGIN)
|
||||
.saturating_sub(duration_with_margin.len())
|
||||
.saturating_sub(strip_ansi_len(&prefix));
|
||||
let padding = " ".repeat(padding_needed);
|
||||
@@ -263,20 +270,33 @@ impl Progress {
|
||||
// Build the done/in-progress/total label
|
||||
let done_count = inner.completed.len();
|
||||
let in_progress_count = inner.in_progress.len();
|
||||
let failed_count = inner.failed_examples;
|
||||
|
||||
let failed_label = if failed_count > 0 {
|
||||
format!(" {} failed ", failed_count)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let range_label = format!(
|
||||
" {}/{}/{} ",
|
||||
done_count, in_progress_count, inner.total_examples
|
||||
);
|
||||
|
||||
// Print a divider line with range label aligned with timestamps
|
||||
// Print a divider line with failed count on left, range label on right
|
||||
let failed_visible_len = strip_ansi_len(&failed_label);
|
||||
let range_visible_len = range_label.len();
|
||||
let left_divider_len = inner
|
||||
let middle_divider_len = inner
|
||||
.terminal_width
|
||||
.saturating_sub(RIGHT_MARGIN)
|
||||
.saturating_sub(MARGIN * 2)
|
||||
.saturating_sub(failed_visible_len)
|
||||
.saturating_sub(range_visible_len);
|
||||
let left_divider = "─".repeat(left_divider_len);
|
||||
let right_divider = "─".repeat(RIGHT_MARGIN);
|
||||
eprintln!("{dim}{left_divider}{reset}{range_label}{dim}{right_divider}{reset}");
|
||||
let left_divider = "─".repeat(MARGIN);
|
||||
let middle_divider = "─".repeat(middle_divider_len);
|
||||
let right_divider = "─".repeat(MARGIN);
|
||||
eprintln!(
|
||||
"{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
|
||||
);
|
||||
|
||||
let mut tasks: Vec<_> = inner.in_progress.iter().collect();
|
||||
tasks.sort_by_key(|(name, _)| *name);
|
||||
@@ -304,7 +324,7 @@ impl Progress {
|
||||
let duration_with_margin = format!("{elapsed} ");
|
||||
let padding_needed = inner
|
||||
.terminal_width
|
||||
.saturating_sub(RIGHT_MARGIN)
|
||||
.saturating_sub(MARGIN)
|
||||
.saturating_sub(duration_with_margin.len())
|
||||
.saturating_sub(strip_ansi_len(&prefix));
|
||||
let padding = " ".repeat(padding_needed);
|
||||
@@ -324,9 +344,23 @@ impl Progress {
|
||||
let _ = std::io::stderr().flush();
|
||||
}
|
||||
|
||||
pub fn clear(&self) {
|
||||
pub fn finalize(&self) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
Self::clear_status_lines(&mut inner);
|
||||
|
||||
// Print summary if there were failures
|
||||
if inner.failed_examples > 0 {
|
||||
let total_processed = inner.completed.len() + inner.failed_examples;
|
||||
let percentage = if total_processed > 0 {
|
||||
inner.failed_examples as f64 / total_processed as f64 * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
eprintln!(
|
||||
"\n{} of {} examples failed ({:.1}%)",
|
||||
inner.failed_examples, total_processed, percentage
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ use crate::{
|
||||
load_project::run_load_project,
|
||||
progress::{InfoStyle, Progress, Step, StepProgress},
|
||||
};
|
||||
use anyhow::Context as _;
|
||||
use collections::HashSet;
|
||||
use edit_prediction::{DebugEvent, EditPredictionStore};
|
||||
use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
|
||||
@@ -17,12 +18,12 @@ pub async fn run_context_retrieval(
|
||||
example: &mut Example,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) {
|
||||
) -> anyhow::Result<()> {
|
||||
if example.context.is_some() {
|
||||
return;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
run_load_project(example, app_state.clone(), cx.clone()).await;
|
||||
run_load_project(example, app_state.clone(), cx.clone()).await?;
|
||||
|
||||
let step_progress: Arc<StepProgress> = Progress::global()
|
||||
.start(Step::Context, &example.name)
|
||||
@@ -31,25 +32,21 @@ pub async fn run_context_retrieval(
|
||||
let state = example.state.as_ref().unwrap();
|
||||
let project = state.project.clone();
|
||||
|
||||
let _lsp_handle = project
|
||||
.update(&mut cx, |project, cx| {
|
||||
project.register_buffer_with_language_servers(&state.buffer, cx)
|
||||
})
|
||||
.unwrap();
|
||||
wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await;
|
||||
let _lsp_handle = project.update(&mut cx, |project, cx| {
|
||||
project.register_buffer_with_language_servers(&state.buffer, cx)
|
||||
})?;
|
||||
wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
|
||||
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
let ep_store = cx.update(|cx| {
|
||||
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
|
||||
})??;
|
||||
|
||||
let mut events = ep_store
|
||||
.update(&mut cx, |store, cx| {
|
||||
store.register_buffer(&state.buffer, &project, cx);
|
||||
store.set_use_context(true);
|
||||
store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
|
||||
store.debug_info(&project, cx)
|
||||
})
|
||||
.unwrap();
|
||||
let mut events = ep_store.update(&mut cx, |store, cx| {
|
||||
store.register_buffer(&state.buffer, &project, cx);
|
||||
store.set_use_context(true);
|
||||
store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
|
||||
store.debug_info(&project, cx)
|
||||
})?;
|
||||
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
@@ -60,9 +57,8 @@ pub async fn run_context_retrieval(
|
||||
}
|
||||
}
|
||||
|
||||
let context_files = ep_store
|
||||
.update(&mut cx, |store, cx| store.context_for_project(&project, cx))
|
||||
.unwrap();
|
||||
let context_files =
|
||||
ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx))?;
|
||||
|
||||
let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
|
||||
step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
|
||||
@@ -70,6 +66,7 @@ pub async fn run_context_retrieval(
|
||||
example.context = Some(ExampleContext {
|
||||
files: context_files,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn wait_for_language_servers_to_start(
|
||||
@@ -77,10 +74,8 @@ async fn wait_for_language_servers_to_start(
|
||||
buffer: &Entity<Buffer>,
|
||||
step_progress: &Arc<StepProgress>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
let lsp_store = project
|
||||
.read_with(cx, |project, _| project.lsp_store())
|
||||
.unwrap();
|
||||
) -> anyhow::Result<()> {
|
||||
let lsp_store = project.read_with(cx, |project, _| project.lsp_store())?;
|
||||
|
||||
let (language_server_ids, mut starting_language_server_ids) = buffer
|
||||
.update(cx, |buffer, cx| {
|
||||
@@ -123,7 +118,7 @@ async fn wait_for_language_servers_to_start(
|
||||
}
|
||||
},
|
||||
_ = timeout.clone().fuse() => {
|
||||
panic!("LSP wait timed out after 5 minutes");
|
||||
return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -132,8 +127,7 @@ async fn wait_for_language_servers_to_start(
|
||||
|
||||
if !language_server_ids.is_empty() {
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
|
||||
.unwrap()
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
|
||||
.detach();
|
||||
}
|
||||
|
||||
@@ -175,10 +169,8 @@ async fn wait_for_language_servers_to_start(
|
||||
];
|
||||
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
|
||||
.await?;
|
||||
|
||||
let mut pending_language_server_ids = HashSet::from_iter(language_server_ids.into_iter());
|
||||
while !pending_language_server_ids.is_empty() {
|
||||
@@ -189,11 +181,12 @@ async fn wait_for_language_servers_to_start(
|
||||
}
|
||||
},
|
||||
_ = timeout.clone().fuse() => {
|
||||
panic!("LSP wait timed out after 5 minutes");
|
||||
return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
drop(subscriptions);
|
||||
step_progress.clear_substatus();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ pub async fn run_scoring(
|
||||
args: &PredictArgs,
|
||||
app_state: Arc<EpAppState>,
|
||||
cx: AsyncApp,
|
||||
) {
|
||||
) -> anyhow::Result<()> {
|
||||
run_prediction(
|
||||
example,
|
||||
Some(args.provider),
|
||||
@@ -23,7 +23,7 @@ pub async fn run_scoring(
|
||||
app_state,
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
.await?;
|
||||
|
||||
let _progress = Progress::global().start(Step::Score, &example.name);
|
||||
|
||||
@@ -43,6 +43,7 @@ pub async fn run_scoring(
|
||||
}
|
||||
|
||||
example.score = scores;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
|
||||
|
||||
Reference in New Issue
Block a user