Add experimental mercury edit prediction provider (#44256)

Release Notes:

- N/A

---------

Co-authored-by: Ben Kunkle <ben@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
Agus Zubiaga
2025-12-06 07:08:44 -03:00
committed by GitHub
parent 51b7d06a27
commit f08fd732a7
18 changed files with 809 additions and 323 deletions

View File

@@ -0,0 +1,11 @@
<svg width="28" height="28" viewBox="0 0 28 28" fill="none" id="svg1378540956_510">
<g clip-path="url(#svg1378540956_510_clip0_1_1506)" transform="translate(4, 4) scale(0.857)">
<path d="M17.0547 0.372066H8.52652L-0.00165176 8.90024V17.4284H8.52652V8.90024H17.0547V0.372066Z" fill="#1A1C20"></path>
<path d="M10.1992 27.6279H18.7274L27.2556 19.0998V10.5716H18.7274V19.0998H10.1992V27.6279Z" fill="#1A1C20"></path>
</g>
<defs>
<clipPath id="svg1378540956_510_clip0_1_1506">
<rect width="27.2559" height="27.2559" fill="white" transform="translate(0 0.37207)"></rect>
</clipPath>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 593 B

View File

@@ -0,0 +1,78 @@
use language::{BufferSnapshot, Point};
use std::ops::Range;
pub fn editable_and_context_ranges_for_cursor_position(
position: Point,
snapshot: &BufferSnapshot,
editable_region_token_limit: usize,
context_token_limit: usize,
) -> (Range<Point>, Range<Point>) {
let mut scope_range = position..position;
let mut remaining_edit_tokens = editable_region_token_limit;
while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
let parent_tokens = guess_token_count(parent.byte_range().len());
let parent_point_range = Point::new(
parent.start_position().row as u32,
parent.start_position().column as u32,
)
..Point::new(
parent.end_position().row as u32,
parent.end_position().column as u32,
);
if parent_point_range == scope_range {
break;
} else if parent_tokens <= editable_region_token_limit {
scope_range = parent_point_range;
remaining_edit_tokens = editable_region_token_limit - parent_tokens;
} else {
break;
}
}
let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
(editable_range, context_range)
}
fn expand_range(
snapshot: &BufferSnapshot,
range: Range<Point>,
mut remaining_tokens: usize,
) -> Range<Point> {
let mut expanded_range = range;
expanded_range.start.column = 0;
expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
loop {
let mut expanded = false;
if remaining_tokens > 0 && expanded_range.start.row > 0 {
expanded_range.start.row -= 1;
let line_tokens =
guess_token_count(snapshot.line_len(expanded_range.start.row) as usize);
remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
expanded = true;
}
if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
expanded_range.end.row += 1;
expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
let line_tokens = guess_token_count(expanded_range.end.column as usize);
remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
expanded = true;
}
if !expanded {
break;
}
}
expanded_range
}
/// Typical number of string bytes per token for the purposes of limiting model input. This is
/// intentionally low to err on the side of underestimating limits.
pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
pub fn guess_token_count(bytes: usize) -> usize {
bytes / BYTES_PER_TOKEN_GUESS
}

View File

