Compare commits

...

4 Commits

Author SHA1 Message Date
Agus Zubiaga
0400e4b3d7 Remove unnecessary blurb 2025-11-20 13:23:22 -03:00
Agus Zubiaga
e25e7d6cb4 Add instructions about keeping the search broad
Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-11-20 13:20:44 -03:00
Agus Zubiaga
0d2cd0a58a Rewrite jump prompt
Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-11-20 12:47:15 -03:00
Agus Zubiaga
4403af9e8c Checkpoint: Adding sweep jump 2025-11-20 10:45:48 -03:00
4 changed files with 572 additions and 108 deletions

1
Cargo.lock generated
View File

@@ -21747,6 +21747,7 @@ dependencies = [
"pretty_assertions",
"project",
"release_channel",
"schemars 1.0.4",
"serde",
"serde_json",
"settings",

View File

@@ -33,8 +33,10 @@ language.workspace = true
language_model.workspace = true
log.workspace = true
open_ai.workspace = true
pretty_assertions.workspace = true
project.workspace = true
release_channel.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
smol.workspace = true
@@ -44,7 +46,6 @@ util.workspace = true
uuid.workspace = true
workspace.workspace = true
worktree.workspace = true
pretty_assertions.workspace = true
[dev-dependencies]
clock = { workspace = true, features = ["test-support"] }

427
crates/zeta2/src/jump.rs Normal file
View File

@@ -0,0 +1,427 @@
use anyhow::{Context as _, Result};
use cloud_llm_client::predict_edits_v3::Excerpt;
use collections::HashMap;
use edit_prediction_context::{EditPredictionExcerpt, Line};
use gpui::http_client::{Method, Request};
use gpui::{AppContext, AsyncApp, Entity, http_client::HttpClient};
use indoc::indoc;
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point};
use open_ai::{FunctionDefinition, MessageContent};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use smol::io::AsyncReadExt;
use std::path::PathBuf;
use std::{
collections::VecDeque,
fmt::Write,
path::Path,
sync::{Arc, LazyLock},
};
use crate::Event;
use crate::assemble_excerpts::assemble_excerpts;
use crate::retrieval_search::run_retrieval_searches;
use cloud_zeta2_prompt::write_codeblock;
/// Search for relevant code
///
/// Alaways run all queries at once with a single invocation of this tool.
#[derive(Clone, Deserialize, Serialize, JsonSchema)]
pub struct SearchToolInput {
/// An array of queries to run in parallel for gathering context
#[schemars(length(max = 5))]
pub queries: Box<[SearchToolQuery]>,
}
/// Search for relevant code by path and their content
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct SearchToolQuery {
/// A glob pattern to match file paths in the codebase to search in.
pub glob: String,
/// A regular expression to match code contents within the matched files.
pub regex: String,
}
pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
);
let description = schema
.get("description")
.and_then(|description| description.as_str())
.unwrap()
.to_string();
(schema.into(), description)
});
pub struct JumpLocation {
pub buffer: Entity<Buffer>,
pub anchor: Anchor,
}
#[derive(Serialize)]
struct OpenRouterWrapper {
#[serde(flatten)]
request: open_ai::Request,
provider: OpenRouterProvider,
}
#[derive(Serialize)]
pub struct OpenRouterProvider {
only: Option<Vec<String>>,
}
pub async fn predict_jump(
active_full_path: Arc<Path>,
cursor_position: Point,
events: VecDeque<Event>,
project: Entity<Project>,
http_client: Arc<dyn HttpClient>,
cx: &mut AsyncApp,
) -> Result<Option<JumpLocation>> {
eprintln!("\n\nRequesting jump");
// todo!
let events = cx.update(|cx| {
events
.into_iter()
.filter_map(|event| event.to_request_event(cx))
.collect::<Vec<_>>()
})?;
let search_queries = cx.background_spawn({
let http_client = http_client.clone();
let active_full_path = active_full_path.clone();
async move {
let prompt = build_jump_prompt(&active_full_path, cursor_position, &events);
eprintln!("Jump prompt:\n{prompt}");
let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
let request_body = OpenRouterWrapper {
request: open_ai::Request {
// model: "qwen3:8b".into(),
model: "qwen/qwen3-coder-30b-a3b-instruct".into(),
messages: vec![open_ai::RequestMessage::User {
content: open_ai::MessageContent::Plain(prompt),
}],
stream: false,
max_completion_tokens: None,
stop: Default::default(),
temperature: 0.7,
tool_choice: None,
parallel_tool_calls: None,
tools: vec![open_ai::ToolDefinition::Function {
function: FunctionDefinition {
name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
description: Some(tool_description),
parameters: Some(tool_schema),
},
}],
prompt_cache_key: None,
reasoning_effort: None,
},
provider: OpenRouterProvider {
only: Some(vec!["nebius/fp8".into()]),
},
};
let request = Request::builder()
.method(Method::POST)
// .uri("http://localhost:11434/v1/chat/completions")
.uri("https://openrouter.ai/api/v1/chat/completions")
.header(
"Authorization",
format!("Bearer {}", std::env::var("OPENROUTER_API_KEY").unwrap()),
)
.header("Content-Type", "application/json")
.header("HTTP-Referer", "https://zed.dev")
.header("X-Title", "Zed Editor")
.body(serde_json::to_string(&request_body)?.into())?;
let mut response = http_client.send(request).await?;
let mut buf = Vec::new();
response.body_mut().read_to_end(&mut buf).await?;
if !response.status().is_success() {
anyhow::bail!("Jump request failed: {}", String::from_utf8_lossy(&buf));
}
let response: open_ai::Response = serde_json::from_slice(&buf)?;
dbg!(&response);
anyhow::Ok((request_body, response))
}
});
let (mut request_body, mut response) = search_queries.await?;
let choice = response
.choices
.pop()
.context("No choices in jump response")?;
let open_ai::RequestMessage::Assistant {
content: _,
tool_calls,
} = &choice.message
else {
anyhow::bail!("Jump response didn't include an assistant message");
};
let mut queries: Vec<cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery> = Vec::new();
let mut tool_call_id = None;
for tool_call in tool_calls {
tool_call_id.get_or_insert(tool_call.id.clone());
let open_ai::ToolCallContent::Function { function } = &tool_call.content;
if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
log::warn!(
"Jump response tried to call an unknown tool: {}",
function.name
);
continue;
}
let input: SearchToolInput = serde_json::from_str(&function.arguments)
.with_context(|| format!("invalid search json {}", &function.arguments))?;
queries.extend(input.queries.into_iter().map(|query| {
cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery {
glob: query.glob,
syntax_node: vec![],
content: Some(query.regex),
}
}));
}
let Some(tool_call_id) = tool_call_id else {
anyhow::bail!("No searches in jump response");
};
if queries.is_empty() {
anyhow::bail!("No queries in jump response");
}
let results = run_retrieval_searches(
queries,
project.clone(),
#[cfg(feature = "eval-support")]
None,
cx,
)
.await?;
dbg!(&results);
if results.is_empty() {
return anyhow::Ok(None);
}
// todo! move to background
let mut combined_results = String::new();
let mut result_buffers = HashMap::default();
for (buffer, ranges) in results {
let (snapshot, full_path) = buffer.read_with(cx, |buffer, cx| {
(
buffer.snapshot(),
buffer
.file()
.map(|file| file.full_path(cx))
.unwrap_or_else(|| PathBuf::from("untitled")),
)
})?;
let ranges = ranges
.into_iter()
.map(|range| {
let point_range = range.to_point(&snapshot);
Line(point_range.start.row)..Line(point_range.end.row)
})
.collect::<Vec<_>>();
let excerpts = assemble_excerpts(&snapshot, ranges);
write_codeblock(
&full_path,
&excerpts,
&[],
Line(snapshot.max_point().row),
true,
&mut combined_results,
);
result_buffers.insert(full_path.clone(), (buffer, snapshot));
}
eprintln!("{combined_results}");
request_body.request.tools.clear();
request_body.request.messages.extend([
choice.message,
open_ai::RequestMessage::Tool {
content: MessageContent::Plain(combined_results),
tool_call_id,
},
open_ai::RequestMessage::User {
content: MessageContent::Plain("{JUMP_INSTRUCTIONS}".into()),
},
]);
let request = Request::builder()
.method(Method::POST)
// .uri("http://localhost:11434/v1/chat/completions")
.uri("https://openrouter.ai/api/v1/chat/completions")
.header(
"Authorization",
format!("Bearer {}", std::env::var("OPENROUTER_API_KEY").unwrap()),
)
.header("Content-Type", "application/json")
.header("HTTP-Referer", "https://zed.dev")
.header("X-Title", "Zed Editor")
.body(serde_json::to_string(&request_body)?.into())?;
let mut response = http_client.send(request).await?;
let mut buf = Vec::new();
response.body_mut().read_to_end(&mut buf).await?;
dbg!(String::from_utf8_lossy(&buf));
if !response.status().is_success() {
anyhow::bail!("Jump request failed: {}", String::from_utf8_lossy(&buf));
}
let mut response: open_ai::Response = serde_json::from_slice(&buf)?;
if response.choices.is_empty() {
return anyhow::Ok(None);
}
let choice = response
.choices
.pop()
.context("No choices in jump response")?;
let open_ai::RequestMessage::Assistant {
content: Some(MessageContent::Plain(response)),
tool_calls: _,
} = &choice.message
else {
anyhow::bail!("Jump response didn't include an assistant message");
};
dbg!(response);
let (file_path, line) = response
.trim()
.split_once("```jump")
.context("Missing open fence")?
.1
.split_once("```")
.context("Missing closing fence")?
.0
.trim()
.split_once(":")
.context("Invalid jump response")?;
dbg!(file_path, line);
let line = line.parse::<u32>()?;
let (buffer, snapshot) = result_buffers
.get(Path::new(file_path))
.context("File not found in search results")?;
anyhow::Ok(Some(JumpLocation {
buffer: buffer.clone(),
anchor: snapshot.anchor_after(Point::new(line.saturating_sub(1), 0)),
}))
}
pub fn build_jump_prompt(
active_full_path: &Path,
cursor_position: Point,
events: &[cloud_llm_client::predict_edits_v3::Event],
) -> String {
let mut events_str = String::new();
for event in events {
write!(&mut events_str, "```diff\n{event}```\n\n").unwrap();
}
let events_str = events_str.trim_end_matches("\n\n");
SEARCH_INSTRUCTIONS
.to_string()
.replace(
"{CURSOR_PATH}",
active_full_path.display().to_string().as_str(),
)
.replace("{CURSOR_LINE}", &(cursor_position.row + 1).to_string())
.replace("{EDIT_HISTORY}", events_str)
}
const SEARCH_INSTRUCTIONS: &str = indoc! {r#"
You are part of an edit prediction system in a code editor.
The user has made a series of changes, and your role is to predict a single far-away location
that needs to be edited as a result.
## Cursor location
The cursor is currently located at `{CURSOR_PATH}:{CURSOR_LINE}`.
Assume all necessary changes near this location are or will be done, and focus on changes that
are needed in other parts of the codebase.
## Edit history
Carefully analyze the edit history in order to infer what the user is currently trying to accomplish,
and gather facts about the changes they have made.
{EDIT_HISTORY}
Use the `search` tool to find more information about the changes and potential locations for the next edit.
### When you find changes to usages
- Did they pass a new argument to a function whose declaration hasn't been updated yet?
- Did they use a method on a type/class that hasn't been added yet?
- Did they use a method from a class/interface/trait that hasn't been implemented/derived on the type yet?
- Did they start using a package or library that hasn't been added to the project yet?
If the code suggets the item in question already existed, but is now being used in a different way,
search for its declaration in order to determine whether changes are necessary.
Alternatively, if the changes suggest the item is newly used, you should perform two parallel searches:
1. Search for the declaration of item to see whether it already exists and whether its definition needs to be updated.
2. Search for the class/type/module/configruation where it _should_ be defined, so that you can suggest jumping
to it if needs to be added.
### When you find changes to declarations
- Did they change the definition of a type/class/table by adding, removing, or altering fields?
- Did they add an argument to a function?
- Did they split a function into multiple functions? Or merge multiple functions into one?
- Did they change the type of a field or function argument?
- Did they move a field from one type to another?
In these cases, you should search for usages of the affected item, so that you can see their current state
and suggest jumping to them if necessary.
If the affected item is public, make sure to include other files that may reference it in your search.
If the name of the affected item is unique enough, search for it in the entire project.
"#};
const JUMP_INSTRUCTIONS: &str = indoc! {"
Now analyze the search results, and explain your findings in 1 or 2 sentences.
If no more edits are needed, output `None`.
If another edit is needed, output the target file path and line number, like this:
```jump
{project_name}/path/to/file.rs:123
```
"};

View File

@@ -43,6 +43,7 @@ use util::{LogErrorFuture, ResultExt as _, TryFutureExt};
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
pub mod assemble_excerpts;
mod jump;
mod prediction;
mod provider;
pub mod retrieval_search;
@@ -761,7 +762,7 @@ impl Zeta {
self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
}
ZetaEditPredictionModel::Sweep => {
self.request_prediction_with_sweep(project, active_buffer, position, cx)
self.request_prediction_with_sweep(project, active_buffer, position, true, cx)
}
}
}
@@ -769,11 +770,12 @@ impl Zeta {
fn request_prediction_with_sweep(
&mut self,
project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
target_buffer: &Entity<Buffer>,
position: language::Anchor,
allow_jump: bool,
cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> {
let snapshot = active_buffer.read(cx).snapshot();
let snapshot = target_buffer.read(cx).snapshot();
let debug_info = self.sweep_ai_debug_info.clone();
let Some(api_token) = self.sweep_api_token.clone() else {
return Task::ready(Ok(None));
@@ -799,7 +801,7 @@ impl Zeta {
let recent_buffer_snapshots = recent_buffers
.filter_map(|project_path| {
let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
if active_buffer == &buffer {
if target_buffer == &buffer {
None
} else {
Some(buffer.read(cx).snapshot())
@@ -808,119 +810,152 @@ impl Zeta {
.take(3)
.collect::<Vec<_>>();
let result = cx.background_spawn(async move {
let text = snapshot.text();
let result = cx.background_spawn({
let full_path = full_path.clone();
let snapshot = snapshot.clone();
// todo!
let events = events.clone();
async move {
let text = snapshot.text();
let mut recent_changes = String::new();
for event in events {
sweep_ai::write_event(event, &mut recent_changes).unwrap();
let mut recent_changes = String::new();
for event in events {
sweep_ai::write_event(event, &mut recent_changes).unwrap();
}
let file_chunks = recent_buffer_snapshots
.into_iter()
.map(|snapshot| {
let end_point = language::Point::new(30, 0).min(snapshot.max_point());
sweep_ai::FileChunk {
content: snapshot
.text_for_range(language::Point::zero()..end_point)
.collect(),
file_path: snapshot
.file()
.map(|f| f.path().as_unix_str())
.unwrap_or("untitled")
.to_string(),
start_line: 0,
end_line: end_point.row as usize,
timestamp: snapshot.file().and_then(|file| {
Some(
file.disk_state()
.mtime()?
.to_seconds_and_nanos_for_persistence()?
.0,
)
}),
}
})
.collect();
let request_body = sweep_ai::AutocompleteRequest {
debug_info,
repo_name,
file_path: full_path.clone(),
file_contents: text.clone(),
original_file_contents: text,
cursor_position: offset,
recent_changes: recent_changes.clone(),
changes_above_cursor: true,
multiple_suggestions: false,
branch: None,
file_chunks,
retrieval_chunks: vec![],
recent_user_actions: vec![],
// TODO
privacy_mode_enabled: false,
};
let mut buf: Vec<u8> = Vec::new();
let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
serde_json::to_writer(writer, &request_body)?;
let body: AsyncBody = buf.into();
const SWEEP_API_URL: &str =
"https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
let request = http_client::Request::builder()
.uri(SWEEP_API_URL)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_token))
.header("Connection", "keep-alive")
.header("Content-Encoding", "br")
.method(Method::POST)
.body(body)?;
let mut response = http_client.send(request).await?;
let mut body: Vec<u8> = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
if !response.status().is_success() {
anyhow::bail!(
"Request failed with status: {:?}\nBody: {}",
response.status(),
String::from_utf8_lossy(&body),
);
};
let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
let old_text = snapshot
.text_for_range(response.start_index..response.end_index)
.collect::<String>();
let edits = language::text_diff(&old_text, &response.completion)
.into_iter()
.map(|(range, text)| {
(
snapshot.anchor_after(response.start_index + range.start)
..snapshot.anchor_before(response.start_index + range.end),
text,
)
})
.collect::<Vec<_>>();
anyhow::Ok((response.autocomplete_id, edits, snapshot))
}
let file_chunks = recent_buffer_snapshots
.into_iter()
.map(|snapshot| {
let end_point = language::Point::new(30, 0).min(snapshot.max_point());
sweep_ai::FileChunk {
content: snapshot
.text_for_range(language::Point::zero()..end_point)
.collect(),
file_path: snapshot
.file()
.map(|f| f.path().as_unix_str())
.unwrap_or("untitled")
.to_string(),
start_line: 0,
end_line: end_point.row as usize,
timestamp: snapshot.file().and_then(|file| {
Some(
file.disk_state()
.mtime()?
.to_seconds_and_nanos_for_persistence()?
.0,
)
}),
}
})
.collect();
let request_body = sweep_ai::AutocompleteRequest {
debug_info,
repo_name,
file_path: full_path.clone(),
file_contents: text.clone(),
original_file_contents: text,
cursor_position: offset,
recent_changes: recent_changes.clone(),
changes_above_cursor: true,
multiple_suggestions: false,
branch: None,
file_chunks,
retrieval_chunks: vec![],
recent_user_actions: vec![],
// TODO
privacy_mode_enabled: false,
};
let mut buf: Vec<u8> = Vec::new();
let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
serde_json::to_writer(writer, &request_body)?;
let body: AsyncBody = buf.into();
const SWEEP_API_URL: &str =
"https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
let request = http_client::Request::builder()
.uri(SWEEP_API_URL)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_token))
.header("Connection", "keep-alive")
.header("Content-Encoding", "br")
.method(Method::POST)
.body(body)?;
let mut response = http_client.send(request).await?;
let mut body: Vec<u8> = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
if !response.status().is_success() {
anyhow::bail!(
"Request failed with status: {:?}\nBody: {}",
response.status(),
String::from_utf8_lossy(&body),
);
};
let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
let old_text = snapshot
.text_for_range(response.start_index..response.end_index)
.collect::<String>();
let edits = language::text_diff(&old_text, &response.completion)
.into_iter()
.map(|(range, text)| {
(
snapshot.anchor_after(response.start_index + range.start)
..snapshot.anchor_before(response.start_index + range.end),
text,
)
})
.collect::<Vec<_>>();
anyhow::Ok((response.autocomplete_id, edits, snapshot))
});
let buffer = active_buffer.clone();
let target_buffer = target_buffer.clone();
let http_client = cx.http_client();
let project = project.clone();
cx.spawn(async move |_, cx| {
cx.spawn(async move |this, cx| {
let (id, edits, old_snapshot) = result.await?;
if edits.is_empty() {
if edits.is_empty() && !events.is_empty() && allow_jump {
let cursor_point = position.to_point(&snapshot);
let jump_result = jump::predict_jump(
full_path,
cursor_point,
events,
project.clone(),
http_client,
cx,
)
.await?;
if let Some(jump) = jump_result {
return this
.update(cx, |this, cx| {
this.request_prediction_with_sweep(
&project,
&jump.buffer,
jump.anchor,
false,
cx,
)
})?
.await;
}
return anyhow::Ok(None);
}
let Some((edits, new_snapshot, preview_task)) =
buffer.read_with(cx, |buffer, cx| {
target_buffer.read_with(cx, |buffer, cx| {
let new_snapshot = buffer.snapshot();
let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
@@ -939,7 +974,7 @@ impl Zeta {
edits,
snapshot: new_snapshot,
edit_preview: preview_task.await,
buffer,
buffer: target_buffer,
};
anyhow::Ok(Some(prediction))