Compare commits

...

1 Commits

Author SHA1 Message Date
Max Brunsfeld
f36dfcc99e WIP
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-12-08 13:32:58 -08:00
8 changed files with 186 additions and 55 deletions

View File

@@ -24,12 +24,9 @@ pub struct PlanContextRetrievalRequest {
pub struct PredictEditsRequest {
pub excerpt: String,
pub excerpt_path: Arc<Path>,
/// Within file
pub excerpt_range: Range<usize>,
pub excerpt_line_range: Range<Line>,
pub editable_range_in_excerpt: Range<usize>,
pub cursor_offset_in_excerpt: usize,
pub cursor_point: Point,
/// Within `signatures`
pub excerpt_parent: Option<usize>,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub related_files: Vec<RelatedFile>,
pub events: Vec<Arc<Event>>,
@@ -74,10 +71,11 @@ pub enum PromptFormat {
MinimalQwen,
/// No instructions, Qwen chat + Seed-Coder 1120 FIM-like template
SeedCoder1120,
Zeta,
}
impl PromptFormat {
pub const DEFAULT: PromptFormat = PromptFormat::Minimal;
pub const DEFAULT: PromptFormat = PromptFormat::Zeta;
}
impl Default for PromptFormat {
@@ -100,6 +98,7 @@ impl std::fmt::Display for PromptFormat {
PromptFormat::Minimal => write!(f, "Minimal"),
PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"),
PromptFormat::SeedCoder1120 => write!(f, "Seed-Coder 1120"),
PromptFormat::Zeta => write!(f, "Zeta"),
}
}
}

View File

@@ -5,6 +5,7 @@ use cloud_llm_client::predict_edits_v3::{
use indoc::indoc;
use std::cmp;
use std::fmt::Write;
use std::ops::Range;
use std::path::Path;
use std::sync::Arc;
@@ -80,12 +81,17 @@ const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#"
"#};
pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<String> {
// todo! do we need this?
let prompt_data = PromptData {
events: request.events.clone(),
cursor_point: request.cursor_point,
cursor_path: request.excerpt_path.clone(),
included_files: request.related_files.clone(),
excerpt: request.excerpt.clone(),
editable_range_in_excerpt: request.editable_range_in_excerpt.clone(),
cursor_offset_in_excerpt: request.cursor_offset_in_excerpt,
};
match request.prompt_format {
PromptFormat::MinimalQwen => {
return Ok(MinimalQwenPrompt.render(&prompt_data));
@@ -93,6 +99,9 @@ pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<S
PromptFormat::SeedCoder1120 => {
return Ok(SeedCoder1120Prompt.render(&prompt_data));
}
PromptFormat::Zeta => {
return Ok(ZetaPrompt.render(&prompt_data));
}
_ => (),
};
@@ -101,8 +110,9 @@ pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<S
vec![(request.cursor_point, CURSOR_MARKER)]
}
PromptFormat::OnlySnippets => vec![],
PromptFormat::MinimalQwen => unreachable!(),
PromptFormat::SeedCoder1120 => unreachable!(),
PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 | PromptFormat::Zeta => {
unreachable!()
}
};
let mut prompt = match request.prompt_format {
@@ -111,6 +121,7 @@ pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<S
PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(),
PromptFormat::MinimalQwen => unreachable!(),
PromptFormat::SeedCoder1120 => unreachable!(),
PromptFormat::Zeta => unreachable!(),
};
if request.events.is_empty() {
@@ -159,6 +170,7 @@ pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<S
The file is in current state, edits from edit history have been applied.
"}
}
PromptFormat::Zeta => unreachable!(),
};
prompt.push_str(excerpts_preamble);
@@ -211,6 +223,7 @@ pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<S
pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams {
match prompt_format {
PromptFormat::SeedCoder1120 => SeedCoder1120Prompt::generation_params(),
PromptFormat::Zeta => ZetaPrompt::generation_params(),
_ => GenerationParams::default(),
}
}
@@ -323,6 +336,9 @@ struct PromptData {
events: Vec<Arc<Event>>,
cursor_point: Point,
cursor_path: Arc<Path>, // TODO: make a common struct with cursor_point
excerpt: String,
editable_range_in_excerpt: Range<usize>,
cursor_offset_in_excerpt: usize,
included_files: Vec<RelatedFile>,
}
@@ -483,3 +499,97 @@ impl SeedCoder1120Prompt {
format!("{}{}{}", suffix, prefix, FIM_MIDDLE)
}
}
struct ZetaPrompt;
impl ZetaPrompt {
const INSTRUCTIONS: &str = "Reference the user excerpt, user edits, and the snippets to understand the developer's intent. Update the editable region of the user excerpt by predicting and completing the changes they would have made next. This may be a deletion, addition, or modification of code.";
const CONTEXT_START: &str = "### Context:";
const CONTEXT_FILE: &str = "<|context_file|>";
const SNIPPET: &str = "<|snippet|>";
const USER_EDITS: &str = "### User Edits:";
const USER_EDITED_FILE: &str = "User edited file";
const USER_EXCERPT: &str = "### User Excerpt:";
const EDITABLE_REGION_START_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
const EDITABLE_REGION_END_WITH_NEWLINE: &str = "\n<|editable_region_end|>";
const USER_CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
}
impl PromptFormatter for ZetaPrompt {
fn render(&self, data: &PromptData) -> String {
let mut prompt = String::new();
prompt.push_str(Self::INSTRUCTIONS);
prompt.push_str("\n\n");
prompt.push_str(Self::CONTEXT_START);
prompt.push_str("\n\n");
for file in &data.included_files {
for excerpt in &file.excerpts {
writeln!(
&mut prompt,
"{} {}",
Self::CONTEXT_FILE,
file.path.to_string_lossy()
)
.ok();
prompt.push_str(Self::SNIPPET);
prompt.push('\n');
prompt.push_str(&excerpt.text);
prompt.push_str("\n\n");
}
}
prompt.push_str(Self::USER_EDITS);
prompt.push_str("\n\n");
for event in &data.events {
match event.as_ref() {
Event::BufferChange {
path,
old_path, // todo!
diff,
..
} => {
write!(
&mut prompt,
"{} \"{}\"\n\n```diff\n{diff}```\n\n",
Self::USER_EDITED_FILE,
path.display()
)
.ok();
}
}
}
prompt.push_str(Self::USER_EXCERPT);
write!(
&mut prompt,
"\n\"{}\"\n\n",
&data.cursor_path.to_string_lossy()
)
.ok();
prompt.push_str(&data.excerpt[..data.editable_range_in_excerpt.start]);
prompt.push_str(Self::EDITABLE_REGION_START_WITH_NEWLINE);
prompt.push_str(
&data.excerpt[data.editable_range_in_excerpt.start..data.cursor_offset_in_excerpt],
);
prompt.push_str(Self::USER_CURSOR_MARKER);
prompt.push_str(
&data.excerpt[dbg!(data.cursor_offset_in_excerpt..data.editable_range_in_excerpt.end)],
);
prompt.push_str(Self::EDITABLE_REGION_END_WITH_NEWLINE);
prompt.push_str(&data.excerpt[data.editable_range_in_excerpt.end..]);
prompt
}
fn generation_params() -> GenerationParams {
return GenerationParams {
stop: Some(vec!["<|im_end|>".into()]),
temperature: None,
top_p: None,
};
}
}

