Compare commits

...

10 Commits

Author SHA1 Message Date
Oleksiy Syvokon
be488f68bc Fix a typo 2025-12-08 11:39:36 +02:00
Oleksiy Syvokon
f380c25925 Extract the last codeblock from LLM response
Co-authored-by: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com>
2025-12-05 15:06:35 +02:00
Piotr Osiewicz
54cd4220c0 Merge branch 'main' into zeta-distill-2 2025-12-05 11:24:04 +01:00
Oleksiy Syvokon
d502082717 Compute diff
Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-12-04 20:31:44 +02:00
Oleksiy Syvokon
a747c2e3ed Preprocess edit history
Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-12-04 20:11:27 +02:00
Oleksiy Syvokon
f105171aab Insert special tags
Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-12-04 19:57:19 +02:00
Oleksiy Syvokon
09b191cc49 Use current file as a context
Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-12-04 19:16:06 +02:00
Oleksiy Syvokon
5431e12cc5 Extract ediable region 2025-12-04 16:36:01 +02:00
Oleksiy Syvokon
dbafd0aab6 Setup worktree, parse LLM response 2025-12-04 16:24:12 +02:00
Oleksiy Syvokon
ad0283cd8d Dummy zeta distill command
Co-authored-by: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com>
2025-12-04 15:40:02 +02:00
9 changed files with 672 additions and 122 deletions

89
Cargo.lock generated
View File