@@ -51,8 +51,11 @@ use thiserror::Error;
use util::{RangeExt as _, ResultExt as _};
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
mod cursor_excerpt;
mod license_detection;
pub mod mercury;
mod onboarding_modal;
pub mod open_ai_response;
mod prediction;
pub mod sweep_ai;
pub mod udiff;
@@ -65,6 +68,7 @@ pub mod zeta2;
mod edit_prediction_tests;
use crate::license_detection::LicenseDetectionWatcher;
use crate::mercury::Mercury;
use crate::onboarding_modal::ZedPredictModal;
pub use crate::prediction::EditPrediction;
pub use crate::prediction::EditPredictionId;
@@ -96,6 +100,12 @@ impl FeatureFlag for SweepFeatureFlag {
const NAME: &str = "sweep-ai";
}
pub struct MercuryFeatureFlag;
impl FeatureFlag for MercuryFeatureFlag {
const NAME: &str = "mercury";
}
pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
context: EditPredictionExcerptOptions {
max_bytes: 512,
@@ -157,6 +167,7 @@ pub struct EditPredictionStore {
eval_cache: Option<Arc<dyn EvalCache>>,
edit_prediction_model: EditPredictionModel,
pub sweep_ai: SweepAi,
pub mercury: Mercury,
data_collection_choice: DataCollectionChoice,
reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
shown_predictions: VecDeque<EditPrediction>,
@@ -169,6 +180,7 @@ pub enum EditPredictionModel {
Zeta1,
Zeta2,
Sweep,
Mercury,
}
#[derive(Debug, Clone, PartialEq)]
@@ -474,6 +486,7 @@ impl EditPredictionStore {
eval_cache: None,
edit_prediction_model: EditPredictionModel::Zeta2,
sweep_ai: SweepAi::new(cx),
mercury: Mercury::new(cx),
data_collection_choice,
reject_predictions_tx: reject_tx,
rated_predictions: Default::default(),
@@ -509,6 +522,15 @@ impl EditPredictionStore {
.is_some()
}
pub fn has_mercury_api_token(&self) -> bool {
self.mercury
.api_token
.clone()
.now_or_never()
.flatten()
.is_some()
}
#[cfg(feature = "eval-support")]
pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
self.eval_cache = Some(cache);
@@ -868,7 +890,7 @@ impl EditPredictionStore {
fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
match self.edit_prediction_model {
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
EditPredictionModel::Sweep => return,
EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
}
let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
@@ -1013,7 +1035,7 @@ impl EditPredictionStore {
) {
match self.edit_prediction_model {
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
EditPredictionModel::Sweep => return,
EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
}
self.reject_predictions_tx
@@ -1373,6 +1395,17 @@ impl EditPredictionStore {
diagnostic_search_range.clone(),
cx,
),
EditPredictionModel::Mercury => self.mercury.request_prediction(
&project,
&active_buffer,
snapshot.clone(),
position,
events,
&project_state.recent_paths,
related_files,
diagnostic_search_range.clone(),
cx,
),
};
cx.spawn(async move |this, cx| {

View File

@@ -1620,7 +1620,7 @@ async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut Te
buffer.edit(
[(
0..0,
" ".repeat(MAX_EVENT_TOKENS * zeta1::BYTES_PER_TOKEN_GUESS),
" ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
)],
None,
cx,

View File

@@ -0,0 +1,340 @@
use anyhow::{Context as _, Result};
use cloud_llm_client::predict_edits_v3::Event;
use credentials_provider::CredentialsProvider;
use edit_prediction_context::RelatedFile;
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use gpui::{
App, AppContext as _, Entity, Task,
http_client::{self, AsyncBody, Method},
};
use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _};
use project::{Project, ProjectPath};
use std::{
collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant,
};
use crate::{
EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response,
prediction::EditPredictionResult,
};
const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
const MAX_CONTEXT_TOKENS: usize = 150;
const MAX_REWRITE_TOKENS: usize = 350;
pub struct Mercury {
pub api_token: Shared<Task<Option<String>>>,
}
impl Mercury {
pub fn new(cx: &App) -> Self {
Mercury {
api_token: load_api_token(cx).shared(),
}
}
pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
self.api_token = Task::ready(api_token.clone()).shared();
store_api_token_in_keychain(api_token, cx)
}
pub fn request_prediction(
&self,
_project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
snapshot: BufferSnapshot,
position: language::Anchor,
events: Vec<Arc<Event>>,
_recent_paths: &VecDeque<ProjectPath>,
related_files: Vec<RelatedFile>,
_diagnostic_search_range: Range<Point>,
cx: &mut App,
) -> Task<Result<Option<EditPredictionResult>>> {
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
return Task::ready(Ok(None));
};
let full_path: Arc<Path> = snapshot
.file()
.map(|file| file.full_path(cx))
.unwrap_or_else(|| "untitled".into())
.into();
let http_client = cx.http_client();
let cursor_point = position.to_point(&snapshot);
let buffer_snapshotted_at = Instant::now();
let result = cx.background_spawn(async move {
let (editable_range, context_range) =
crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
cursor_point,
&snapshot,
MAX_CONTEXT_TOKENS,
MAX_REWRITE_TOKENS,
);
let offset_range = editable_range.to_offset(&snapshot);
let prompt = build_prompt(
&events,
&related_files,
&snapshot,
full_path.as_ref(),
cursor_point,
editable_range,
context_range.clone(),
);
let inputs = EditPredictionInputs {
events: events,
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
path: full_path.clone(),
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
start_line: cloud_llm_client::predict_edits_v3::Line(
context_range.start.row,
),
text: snapshot
.text_for_range(context_range.clone())
.collect::<String>()
.into(),
}],
}],
cursor_point: cloud_llm_client::predict_edits_v3::Point {
column: cursor_point.column,
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
},
cursor_path: full_path.clone(),
};
let request_body = open_ai::Request {
model: "mercury-coder".into(),
messages: vec![open_ai::RequestMessage::User {
content: open_ai::MessageContent::Plain(prompt),
}],
stream: false,
max_completion_tokens: None,
stop: vec![],
temperature: None,
tool_choice: None,
parallel_tool_calls: None,
tools: vec![],
prompt_cache_key: None,
reasoning_effort: None,
};
let buf = serde_json::to_vec(&request_body)?;
let body: AsyncBody = buf.into();
let request = http_client::Request::builder()
.uri(MERCURY_API_URL)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_token))
.header("Connection", "keep-alive")
.method(Method::POST)
.body(body)
.context("Failed to create request")?;
let mut response = http_client
.send(request)
.await
.context("Failed to send request")?;
let mut body: Vec<u8> = Vec::new();
response
.body_mut()
.read_to_end(&mut body)
.await
.context("Failed to read response body")?;
let response_received_at = Instant::now();
if !response.status().is_success() {
anyhow::bail!(
"Request failed with status: {:?}\nBody: {}",
response.status(),
String::from_utf8_lossy(&body),
);
};
let mut response: open_ai::Response =
serde_json::from_slice(&body).context("Failed to parse response")?;
let id = mem::take(&mut response.id);
let response_str = text_from_response(response).unwrap_or_default();
let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
let mut edits = Vec::new();
const NO_PREDICTION_OUTPUT: &str = "None";
if response_str != NO_PREDICTION_OUTPUT {
let old_text = snapshot
.text_for_range(offset_range.clone())
.collect::<String>();
edits.extend(
language::text_diff(&old_text, &response_str)
.into_iter()
.map(|(range, text)| {
(
snapshot.anchor_after(offset_range.start + range.start)
..snapshot.anchor_before(offset_range.start + range.end),
text,
)
}),
);
}
anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
});
let buffer = active_buffer.clone();
cx.spawn(async move |cx| {
let (id, edits, old_snapshot, response_received_at, inputs) =
result.await.context("Mercury edit prediction failed")?;
anyhow::Ok(Some(
EditPredictionResult::new(
EditPredictionId(id.into()),
&buffer,
&old_snapshot,
edits.into(),
buffer_snapshotted_at,
response_received_at,
inputs,
cx,
)
.await,
))
})
}
}
fn build_prompt(
events: &[Arc<Event>],
related_files: &[RelatedFile],
cursor_buffer: &BufferSnapshot,
cursor_buffer_path: &Path,
cursor_point: Point,
editable_range: Range<Point>,
context_range: Range<Point>,
) -> String {
const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
const RECENTLY_VIEWED_SNIPPET_END: &str = "<|/recently_viewed_code_snippet|>\n";
const CURRENT_FILE_CONTENT_START: &str = "<|current_file_content|>\n";
const CURRENT_FILE_CONTENT_END: &str = "<|/current_file_content|>\n";
const CODE_TO_EDIT_START: &str = "<|code_to_edit|>\n";
const CODE_TO_EDIT_END: &str = "<|/code_to_edit|>\n";
const EDIT_DIFF_HISTORY_START: &str = "<|edit_diff_history|>\n";
const EDIT_DIFF_HISTORY_END: &str = "<|/edit_diff_history|>\n";
const CURSOR_TAG: &str = "<|cursor|>";
const CODE_SNIPPET_FILE_PATH_PREFIX: &str = "code_snippet_file_path: ";
const CURRENT_FILE_PATH_PREFIX: &str = "current_file_path: ";
let mut prompt = String::new();
push_delimited(
&mut prompt,
RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
|prompt| {
for related_file in related_files {
for related_excerpt in &related_file.excerpts {
push_delimited(
prompt,
RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
|prompt| {
prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
prompt.push_str(related_file.path.path.as_unix_str());
prompt.push('\n');
prompt.push_str(&related_excerpt.text.to_string());
},
);
}
}
},
);
push_delimited(
&mut prompt,
CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
|prompt| {
prompt.push_str(CURRENT_FILE_PATH_PREFIX);
prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref());
prompt.push('\n');
let prefix_range = context_range.start..editable_range.start;
let suffix_range = editable_range.end..context_range.end;
prompt.extend(cursor_buffer.text_for_range(prefix_range));
push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
let range_before_cursor = editable_range.start..cursor_point;
let range_after_cursor = cursor_point..editable_range.end;
prompt.extend(cursor_buffer.text_for_range(range_before_cursor));
prompt.push_str(CURSOR_TAG);
prompt.extend(cursor_buffer.text_for_range(range_after_cursor));
});
prompt.extend(cursor_buffer.text_for_range(suffix_range));
},
);
push_delimited(
&mut prompt,
EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
|prompt| {
for event in events {
writeln!(prompt, "{event}").unwrap();
}
},
);
prompt
}
fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(&mut String)) {
prompt.push_str(delimiters.start);
cb(prompt);
prompt.push_str(delimiters.end);
}
pub const MERCURY_CREDENTIALS_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
pub fn load_api_token(cx: &App) -> Task<Option<String>> {
if let Some(api_token) = std::env::var("MERCURY_AI_TOKEN")
.ok()
.filter(|value| !value.is_empty())
{
return Task::ready(Some(api_token));
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |cx| {
let (_, credentials) = credentials_provider
.read_credentials(MERCURY_CREDENTIALS_URL, &cx)
.await
.ok()??;
String::from_utf8(credentials).ok()
})
}
fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |cx| {
if let Some(api_token) = api_token {
credentials_provider
.write_credentials(
MERCURY_CREDENTIALS_URL,
MERCURY_CREDENTIALS_USERNAME,
api_token.as_bytes(),
cx,
)
.await
.context("Failed to save Mercury API token to system keychain")
} else {
credentials_provider
.delete_credentials(MERCURY_CREDENTIALS_URL, cx)
.await
.context("Failed to delete Mercury API token from system keychain")
}
})
}