View File

@@ -119,6 +119,7 @@ impl Mercury {
tools: vec![],
prompt_cache_key: None,
reasoning_effort: None,
max_tokens: None,
};
let buf = serde_json::to_vec(&request_body)?;

View File

@@ -1,5 +1,6 @@
#[cfg(feature = "eval-support")]
use crate::EvalCacheEntryKind;
use crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position;
use crate::open_ai_response::text_from_response;
use crate::prediction::EditPredictionResult;
use crate::{
@@ -14,7 +15,7 @@ use edit_prediction_context::{EditPredictionExcerpt, Line};
use edit_prediction_context::{RelatedExcerpt, RelatedFile};
use futures::channel::oneshot;
use gpui::{Entity, Task, prelude::*};
use language::{Anchor, BufferSnapshot};
use language::{Anchor, BufferSnapshot, OffsetRangeExt};
use language::{Buffer, Point, ToOffset as _, ToPoint};
use project::{Project, ProjectItem as _};
use release_channel::AppVersion;
@@ -25,6 +26,9 @@ use std::{
time::{Duration, Instant},
};
pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
pub fn request_prediction_with_zeta2(
store: &mut EditPredictionStore,
project: &Entity<Project>,
@@ -74,41 +78,20 @@ pub fn request_prediction_with_zeta2(
let excerpt_options = options.context;
let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
cursor_point,
&active_snapshot,
&excerpt_options,
) else {
return Ok((None, None));
};
MAX_REWRITE_TOKENS,
MAX_CONTEXT_TOKENS,
);
let excerpt = active_snapshot
.text_for_range(context_range.clone())
.collect::<String>();
let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
..active_snapshot.anchor_before(excerpt.range.end);
let related_excerpt = RelatedExcerpt {
anchor_range: excerpt_anchor_range.clone(),
point_range: Point::new(excerpt.line_range.start.0, 0)
..Point::new(excerpt.line_range.end.0, 0),
text: active_snapshot.as_rope().slice(excerpt.range),
};
if let Some(buffer_ix) = included_files
.iter()
.position(|file| file.buffer.entity_id() == active_buffer.entity_id())
{
let file = &mut included_files[buffer_ix];
file.excerpts.push(related_excerpt);
file.merge_excerpts();
let last_ix = included_files.len() - 1;
included_files.swap(buffer_ix, last_ix);
} else {
let active_file = RelatedFile {
path: active_project_path,
buffer: active_buffer.downgrade(),
excerpts: vec![related_excerpt],
max_row: active_snapshot.max_point().row,
};
included_files.push(active_file);
}
let context_offset = context_range.start.to_offset(&active_snapshot);
let editable_offset_range = editable_range.to_offset(&active_snapshot);
let excerpt_anchor_range = active_snapshot.anchor_after(context_offset)
..active_snapshot.anchor_before(context_offset);
let included_files = included_files
.iter()
@@ -128,9 +111,7 @@ pub fn request_prediction_with_zeta2(
let cloud_request = predict_edits_v3::PredictEditsRequest {
excerpt_path,
excerpt: String::new(),
excerpt_line_range: Line(0)..Line(0),
excerpt_range: 0..0,
excerpt,
cursor_point: predict_edits_v3::Point {
line: predict_edits_v3::Line(cursor_point.row),
column: cursor_point.column,
@@ -141,9 +122,11 @@ pub fn request_prediction_with_zeta2(
debug_info: debug_tx.is_some(),
prompt_max_bytes: Some(options.max_prompt_bytes),
prompt_format: options.prompt_format,
excerpt_parent: None,
git_info: None,
trigger,
editable_range_in_excerpt: (editable_offset_range.start - context_offset)
..(editable_offset_range.end - context_offset),
cursor_offset_in_excerpt: cursor_offset - context_offset,
};
let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
@@ -190,6 +173,9 @@ pub fn request_prediction_with_zeta2(
}
let prompt = prompt_result?;
eprintln!("prompt:\n{prompt}");
let generation_params =
cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
let request = open_ai::Request {
@@ -199,6 +185,7 @@ pub fn request_prediction_with_zeta2(
}],
stream: false,
max_completion_tokens: None,
max_tokens: Some(1024 * 4),
stop: generation_params.stop.unwrap_or_default(),
temperature: generation_params.temperature.or(Some(0.7)),
tool_choice: None,
@@ -261,17 +248,40 @@ pub fn request_prediction_with_zeta2(
}
};
let (_, edits) = match options.prompt_format {
let edits = match options.prompt_format {
PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
if output_text.contains("--- a/\n+++ b/\nNo edits") {
let edits = vec![];
(&active_snapshot, edits)
vec![]
} else {
crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
crate::udiff::parse_diff(&output_text, get_buffer_from_context)
.await?
.1
}
}
PromptFormat::OldTextNewText => {
crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
.await?
.1
}
PromptFormat::Zeta => {
let old_text = active_snapshot
.text_for_range(editable_offset_range.clone())
.collect::<String>();
let new_text = output_text.trim_end_matches("<|im_end|>");
eprintln!("OUTPUT:\n<old_text>\n{old_text}\n</old_text>\n<new_text>\n{new_text}\n</new_text>");
language::text_diff(&old_text, &new_text)
.into_iter()
.map(|(range, text)| {
(
active_snapshot
.anchor_after(editable_offset_range.start + range.start)
..active_snapshot
.anchor_before(editable_offset_range.start + range.end),
text,
)
})
.collect()
}
_ => {
bail!("unsupported prompt format {}", options.prompt_format)

View File

@@ -197,11 +197,12 @@ fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions {
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
enum PromptFormat {
OnlySnippets,
#[default]
OldTextNewText,
Minimal,
MinimalQwen,
SeedCoder1120,
#[default]
Zeta,
}
impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
@@ -212,6 +213,7 @@ impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
Self::Zeta => predict_edits_v3::PromptFormat::Zeta,
}
}
}

View File

@@ -129,6 +129,7 @@ pub async fn perform_predict(
store.update(cx, |store, _cx| {
let mut options = store.options().clone();
options.prompt_format = prompt_format.into();
store.set_use_context(true);
store.set_options(options);
})?;
@@ -143,7 +144,7 @@ pub async fn perform_predict(
let mut start_time = None;
let mut retrieval_finished_at = None;
while let Some(event) = debug_rx.next().await {
match event {
match dbg!(event) {
edit_prediction::DebugEvent::ContextRetrievalStarted(info) => {
start_time = Some(info.timestamp);
fs::write(
@@ -219,9 +220,14 @@ pub async fn perform_predict(
}
});
store.update(cx, |store, cx| {
store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx)
})?;
store
.update(cx, |store, cx| {
store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx);
store.project_context_updates(&project)
})?
.unwrap()
.recv()
.await;
}
let prediction = store

View File

@@ -440,6 +440,7 @@ pub fn into_open_ai(
stop: request.stop,
temperature: request.temperature.or(Some(1.0)),
max_completion_tokens: max_output_tokens,
max_tokens: None,
parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
// Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
Some(false)

View File

@@ -264,6 +264,8 @@ pub struct Request {
pub stream: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u64>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stop: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]