Add experimental mercury edit prediction provider (#44256)
Release Notes: - N/A --------- Co-authored-by: Ben Kunkle <ben@zed.dev> Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
11
assets/icons/inception.svg
Normal file
11
assets/icons/inception.svg
Normal file
@@ -0,0 +1,11 @@
|
||||
<svg width="28" height="28" viewBox="0 0 28 28" fill="none" id="svg1378540956_510">
|
||||
<g clip-path="url(#svg1378540956_510_clip0_1_1506)" transform="translate(4, 4) scale(0.857)">
|
||||
<path d="M17.0547 0.372066H8.52652L-0.00165176 8.90024V17.4284H8.52652V8.90024H17.0547V0.372066Z" fill="#1A1C20"></path>
|
||||
<path d="M10.1992 27.6279H18.7274L27.2556 19.0998V10.5716H18.7274V19.0998H10.1992V27.6279Z" fill="#1A1C20"></path>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="svg1378540956_510_clip0_1_1506">
|
||||
<rect width="27.2559" height="27.2559" fill="white" transform="translate(0 0.37207)"></rect>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 593 B |
78
crates/edit_prediction/src/cursor_excerpt.rs
Normal file
78
crates/edit_prediction/src/cursor_excerpt.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
use language::{BufferSnapshot, Point};
|
||||
use std::ops::Range;
|
||||
|
||||
pub fn editable_and_context_ranges_for_cursor_position(
|
||||
position: Point,
|
||||
snapshot: &BufferSnapshot,
|
||||
editable_region_token_limit: usize,
|
||||
context_token_limit: usize,
|
||||
) -> (Range<Point>, Range<Point>) {
|
||||
let mut scope_range = position..position;
|
||||
let mut remaining_edit_tokens = editable_region_token_limit;
|
||||
|
||||
while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
|
||||
let parent_tokens = guess_token_count(parent.byte_range().len());
|
||||
let parent_point_range = Point::new(
|
||||
parent.start_position().row as u32,
|
||||
parent.start_position().column as u32,
|
||||
)
|
||||
..Point::new(
|
||||
parent.end_position().row as u32,
|
||||
parent.end_position().column as u32,
|
||||
);
|
||||
if parent_point_range == scope_range {
|
||||
break;
|
||||
} else if parent_tokens <= editable_region_token_limit {
|
||||
scope_range = parent_point_range;
|
||||
remaining_edit_tokens = editable_region_token_limit - parent_tokens;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
|
||||
let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
|
||||
(editable_range, context_range)
|
||||
}
|
||||
|
||||
fn expand_range(
|
||||
snapshot: &BufferSnapshot,
|
||||
range: Range<Point>,
|
||||
mut remaining_tokens: usize,
|
||||
) -> Range<Point> {
|
||||
let mut expanded_range = range;
|
||||
expanded_range.start.column = 0;
|
||||
expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
|
||||
loop {
|
||||
let mut expanded = false;
|
||||
|
||||
if remaining_tokens > 0 && expanded_range.start.row > 0 {
|
||||
expanded_range.start.row -= 1;
|
||||
let line_tokens =
|
||||
guess_token_count(snapshot.line_len(expanded_range.start.row) as usize);
|
||||
remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
|
||||
expanded = true;
|
||||
}
|
||||
|
||||
if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
|
||||
expanded_range.end.row += 1;
|
||||
expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
|
||||
let line_tokens = guess_token_count(expanded_range.end.column as usize);
|
||||
remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
|
||||
expanded = true;
|
||||
}
|
||||
|
||||
if !expanded {
|
||||
break;
|
||||
}
|
||||
}
|
||||
expanded_range
|
||||
}
|
||||
|
||||
/// Typical number of string bytes per token for the purposes of limiting model input. This is
|
||||
/// intentionally low to err on the side of underestimating limits.
|
||||
pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
|
||||
|
||||
pub fn guess_token_count(bytes: usize) -> usize {
|
||||
bytes / BYTES_PER_TOKEN_GUESS
|
||||
}
|
||||
@@ -51,8 +51,11 @@ use thiserror::Error;
|
||||
use util::{RangeExt as _, ResultExt as _};
|
||||
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
|
||||
|
||||
mod cursor_excerpt;
|
||||
mod license_detection;
|
||||
pub mod mercury;
|
||||
mod onboarding_modal;
|
||||
pub mod open_ai_response;
|
||||
mod prediction;
|
||||
pub mod sweep_ai;
|
||||
pub mod udiff;
|
||||
@@ -65,6 +68,7 @@ pub mod zeta2;
|
||||
mod edit_prediction_tests;
|
||||
|
||||
use crate::license_detection::LicenseDetectionWatcher;
|
||||
use crate::mercury::Mercury;
|
||||
use crate::onboarding_modal::ZedPredictModal;
|
||||
pub use crate::prediction::EditPrediction;
|
||||
pub use crate::prediction::EditPredictionId;
|
||||
@@ -96,6 +100,12 @@ impl FeatureFlag for SweepFeatureFlag {
|
||||
const NAME: &str = "sweep-ai";
|
||||
}
|
||||
|
||||
pub struct MercuryFeatureFlag;
|
||||
|
||||
impl FeatureFlag for MercuryFeatureFlag {
|
||||
const NAME: &str = "mercury";
|
||||
}
|
||||
|
||||
pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
|
||||
context: EditPredictionExcerptOptions {
|
||||
max_bytes: 512,
|
||||
@@ -157,6 +167,7 @@ pub struct EditPredictionStore {
|
||||
eval_cache: Option<Arc<dyn EvalCache>>,
|
||||
edit_prediction_model: EditPredictionModel,
|
||||
pub sweep_ai: SweepAi,
|
||||
pub mercury: Mercury,
|
||||
data_collection_choice: DataCollectionChoice,
|
||||
reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
|
||||
shown_predictions: VecDeque<EditPrediction>,
|
||||
@@ -169,6 +180,7 @@ pub enum EditPredictionModel {
|
||||
Zeta1,
|
||||
Zeta2,
|
||||
Sweep,
|
||||
Mercury,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@@ -474,6 +486,7 @@ impl EditPredictionStore {
|
||||
eval_cache: None,
|
||||
edit_prediction_model: EditPredictionModel::Zeta2,
|
||||
sweep_ai: SweepAi::new(cx),
|
||||
mercury: Mercury::new(cx),
|
||||
data_collection_choice,
|
||||
reject_predictions_tx: reject_tx,
|
||||
rated_predictions: Default::default(),
|
||||
@@ -509,6 +522,15 @@ impl EditPredictionStore {
|
||||
.is_some()
|
||||
}
|
||||
|
||||
pub fn has_mercury_api_token(&self) -> bool {
|
||||
self.mercury
|
||||
.api_token
|
||||
.clone()
|
||||
.now_or_never()
|
||||
.flatten()
|
||||
.is_some()
|
||||
}
|
||||
|
||||
#[cfg(feature = "eval-support")]
|
||||
pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
|
||||
self.eval_cache = Some(cache);
|
||||
@@ -868,7 +890,7 @@ impl EditPredictionStore {
|
||||
fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
|
||||
match self.edit_prediction_model {
|
||||
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
|
||||
EditPredictionModel::Sweep => return,
|
||||
EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
|
||||
}
|
||||
|
||||
let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
|
||||
@@ -1013,7 +1035,7 @@ impl EditPredictionStore {
|
||||
) {
|
||||
match self.edit_prediction_model {
|
||||
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
|
||||
EditPredictionModel::Sweep => return,
|
||||
EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
|
||||
}
|
||||
|
||||
self.reject_predictions_tx
|
||||
@@ -1373,6 +1395,17 @@ impl EditPredictionStore {
|
||||
diagnostic_search_range.clone(),
|
||||
cx,
|
||||
),
|
||||
EditPredictionModel::Mercury => self.mercury.request_prediction(
|
||||
&project,
|
||||
&active_buffer,
|
||||
snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
&project_state.recent_paths,
|
||||
related_files,
|
||||
diagnostic_search_range.clone(),
|
||||
cx,
|
||||
),
|
||||
};
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
|
||||
@@ -1620,7 +1620,7 @@ async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut Te
|
||||
buffer.edit(
|
||||
[(
|
||||
0..0,
|
||||
" ".repeat(MAX_EVENT_TOKENS * zeta1::BYTES_PER_TOKEN_GUESS),
|
||||
" ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
|
||||
)],
|
||||
None,
|
||||
cx,
|
||||
|
||||
340
crates/edit_prediction/src/mercury.rs
Normal file
340
crates/edit_prediction/src/mercury.rs
Normal file
@@ -0,0 +1,340 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use cloud_llm_client::predict_edits_v3::Event;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use edit_prediction_context::RelatedFile;
|
||||
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
|
||||
use gpui::{
|
||||
App, AppContext as _, Entity, Task,
|
||||
http_client::{self, AsyncBody, Method},
|
||||
};
|
||||
use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _};
|
||||
use project::{Project, ProjectPath};
|
||||
use std::{
|
||||
collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response,
|
||||
prediction::EditPredictionResult,
|
||||
};
|
||||
|
||||
const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
|
||||
const MAX_CONTEXT_TOKENS: usize = 150;
|
||||
const MAX_REWRITE_TOKENS: usize = 350;
|
||||
|
||||
pub struct Mercury {
|
||||
pub api_token: Shared<Task<Option<String>>>,
|
||||
}
|
||||
|
||||
impl Mercury {
|
||||
pub fn new(cx: &App) -> Self {
|
||||
Mercury {
|
||||
api_token: load_api_token(cx).shared(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
|
||||
self.api_token = Task::ready(api_token.clone()).shared();
|
||||
store_api_token_in_keychain(api_token, cx)
|
||||
}
|
||||
|
||||
pub fn request_prediction(
|
||||
&self,
|
||||
_project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
snapshot: BufferSnapshot,
|
||||
position: language::Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
_recent_paths: &VecDeque<ProjectPath>,
|
||||
related_files: Vec<RelatedFile>,
|
||||
_diagnostic_search_range: Range<Point>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
|
||||
return Task::ready(Ok(None));
|
||||
};
|
||||
let full_path: Arc<Path> = snapshot
|
||||
.file()
|
||||
.map(|file| file.full_path(cx))
|
||||
.unwrap_or_else(|| "untitled".into())
|
||||
.into();
|
||||
|
||||
let http_client = cx.http_client();
|
||||
let cursor_point = position.to_point(&snapshot);
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
|
||||
let result = cx.background_spawn(async move {
|
||||
let (editable_range, context_range) =
|
||||
crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
|
||||
cursor_point,
|
||||
&snapshot,
|
||||
MAX_CONTEXT_TOKENS,
|
||||
MAX_REWRITE_TOKENS,
|
||||
);
|
||||
|
||||
let offset_range = editable_range.to_offset(&snapshot);
|
||||
let prompt = build_prompt(
|
||||
&events,
|
||||
&related_files,
|
||||
&snapshot,
|
||||
full_path.as_ref(),
|
||||
cursor_point,
|
||||
editable_range,
|
||||
context_range.clone(),
|
||||
);
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
events: events,
|
||||
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
|
||||
path: full_path.clone(),
|
||||
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
|
||||
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
|
||||
start_line: cloud_llm_client::predict_edits_v3::Line(
|
||||
context_range.start.row,
|
||||
),
|
||||
text: snapshot
|
||||
.text_for_range(context_range.clone())
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
}],
|
||||
}],
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
column: cursor_point.column,
|
||||
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
|
||||
},
|
||||
cursor_path: full_path.clone(),
|
||||
};
|
||||
|
||||
let request_body = open_ai::Request {
|
||||
model: "mercury-coder".into(),
|
||||
messages: vec![open_ai::RequestMessage::User {
|
||||
content: open_ai::MessageContent::Plain(prompt),
|
||||
}],
|
||||
stream: false,
|
||||
max_completion_tokens: None,
|
||||
stop: vec![],
|
||||
temperature: None,
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
tools: vec![],
|
||||
prompt_cache_key: None,
|
||||
reasoning_effort: None,
|
||||
};
|
||||
|
||||
let buf = serde_json::to_vec(&request_body)?;
|
||||
let body: AsyncBody = buf.into();
|
||||
|
||||
let request = http_client::Request::builder()
|
||||
.uri(MERCURY_API_URL)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_token))
|
||||
.header("Connection", "keep-alive")
|
||||
.method(Method::POST)
|
||||
.body(body)
|
||||
.context("Failed to create request")?;
|
||||
|
||||
let mut response = http_client
|
||||
.send(request)
|
||||
.await
|
||||
.context("Failed to send request")?;
|
||||
|
||||
let mut body: Vec<u8> = Vec::new();
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_end(&mut body)
|
||||
.await
|
||||
.context("Failed to read response body")?;
|
||||
|
||||
let response_received_at = Instant::now();
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!(
|
||||
"Request failed with status: {:?}\nBody: {}",
|
||||
response.status(),
|
||||
String::from_utf8_lossy(&body),
|
||||
);
|
||||
};
|
||||
|
||||
let mut response: open_ai::Response =
|
||||
serde_json::from_slice(&body).context("Failed to parse response")?;
|
||||
|
||||
let id = mem::take(&mut response.id);
|
||||
let response_str = text_from_response(response).unwrap_or_default();
|
||||
|
||||
let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
|
||||
let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
|
||||
|
||||
let mut edits = Vec::new();
|
||||
const NO_PREDICTION_OUTPUT: &str = "None";
|
||||
|
||||
if response_str != NO_PREDICTION_OUTPUT {
|
||||
let old_text = snapshot
|
||||
.text_for_range(offset_range.clone())
|
||||
.collect::<String>();
|
||||
edits.extend(
|
||||
language::text_diff(&old_text, &response_str)
|
||||
.into_iter()
|
||||
.map(|(range, text)| {
|
||||
(
|
||||
snapshot.anchor_after(offset_range.start + range.start)
|
||||
..snapshot.anchor_before(offset_range.start + range.end),
|
||||
text,
|
||||
)
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
|
||||
});
|
||||
|
||||
let buffer = active_buffer.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let (id, edits, old_snapshot, response_received_at, inputs) =
|
||||
result.await.context("Mercury edit prediction failed")?;
|
||||
anyhow::Ok(Some(
|
||||
EditPredictionResult::new(
|
||||
EditPredictionId(id.into()),
|
||||
&buffer,
|
||||
&old_snapshot,
|
||||
edits.into(),
|
||||
buffer_snapshotted_at,
|
||||
response_received_at,
|
||||
inputs,
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn build_prompt(
|
||||
events: &[Arc<Event>],
|
||||
related_files: &[RelatedFile],
|
||||
cursor_buffer: &BufferSnapshot,
|
||||
cursor_buffer_path: &Path,
|
||||
cursor_point: Point,
|
||||
editable_range: Range<Point>,
|
||||
context_range: Range<Point>,
|
||||
) -> String {
|
||||
const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
|
||||
const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
|
||||
const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
|
||||
const RECENTLY_VIEWED_SNIPPET_END: &str = "<|/recently_viewed_code_snippet|>\n";
|
||||
const CURRENT_FILE_CONTENT_START: &str = "<|current_file_content|>\n";
|
||||
const CURRENT_FILE_CONTENT_END: &str = "<|/current_file_content|>\n";
|
||||
const CODE_TO_EDIT_START: &str = "<|code_to_edit|>\n";
|
||||
const CODE_TO_EDIT_END: &str = "<|/code_to_edit|>\n";
|
||||
const EDIT_DIFF_HISTORY_START: &str = "<|edit_diff_history|>\n";
|
||||
const EDIT_DIFF_HISTORY_END: &str = "<|/edit_diff_history|>\n";
|
||||
const CURSOR_TAG: &str = "<|cursor|>";
|
||||
const CODE_SNIPPET_FILE_PATH_PREFIX: &str = "code_snippet_file_path: ";
|
||||
const CURRENT_FILE_PATH_PREFIX: &str = "current_file_path: ";
|
||||
|
||||
let mut prompt = String::new();
|
||||
|
||||
push_delimited(
|
||||
&mut prompt,
|
||||
RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
|
||||
|prompt| {
|
||||
for related_file in related_files {
|
||||
for related_excerpt in &related_file.excerpts {
|
||||
push_delimited(
|
||||
prompt,
|
||||
RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
|
||||
|prompt| {
|
||||
prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
|
||||
prompt.push_str(related_file.path.path.as_unix_str());
|
||||
prompt.push('\n');
|
||||
prompt.push_str(&related_excerpt.text.to_string());
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
push_delimited(
|
||||
&mut prompt,
|
||||
CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
|
||||
|prompt| {
|
||||
prompt.push_str(CURRENT_FILE_PATH_PREFIX);
|
||||
prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref());
|
||||
prompt.push('\n');
|
||||
|
||||
let prefix_range = context_range.start..editable_range.start;
|
||||
let suffix_range = editable_range.end..context_range.end;
|
||||
|
||||
prompt.extend(cursor_buffer.text_for_range(prefix_range));
|
||||
push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
|
||||
let range_before_cursor = editable_range.start..cursor_point;
|
||||
let range_after_cursor = cursor_point..editable_range.end;
|
||||
prompt.extend(cursor_buffer.text_for_range(range_before_cursor));
|
||||
prompt.push_str(CURSOR_TAG);
|
||||
prompt.extend(cursor_buffer.text_for_range(range_after_cursor));
|
||||
});
|
||||
prompt.extend(cursor_buffer.text_for_range(suffix_range));
|
||||
},
|
||||
);
|
||||
|
||||
push_delimited(
|
||||
&mut prompt,
|
||||
EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
|
||||
|prompt| {
|
||||
for event in events {
|
||||
writeln!(prompt, "{event}").unwrap();
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(&mut String)) {
|
||||
prompt.push_str(delimiters.start);
|
||||
cb(prompt);
|
||||
prompt.push_str(delimiters.end);
|
||||
}
|
||||
|
||||
pub const MERCURY_CREDENTIALS_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
|
||||
pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
|
||||
|
||||
pub fn load_api_token(cx: &App) -> Task<Option<String>> {
|
||||
if let Some(api_token) = std::env::var("MERCURY_AI_TOKEN")
|
||||
.ok()
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
return Task::ready(Some(api_token));
|
||||
}
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
cx.spawn(async move |cx| {
|
||||
let (_, credentials) = credentials_provider
|
||||
.read_credentials(MERCURY_CREDENTIALS_URL, &cx)
|
||||
.await
|
||||
.ok()??;
|
||||
String::from_utf8(credentials).ok()
|
||||
})
|
||||
}
|
||||
|
||||
fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
if let Some(api_token) = api_token {
|
||||
credentials_provider
|
||||
.write_credentials(
|
||||
MERCURY_CREDENTIALS_URL,
|
||||
MERCURY_CREDENTIALS_USERNAME,
|
||||
api_token.as_bytes(),
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
.context("Failed to save Mercury API token to system keychain")
|
||||
} else {
|
||||
credentials_provider
|
||||
.delete_credentials(MERCURY_CREDENTIALS_URL, cx)
|
||||
.await
|
||||
.context("Failed to delete Mercury API token from system keychain")
|
||||
}
|
||||
})
|
||||
}
|
||||
31
crates/edit_prediction/src/open_ai_response.rs
Normal file
31
crates/edit_prediction/src/open_ai_response.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
|
||||
let choice = res.choices.pop()?;
|
||||
let output_text = match choice.message {
|
||||
open_ai::RequestMessage::Assistant {
|
||||
content: Some(open_ai::MessageContent::Plain(content)),
|
||||
..
|
||||
} => content,
|
||||
open_ai::RequestMessage::Assistant {
|
||||
content: Some(open_ai::MessageContent::Multipart(mut content)),
|
||||
..
|
||||
} => {
|
||||
if content.is_empty() {
|
||||
log::error!("No output from Baseten completion response");
|
||||
return None;
|
||||
}
|
||||
|
||||
match content.remove(0) {
|
||||
open_ai::MessagePart::Text { text } => text,
|
||||
open_ai::MessagePart::Image { .. } => {
|
||||
log::error!("Expected text, got an image");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
log::error!("Invalid response message: {:?}", choice.message);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
Some(output_text)
|
||||
}
|
||||
@@ -1,9 +1,8 @@
|
||||
mod input_excerpt;
|
||||
|
||||
use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
|
||||
|
||||
use crate::{
|
||||
EditPredictionId, EditPredictionStore, ZedUpdateRequiredError,
|
||||
cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
|
||||
prediction::{EditPredictionInputs, EditPredictionResult},
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
@@ -12,7 +11,6 @@ use cloud_llm_client::{
|
||||
predict_edits_v3::Event,
|
||||
};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
|
||||
use input_excerpt::excerpt_for_cursor_position;
|
||||
use language::{
|
||||
Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
|
||||
};
|
||||
@@ -495,10 +493,174 @@ pub fn format_event(event: &Event) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
/// Typical number of string bytes per token for the purposes of limiting model input. This is
|
||||
/// intentionally low to err on the side of underestimating limits.
|
||||
pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
|
||||
|
||||
fn guess_token_count(bytes: usize) -> usize {
|
||||
bytes / BYTES_PER_TOKEN_GUESS
|
||||
#[derive(Debug)]
|
||||
pub struct InputExcerpt {
|
||||
pub context_range: Range<Point>,
|
||||
pub editable_range: Range<Point>,
|
||||
pub prompt: String,
|
||||
}
|
||||
|
||||
pub fn excerpt_for_cursor_position(
|
||||
position: Point,
|
||||
path: &str,
|
||||
snapshot: &BufferSnapshot,
|
||||
editable_region_token_limit: usize,
|
||||
context_token_limit: usize,
|
||||
) -> InputExcerpt {
|
||||
let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
|
||||
position,
|
||||
snapshot,
|
||||
editable_region_token_limit,
|
||||
context_token_limit,
|
||||
);
|
||||
|
||||
let mut prompt = String::new();
|
||||
|
||||
writeln!(&mut prompt, "```{path}").unwrap();
|
||||
if context_range.start == Point::zero() {
|
||||
writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
|
||||
}
|
||||
|
||||
for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
|
||||
prompt.push_str(chunk.text);
|
||||
}
|
||||
|
||||
push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
|
||||
|
||||
for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
|
||||
prompt.push_str(chunk.text);
|
||||
}
|
||||
write!(prompt, "\n```").unwrap();
|
||||
|
||||
InputExcerpt {
|
||||
context_range,
|
||||
editable_range,
|
||||
prompt,
|
||||
}
|
||||
}
|
||||
|
||||
fn push_editable_range(
|
||||
cursor_position: Point,
|
||||
snapshot: &BufferSnapshot,
|
||||
editable_range: Range<Point>,
|
||||
prompt: &mut String,
|
||||
) {
|
||||
writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
|
||||
for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
|
||||
prompt.push_str(chunk.text);
|
||||
}
|
||||
prompt.push_str(CURSOR_MARKER);
|
||||
for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
|
||||
prompt.push_str(chunk.text);
|
||||
}
|
||||
write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::{App, AppContext};
|
||||
use indoc::indoc;
|
||||
use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_excerpt_for_cursor_position(cx: &mut App) {
|
||||
let text = indoc! {r#"
|
||||
fn foo() {
|
||||
let x = 42;
|
||||
println!("Hello, world!");
|
||||
}
|
||||
|
||||
fn bar() {
|
||||
let x = 42;
|
||||
let mut sum = 0;
|
||||
for i in 0..x {
|
||||
sum += i;
|
||||
}
|
||||
println!("Sum: {}", sum);
|
||||
return sum;
|
||||
}
|
||||
|
||||
fn generate_random_numbers() -> Vec<i32> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut numbers = Vec::new();
|
||||
for _ in 0..5 {
|
||||
numbers.push(rng.random_range(1..101));
|
||||
}
|
||||
numbers
|
||||
}
|
||||
"#};
|
||||
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
|
||||
// Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
|
||||
// when a larger scope doesn't fit the editable region.
|
||||
let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
|
||||
assert_eq!(
|
||||
excerpt.prompt,
|
||||
indoc! {r#"
|
||||
```main.rs
|
||||
let x = 42;
|
||||
println!("Hello, world!");
|
||||
<|editable_region_start|>
|
||||
}
|
||||
|
||||
fn bar() {
|
||||
let x = 42;
|
||||
let mut sum = 0;
|
||||
for i in 0..x {
|
||||
sum += i;
|
||||
}
|
||||
println!("Sum: {}", sum);
|
||||
r<|user_cursor_is_here|>eturn sum;
|
||||
}
|
||||
|
||||
fn generate_random_numbers() -> Vec<i32> {
|
||||
<|editable_region_end|>
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut numbers = Vec::new();
|
||||
```"#}
|
||||
);
|
||||
|
||||
// The `bar` function won't fit within the editable region, so we resort to line-based expansion.
|
||||
let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
|
||||
assert_eq!(
|
||||
excerpt.prompt,
|
||||
indoc! {r#"
|
||||
```main.rs
|
||||
fn bar() {
|
||||
let x = 42;
|
||||
let mut sum = 0;
|
||||
<|editable_region_start|>
|
||||
for i in 0..x {
|
||||
sum += i;
|
||||
}
|
||||
println!("Sum: {}", sum);
|
||||
r<|user_cursor_is_here|>eturn sum;
|
||||
}
|
||||
|
||||
fn generate_random_numbers() -> Vec<i32> {
|
||||
let mut rng = rand::thread_rng();
|
||||
<|editable_region_end|>
|
||||
let mut numbers = Vec::new();
|
||||
for _ in 0..5 {
|
||||
numbers.push(rng.random_range(1..101));
|
||||
```"#}
|
||||
);
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,231 +0,0 @@
|
||||
use super::{
|
||||
CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, START_OF_FILE_MARKER,
|
||||
guess_token_count,
|
||||
};
|
||||
use language::{BufferSnapshot, Point};
|
||||
use std::{fmt::Write, ops::Range};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct InputExcerpt {
|
||||
pub context_range: Range<Point>,
|
||||
pub editable_range: Range<Point>,
|
||||
pub prompt: String,
|
||||
}
|
||||
|
||||
pub fn excerpt_for_cursor_position(
|
||||
position: Point,
|
||||
path: &str,
|
||||
snapshot: &BufferSnapshot,
|
||||
editable_region_token_limit: usize,
|
||||
context_token_limit: usize,
|
||||
) -> InputExcerpt {
|
||||
let mut scope_range = position..position;
|
||||
let mut remaining_edit_tokens = editable_region_token_limit;
|
||||
|
||||
while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
|
||||
let parent_tokens = guess_token_count(parent.byte_range().len());
|
||||
let parent_point_range = Point::new(
|
||||
parent.start_position().row as u32,
|
||||
parent.start_position().column as u32,
|
||||
)
|
||||
..Point::new(
|
||||
parent.end_position().row as u32,
|
||||
parent.end_position().column as u32,
|
||||
);
|
||||
if parent_point_range == scope_range {
|
||||
break;
|
||||
} else if parent_tokens <= editable_region_token_limit {
|
||||
scope_range = parent_point_range;
|
||||
remaining_edit_tokens = editable_region_token_limit - parent_tokens;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
|
||||
let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
|
||||
|
||||
let mut prompt = String::new();
|
||||
|
||||
writeln!(&mut prompt, "```{path}").unwrap();
|
||||
if context_range.start == Point::zero() {
|
||||
writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
|
||||
}
|
||||
|
||||
for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
|
||||
prompt.push_str(chunk.text);
|
||||
}
|
||||
|
||||
push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
|
||||
|
||||
for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
|
||||
prompt.push_str(chunk.text);
|
||||
}
|
||||
write!(prompt, "\n```").unwrap();
|
||||
|
||||
InputExcerpt {
|
||||
context_range,
|
||||
editable_range,
|
||||
prompt,
|
||||
}
|
||||
}
|
||||
|
||||
fn push_editable_range(
|
||||
cursor_position: Point,
|
||||
snapshot: &BufferSnapshot,
|
||||
editable_range: Range<Point>,
|
||||
prompt: &mut String,
|
||||
) {
|
||||
writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
|
||||
for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
|
||||
prompt.push_str(chunk.text);
|
||||
}
|
||||
prompt.push_str(CURSOR_MARKER);
|
||||
for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
|
||||
prompt.push_str(chunk.text);
|
||||
}
|
||||
write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
|
||||
}
|
||||
|
||||
fn expand_range(
|
||||
snapshot: &BufferSnapshot,
|
||||
range: Range<Point>,
|
||||
mut remaining_tokens: usize,
|
||||
) -> Range<Point> {
|
||||
let mut expanded_range = range;
|
||||
expanded_range.start.column = 0;
|
||||
expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
|
||||
loop {
|
||||
let mut expanded = false;
|
||||
|
||||
if remaining_tokens > 0 && expanded_range.start.row > 0 {
|
||||
expanded_range.start.row -= 1;
|
||||
let line_tokens =
|
||||
guess_token_count(snapshot.line_len(expanded_range.start.row) as usize);
|
||||
remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
|
||||
expanded = true;
|
||||
}
|
||||
|
||||
if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
|
||||
expanded_range.end.row += 1;
|
||||
expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
|
||||
let line_tokens = guess_token_count(expanded_range.end.column as usize);
|
||||
remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
|
||||
expanded = true;
|
||||
}
|
||||
|
||||
if !expanded {
|
||||
break;
|
||||
}
|
||||
}
|
||||
expanded_range
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::{App, AppContext};
|
||||
use indoc::indoc;
|
||||
use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_excerpt_for_cursor_position(cx: &mut App) {
|
||||
let text = indoc! {r#"
|
||||
fn foo() {
|
||||
let x = 42;
|
||||
println!("Hello, world!");
|
||||
}
|
||||
|
||||
fn bar() {
|
||||
let x = 42;
|
||||
let mut sum = 0;
|
||||
for i in 0..x {
|
||||
sum += i;
|
||||
}
|
||||
println!("Sum: {}", sum);
|
||||
return sum;
|
||||
}
|
||||
|
||||
fn generate_random_numbers() -> Vec<i32> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut numbers = Vec::new();
|
||||
for _ in 0..5 {
|
||||
numbers.push(rng.random_range(1..101));
|
||||
}
|
||||
numbers
|
||||
}
|
||||
"#};
|
||||
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
|
||||
// Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
|
||||
// when a larger scope doesn't fit the editable region.
|
||||
let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
|
||||
assert_eq!(
|
||||
excerpt.prompt,
|
||||
indoc! {r#"
|
||||
```main.rs
|
||||
let x = 42;
|
||||
println!("Hello, world!");
|
||||
<|editable_region_start|>
|
||||
}
|
||||
|
||||
fn bar() {
|
||||
let x = 42;
|
||||
let mut sum = 0;
|
||||
for i in 0..x {
|
||||
sum += i;
|
||||
}
|
||||
println!("Sum: {}", sum);
|
||||
r<|user_cursor_is_here|>eturn sum;
|
||||
}
|
||||
|
||||
fn generate_random_numbers() -> Vec<i32> {
|
||||
<|editable_region_end|>
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut numbers = Vec::new();
|
||||
```"#}
|
||||
);
|
||||
|
||||
// The `bar` function won't fit within the editable region, so we resort to line-based expansion.
|
||||
let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
|
||||
assert_eq!(
|
||||
excerpt.prompt,
|
||||
indoc! {r#"
|
||||
```main.rs
|
||||
fn bar() {
|
||||
let x = 42;
|
||||
let mut sum = 0;
|
||||
<|editable_region_start|>
|
||||
for i in 0..x {
|
||||
sum += i;
|
||||
}
|
||||
println!("Sum: {}", sum);
|
||||
r<|user_cursor_is_here|>eturn sum;
|
||||
}
|
||||
|
||||
fn generate_random_numbers() -> Vec<i32> {
|
||||
let mut rng = rand::thread_rng();
|
||||
<|editable_region_end|>
|
||||
let mut numbers = Vec::new();
|
||||
for _ in 0..5 {
|
||||
numbers.push(rng.random_range(1..101));
|
||||
```"#}
|
||||
);
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
#[cfg(feature = "eval-support")]
|
||||
use crate::EvalCacheEntryKind;
|
||||
use crate::open_ai_response::text_from_response;
|
||||
use crate::prediction::EditPredictionResult;
|
||||
use crate::{
|
||||
DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
|
||||
@@ -199,7 +200,7 @@ pub fn request_prediction_with_zeta2(
|
||||
stream: false,
|
||||
max_completion_tokens: None,
|
||||
stop: generation_params.stop.unwrap_or_default(),
|
||||
temperature: generation_params.temperature.unwrap_or(0.7),
|
||||
temperature: generation_params.temperature.or(Some(0.7)),
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
tools: vec![],
|
||||
@@ -324,35 +325,3 @@ pub fn request_prediction_with_zeta2(
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
|
||||
let choice = res.choices.pop()?;
|
||||
let output_text = match choice.message {
|
||||
open_ai::RequestMessage::Assistant {
|
||||
content: Some(open_ai::MessageContent::Plain(content)),
|
||||
..
|
||||
} => content,
|
||||
open_ai::RequestMessage::Assistant {
|
||||
content: Some(open_ai::MessageContent::Multipart(mut content)),
|
||||
..
|
||||
} => {
|
||||
if content.is_empty() {
|
||||
log::error!("No output from Baseten completion response");
|
||||
return None;
|
||||
}
|
||||
|
||||
match content.remove(0) {
|
||||
open_ai::MessagePart::Text { text } => text,
|
||||
open_ai::MessagePart::Image { .. } => {
|
||||
log::error!("Expected text, got an image");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
log::error!("Invalid response message: {:?}", choice.message);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
Some(output_text)
|
||||
}
|
||||
|
||||
@@ -198,8 +198,9 @@ pub async fn perform_predict(
|
||||
|
||||
let response =
|
||||
request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
|
||||
let response = edit_prediction::zeta2::text_from_response(response)
|
||||
.unwrap_or_default();
|
||||
let response =
|
||||
edit_prediction::open_ai_response::text_from_response(response)
|
||||
.unwrap_or_default();
|
||||
let prediction_finished_at = Instant::now();
|
||||
fs::write(example_run_dir.join("prediction_response.md"), &response)?;
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ use client::{Client, UserStore, zed_urls};
|
||||
use cloud_llm_client::UsageLimit;
|
||||
use codestral::CodestralEditPredictionDelegate;
|
||||
use copilot::{Copilot, Status};
|
||||
use edit_prediction::{SweepFeatureFlag, Zeta2FeatureFlag};
|
||||
use edit_prediction::{MercuryFeatureFlag, SweepFeatureFlag, Zeta2FeatureFlag};
|
||||
use edit_prediction_types::EditPredictionDelegateHandle;
|
||||
use editor::{
|
||||
Editor, MultiBufferOffset, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll,
|
||||
@@ -23,6 +23,7 @@ use language::{
|
||||
use project::DisableAiSettings;
|
||||
use regex::Regex;
|
||||
use settings::{
|
||||
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore,
|
||||
update_settings_file,
|
||||
@@ -44,7 +45,7 @@ use workspace::{
|
||||
use zed_actions::OpenBrowser;
|
||||
|
||||
use crate::{
|
||||
RatePredictions, SweepApiKeyModal,
|
||||
ExternalProviderApiKeyModal, RatePredictions,
|
||||
rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag,
|
||||
};
|
||||
|
||||
@@ -311,21 +312,31 @@ impl Render for EditPredictionButton {
|
||||
provider @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
|
||||
let enabled = self.editor_enabled.unwrap_or(true);
|
||||
|
||||
let is_sweep = matches!(
|
||||
provider,
|
||||
let ep_icon;
|
||||
let mut missing_token = false;
|
||||
|
||||
match provider {
|
||||
EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
|
||||
)
|
||||
);
|
||||
|
||||
let sweep_missing_token = is_sweep
|
||||
&& !edit_prediction::EditPredictionStore::try_global(cx)
|
||||
.map_or(false, |ep_store| ep_store.read(cx).has_sweep_api_token());
|
||||
|
||||
let ep_icon = match (is_sweep, enabled) {
|
||||
(true, _) => IconName::SweepAi,
|
||||
(false, true) => IconName::ZedPredict,
|
||||
(false, false) => IconName::ZedPredictDisabled,
|
||||
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
) => {
|
||||
ep_icon = IconName::SweepAi;
|
||||
missing_token = edit_prediction::EditPredictionStore::try_global(cx)
|
||||
.is_some_and(|ep_store| !ep_store.read(cx).has_sweep_api_token());
|
||||
}
|
||||
EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
) => {
|
||||
ep_icon = IconName::Inception;
|
||||
missing_token = edit_prediction::EditPredictionStore::try_global(cx)
|
||||
.is_some_and(|ep_store| !ep_store.read(cx).has_mercury_api_token());
|
||||
}
|
||||
_ => {
|
||||
ep_icon = if enabled {
|
||||
IconName::ZedPredict
|
||||
} else {
|
||||
IconName::ZedPredictDisabled
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
if edit_prediction::should_show_upsell_modal() {
|
||||
@@ -369,7 +380,7 @@ impl Render for EditPredictionButton {
|
||||
let show_editor_predictions = self.editor_show_predictions;
|
||||
let user = self.user_store.read(cx).current_user();
|
||||
|
||||
let indicator_color = if sweep_missing_token {
|
||||
let indicator_color = if missing_token {
|
||||
Some(Color::Error)
|
||||
} else if enabled && (!show_editor_predictions || over_limit) {
|
||||
Some(if over_limit {
|
||||
@@ -532,6 +543,12 @@ impl EditPredictionButton {
|
||||
));
|
||||
}
|
||||
|
||||
if cx.has_flag::<MercuryFeatureFlag>() {
|
||||
providers.push(EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
));
|
||||
}
|
||||
|
||||
if cx.has_flag::<Zeta2FeatureFlag>() {
|
||||
providers.push(EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
@@ -628,7 +645,66 @@ impl EditPredictionButton {
|
||||
if let Some(workspace) = window.root::<Workspace>().flatten() {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
workspace.toggle_modal(window, cx, |window, cx| {
|
||||
SweepApiKeyModal::new(window, cx)
|
||||
ExternalProviderApiKeyModal::new(
|
||||
window,
|
||||
cx,
|
||||
|api_key, store, cx| {
|
||||
store
|
||||
.sweep_ai
|
||||
.set_api_token(api_key, cx)
|
||||
.detach_and_log_err(cx);
|
||||
},
|
||||
)
|
||||
});
|
||||
});
|
||||
};
|
||||
} else {
|
||||
set_completion_provider(fs.clone(), cx, provider);
|
||||
}
|
||||
});
|
||||
|
||||
menu.item(entry)
|
||||
}
|
||||
EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
) => {
|
||||
let has_api_token = edit_prediction::EditPredictionStore::try_global(cx)
|
||||
.map_or(false, |ep_store| ep_store.read(cx).has_mercury_api_token());
|
||||
|
||||
let should_open_modal = !has_api_token || is_current;
|
||||
|
||||
let entry = if has_api_token {
|
||||
ContextMenuEntry::new("Mercury")
|
||||
.toggleable(IconPosition::Start, is_current)
|
||||
} else {
|
||||
ContextMenuEntry::new("Mercury")
|
||||
.icon(IconName::XCircle)
|
||||
.icon_color(Color::Error)
|
||||
.documentation_aside(
|
||||
DocumentationSide::Left,
|
||||
DocumentationEdge::Bottom,
|
||||
|_| {
|
||||
Label::new("Click to configure your Mercury API token")
|
||||
.into_any_element()
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
let entry = entry.handler(move |window, cx| {
|
||||
if should_open_modal {
|
||||
if let Some(workspace) = window.root::<Workspace>().flatten() {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
workspace.toggle_modal(window, cx, |window, cx| {
|
||||
ExternalProviderApiKeyModal::new(
|
||||
window,
|
||||
cx,
|
||||
|api_key, store, cx| {
|
||||
store
|
||||
.mercury
|
||||
.set_api_token(api_key, cx)
|
||||
.detach_and_log_err(cx);
|
||||
},
|
||||
)
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
mod edit_prediction_button;
|
||||
mod edit_prediction_context_view;
|
||||
mod external_provider_api_token_modal;
|
||||
mod rate_prediction_modal;
|
||||
mod sweep_api_token_modal;
|
||||
|
||||
use std::any::{Any as _, TypeId};
|
||||
|
||||
@@ -17,7 +17,7 @@ use ui::{App, prelude::*};
|
||||
use workspace::{SplitDirection, Workspace};
|
||||
|
||||
pub use edit_prediction_button::{EditPredictionButton, ToggleMenu};
|
||||
pub use sweep_api_token_modal::SweepApiKeyModal;
|
||||
pub use external_provider_api_token_modal::ExternalProviderApiKeyModal;
|
||||
|
||||
use crate::rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag;
|
||||
|
||||
|
||||
@@ -6,18 +6,24 @@ use ui::{Button, ButtonStyle, Clickable, Headline, HeadlineSize, prelude::*};
|
||||
use ui_input::InputField;
|
||||
use workspace::ModalView;
|
||||
|
||||
pub struct SweepApiKeyModal {
|
||||
pub struct ExternalProviderApiKeyModal {
|
||||
api_key_input: Entity<InputField>,
|
||||
focus_handle: FocusHandle,
|
||||
on_confirm: Box<dyn Fn(Option<String>, &mut EditPredictionStore, &mut App)>,
|
||||
}
|
||||
|
||||
impl SweepApiKeyModal {
|
||||
pub fn new(window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your Sweep API token"));
|
||||
impl ExternalProviderApiKeyModal {
|
||||
pub fn new(
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
on_confirm: impl Fn(Option<String>, &mut EditPredictionStore, &mut App) + 'static,
|
||||
) -> Self {
|
||||
let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your API key"));
|
||||
|
||||
Self {
|
||||
api_key_input,
|
||||
focus_handle: cx.focus_handle(),
|
||||
on_confirm: Box::new(on_confirm),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,39 +36,34 @@ impl SweepApiKeyModal {
|
||||
let api_key = (!api_key.trim().is_empty()).then_some(api_key);
|
||||
|
||||
if let Some(ep_store) = EditPredictionStore::try_global(cx) {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store
|
||||
.sweep_ai
|
||||
.set_api_token(api_key, cx)
|
||||
.detach_and_log_err(cx);
|
||||
});
|
||||
ep_store.update(cx, |ep_store, cx| (self.on_confirm)(api_key, ep_store, cx))
|
||||
}
|
||||
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
}
|
||||
|
||||
impl EventEmitter<DismissEvent> for SweepApiKeyModal {}
|
||||
impl EventEmitter<DismissEvent> for ExternalProviderApiKeyModal {}
|
||||
|
||||
impl ModalView for SweepApiKeyModal {}
|
||||
impl ModalView for ExternalProviderApiKeyModal {}
|
||||
|
||||
impl Focusable for SweepApiKeyModal {
|
||||
impl Focusable for ExternalProviderApiKeyModal {
|
||||
fn focus_handle(&self, _cx: &App) -> FocusHandle {
|
||||
self.focus_handle.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for SweepApiKeyModal {
|
||||
impl Render for ExternalProviderApiKeyModal {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
v_flex()
|
||||
.key_context("SweepApiKeyModal")
|
||||
.key_context("ExternalApiKeyModal")
|
||||
.on_action(cx.listener(Self::cancel))
|
||||
.on_action(cx.listener(Self::confirm))
|
||||
.elevation_2(cx)
|
||||
.w(px(400.))
|
||||
.p_4()
|
||||
.gap_3()
|
||||
.child(Headline::new("Sweep API Token").size(HeadlineSize::Small))
|
||||
.child(Headline::new("API Token").size(HeadlineSize::Small))
|
||||
.child(self.api_key_input.clone())
|
||||
.child(
|
||||
h_flex()
|
||||
@@ -34,8 +34,8 @@ pub enum IconName {
|
||||
ArrowRightLeft,
|
||||
ArrowUp,
|
||||
ArrowUpRight,
|
||||
Attach,
|
||||
AtSign,
|
||||
Attach,
|
||||
AudioOff,
|
||||
AudioOn,
|
||||
Backspace,
|
||||
@@ -45,8 +45,8 @@ pub enum IconName {
|
||||
BellRing,
|
||||
Binary,
|
||||
Blocks,
|
||||
BoltOutlined,
|
||||
BoltFilled,
|
||||
BoltOutlined,
|
||||
Book,
|
||||
BookCopy,
|
||||
CaseSensitive,
|
||||
@@ -80,9 +80,9 @@ pub enum IconName {
|
||||
Debug,
|
||||
DebugBreakpoint,
|
||||
DebugContinue,
|
||||
DebugDetach,
|
||||
DebugDisabledBreakpoint,
|
||||
DebugDisabledLogBreakpoint,
|
||||
DebugDetach,
|
||||
DebugIgnoreBreakpoints,
|
||||
DebugLogBreakpoint,
|
||||
DebugPause,
|
||||
@@ -140,6 +140,7 @@ pub enum IconName {
|
||||
Hash,
|
||||
HistoryRerun,
|
||||
Image,
|
||||
Inception,
|
||||
Indicator,
|
||||
Info,
|
||||
Json,
|
||||
@@ -147,6 +148,7 @@ pub enum IconName {
|
||||
Library,
|
||||
LineHeight,
|
||||
Link,
|
||||
Linux,
|
||||
ListCollapse,
|
||||
ListFilter,
|
||||
ListTodo,
|
||||
@@ -172,8 +174,8 @@ pub enum IconName {
|
||||
PencilUnavailable,
|
||||
Person,
|
||||
Pin,
|
||||
PlayOutlined,
|
||||
PlayFilled,
|
||||
PlayOutlined,
|
||||
Plus,
|
||||
Power,
|
||||
Public,
|
||||
@@ -259,15 +261,14 @@ pub enum IconName {
|
||||
ZedAssistant,
|
||||
ZedBurnMode,
|
||||
ZedBurnModeOn,
|
||||
ZedSrcCustom,
|
||||
ZedSrcExtension,
|
||||
ZedPredict,
|
||||
ZedPredictDisabled,
|
||||
ZedPredictDown,
|
||||
ZedPredictError,
|
||||
ZedPredictUp,
|
||||
ZedSrcCustom,
|
||||
ZedSrcExtension,
|
||||
ZedXCopilot,
|
||||
Linux,
|
||||
}
|
||||
|
||||
impl IconName {
|
||||
|
||||
@@ -438,7 +438,7 @@ pub fn into_open_ai(
|
||||
messages,
|
||||
stream,
|
||||
stop: request.stop,
|
||||
temperature: request.temperature.unwrap_or(1.0),
|
||||
temperature: request.temperature.or(Some(1.0)),
|
||||
max_completion_tokens: max_output_tokens,
|
||||
parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
|
||||
// Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
|
||||
|
||||
@@ -266,7 +266,8 @@ pub struct Request {
|
||||
pub max_completion_tokens: Option<u64>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: f32,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
/// Whether to enable parallel function calling during tool use.
|
||||
|
||||
@@ -81,6 +81,7 @@ pub enum EditPredictionProvider {
|
||||
|
||||
pub const EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME: &str = "sweep";
|
||||
pub const EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME: &str = "zeta2";
|
||||
pub const EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME: &str = "mercury";
|
||||
|
||||
impl<'de> Deserialize<'de> for EditPredictionProvider {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
@@ -111,6 +112,13 @@ impl<'de> Deserialize<'de> for EditPredictionProvider {
|
||||
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
)
|
||||
}
|
||||
Content::Experimental(name)
|
||||
if name == EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME =>
|
||||
{
|
||||
EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
)
|
||||
}
|
||||
Content::Experimental(name)
|
||||
if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME =>
|
||||
{
|
||||
|
||||
@@ -9,6 +9,7 @@ use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
|
||||
use language::language_settings::{EditPredictionProvider, all_language_settings};
|
||||
use language_models::MistralLanguageModelProvider;
|
||||
use settings::{
|
||||
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore,
|
||||
};
|
||||
@@ -219,6 +220,10 @@ fn assign_edit_prediction_provider(
|
||||
&& cx.has_flag::<Zeta2FeatureFlag>()
|
||||
{
|
||||
edit_prediction::EditPredictionModel::Zeta2
|
||||
} else if name == EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME
|
||||
&& cx.has_flag::<Zeta2FeatureFlag>()
|
||||
{
|
||||
edit_prediction::EditPredictionModel::Mercury
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user