View File

@@ -0,0 +1,31 @@
pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
let choice = res.choices.pop()?;
let output_text = match choice.message {
open_ai::RequestMessage::Assistant {
content: Some(open_ai::MessageContent::Plain(content)),
..
} => content,
open_ai::RequestMessage::Assistant {
content: Some(open_ai::MessageContent::Multipart(mut content)),
..
} => {
if content.is_empty() {
log::error!("No output from Baseten completion response");
return None;
}
match content.remove(0) {
open_ai::MessagePart::Text { text } => text,
open_ai::MessagePart::Image { .. } => {
log::error!("Expected text, got an image");
return None;
}
}
}
_ => {
log::error!("Invalid response message: {:?}", choice.message);
return None;
}
};
Some(output_text)
}

View File

@@ -1,9 +1,8 @@
mod input_excerpt;
use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
use crate::{
EditPredictionId, EditPredictionStore, ZedUpdateRequiredError,
cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
prediction::{EditPredictionInputs, EditPredictionResult},
};
use anyhow::{Context as _, Result};
@@ -12,7 +11,6 @@ use cloud_llm_client::{
predict_edits_v3::Event,
};
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
use input_excerpt::excerpt_for_cursor_position;
use language::{
Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
};
@@ -495,10 +493,174 @@ pub fn format_event(event: &Event) -> String {
}
}
/// Typical number of string bytes per token for the purposes of limiting model input. This is
/// intentionally low to err on the side of underestimating limits.
pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
fn guess_token_count(bytes: usize) -> usize {
bytes / BYTES_PER_TOKEN_GUESS
#[derive(Debug)]
pub struct InputExcerpt {
pub context_range: Range<Point>,
pub editable_range: Range<Point>,
pub prompt: String,
}
pub fn excerpt_for_cursor_position(
position: Point,
path: &str,
snapshot: &BufferSnapshot,
editable_region_token_limit: usize,
context_token_limit: usize,
) -> InputExcerpt {
let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
position,
snapshot,
editable_region_token_limit,
context_token_limit,
);
let mut prompt = String::new();
writeln!(&mut prompt, "```{path}").unwrap();
if context_range.start == Point::zero() {
writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
}
for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
prompt.push_str(chunk.text);
}
push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
prompt.push_str(chunk.text);
}
write!(prompt, "\n```").unwrap();
InputExcerpt {
context_range,
editable_range,
prompt,
}
}
fn push_editable_range(
cursor_position: Point,
snapshot: &BufferSnapshot,
editable_range: Range<Point>,
prompt: &mut String,
) {
writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
prompt.push_str(chunk.text);
}
prompt.push_str(CURSOR_MARKER);
for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
prompt.push_str(chunk.text);
}
write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::{App, AppContext};
use indoc::indoc;
use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
use std::sync::Arc;
#[gpui::test]
fn test_excerpt_for_cursor_position(cx: &mut App) {
let text = indoc! {r#"
fn foo() {
let x = 42;
println!("Hello, world!");
}
fn bar() {
let x = 42;
let mut sum = 0;
for i in 0..x {
sum += i;
}
println!("Sum: {}", sum);
return sum;
}
fn generate_random_numbers() -> Vec<i32> {
let mut rng = rand::thread_rng();
let mut numbers = Vec::new();
for _ in 0..5 {
numbers.push(rng.random_range(1..101));
}
numbers
}
"#};
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
let snapshot = buffer.read(cx).snapshot();
// Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
// when a larger scope doesn't fit the editable region.
let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
assert_eq!(
excerpt.prompt,
indoc! {r#"
```main.rs
let x = 42;
println!("Hello, world!");
<|editable_region_start|>
}
fn bar() {
let x = 42;
let mut sum = 0;
for i in 0..x {
sum += i;
}
println!("Sum: {}", sum);
r<|user_cursor_is_here|>eturn sum;
}
fn generate_random_numbers() -> Vec<i32> {
<|editable_region_end|>
let mut rng = rand::thread_rng();
let mut numbers = Vec::new();
```"#}
);
// The `bar` function won't fit within the editable region, so we resort to line-based expansion.
let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
assert_eq!(
excerpt.prompt,
indoc! {r#"
```main.rs
fn bar() {
let x = 42;
let mut sum = 0;
<|editable_region_start|>
for i in 0..x {
sum += i;
}
println!("Sum: {}", sum);
r<|user_cursor_is_here|>eturn sum;
}
fn generate_random_numbers() -> Vec<i32> {
let mut rng = rand::thread_rng();
<|editable_region_end|>
let mut numbers = Vec::new();
for _ in 0..5 {
numbers.push(rng.random_range(1..101));
```"#}
);
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_rust::LANGUAGE.into()),
)
}
}

