Compare commits
10 Commits
ex-pointer
...
zeta-disti
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be488f68bc | ||
|
|
f380c25925 | ||
|
|
54cd4220c0 | ||
|
|
d502082717 | ||
|
|
a747c2e3ed | ||
|
|
f105171aab | ||
|
|
09b191cc49 | ||
|
|
5431e12cc5 | ||
|
|
dbafd0aab6 | ||
|
|
ad0283cd8d |
89
Cargo.lock
generated
89
Cargo.lock
generated
@@ -639,6 +639,38 @@ dependencies = [
|
||||
"thiserror 2.0.17",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anthropic-sdk-rust"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "45395c37cc1ce9981a0bd0ba573462ce4809586324cb7b358ecd2e5ed6652485"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"backoff",
|
||||
"base64 0.22.1",
|
||||
"bytes 1.10.1",
|
||||
"chrono",
|
||||
"dotenvy",
|
||||
"eventsource-stream",
|
||||
"futures 0.3.31",
|
||||
"mime",
|
||||
"once_cell",
|
||||
"pin-project",
|
||||
"reqwest 0.12.24",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"tempfile",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tower 0.4.13",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"url",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "any_vec"
|
||||
version = "0.14.0"
|
||||
@@ -1959,6 +1991,17 @@ dependencies = [
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backoff"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
|
||||
dependencies = [
|
||||
"getrandom 0.2.16",
|
||||
"instant",
|
||||
"rand 0.8.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.76"
|
||||
@@ -5165,6 +5208,7 @@ dependencies = [
|
||||
name = "edit_prediction_cli"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anthropic-sdk-rust",
|
||||
"anyhow",
|
||||
"chrono",
|
||||
"clap",
|
||||
@@ -5738,6 +5782,17 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "eventsource-stream"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"nom 7.1.3",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exec"
|
||||
version = "0.3.1"
|
||||
@@ -7855,6 +7910,22 @@ dependencies = [
|
||||
"tokio-native-tls",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-tls"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
|
||||
dependencies = [
|
||||
"bytes 1.10.1",
|
||||
"http-body-util",
|
||||
"hyper 1.7.0",
|
||||
"hyper-util",
|
||||
"native-tls",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-util"
|
||||
version = "0.1.17"
|
||||
@@ -7874,9 +7945,11 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"socket2 0.6.1",
|
||||
"system-configuration 0.6.1",
|
||||
"tokio",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
"windows-registry 0.5.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8128,7 +8201,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4b0f83760fb341a774ed326568e19f5a863af4a952def8c39f9ab92fd95b88e5"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown 0.15.5",
|
||||
"hashbrown 0.16.1",
|
||||
"serde",
|
||||
"serde_core",
|
||||
]
|
||||
@@ -13460,7 +13533,7 @@ dependencies = [
|
||||
"http-body 0.4.6",
|
||||
"hyper 0.14.32",
|
||||
"hyper-rustls 0.24.2",
|
||||
"hyper-tls",
|
||||
"hyper-tls 0.5.0",
|
||||
"ipnet",
|
||||
"js-sys",
|
||||
"log",
|
||||
@@ -13496,29 +13569,40 @@ checksum = "9d0946410b9f7b082a427e4ef5c8ff541a88b357bc6c637c40db3a68ac70a36f"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"bytes 1.10.1",
|
||||
"encoding_rs",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2 0.4.12",
|
||||
"http 1.3.1",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"hyper 1.7.0",
|
||||
"hyper-rustls 0.27.7",
|
||||
"hyper-tls 0.6.0",
|
||||
"hyper-util",
|
||||
"js-sys",
|
||||
"log",
|
||||
"mime",
|
||||
"mime_guess",
|
||||
"native-tls",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustls-pki-types",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper 1.0.2",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-util",
|
||||
"tower 0.5.2",
|
||||
"tower-http 0.6.6",
|
||||
"tower-service",
|
||||
"url",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"wasm-streams",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
@@ -16818,6 +16902,7 @@ dependencies = [
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -13,6 +13,7 @@ name = "ep_cli"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
anthropic-sdk-rust = "0.1.1"
|
||||
|
||||
anyhow.workspace = true
|
||||
chrono.workspace = true
|
||||
|
||||
@@ -3,6 +3,8 @@ use std::{
|
||||
cell::RefCell,
|
||||
fmt::{self, Display},
|
||||
fs,
|
||||
hash::Hash,
|
||||
hash::Hasher,
|
||||
io::Write,
|
||||
mem,
|
||||
path::{Path, PathBuf},
|
||||
@@ -43,7 +45,7 @@ pub struct NamedExample {
|
||||
pub example: Example,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Hash, Serialize, Deserialize)]
|
||||
pub struct Example {
|
||||
pub repository_url: String,
|
||||
pub revision: String,
|
||||
@@ -54,6 +56,134 @@ pub struct Example {
|
||||
pub expected_patch: String,
|
||||
}
|
||||
|
||||
impl Example {
|
||||
fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
|
||||
// git@github.com:owner/repo.git
|
||||
if self.repository_url.contains('@') {
|
||||
let (owner, repo) = self
|
||||
.repository_url
|
||||
.split_once(':')
|
||||
.context("expected : in git url")?
|
||||
.1
|
||||
.split_once('/')
|
||||
.context("expected / in git url")?;
|
||||
Ok((
|
||||
Cow::Borrowed(owner),
|
||||
Cow::Borrowed(repo.trim_end_matches(".git")),
|
||||
))
|
||||
// http://github.com/owner/repo.git
|
||||
} else {
|
||||
let url = Url::parse(&self.repository_url)?;
|
||||
let mut segments = url.path_segments().context("empty http url")?;
|
||||
let owner = segments
|
||||
.next()
|
||||
.context("expected owner path segment")?
|
||||
.to_string();
|
||||
let repo = segments
|
||||
.next()
|
||||
.context("expected repo path segment")?
|
||||
.trim_end_matches(".git")
|
||||
.to_string();
|
||||
assert!(segments.next().is_none());
|
||||
|
||||
Ok((owner.into(), repo.into()))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn setup_worktree(&self, file_name: String) -> Result<PathBuf> {
|
||||
let (repo_owner, repo_name) = self.repo_name()?;
|
||||
|
||||
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
|
||||
let repo_lock = lock_repo(&repo_dir).await;
|
||||
|
||||
if !repo_dir.is_dir() {
|
||||
fs::create_dir_all(&repo_dir)?;
|
||||
run_git(&repo_dir, &["init"]).await?;
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["remote", "add", "origin", &self.repository_url],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Resolve the example to a revision, fetching it if needed.
|
||||
let revision = run_git(
|
||||
&repo_dir,
|
||||
&["rev-parse", &format!("{}^{{commit}}", self.revision)],
|
||||
)
|
||||
.await;
|
||||
let revision = if let Ok(revision) = revision {
|
||||
revision
|
||||
} else {
|
||||
if run_git(
|
||||
&repo_dir,
|
||||
&["fetch", "--depth", "1", "origin", &self.revision],
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
run_git(&repo_dir, &["fetch", "origin"]).await?;
|
||||
}
|
||||
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
|
||||
if revision != self.revision {
|
||||
run_git(&repo_dir, &["tag", &self.revision, &revision]).await?;
|
||||
}
|
||||
revision
|
||||
};
|
||||
|
||||
// Create the worktree for this example if needed.
|
||||
let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
|
||||
if worktree_path.is_dir() {
|
||||
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
|
||||
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
|
||||
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
|
||||
} else {
|
||||
let worktree_path_string = worktree_path.to_string_lossy();
|
||||
run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["worktree", "add", "-f", &worktree_path_string, &file_name],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
drop(repo_lock);
|
||||
|
||||
// Apply the uncommitted diff for this example.
|
||||
if !self.uncommitted_diff.is_empty() {
|
||||
let mut apply_process = smol::process::Command::new("git")
|
||||
.current_dir(&worktree_path)
|
||||
.args(&["apply", "-"])
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
let mut stdin = apply_process.stdin.take().unwrap();
|
||||
stdin.write_all(self.uncommitted_diff.as_bytes()).await?;
|
||||
stdin.close().await?;
|
||||
drop(stdin);
|
||||
|
||||
let apply_result = apply_process.output().await?;
|
||||
if !apply_result.status.success() {
|
||||
anyhow::bail!(
|
||||
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
apply_result.status,
|
||||
String::from_utf8_lossy(&apply_result.stderr),
|
||||
String::from_utf8_lossy(&apply_result.stdout),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(worktree_path)
|
||||
}
|
||||
|
||||
pub fn unique_name(&self) -> String {
|
||||
let mut hasher = std::hash::DefaultHasher::new();
|
||||
self.hash(&mut hasher);
|
||||
let disambiguator = hasher.finish();
|
||||
let hash = format!("{:04x}", disambiguator);
|
||||
format!("{}_{}", &self.revision[..8], &hash[..4])
|
||||
}
|
||||
}
|
||||
|
||||
pub type ActualExcerpt = Excerpt;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
@@ -292,90 +422,7 @@ impl NamedExample {
|
||||
}
|
||||
|
||||
pub async fn setup_worktree(&self) -> Result<PathBuf> {
|
||||
let (repo_owner, repo_name) = self.repo_name()?;
|
||||
let file_name = self.file_name();
|
||||
|
||||
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
|
||||
let repo_lock = lock_repo(&repo_dir).await;
|
||||
|
||||
if !repo_dir.is_dir() {
|
||||
fs::create_dir_all(&repo_dir)?;
|
||||
run_git(&repo_dir, &["init"]).await?;
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["remote", "add", "origin", &self.example.repository_url],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Resolve the example to a revision, fetching it if needed.
|
||||
let revision = run_git(
|
||||
&repo_dir,
|
||||
&[
|
||||
"rev-parse",
|
||||
&format!("{}^{{commit}}", self.example.revision),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
let revision = if let Ok(revision) = revision {
|
||||
revision
|
||||
} else {
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["fetch", "--depth", "1", "origin", &self.example.revision],
|
||||
)
|
||||
.await?;
|
||||
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
|
||||
if revision != self.example.revision {
|
||||
run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?;
|
||||
}
|
||||
revision
|
||||
};
|
||||
|
||||
// Create the worktree for this example if needed.
|
||||
let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
|
||||
if worktree_path.is_dir() {
|
||||
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
|
||||
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
|
||||
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
|
||||
} else {
|
||||
let worktree_path_string = worktree_path.to_string_lossy();
|
||||
run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["worktree", "add", "-f", &worktree_path_string, &file_name],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
drop(repo_lock);
|
||||
|
||||
// Apply the uncommitted diff for this example.
|
||||
if !self.example.uncommitted_diff.is_empty() {
|
||||
let mut apply_process = smol::process::Command::new("git")
|
||||
.current_dir(&worktree_path)
|
||||
.args(&["apply", "-"])
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
let mut stdin = apply_process.stdin.take().unwrap();
|
||||
stdin
|
||||
.write_all(self.example.uncommitted_diff.as_bytes())
|
||||
.await?;
|
||||
stdin.close().await?;
|
||||
drop(stdin);
|
||||
|
||||
let apply_result = apply_process.output().await?;
|
||||
if !apply_result.status.success() {
|
||||
anyhow::bail!(
|
||||
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
apply_result.status,
|
||||
String::from_utf8_lossy(&apply_result.stderr),
|
||||
String::from_utf8_lossy(&apply_result.stdout),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(worktree_path)
|
||||
self.example.setup_worktree(self.file_name()).await
|
||||
}
|
||||
|
||||
pub fn file_name(&self) -> String {
|
||||
@@ -391,40 +438,6 @@ impl NamedExample {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
|
||||
// git@github.com:owner/repo.git
|
||||
if self.example.repository_url.contains('@') {
|
||||
let (owner, repo) = self
|
||||
.example
|
||||
.repository_url
|
||||
.split_once(':')
|
||||
.context("expected : in git url")?
|
||||
.1
|
||||
.split_once('/')
|
||||
.context("expected / in git url")?;
|
||||
Ok((
|
||||
Cow::Borrowed(owner),
|
||||
Cow::Borrowed(repo.trim_end_matches(".git")),
|
||||
))
|
||||
// http://github.com/owner/repo.git
|
||||
} else {
|
||||
let url = Url::parse(&self.example.repository_url)?;
|
||||
let mut segments = url.path_segments().context("empty http url")?;
|
||||
let owner = segments
|
||||
.next()
|
||||
.context("expected owner path segment")?
|
||||
.to_string();
|
||||
let repo = segments
|
||||
.next()
|
||||
.context("expected repo path segment")?
|
||||
.trim_end_matches(".git")
|
||||
.to_string();
|
||||
assert!(segments.next().is_none());
|
||||
|
||||
Ok((owner.into(), repo.into()))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn cursor_position(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
|
||||
@@ -5,6 +5,7 @@ mod metrics;
|
||||
mod paths;
|
||||
mod predict;
|
||||
mod source_location;
|
||||
mod training;
|
||||
mod util;
|
||||
|
||||
use crate::{
|
||||
@@ -13,9 +14,10 @@ use crate::{
|
||||
headless::ZetaCliAppState,
|
||||
predict::run_predict,
|
||||
source_location::SourceLocation,
|
||||
training::{context::ContextType, distill::run_distill},
|
||||
util::{open_buffer, open_buffer_with_language_server},
|
||||
};
|
||||
use ::util::paths::PathStyle;
|
||||
use ::util::{ResultExt, paths::PathStyle};
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Args, Parser, Subcommand, ValueEnum};
|
||||
use cloud_llm_client::predict_edits_v3;
|
||||
@@ -43,6 +45,7 @@ enum Command {
|
||||
Context(ContextArgs),
|
||||
Predict(PredictArguments),
|
||||
Eval(EvaluateArguments),
|
||||
Distill(DistillArguments),
|
||||
ConvertExample {
|
||||
path: PathBuf,
|
||||
#[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
|
||||
@@ -111,6 +114,13 @@ pub struct PredictArguments {
|
||||
options: PredictionOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct DistillArguments {
|
||||
split_commit_dataset: PathBuf,
|
||||
#[clap(long, value_enum, default_value_t = ContextType::CurrentFile)]
|
||||
context_type: ContextType,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Args)]
|
||||
pub struct PredictionOptions {
|
||||
#[clap(flatten)]
|
||||
@@ -468,6 +478,13 @@ fn main() {
|
||||
Some(Command::Eval(arguments)) => {
|
||||
run_evaluate(arguments, &app_state, cx).await;
|
||||
}
|
||||
Some(Command::Distill(arguments)) => {
|
||||
let _guard = cx
|
||||
.update(|cx| gpui_tokio::Tokio::handle(cx))
|
||||
.unwrap()
|
||||
.enter();
|
||||
run_distill(arguments).await.log_err();
|
||||
}
|
||||
Some(Command::ConvertExample {
|
||||
path,
|
||||
output_format,
|
||||
|
||||
89
crates/edit_prediction_cli/src/training/context.rs
Normal file
89
crates/edit_prediction_cli/src/training/context.rs
Normal file
@@ -0,0 +1,89 @@
|
||||
use std::path::Path;
|
||||
|
||||
use crate::{source_location::SourceLocation, training::teacher::TeacherModel};
|
||||
|
||||
#[derive(Debug, Clone, Default, clap::ValueEnum)]
|
||||
pub enum ContextType {
|
||||
#[default]
|
||||
CurrentFile,
|
||||
}
|
||||
|
||||
const MAX_CONTEXT_SIZE: usize = 32768;
|
||||
|
||||
pub fn collect_context(
|
||||
context_type: &ContextType,
|
||||
worktree_dir: &Path,
|
||||
cursor: SourceLocation,
|
||||
) -> String {
|
||||
let context = match context_type {
|
||||
ContextType::CurrentFile => {
|
||||
let file_path = worktree_dir.join(cursor.path.as_std_path());
|
||||
let context = std::fs::read_to_string(&file_path).unwrap_or_default();
|
||||
|
||||
let context = add_special_tags(&context, worktree_dir, cursor);
|
||||
context
|
||||
}
|
||||
};
|
||||
|
||||
let region_end_offset = context.find(TeacherModel::REGION_END);
|
||||
|
||||
if context.len() <= MAX_CONTEXT_SIZE {
|
||||
return context;
|
||||
}
|
||||
|
||||
if let Some(region_end_offset) = region_end_offset
|
||||
&& region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE
|
||||
{
|
||||
let to_truncate = context.len() - MAX_CONTEXT_SIZE;
|
||||
format!(
|
||||
"[...{} bytes truncated]\n{}\n",
|
||||
to_truncate,
|
||||
&context[to_truncate..]
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"{}\n[...{} bytes truncated]\n",
|
||||
&context[..MAX_CONTEXT_SIZE],
|
||||
context.len() - MAX_CONTEXT_SIZE
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Add <|editable_region_start/end|> tags
|
||||
fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String {
|
||||
let path = worktree_dir.join(cursor.path.as_std_path());
|
||||
let file = std::fs::read_to_string(&path).unwrap_or_default();
|
||||
let lines = file.lines().collect::<Vec<_>>();
|
||||
let cursor_row = cursor.point.row as usize;
|
||||
let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE);
|
||||
let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len());
|
||||
|
||||
let snippet = lines[start_line..end_line].join("\n");
|
||||
|
||||
if context.contains(&snippet) {
|
||||
let mut cursor_line = lines[cursor_row].to_string();
|
||||
cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR);
|
||||
|
||||
let mut snippet_with_tags_lines = vec![];
|
||||
snippet_with_tags_lines.push(TeacherModel::REGION_START);
|
||||
snippet_with_tags_lines.extend(&lines[start_line..cursor_row]);
|
||||
snippet_with_tags_lines.push(&cursor_line);
|
||||
snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]);
|
||||
snippet_with_tags_lines.push(TeacherModel::REGION_END);
|
||||
let snippet_with_tags = snippet_with_tags_lines.join("\n");
|
||||
|
||||
context.replace(&snippet, &snippet_with_tags)
|
||||
} else {
|
||||
log::warn!(
|
||||
"Can't find area around the cursor in the context; proceeding without special tags"
|
||||
);
|
||||
context.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn strip_special_tags(context: &str) -> String {
|
||||
context
|
||||
.replace(TeacherModel::REGION_START, "")
|
||||
.replace(TeacherModel::REGION_END, "")
|
||||
.replace(TeacherModel::USER_CURSOR, "")
|
||||
}
|
||||
61
crates/edit_prediction_cli/src/training/distill.rs
Normal file
61
crates/edit_prediction_cli/src/training/distill.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{
|
||||
DistillArguments,
|
||||
example::Example,
|
||||
source_location::SourceLocation,
|
||||
training::{
|
||||
context::ContextType,
|
||||
teacher::{TeacherModel, TeacherOutput},
|
||||
},
|
||||
};
|
||||
use anyhow::Result;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SplitCommit {
|
||||
repo_url: String,
|
||||
commit_sha: String,
|
||||
edit_history: String,
|
||||
expected_patch: String,
|
||||
cursor_position: String,
|
||||
}
|
||||
|
||||
pub async fn run_distill(arguments: DistillArguments) -> Result<()> {
|
||||
let split_commits: Vec<SplitCommit> = std::fs::read_to_string(&arguments.split_commit_dataset)
|
||||
.expect("Failed to read split commit dataset")
|
||||
.lines()
|
||||
.map(|line| serde_json::from_str(line).expect("Failed to parse JSON line"))
|
||||
.collect();
|
||||
|
||||
for commit in split_commits {
|
||||
let distilled = distill_one(commit).await?;
|
||||
println!("{}", serde_json::to_string(&distilled)?);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn distill_one(commit: SplitCommit) -> Result<TeacherOutput> {
|
||||
let cursor: SourceLocation = commit
|
||||
.cursor_position
|
||||
.parse()
|
||||
.expect("Failed to parse cursor position");
|
||||
|
||||
let path = cursor.path.to_rel_path_buf();
|
||||
|
||||
let example = Example {
|
||||
repository_url: commit.repo_url,
|
||||
revision: commit.commit_sha,
|
||||
uncommitted_diff: commit.edit_history.clone(),
|
||||
cursor_path: path.as_std_path().to_path_buf(),
|
||||
cursor_position: commit.cursor_position,
|
||||
edit_history: commit.edit_history, // todo: trim
|
||||
expected_patch: commit.expected_patch,
|
||||
};
|
||||
|
||||
let teacher = TeacherModel::new("claude-sonnet-4-5".to_string(), ContextType::CurrentFile);
|
||||
|
||||
let prediction = teacher.predict(example).await;
|
||||
|
||||
prediction
|
||||
}
|
||||
3
crates/edit_prediction_cli/src/training/mod.rs
Normal file
3
crates/edit_prediction_cli/src/training/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod context;
|
||||
pub mod distill;
|
||||
pub mod teacher;
|
||||
48
crates/edit_prediction_cli/src/training/teacher.prompt.md
Normal file
48
crates/edit_prediction_cli/src/training/teacher.prompt.md
Normal file
@@ -0,0 +1,48 @@
|
||||
# Instructions
|
||||
|
||||
You are a code completion assistant helping a programmer finish their work. Your task is to:
|
||||
|
||||
1. Analyze the edit history to understand what the programmer is trying to achieve
|
||||
2. Identify any incomplete refactoring or changes that need to be finished
|
||||
3. Make the remaining edits that a human programmer would logically make next (by rewriting the corresponding code sections)
|
||||
4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere.
|
||||
|
||||
Focus on:
|
||||
- Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs)
|
||||
- Completing any partially-applied changes across the codebase
|
||||
- Ensuring consistency with the programming style and patterns already established
|
||||
- Making edits that maintain or improve code quality
|
||||
- If the programmer started refactoring one instance of a pattern, find and update ALL similar instances
|
||||
- Don't write a lot of code if you're not sure what to do
|
||||
|
||||
Rules:
|
||||
- Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
|
||||
- Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
|
||||
|
||||
Input format:
|
||||
- You receive small code fragments called context (structs, field definitions, function signatures, etc.). They may or may not be relevant.
|
||||
- Never modify the context code.
|
||||
- You also receive a code snippet between <|editable_region_start|> and <|editable_region_end|>. This is the editable region.
|
||||
- The cursor position is marked with <|user_cursor|>.
|
||||
|
||||
Output format:
|
||||
- Return the entire editable region, applying any edits you make.
|
||||
- Remove the <|user_cursor|> marker.
|
||||
- Wrap the edited code in a block of exactly five backticks.
|
||||
|
||||
Output example:
|
||||
`````
|
||||
// `zed --askpass` Makes zed operate in nc/netcat mode for use with askpass
|
||||
if let Some(socket) = &args.askpass {{
|
||||
askpass::main(socket);
|
||||
return Ok(());
|
||||
}}
|
||||
`````
|
||||
|
||||
## User Edits History
|
||||
|
||||
{{edit_history}}
|
||||
|
||||
## Code Context
|
||||
|
||||
{{context}}
|
||||
233
crates/edit_prediction_cli/src/training/teacher.rs
Normal file
233
crates/edit_prediction_cli/src/training/teacher.rs
Normal file
@@ -0,0 +1,233 @@
|
||||
use crate::{
|
||||
example::Example,
|
||||
source_location::SourceLocation,
|
||||
training::context::{ContextType, collect_context, strip_special_tags},
|
||||
};
|
||||
use anthropic_sdk::{Anthropic, ContentBlock, MessageCreateBuilder};
|
||||
use anyhow::Result;
|
||||
|
||||
pub struct TeacherModel {
|
||||
llm_name: String,
|
||||
context: ContextType,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct TeacherOutput {
|
||||
parsed_output: String,
|
||||
prompt: String,
|
||||
raw_llm_response: String,
|
||||
context: String,
|
||||
diff: String,
|
||||
}
|
||||
|
||||
impl TeacherModel {
|
||||
const PROMPT: &str = include_str!("teacher.prompt.md");
|
||||
pub(crate) const REGION_START: &str = "<|editable_region_start|>\n";
|
||||
pub(crate) const REGION_END: &str = "<|editable_region_end|>";
|
||||
pub(crate) const USER_CURSOR: &str = "<|user_cursor|>";
|
||||
|
||||
/// Number of lines to include before the cursor position
|
||||
pub(crate) const LEFT_CONTEXT_SIZE: usize = 5;
|
||||
|
||||
/// Number of lines to include after the cursor position
|
||||
pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5;
|
||||
|
||||
/// Truncate edit history to this number of last lines
|
||||
const MAX_HISTORY_LINES: usize = 128;
|
||||
|
||||
pub fn new(llm_name: String, context: ContextType) -> Self {
|
||||
TeacherModel { llm_name, context }
|
||||
}
|
||||
|
||||
pub async fn predict(&self, input: Example) -> Result<TeacherOutput> {
|
||||
let name = input.unique_name();
|
||||
let worktree_dir = input.setup_worktree(name).await?;
|
||||
let cursor: SourceLocation = input
|
||||
.cursor_position
|
||||
.parse()
|
||||
.expect("Failed to parse cursor position");
|
||||
|
||||
let context = collect_context(&self.context, &worktree_dir, cursor.clone());
|
||||
let edit_history = Self::format_edit_history(&input.edit_history);
|
||||
|
||||
let prompt = Self::PROMPT
|
||||
.replace("{{context}}", &context)
|
||||
.replace("{{edit_history}}", &edit_history);
|
||||
|
||||
let client = Anthropic::from_env()?;
|
||||
let response = client
|
||||
.messages()
|
||||
.create(
|
||||
MessageCreateBuilder::new(self.llm_name.clone(), 16384)
|
||||
.user(prompt.clone())
|
||||
.build(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let response_text = response
|
||||
.content
|
||||
.into_iter()
|
||||
.filter_map(|content| {
|
||||
if let ContentBlock::Text { text } = content {
|
||||
Some(text)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
let parsed_output = self.parse_response(&response_text);
|
||||
|
||||
let original_editable_region = Self::extract_editable_region(&context);
|
||||
let context_after_edit = context.replace(&original_editable_region, &parsed_output);
|
||||
let context_after_edit = strip_special_tags(&context_after_edit);
|
||||
let context_before_edit = strip_special_tags(&context);
|
||||
let diff = language::unified_diff(&context_before_edit, &context_after_edit);
|
||||
|
||||
Ok(TeacherOutput {
|
||||
parsed_output,
|
||||
prompt,
|
||||
raw_llm_response: response_text,
|
||||
context,
|
||||
diff,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_response(&self, content: &str) -> String {
|
||||
let codeblock = Self::extract_last_codeblock(content);
|
||||
let editable_region = Self::extract_editable_region(&codeblock);
|
||||
|
||||
editable_region
|
||||
}
|
||||
|
||||
/// Extract content from the last code-fenced block if any, or else return content as is
|
||||
fn extract_last_codeblock(text: &str) -> String {
|
||||
let mut last_block = None;
|
||||
let mut search_start = 0;
|
||||
|
||||
while let Some(start) = text[search_start..].find("```") {
|
||||
let start = start + search_start;
|
||||
let bytes = text.as_bytes();
|
||||
let mut backtick_end = start;
|
||||
|
||||
while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
|
||||
backtick_end += 1;
|
||||
}
|
||||
|
||||
let backtick_count = backtick_end - start;
|
||||
let closing_backticks = "`".repeat(backtick_count);
|
||||
|
||||
if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
|
||||
let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
|
||||
last_block = Some(code_block.to_string());
|
||||
search_start = backtick_end + end_pos + backtick_count;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
last_block.unwrap_or_else(|| text.to_string())
|
||||
}
|
||||
|
||||
fn extract_editable_region(text: &str) -> String {
|
||||
let start = text
|
||||
.find(Self::REGION_START)
|
||||
.map_or(0, |pos| pos + Self::REGION_START.len());
|
||||
let end = text.find(Self::REGION_END).unwrap_or(text.len());
|
||||
|
||||
text[start..end].to_string()
|
||||
}
|
||||
|
||||
/// Truncates edit history to a maximum length and removes comments (unified diff garbage lines)
|
||||
fn format_edit_history(edit_history: &str) -> String {
|
||||
let lines = edit_history
|
||||
.lines()
|
||||
.filter(|&s| Self::is_content_line(s))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
|
||||
&lines[lines.len() - Self::MAX_HISTORY_LINES..]
|
||||
} else {
|
||||
&lines
|
||||
};
|
||||
history_lines.join("\n")
|
||||
}
|
||||
|
||||
fn is_content_line(s: &str) -> bool {
|
||||
s.starts_with("-")
|
||||
|| s.starts_with("+")
|
||||
|| s.starts_with(" ")
|
||||
|| s.starts_with("---")
|
||||
|| s.starts_with("+++")
|
||||
|| s.starts_with("@@")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_response() {
|
||||
let teacher = TeacherModel::new("test".to_string(), ContextType::CurrentFile);
|
||||
let response = "This is a test response.";
|
||||
let parsed = teacher.parse_response(response);
|
||||
assert_eq!(parsed, response.to_string());
|
||||
|
||||
let response = indoc::indoc! {"
|
||||
Some thinking
|
||||
|
||||
`````
|
||||
actual response
|
||||
`````
|
||||
"};
|
||||
let parsed = teacher.parse_response(response);
|
||||
assert_eq!(parsed, "actual response");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_last_code_block() {
|
||||
let text = indoc::indoc! {"
|
||||
Some thinking
|
||||
|
||||
```
|
||||
first block
|
||||
```
|
||||
|
||||
`````
|
||||
last block
|
||||
`````
|
||||
"};
|
||||
let last_block = TeacherModel::extract_last_codeblock(text);
|
||||
assert_eq!(last_block, "last block");
|
||||
|
||||
// https://on.tty-share.com/s/-pZoHQTn8OTfu6W9KvuyEJgFKwfR1CrCJSRwC1Y2I94SzoVLHekaqmrCcaO1d_lNpGQ/
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_editable_region() {
|
||||
let teacher = TeacherModel::new("test".to_string(), ContextType::CurrentFile);
|
||||
let response = indoc::indoc! {"
|
||||
some lines
|
||||
are
|
||||
here
|
||||
<|editable_region_start|>
|
||||
one
|
||||
two three
|
||||
|
||||
<|editable_region_end|>
|
||||
more
|
||||
lines here
|
||||
"};
|
||||
let parsed = teacher.parse_response(response);
|
||||
assert_eq!(
|
||||
parsed,
|
||||
indoc::indoc! {"
|
||||
one
|
||||
two three
|
||||
|
||||
"}
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user