@@ -639,6 +639,38 @@ dependencies = [
"thiserror 2.0.17",
]
[[package]]
name = "anthropic-sdk-rust"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45395c37cc1ce9981a0bd0ba573462ce4809586324cb7b358ecd2e5ed6652485"
dependencies = [
"async-trait",
"backoff",
"base64 0.22.1",
"bytes 1.10.1",
"chrono",
"dotenvy",
"eventsource-stream",
"futures 0.3.31",
"mime",
"once_cell",
"pin-project",
"reqwest 0.12.24",
"serde",
"serde_json",
"sha2",
"tempfile",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
"tower 0.4.13",
"tracing",
"tracing-subscriber",
"url",
"uuid",
]
[[package]]
name = "any_vec"
version = "0.14.0"
@@ -1959,6 +1991,17 @@ dependencies = [
"tower-service",
]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"getrandom 0.2.16",
"instant",
"rand 0.8.5",
]
[[package]]
name = "backtrace"
version = "0.3.76"
@@ -5165,6 +5208,7 @@ dependencies = [
name = "edit_prediction_cli"
version = "0.1.0"
dependencies = [
"anthropic-sdk-rust",
"anyhow",
"chrono",
"clap",
@@ -5738,6 +5782,17 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom 7.1.3",
"pin-project-lite",
]
[[package]]
name = "exec"
version = "0.3.1"
@@ -7855,6 +7910,22 @@ dependencies = [
"tokio-native-tls",
]
[[package]]
name = "hyper-tls"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
dependencies = [
"bytes 1.10.1",
"http-body-util",
"hyper 1.7.0",
"hyper-util",
"native-tls",
"tokio",
"tokio-native-tls",
"tower-service",
]
[[package]]
name = "hyper-util"
version = "0.1.17"
@@ -7874,9 +7945,11 @@ dependencies = [
"percent-encoding",
"pin-project-lite",
"socket2 0.6.1",
"system-configuration 0.6.1",
"tokio",
"tower-service",
"tracing",
"windows-registry 0.5.3",
]
[[package]]
@@ -8128,7 +8201,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b0f83760fb341a774ed326568e19f5a863af4a952def8c39f9ab92fd95b88e5"
dependencies = [
"equivalent",
"hashbrown 0.15.5",
"hashbrown 0.16.1",
"serde",
"serde_core",
]
@@ -13460,7 +13533,7 @@ dependencies = [
"http-body 0.4.6",
"hyper 0.14.32",
"hyper-rustls 0.24.2",
"hyper-tls",
"hyper-tls 0.5.0",
"ipnet",
"js-sys",
"log",
@@ -13496,29 +13569,40 @@ checksum = "9d0946410b9f7b082a427e4ef5c8ff541a88b357bc6c637c40db3a68ac70a36f"
dependencies = [
"base64 0.22.1",
"bytes 1.10.1",
"encoding_rs",
"futures-channel",
"futures-core",
"futures-util",
"h2 0.4.12",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"hyper 1.7.0",
"hyper-rustls 0.27.7",
"hyper-tls 0.6.0",
"hyper-util",
"js-sys",
"log",
"mime",
"mime_guess",
"native-tls",
"percent-encoding",
"pin-project-lite",
"rustls-pki-types",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper 1.0.2",
"tokio",
"tokio-native-tls",
"tokio-util",
"tower 0.5.2",
"tower-http 0.6.6",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
]
@@ -16818,6 +16902,7 @@ dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
"tokio-util",
]
[[package]]

View File

@@ -13,6 +13,7 @@ name = "ep_cli"
path = "src/main.rs"
[dependencies]
anthropic-sdk-rust = "0.1.1"
anyhow.workspace = true
chrono.workspace = true

View File

@@ -3,6 +3,8 @@ use std::{
cell::RefCell,
fmt::{self, Display},
fs,
hash::Hash,
hash::Hasher,
io::Write,
mem,
path::{Path, PathBuf},
@@ -43,7 +45,7 @@ pub struct NamedExample {
pub example: Example,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Hash, Serialize, Deserialize)]
pub struct Example {
pub repository_url: String,
pub revision: String,
@@ -54,6 +56,134 @@ pub struct Example {
pub expected_patch: String,
}
impl Example {
fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
// git@github.com:owner/repo.git
if self.repository_url.contains('@') {
let (owner, repo) = self
.repository_url
.split_once(':')
.context("expected : in git url")?
.1
.split_once('/')
.context("expected / in git url")?;
Ok((
Cow::Borrowed(owner),
Cow::Borrowed(repo.trim_end_matches(".git")),
))
// http://github.com/owner/repo.git
} else {
let url = Url::parse(&self.repository_url)?;
let mut segments = url.path_segments().context("empty http url")?;
let owner = segments
.next()
.context("expected owner path segment")?
.to_string();
let repo = segments
.next()
.context("expected repo path segment")?
.trim_end_matches(".git")
.to_string();
assert!(segments.next().is_none());
Ok((owner.into(), repo.into()))
}
}
pub async fn setup_worktree(&self, file_name: String) -> Result<PathBuf> {
let (repo_owner, repo_name) = self.repo_name()?;
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
let repo_lock = lock_repo(&repo_dir).await;
if !repo_dir.is_dir() {
fs::create_dir_all(&repo_dir)?;
run_git(&repo_dir, &["init"]).await?;
run_git(
&repo_dir,
&["remote", "add", "origin", &self.repository_url],
)
.await?;
}
// Resolve the example to a revision, fetching it if needed.
let revision = run_git(
&repo_dir,
&["rev-parse", &format!("{}^{{commit}}", self.revision)],
)
.await;
let revision = if let Ok(revision) = revision {
revision
} else {
if run_git(
&repo_dir,
&["fetch", "--depth", "1", "origin", &self.revision],
)
.await
.is_err()
{
run_git(&repo_dir, &["fetch", "origin"]).await?;
}
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
if revision != self.revision {
run_git(&repo_dir, &["tag", &self.revision, &revision]).await?;
}
revision
};
// Create the worktree for this example if needed.
let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
if worktree_path.is_dir() {
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
} else {
let worktree_path_string = worktree_path.to_string_lossy();
run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
run_git(
&repo_dir,
&["worktree", "add", "-f", &worktree_path_string, &file_name],
)
.await?;
}
drop(repo_lock);
// Apply the uncommitted diff for this example.
if !self.uncommitted_diff.is_empty() {
let mut apply_process = smol::process::Command::new("git")
.current_dir(&worktree_path)
.args(&["apply", "-"])
.stdin(std::process::Stdio::piped())
.spawn()?;
let mut stdin = apply_process.stdin.take().unwrap();
stdin.write_all(self.uncommitted_diff.as_bytes()).await?;
stdin.close().await?;
drop(stdin);
let apply_result = apply_process.output().await?;
if !apply_result.status.success() {
anyhow::bail!(
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
apply_result.status,
String::from_utf8_lossy(&apply_result.stderr),
String::from_utf8_lossy(&apply_result.stdout),
);
}
}
Ok(worktree_path)
}
pub fn unique_name(&self) -> String {
let mut hasher = std::hash::DefaultHasher::new();
self.hash(&mut hasher);
let disambiguator = hasher.finish();
let hash = format!("{:04x}", disambiguator);
format!("{}_{}", &self.revision[..8], &hash[..4])
}
}
pub type ActualExcerpt = Excerpt;
#[derive(Clone, Debug, Serialize, Deserialize)]
@@ -292,90 +422,7 @@ impl NamedExample {
}
pub async fn setup_worktree(&self) -> Result<PathBuf> {
let (repo_owner, repo_name) = self.repo_name()?;
let file_name = self.file_name();
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
let repo_lock = lock_repo(&repo_dir).await;
if !repo_dir.is_dir() {
fs::create_dir_all(&repo_dir)?;
run_git(&repo_dir, &["init"]).await?;
run_git(
&repo_dir,
&["remote", "add", "origin", &self.example.repository_url],
)
.await?;
}
// Resolve the example to a revision, fetching it if needed.
let revision = run_git(
&repo_dir,
&[
"rev-parse",
&format!("{}^{{commit}}", self.example.revision),
],
)
.await;
let revision = if let Ok(revision) = revision {
revision
} else {
run_git(
&repo_dir,
&["fetch", "--depth", "1", "origin", &self.example.revision],
)
.await?;
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
if revision != self.example.revision {
run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?;
}
revision
};
// Create the worktree for this example if needed.
let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
if worktree_path.is_dir() {
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
} else {
let worktree_path_string = worktree_path.to_string_lossy();
run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
run_git(
&repo_dir,
&["worktree", "add", "-f", &worktree_path_string, &file_name],
)
.await?;
}
drop(repo_lock);
// Apply the uncommitted diff for this example.
if !self.example.uncommitted_diff.is_empty() {
let mut apply_process = smol::process::Command::new("git")
.current_dir(&worktree_path)
.args(&["apply", "-"])
.stdin(std::process::Stdio::piped())
.spawn()?;
let mut stdin = apply_process.stdin.take().unwrap();
stdin
.write_all(self.example.uncommitted_diff.as_bytes())
.await?;
stdin.close().await?;
drop(stdin);
let apply_result = apply_process.output().await?;
if !apply_result.status.success() {
anyhow::bail!(
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
apply_result.status,
String::from_utf8_lossy(&apply_result.stderr),
String::from_utf8_lossy(&apply_result.stdout),
);
}
}
Ok(worktree_path)
self.example.setup_worktree(self.file_name()).await
}
pub fn file_name(&self) -> String {
@@ -391,40 +438,6 @@ impl NamedExample {
.collect()
}
fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
// git@github.com:owner/repo.git
if self.example.repository_url.contains('@') {
let (owner, repo) = self
.example
.repository_url
.split_once(':')
.context("expected : in git url")?
.1
.split_once('/')
.context("expected / in git url")?;
Ok((
Cow::Borrowed(owner),
Cow::Borrowed(repo.trim_end_matches(".git")),
))
// http://github.com/owner/repo.git
} else {
let url = Url::parse(&self.example.repository_url)?;
let mut segments = url.path_segments().context("empty http url")?;
let owner = segments
.next()
.context("expected owner path segment")?
.to_string();
let repo = segments
.next()
.context("expected repo path segment")?
.trim_end_matches(".git")
.to_string();
assert!(segments.next().is_none());
Ok((owner.into(), repo.into()))
}
}
pub async fn cursor_position(
&self,
project: &Entity<Project>,

View File

@@ -5,6 +5,7 @@ mod metrics;
mod paths;
mod predict;
mod source_location;
mod training;
mod util;
use crate::{
@@ -13,9 +14,10 @@ use crate::{
headless::ZetaCliAppState,
predict::run_predict,
source_location::SourceLocation,
training::{context::ContextType, distill::run_distill},
util::{open_buffer, open_buffer_with_language_server},
};
use ::util::paths::PathStyle;
use ::util::{ResultExt, paths::PathStyle};
use anyhow::{Result, anyhow};
use clap::{Args, Parser, Subcommand, ValueEnum};
use cloud_llm_client::predict_edits_v3;
@@ -43,6 +45,7 @@ enum Command {
Context(ContextArgs),
Predict(PredictArguments),
Eval(EvaluateArguments),
Distill(DistillArguments),
ConvertExample {
path: PathBuf,
#[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
@@ -111,6 +114,13 @@ pub struct PredictArguments {
options: PredictionOptions,
}
#[derive(Debug, Args)]
pub struct DistillArguments {
split_commit_dataset: PathBuf,
#[clap(long, value_enum, default_value_t = ContextType::CurrentFile)]
context_type: ContextType,
}
#[derive(Clone, Debug, Args)]
pub struct PredictionOptions {
#[clap(flatten)]
@@ -468,6 +478,13 @@ fn main() {
Some(Command::Eval(arguments)) => {
run_evaluate(arguments, &app_state, cx).await;
}
Some(Command::Distill(arguments)) => {
let _guard = cx
.update(|cx| gpui_tokio::Tokio::handle(cx))
.unwrap()
.enter();
run_distill(arguments).await.log_err();
}
Some(Command::ConvertExample {
path,
output_format,

View File

@@ -0,0 +1,89 @@
use std::path::Path;
use crate::{source_location::SourceLocation, training::teacher::TeacherModel};
#[derive(Debug, Clone, Default, clap::ValueEnum)]
pub enum ContextType {
#[default]
CurrentFile,
}
const MAX_CONTEXT_SIZE: usize = 32768;
pub fn collect_context(
context_type: &ContextType,
worktree_dir: &Path,
cursor: SourceLocation,
) -> String {
let context = match context_type {
ContextType::CurrentFile => {
let file_path = worktree_dir.join(cursor.path.as_std_path());
let context = std::fs::read_to_string(&file_path).unwrap_or_default();
let context = add_special_tags(&context, worktree_dir, cursor);
context
}
};
let region_end_offset = context.find(TeacherModel::REGION_END);
if context.len() <= MAX_CONTEXT_SIZE {
return context;
}
if let Some(region_end_offset) = region_end_offset
&& region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE
{
let to_truncate = context.len() - MAX_CONTEXT_SIZE;
format!(
"[...{} bytes truncated]\n{}\n",
to_truncate,
&context[to_truncate..]
)
} else {
format!(
"{}\n[...{} bytes truncated]\n",
&context[..MAX_CONTEXT_SIZE],
context.len() - MAX_CONTEXT_SIZE
)
}
}
/// Add <|editable_region_start/end|> tags
fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String {
let path = worktree_dir.join(cursor.path.as_std_path());
let file = std::fs::read_to_string(&path).unwrap_or_default();
let lines = file.lines().collect::<Vec<_>>();
let cursor_row = cursor.point.row as usize;
let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE);
let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len());
let snippet = lines[start_line..end_line].join("\n");
if context.contains(&snippet) {
let mut cursor_line = lines[cursor_row].to_string();
cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR);
let mut snippet_with_tags_lines = vec![];
snippet_with_tags_lines.push(TeacherModel::REGION_START);
snippet_with_tags_lines.extend(&lines[start_line..cursor_row]);
snippet_with_tags_lines.push(&cursor_line);
snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]);
snippet_with_tags_lines.push(TeacherModel::REGION_END);
let snippet_with_tags = snippet_with_tags_lines.join("\n");
context.replace(&snippet, &snippet_with_tags)
} else {
log::warn!(
"Can't find area around the cursor in the context; proceeding without special tags"
);
context.to_string()
}
}
pub fn strip_special_tags(context: &str) -> String {
context
.replace(TeacherModel::REGION_START, "")
.replace(TeacherModel::REGION_END, "")
.replace(TeacherModel::USER_CURSOR, "")
}

View File

@@ -0,0 +1,61 @@
use serde::Deserialize;
use crate::{
DistillArguments,
example::Example,
source_location::SourceLocation,
training::{
context::ContextType,
teacher::{TeacherModel, TeacherOutput},
},
};
use anyhow::Result;
#[derive(Debug, Deserialize)]
pub struct SplitCommit {
repo_url: String,
commit_sha: String,
edit_history: String,
expected_patch: String,
cursor_position: String,
}
pub async fn run_distill(arguments: DistillArguments) -> Result<()> {
let split_commits: Vec<SplitCommit> = std::fs::read_to_string(&arguments.split_commit_dataset)
.expect("Failed to read split commit dataset")
.lines()
.map(|line| serde_json::from_str(line).expect("Failed to parse JSON line"))
.collect();
for commit in split_commits {
let distilled = distill_one(commit).await?;
println!("{}", serde_json::to_string(&distilled)?);
}
Ok(())
}
pub async fn distill_one(commit: SplitCommit) -> Result<TeacherOutput> {
let cursor: SourceLocation = commit
.cursor_position
.parse()
.expect("Failed to parse cursor position");
let path = cursor.path.to_rel_path_buf();
let example = Example {
repository_url: commit.repo_url,
revision: commit.commit_sha,
uncommitted_diff: commit.edit_history.clone(),
cursor_path: path.as_std_path().to_path_buf(),
cursor_position: commit.cursor_position,
edit_history: commit.edit_history, // todo: trim
expected_patch: commit.expected_patch,
};
let teacher = TeacherModel::new("claude-sonnet-4-5".to_string(), ContextType::CurrentFile);
let prediction = teacher.predict(example).await;
prediction
}

View File

@@ -0,0 +1,3 @@
pub mod context;
pub mod distill;
pub mod teacher;

View File

@@ -0,0 +1,48 @@
# Instructions
You are a code completion assistant helping a programmer finish their work. Your task is to:
1. Analyze the edit history to understand what the programmer is trying to achieve
2. Identify any incomplete refactoring or changes that need to be finished
3. Make the remaining edits that a human programmer would logically make next (by rewriting the corresponding code sections)
4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere.
Focus on:
- Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs)
- Completing any partially-applied changes across the codebase
- Ensuring consistency with the programming style and patterns already established
- Making edits that maintain or improve code quality
- If the programmer started refactoring one instance of a pattern, find and update ALL similar instances
- Don't write a lot of code if you're not sure what to do
Rules:
- Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
- Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
Input format:
- You receive small code fragments called context (structs, field definitions, function signatures, etc.). They may or may not be relevant.
- Never modify the context code.
- You also receive a code snippet between <|editable_region_start|> and <|editable_region_end|>. This is the editable region.
- The cursor position is marked with <|user_cursor|>.
Output format:
- Return the entire editable region, applying any edits you make.
- Remove the <|user_cursor|> marker.
- Wrap the edited code in a block of exactly five backticks.
Output example:
`````
// `zed --askpass` Makes zed operate in nc/netcat mode for use with askpass
if let Some(socket) = &args.askpass {{
askpass::main(socket);
return Ok(());
}}
`````
## User Edits History
{{edit_history}}
## Code Context
{{context}}

View File

@@ -0,0 +1,233 @@
use crate::{
example::Example,
source_location::SourceLocation,
training::context::{ContextType, collect_context, strip_special_tags},
};
use anthropic_sdk::{Anthropic, ContentBlock, MessageCreateBuilder};
use anyhow::Result;
pub struct TeacherModel {
llm_name: String,
context: ContextType,
}
#[derive(Debug, serde::Serialize)]
pub struct TeacherOutput {
parsed_output: String,
prompt: String,
raw_llm_response: String,
context: String,
diff: String,
}
impl TeacherModel {
const PROMPT: &str = include_str!("teacher.prompt.md");
pub(crate) const REGION_START: &str = "<|editable_region_start|>\n";
pub(crate) const REGION_END: &str = "<|editable_region_end|>";
pub(crate) const USER_CURSOR: &str = "<|user_cursor|>";
/// Number of lines to include before the cursor position
pub(crate) const LEFT_CONTEXT_SIZE: usize = 5;
/// Number of lines to include after the cursor position
pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5;
/// Truncate edit history to this number of last lines
const MAX_HISTORY_LINES: usize = 128;
pub fn new(llm_name: String, context: ContextType) -> Self {
TeacherModel { llm_name, context }
}
pub async fn predict(&self, input: Example) -> Result<TeacherOutput> {
let name = input.unique_name();
let worktree_dir = input.setup_worktree(name).await?;
let cursor: SourceLocation = input
.cursor_position
.parse()
.expect("Failed to parse cursor position");
let context = collect_context(&self.context, &worktree_dir, cursor.clone());
let edit_history = Self::format_edit_history(&input.edit_history);
let prompt = Self::PROMPT
.replace("{{context}}", &context)
.replace("{{edit_history}}", &edit_history);
let client = Anthropic::from_env()?;
let response = client
.messages()
.create(
MessageCreateBuilder::new(self.llm_name.clone(), 16384)
.user(prompt.clone())
.build(),
)
.await?;
let response_text = response
.content
.into_iter()
.filter_map(|content| {
if let ContentBlock::Text { text } = content {
Some(text)
} else {
None
}
})
.collect::<Vec<String>>()
.join("\n");
let parsed_output = self.parse_response(&response_text);
let original_editable_region = Self::extract_editable_region(&context);
let context_after_edit = context.replace(&original_editable_region, &parsed_output);
let context_after_edit = strip_special_tags(&context_after_edit);
let context_before_edit = strip_special_tags(&context);
let diff = language::unified_diff(&context_before_edit, &context_after_edit);
Ok(TeacherOutput {
parsed_output,
prompt,
raw_llm_response: response_text,
context,
diff,
})
}
fn parse_response(&self, content: &str) -> String {
let codeblock = Self::extract_last_codeblock(content);
let editable_region = Self::extract_editable_region(&codeblock);
editable_region
}
/// Extract content from the last code-fenced block if any, or else return content as is
fn extract_last_codeblock(text: &str) -> String {
let mut last_block = None;
let mut search_start = 0;
while let Some(start) = text[search_start..].find("```") {
let start = start + search_start;
let bytes = text.as_bytes();
let mut backtick_end = start;
while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
backtick_end += 1;
}
let backtick_count = backtick_end - start;
let closing_backticks = "`".repeat(backtick_count);
if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
last_block = Some(code_block.to_string());
search_start = backtick_end + end_pos + backtick_count;
} else {
break;
}
}
last_block.unwrap_or_else(|| text.to_string())
}
fn extract_editable_region(text: &str) -> String {
let start = text
.find(Self::REGION_START)
.map_or(0, |pos| pos + Self::REGION_START.len());
let end = text.find(Self::REGION_END).unwrap_or(text.len());
text[start..end].to_string()
}
/// Truncates edit history to a maximum length and removes comments (unified diff garbage lines)
fn format_edit_history(edit_history: &str) -> String {
let lines = edit_history
.lines()
.filter(|&s| Self::is_content_line(s))
.collect::<Vec<_>>();
let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
&lines[lines.len() - Self::MAX_HISTORY_LINES..]
} else {
&lines
};
history_lines.join("\n")
}
fn is_content_line(s: &str) -> bool {
s.starts_with("-")
|| s.starts_with("+")
|| s.starts_with(" ")
|| s.starts_with("---")
|| s.starts_with("+++")
|| s.starts_with("@@")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_response() {
let teacher = TeacherModel::new("test".to_string(), ContextType::CurrentFile);
let response = "This is a test response.";
let parsed = teacher.parse_response(response);
assert_eq!(parsed, response.to_string());
let response = indoc::indoc! {"
Some thinking
`````
actual response
`````
"};
let parsed = teacher.parse_response(response);
assert_eq!(parsed, "actual response");
}
#[test]
fn test_extract_last_code_block() {
let text = indoc::indoc! {"
Some thinking
```
first block
```
`````
last block
`````
"};
let last_block = TeacherModel::extract_last_codeblock(text);
assert_eq!(last_block, "last block");
// https://on.tty-share.com/s/-pZoHQTn8OTfu6W9KvuyEJgFKwfR1CrCJSRwC1Y2I94SzoVLHekaqmrCcaO1d_lNpGQ/
}
#[test]
fn test_extract_editable_region() {
let teacher = TeacherModel::new("test".to_string(), ContextType::CurrentFile);
let response = indoc::indoc! {"
some lines
are
here
<|editable_region_start|>
one
two three
<|editable_region_end|>
more
lines here
"};
let parsed = teacher.parse_response(response);
assert_eq!(
parsed,
indoc::indoc! {"
one
two three
"}
);
}
}