View File

@@ -1,231 +0,0 @@
use super::{
CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, START_OF_FILE_MARKER,
guess_token_count,
};
use language::{BufferSnapshot, Point};
use std::{fmt::Write, ops::Range};
#[derive(Debug)]
pub struct InputExcerpt {
pub context_range: Range<Point>,
pub editable_range: Range<Point>,
pub prompt: String,
}
pub fn excerpt_for_cursor_position(
position: Point,
path: &str,
snapshot: &BufferSnapshot,
editable_region_token_limit: usize,
context_token_limit: usize,
) -> InputExcerpt {
let mut scope_range = position..position;
let mut remaining_edit_tokens = editable_region_token_limit;
while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
let parent_tokens = guess_token_count(parent.byte_range().len());
let parent_point_range = Point::new(
parent.start_position().row as u32,
parent.start_position().column as u32,
)
..Point::new(
parent.end_position().row as u32,
parent.end_position().column as u32,
);
if parent_point_range == scope_range {
break;
} else if parent_tokens <= editable_region_token_limit {
scope_range = parent_point_range;
remaining_edit_tokens = editable_region_token_limit - parent_tokens;
} else {
break;
}
}
let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
let mut prompt = String::new();
writeln!(&mut prompt, "```{path}").unwrap();
if context_range.start == Point::zero() {
writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
}
for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
prompt.push_str(chunk.text);
}
push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
prompt.push_str(chunk.text);
}
write!(prompt, "\n```").unwrap();
InputExcerpt {
context_range,
editable_range,
prompt,
}
}
fn push_editable_range(
cursor_position: Point,
snapshot: &BufferSnapshot,
editable_range: Range<Point>,
prompt: &mut String,
) {
writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
prompt.push_str(chunk.text);
}
prompt.push_str(CURSOR_MARKER);
for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
prompt.push_str(chunk.text);
}
write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
}
fn expand_range(
snapshot: &BufferSnapshot,
range: Range<Point>,
mut remaining_tokens: usize,
) -> Range<Point> {
let mut expanded_range = range;
expanded_range.start.column = 0;
expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
loop {
let mut expanded = false;
if remaining_tokens > 0 && expanded_range.start.row > 0 {
expanded_range.start.row -= 1;
let line_tokens =
guess_token_count(snapshot.line_len(expanded_range.start.row) as usize);
remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
expanded = true;
}
if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
expanded_range.end.row += 1;
expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
let line_tokens = guess_token_count(expanded_range.end.column as usize);
remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
expanded = true;
}
if !expanded {
break;
}
}
expanded_range
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::{App, AppContext};
use indoc::indoc;
use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
use std::sync::Arc;
#[gpui::test]
fn test_excerpt_for_cursor_position(cx: &mut App) {
let text = indoc! {r#"
fn foo() {
let x = 42;
println!("Hello, world!");
}
fn bar() {
let x = 42;
let mut sum = 0;
for i in 0..x {
sum += i;
}
println!("Sum: {}", sum);
return sum;
}
fn generate_random_numbers() -> Vec<i32> {
let mut rng = rand::thread_rng();
let mut numbers = Vec::new();
for _ in 0..5 {
numbers.push(rng.random_range(1..101));
}
numbers
}
"#};
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
let snapshot = buffer.read(cx).snapshot();
// Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
// when a larger scope doesn't fit the editable region.
let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
assert_eq!(
excerpt.prompt,
indoc! {r#"
```main.rs
let x = 42;
println!("Hello, world!");
<|editable_region_start|>
}
fn bar() {
let x = 42;
let mut sum = 0;
for i in 0..x {
sum += i;
}
println!("Sum: {}", sum);
r<|user_cursor_is_here|>eturn sum;
}
fn generate_random_numbers() -> Vec<i32> {
<|editable_region_end|>
let mut rng = rand::thread_rng();
let mut numbers = Vec::new();
```"#}
);
// The `bar` function won't fit within the editable region, so we resort to line-based expansion.
let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
assert_eq!(
excerpt.prompt,
indoc! {r#"
```main.rs
fn bar() {
let x = 42;
let mut sum = 0;
<|editable_region_start|>
for i in 0..x {
sum += i;
}
println!("Sum: {}", sum);
r<|user_cursor_is_here|>eturn sum;
}
fn generate_random_numbers() -> Vec<i32> {
let mut rng = rand::thread_rng();
<|editable_region_end|>
let mut numbers = Vec::new();
for _ in 0..5 {
numbers.push(rng.random_range(1..101));
```"#}
);
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_rust::LANGUAGE.into()),
)
}
}

View File

@@ -1,5 +1,6 @@
#[cfg(feature = "eval-support")]
use crate::EvalCacheEntryKind;
use crate::open_ai_response::text_from_response;
use crate::prediction::EditPredictionResult;
use crate::{
DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
@@ -199,7 +200,7 @@ pub fn request_prediction_with_zeta2(
stream: false,
max_completion_tokens: None,
stop: generation_params.stop.unwrap_or_default(),
temperature: generation_params.temperature.unwrap_or(0.7),
temperature: generation_params.temperature.or(Some(0.7)),
tool_choice: None,
parallel_tool_calls: None,
tools: vec![],
@@ -324,35 +325,3 @@ pub fn request_prediction_with_zeta2(
))
})
}
pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
let choice = res.choices.pop()?;
let output_text = match choice.message {
open_ai::RequestMessage::Assistant {
content: Some(open_ai::MessageContent::Plain(content)),
..
} => content,
open_ai::RequestMessage::Assistant {
content: Some(open_ai::MessageContent::Multipart(mut content)),
..
} => {
if content.is_empty() {
log::error!("No output from Baseten completion response");
return None;
}
match content.remove(0) {
open_ai::MessagePart::Text { text } => text,
open_ai::MessagePart::Image { .. } => {
log::error!("Expected text, got an image");
return None;
}
}
}
_ => {
log::error!("Invalid response message: {:?}", choice.message);
return None;
}
};
Some(output_text)
}

View File

@@ -198,8 +198,9 @@ pub async fn perform_predict(
let response =
request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
let response = edit_prediction::zeta2::text_from_response(response)
.unwrap_or_default();
let response =
edit_prediction::open_ai_response::text_from_response(response)
.unwrap_or_default();
let prediction_finished_at = Instant::now();
fs::write(example_run_dir.join("prediction_response.md"), &response)?;

View File

@@ -3,7 +3,7 @@ use client::{Client, UserStore, zed_urls};
use cloud_llm_client::UsageLimit;
use codestral::CodestralEditPredictionDelegate;
use copilot::{Copilot, Status};
use edit_prediction::{SweepFeatureFlag, Zeta2FeatureFlag};
use edit_prediction::{MercuryFeatureFlag, SweepFeatureFlag, Zeta2FeatureFlag};
use edit_prediction_types::EditPredictionDelegateHandle;
use editor::{
Editor, MultiBufferOffset, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll,
@@ -23,6 +23,7 @@ use language::{
use project::DisableAiSettings;
use regex::Regex;
use settings::{
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore,
update_settings_file,
@@ -44,7 +45,7 @@ use workspace::{
use zed_actions::OpenBrowser;
use crate::{
RatePredictions, SweepApiKeyModal,
ExternalProviderApiKeyModal, RatePredictions,
rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag,
};
@@ -311,21 +312,31 @@ impl Render for EditPredictionButton {
provider @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
let enabled = self.editor_enabled.unwrap_or(true);
let is_sweep = matches!(
provider,
let ep_icon;
let mut missing_token = false;
match provider {
EditPredictionProvider::Experimental(
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
)
);
let sweep_missing_token = is_sweep
&& !edit_prediction::EditPredictionStore::try_global(cx)
.map_or(false, |ep_store| ep_store.read(cx).has_sweep_api_token());
let ep_icon = match (is_sweep, enabled) {
(true, _) => IconName::SweepAi,
(false, true) => IconName::ZedPredict,
(false, false) => IconName::ZedPredictDisabled,
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
) => {
ep_icon = IconName::SweepAi;
missing_token = edit_prediction::EditPredictionStore::try_global(cx)
.is_some_and(|ep_store| !ep_store.read(cx).has_sweep_api_token());
}
EditPredictionProvider::Experimental(
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
) => {
ep_icon = IconName::Inception;
missing_token = edit_prediction::EditPredictionStore::try_global(cx)
.is_some_and(|ep_store| !ep_store.read(cx).has_mercury_api_token());
}
_ => {
ep_icon = if enabled {
IconName::ZedPredict
} else {
IconName::ZedPredictDisabled
};
}
};
if edit_prediction::should_show_upsell_modal() {
@@ -369,7 +380,7 @@ impl Render for EditPredictionButton {
let show_editor_predictions = self.editor_show_predictions;
let user = self.user_store.read(cx).current_user();
let indicator_color = if sweep_missing_token {
let indicator_color = if missing_token {
Some(Color::Error)
} else if enabled && (!show_editor_predictions || over_limit) {
Some(if over_limit {
@@ -532,6 +543,12 @@ impl EditPredictionButton {
));
}
if cx.has_flag::<MercuryFeatureFlag>() {
providers.push(EditPredictionProvider::Experimental(
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
));
}
if cx.has_flag::<Zeta2FeatureFlag>() {
providers.push(EditPredictionProvider::Experimental(
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
@@ -628,7 +645,66 @@ impl EditPredictionButton {
if let Some(workspace) = window.root::<Workspace>().flatten() {
workspace.update(cx, |workspace, cx| {
workspace.toggle_modal(window, cx, |window, cx| {
SweepApiKeyModal::new(window, cx)
ExternalProviderApiKeyModal::new(
window,
cx,
|api_key, store, cx| {
store
.sweep_ai
.set_api_token(api_key, cx)
.detach_and_log_err(cx);
},
)
});
});
};
} else {
set_completion_provider(fs.clone(), cx, provider);
}
});
menu.item(entry)
}
EditPredictionProvider::Experimental(
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
) => {
let has_api_token = edit_prediction::EditPredictionStore::try_global(cx)
.map_or(false, |ep_store| ep_store.read(cx).has_mercury_api_token());
let should_open_modal = !has_api_token || is_current;
let entry = if has_api_token {
ContextMenuEntry::new("Mercury")
.toggleable(IconPosition::Start, is_current)
} else {
ContextMenuEntry::new("Mercury")
.icon(IconName::XCircle)
.icon_color(Color::Error)
.documentation_aside(
DocumentationSide::Left,
DocumentationEdge::Bottom,
|_| {
Label::new("Click to configure your Mercury API token")
.into_any_element()
},
)
};
let entry = entry.handler(move |window, cx| {
if should_open_modal {
if let Some(workspace) = window.root::<Workspace>().flatten() {
workspace.update(cx, |workspace, cx| {
workspace.toggle_modal(window, cx, |window, cx| {
ExternalProviderApiKeyModal::new(
window,
cx,
|api_key, store, cx| {
store
.mercury
.set_api_token(api_key, cx)
.detach_and_log_err(cx);
},
)
});
});
};

View File

@@ -1,7 +1,7 @@
mod edit_prediction_button;
mod edit_prediction_context_view;
mod external_provider_api_token_modal;
mod rate_prediction_modal;
mod sweep_api_token_modal;
use std::any::{Any as _, TypeId};
@@ -17,7 +17,7 @@ use ui::{App, prelude::*};
use workspace::{SplitDirection, Workspace};
pub use edit_prediction_button::{EditPredictionButton, ToggleMenu};
pub use sweep_api_token_modal::SweepApiKeyModal;
pub use external_provider_api_token_modal::ExternalProviderApiKeyModal;
use crate::rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag;

View File

@@ -6,18 +6,24 @@ use ui::{Button, ButtonStyle, Clickable, Headline, HeadlineSize, prelude::*};
use ui_input::InputField;
use workspace::ModalView;
pub struct SweepApiKeyModal {
pub struct ExternalProviderApiKeyModal {
api_key_input: Entity<InputField>,
focus_handle: FocusHandle,
on_confirm: Box<dyn Fn(Option<String>, &mut EditPredictionStore, &mut App)>,
}
impl SweepApiKeyModal {
pub fn new(window: &mut Window, cx: &mut Context<Self>) -> Self {
let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your Sweep API token"));
impl ExternalProviderApiKeyModal {
pub fn new(
window: &mut Window,
cx: &mut Context<Self>,
on_confirm: impl Fn(Option<String>, &mut EditPredictionStore, &mut App) + 'static,
) -> Self {
let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your API key"));
Self {
api_key_input,
focus_handle: cx.focus_handle(),
on_confirm: Box::new(on_confirm),
}
}
@@ -30,39 +36,34 @@ impl SweepApiKeyModal {
let api_key = (!api_key.trim().is_empty()).then_some(api_key);
if let Some(ep_store) = EditPredictionStore::try_global(cx) {
ep_store.update(cx, |ep_store, cx| {
ep_store
.sweep_ai
.set_api_token(api_key, cx)
.detach_and_log_err(cx);
});
ep_store.update(cx, |ep_store, cx| (self.on_confirm)(api_key, ep_store, cx))
}
cx.emit(DismissEvent);
}
}
impl EventEmitter<DismissEvent> for SweepApiKeyModal {}
impl EventEmitter<DismissEvent> for ExternalProviderApiKeyModal {}
impl ModalView for SweepApiKeyModal {}
impl ModalView for ExternalProviderApiKeyModal {}
impl Focusable for SweepApiKeyModal {
impl Focusable for ExternalProviderApiKeyModal {
fn focus_handle(&self, _cx: &App) -> FocusHandle {
self.focus_handle.clone()
}
}
impl Render for SweepApiKeyModal {
impl Render for ExternalProviderApiKeyModal {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
.key_context("SweepApiKeyModal")
.key_context("ExternalApiKeyModal")
.on_action(cx.listener(Self::cancel))
.on_action(cx.listener(Self::confirm))
.elevation_2(cx)
.w(px(400.))
.p_4()
.gap_3()
.child(Headline::new("Sweep API Token").size(HeadlineSize::Small))
.child(Headline::new("API Token").size(HeadlineSize::Small))
.child(self.api_key_input.clone())
.child(
h_flex()

View File

@@ -34,8 +34,8 @@ pub enum IconName {
ArrowRightLeft,
ArrowUp,
ArrowUpRight,
Attach,
AtSign,
Attach,
AudioOff,
AudioOn,
Backspace,
@@ -45,8 +45,8 @@ pub enum IconName {
BellRing,
Binary,
Blocks,
BoltOutlined,
BoltFilled,
BoltOutlined,
Book,
BookCopy,
CaseSensitive,
@@ -80,9 +80,9 @@ pub enum IconName {
Debug,
DebugBreakpoint,
DebugContinue,
DebugDetach,
DebugDisabledBreakpoint,
DebugDisabledLogBreakpoint,
DebugDetach,
DebugIgnoreBreakpoints,
DebugLogBreakpoint,
DebugPause,
@@ -140,6 +140,7 @@ pub enum IconName {
Hash,
HistoryRerun,
Image,
Inception,
Indicator,
Info,
Json,
@@ -147,6 +148,7 @@ pub enum IconName {
Library,
LineHeight,
Link,
Linux,
ListCollapse,
ListFilter,
ListTodo,
@@ -172,8 +174,8 @@ pub enum IconName {
PencilUnavailable,
Person,
Pin,
PlayOutlined,
PlayFilled,
PlayOutlined,
Plus,
Power,
Public,
@@ -259,15 +261,14 @@ pub enum IconName {
ZedAssistant,
ZedBurnMode,
ZedBurnModeOn,
ZedSrcCustom,
ZedSrcExtension,
ZedPredict,
ZedPredictDisabled,
ZedPredictDown,
ZedPredictError,
ZedPredictUp,
ZedSrcCustom,
ZedSrcExtension,
ZedXCopilot,
Linux,
}
impl IconName {

View File

@@ -438,7 +438,7 @@ pub fn into_open_ai(
messages,
stream,
stop: request.stop,
temperature: request.temperature.unwrap_or(1.0),
temperature: request.temperature.or(Some(1.0)),
max_completion_tokens: max_output_tokens,
parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
// Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.

View File

@@ -266,7 +266,8 @@ pub struct Request {
pub max_completion_tokens: Option<u64>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stop: Vec<String>,
pub temperature: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
/// Whether to enable parallel function calling during tool use.

View File

@@ -81,6 +81,7 @@ pub enum EditPredictionProvider {
pub const EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME: &str = "sweep";
pub const EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME: &str = "zeta2";
pub const EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME: &str = "mercury";
impl<'de> Deserialize<'de> for EditPredictionProvider {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
@@ -111,6 +112,13 @@ impl<'de> Deserialize<'de> for EditPredictionProvider {
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
)
}
Content::Experimental(name)
if name == EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME =>
{
EditPredictionProvider::Experimental(
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
)
}
Content::Experimental(name)
if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME =>
{

View File

@@ -9,6 +9,7 @@ use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
use language::language_settings::{EditPredictionProvider, all_language_settings};
use language_models::MistralLanguageModelProvider;
use settings::{
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore,
};
@@ -219,6 +220,10 @@ fn assign_edit_prediction_provider(
&& cx.has_flag::<Zeta2FeatureFlag>()
{
edit_prediction::EditPredictionModel::Zeta2
} else if name == EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME
&& cx.has_flag::<Zeta2FeatureFlag>()
{
edit_prediction::EditPredictionModel::Mercury
} else {
return false;
}