Compare commits

...

21 Commits

Author SHA1 Message Date
Max Brunsfeld
9a0d6622f5 Rename zeta2 crate to zeta for now 2025-11-24 17:59:15 -08:00
Max Brunsfeld
2c74ed2c1c Remove the old zeta crate 2025-11-24 17:49:58 -08:00
Max Brunsfeld
987714dac9 Remove dbg 2025-11-24 17:48:09 -08:00
Max Brunsfeld
b8525ee54d Put back feature gating of edit prediction actions 2025-11-24 17:47:31 -08:00
Max Brunsfeld
e549da6962 Resurrect zeta onboarding modal 2025-11-24 17:43:22 -08:00
Max Brunsfeld
379abe9f7e Unused deps 2025-11-24 17:38:07 -08:00
Max Brunsfeld
2d20452c3a Get zeta1's tests running 2025-11-24 17:37:14 -08:00
Max Brunsfeld
4287437714 Provide EditPredictionInputs on debug event, not cloud request 2025-11-24 17:08:49 -08:00
Max Brunsfeld
5813dee688 Clippy 2025-11-24 16:19:52 -08:00
Max Brunsfeld
4c54e7a6d3 Remove unused deps 2025-11-24 16:17:42 -08:00
Max Brunsfeld
4d79f6dbf2 Eagerly convert events into the format used for prompting 2025-11-24 15:57:41 -08:00
Max Brunsfeld
c3d1e36eba Show prediction inputs in rate predictions modal
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-11-24 15:04:41 -08:00
Max Brunsfeld
6638641d94 Implement prediction rating, display prediction inputs
Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-11-24 14:22:05 -08:00
Max Brunsfeld
cfde8d1110 Merge branch 'main' into zeta1-in-zeta2 2025-11-24 12:08:47 -08:00
Max Brunsfeld
930ff3e255 Start work on integrating completion rating modal 2025-11-22 21:01:13 -08:00
Max Brunsfeld
9b6b2d7bf0 Fix crash in cancelling of completions 2025-11-22 08:46:11 -08:00
Max Brunsfeld
ca74a1fa15 Start work on removing duplicated cloud request code 2025-11-21 17:07:59 -08:00
Max Brunsfeld
a3cb30d2c2 Port outcome reporting to zeta2
Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-11-21 17:07:59 -08:00
Max Brunsfeld
936b23b5eb Allow switching between zeta1 and zeta2 via settings
And the edit prediction menu. Now, the zeta2 feature flag just makes it
*possible* to use zeta 2, but does not automatically switch you.

Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-11-21 14:48:54 -08:00
Max Brunsfeld
58e10147d3 Checkpoint - zeta1 in zeta2
Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-11-21 14:27:49 -08:00
Agus Zubiaga
daef3aaad1 checkpoint: Start porting zeta1 to zeta2 crate
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-11-21 18:36:43 -03:00
41 changed files with 5030 additions and 6225 deletions

65
Cargo.lock generated
View File

@@ -5309,7 +5309,6 @@ dependencies = [
"workspace",
"zed_actions",
"zeta",
"zeta2",
]
[[package]]
@@ -21316,7 +21315,6 @@ dependencies = [
"zed_actions",
"zed_env_vars",
"zeta",
"zeta2",
"zeta2_tools",
"zlog",
"zlog_settings",
@@ -21636,48 +21634,52 @@ dependencies = [
"ai_onboarding",
"anyhow",
"arrayvec",
"call",
"brotli",
"buffer_diff",
"client",
"clock",
"cloud_api_types",
"cloud_llm_client",
"cloud_zeta2_prompt",
"collections",
"command_palette_hooks",
"copilot",
"ctor",
"db",
"edit_prediction",
"edit_prediction_context",
"editor",
"feature_flags",
"fs",
"futures 0.3.31",
"gpui",
"http_client",
"indoc",
"itertools 0.14.0",
"language",
"language_model",
"log",
"lsp",
"markdown",
"menu",
"open_ai",
"parking_lot",
"postage",
"pretty_assertions",
"project",
"rand 0.9.2",
"regex",
"release_channel",
"reqwest_client",
"rpc",
"semver",
"serde",
"serde_json",
"settings",
"smol",
"strsim",
"strum 0.27.2",
"telemetry",
"telemetry_events",
"theme",
"thiserror 2.0.17",
"tree-sitter-go",
"tree-sitter-rust",
"ui",
"util",
"uuid",
@@ -21687,53 +21689,11 @@ dependencies = [
"zlog",
]
[[package]]
name = "zeta2"
version = "0.1.0"
dependencies = [
"anyhow",
"arrayvec",
"brotli",
"chrono",
"client",
"clock",
"cloud_llm_client",
"cloud_zeta2_prompt",
"collections",
"edit_prediction",
"edit_prediction_context",
"feature_flags",
"futures 0.3.31",
"gpui",
"indoc",
"language",
"language_model",
"log",
"lsp",
"open_ai",
"pretty_assertions",
"project",
"release_channel",
"semver",
"serde",
"serde_json",
"settings",
"smol",
"strsim",
"thiserror 2.0.17",
"util",
"uuid",
"workspace",
"worktree",
"zlog",
]
[[package]]
name = "zeta2_tools"
version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
"clap",
"client",
"cloud_llm_client",
@@ -21746,9 +21706,7 @@ dependencies = [
"gpui",
"indoc",
"language",
"log",
"multi_buffer",
"ordered-float 2.10.1",
"pretty_assertions",
"project",
"serde",
@@ -21760,7 +21718,7 @@ dependencies = [
"ui_input",
"util",
"workspace",
"zeta2",
"zeta",
"zlog",
]
@@ -21810,7 +21768,6 @@ dependencies = [
"util",
"watch",
"zeta",
"zeta2",
"zlog",
]

View File

@@ -201,7 +201,6 @@ members = [
"crates/zed_actions",
"crates/zed_env_vars",
"crates/zeta",
"crates/zeta2",
"crates/zeta_cli",
"crates/zlog",
"crates/zlog_settings",
@@ -433,7 +432,6 @@ zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
zed_env_vars = { path = "crates/zed_env_vars" }
zeta = { path = "crates/zeta" }
zeta2 = { path = "crates/zeta2" }
zlog = { path = "crates/zlog" }
zlog_settings = { path = "crates/zlog_settings" }

View File

@@ -1218,23 +1218,23 @@
}
},
{
"context": "RateCompletionModal",
"context": "RatePredictionsModal",
"use_key_equivalents": true,
"bindings": {
"cmd-shift-enter": "zeta::ThumbsUpActiveCompletion",
"cmd-shift-backspace": "zeta::ThumbsDownActiveCompletion",
"cmd-shift-enter": "zeta::ThumbsUpActivePrediction",
"cmd-shift-backspace": "zeta::ThumbsDownActivePrediction",
"shift-down": "zeta::NextEdit",
"shift-up": "zeta::PreviousEdit",
"right": "zeta::PreviewCompletion"
"right": "zeta::PreviewPrediction"
}
},
{
"context": "RateCompletionModal > Editor",
"context": "RatePredictionsModal > Editor",
"use_key_equivalents": true,
"bindings": {
"escape": "zeta::FocusCompletions",
"cmd-shift-enter": "zeta::ThumbsUpActiveCompletion",
"cmd-shift-backspace": "zeta::ThumbsDownActiveCompletion"
"escape": "zeta::FocusPredictions",
"cmd-shift-enter": "zeta::ThumbsUpActivePrediction",
"cmd-shift-backspace": "zeta::ThumbsDownActivePrediction"
}
},
{

View File

@@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};
use std::{
fmt::{Display, Write as _},
ops::{Add, Range, Sub},
path::{Path, PathBuf},
path::Path,
sync::Arc,
};
use strum::EnumIter;
@@ -17,7 +17,7 @@ pub struct PlanContextRetrievalRequest {
pub excerpt_path: Arc<Path>,
pub excerpt_line_range: Range<Line>,
pub cursor_file_max_row: Line,
pub events: Vec<Event>,
pub events: Vec<Arc<Event>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -36,7 +36,7 @@ pub struct PredictEditsRequest {
pub signatures: Vec<Signature>,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub referenced_declarations: Vec<ReferencedDeclaration>,
pub events: Vec<Event>,
pub events: Vec<Arc<Event>>,
#[serde(default)]
pub can_collect_data: bool,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
@@ -120,10 +120,11 @@ impl std::fmt::Display for PromptFormat {
#[serde(tag = "event")]
pub enum Event {
BufferChange {
path: Option<PathBuf>,
old_path: Option<PathBuf>,
path: Arc<Path>,
old_path: Arc<Path>,
diff: String,
predicted: bool,
in_open_source_repo: bool,
},
}
@@ -135,23 +136,21 @@ impl Display for Event {
old_path,
diff,
predicted,
..
} => {
let new_path = path.as_deref().unwrap_or(Path::new("untitled"));
let old_path = old_path.as_deref().unwrap_or(new_path);
if *predicted {
write!(
f,
"// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
DiffPathFmt(old_path),
DiffPathFmt(new_path)
DiffPathFmt(path)
)
} else {
write!(
f,
"--- a/{}\n+++ b/{}\n{diff}",
DiffPathFmt(old_path),
DiffPathFmt(new_path)
DiffPathFmt(path)
)
}
}
@@ -300,10 +299,11 @@ mod tests {
#[test]
fn test_event_display() {
let ev = Event::BufferChange {
path: None,
old_path: None,
path: Path::new("untitled").into(),
old_path: Path::new("untitled").into(),
diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
predicted: false,
in_open_source_repo: true,
};
assert_eq!(
ev.to_string(),
@@ -317,10 +317,11 @@ mod tests {
);
let ev = Event::BufferChange {
path: Some(PathBuf::from("foo/bar.txt")),
old_path: Some(PathBuf::from("foo/bar.txt")),
path: Path::new("foo/bar.txt").into(),
old_path: Path::new("foo/bar.txt").into(),
diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
predicted: false,
in_open_source_repo: true,
};
assert_eq!(
ev.to_string(),
@@ -334,10 +335,11 @@ mod tests {
);
let ev = Event::BufferChange {
path: Some(PathBuf::from("abc.txt")),
old_path: Some(PathBuf::from("123.txt")),
path: Path::new("abc.txt").into(),
old_path: Path::new("123.txt").into(),
diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
predicted: false,
in_open_source_repo: true,
};
assert_eq!(
ev.to_string(),
@@ -351,10 +353,11 @@ mod tests {
);
let ev = Event::BufferChange {
path: Some(PathBuf::from("abc.txt")),
old_path: Some(PathBuf::from("123.txt")),
path: Path::new("abc.txt").into(),
old_path: Path::new("123.txt").into(),
diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
predicted: true,
in_open_source_repo: true,
};
assert_eq!(
ev.to_string(),

View File

@@ -432,7 +432,7 @@ pub fn write_excerpts<'a>(
}
}
pub fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
pub fn push_events(output: &mut String, events: &[Arc<predict_edits_v3::Event>]) {
if events.is_empty() {
return;
};
@@ -910,7 +910,7 @@ fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle
}
struct PromptData {
events: Vec<Event>,
events: Vec<Arc<Event>>,
cursor_point: Point,
cursor_path: Arc<Path>, // TODO: make a common struct with cursor_point
included_files: Vec<IncludedFile>,

View File

@@ -35,7 +35,6 @@ ui.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zeta.workspace = true
zeta2.workspace = true
[dev-dependencies]
copilot = { workspace = true, features = ["test-support"] }

View File

@@ -21,7 +21,9 @@ use language::{
use project::DisableAiSettings;
use regex::Regex;
use settings::{
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore, update_settings_file,
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore,
update_settings_file,
};
use std::{
sync::{Arc, LazyLock},
@@ -38,7 +40,7 @@ use workspace::{
};
use zed_actions::OpenBrowser;
use zeta::RateCompletions;
use zeta2::SweepFeatureFlag;
use zeta::{SweepFeatureFlag, Zeta2FeatureFlag};
actions!(
edit_prediction,
@@ -300,10 +302,7 @@ impl Render for EditPredictionButton {
.with_handle(self.popover_menu_handle.clone()),
)
}
provider @ (EditPredictionProvider::Experimental(
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
)
| EditPredictionProvider::Zed) => {
provider @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
let enabled = self.editor_enabled.unwrap_or(true);
let is_sweep = matches!(
@@ -430,9 +429,7 @@ impl Render for EditPredictionButton {
div().child(popover_menu.into_any_element())
}
EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => {
div().hidden()
}
EditPredictionProvider::None => div().hidden(),
}
}
}
@@ -497,6 +494,12 @@ impl EditPredictionButton {
));
}
if cx.has_flag::<Zeta2FeatureFlag>() {
providers.push(EditPredictionProvider::Experimental(
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
));
}
providers
}
@@ -554,7 +557,7 @@ impl EditPredictionButton {
EditPredictionProvider::Experimental(
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
) => {
let has_api_token = zeta2::Zeta::try_global(cx)
let has_api_token = zeta::Zeta::try_global(cx)
.map_or(false, |zeta| zeta.read(cx).has_sweep_api_token());
let entry = ContextMenuEntry::new("Sweep")
@@ -571,6 +574,11 @@ impl EditPredictionButton {
menu.item(entry)
}
EditPredictionProvider::Experimental(
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
) => menu.entry("Zeta2", None, move |_, cx| {
set_completion_provider(fs.clone(), cx, provider);
}),
EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => {
continue;
}

View File

@@ -13,6 +13,7 @@ use crate::{
},
task_context::RunnableRange,
text_diff::text_diff,
unified_diff,
};
pub use crate::{
Grammar, Language, LanguageRegistry,
@@ -745,6 +746,33 @@ pub struct EditPreview {
}
impl EditPreview {
pub fn as_unified_diff(&self, edits: &[(Range<Anchor>, impl AsRef<str>)]) -> Option<String> {
let (first, _) = edits.first()?;
let (last, _) = edits.last()?;
let start = first.start.to_point(&self.old_snapshot);
let old_end = last.end.to_point(&self.old_snapshot);
let new_end = last
.end
.bias_right(&self.old_snapshot)
.to_point(&self.applied_edits_snapshot);
let start = Point::new(start.row.saturating_sub(3), 0);
let old_end = Point::new(old_end.row + 3, 0).min(self.old_snapshot.max_point());
let new_end = Point::new(new_end.row + 3, 0).min(self.applied_edits_snapshot.max_point());
Some(unified_diff(
&self
.old_snapshot
.text_for_range(start..old_end)
.collect::<String>(),
&self
.applied_edits_snapshot
.text_for_range(start..new_end)
.collect::<String>(),
))
}
pub fn highlight_edits(
&self,
current_snapshot: &BufferSnapshot,
@@ -758,6 +786,8 @@ impl EditPreview {
let mut highlighted_text = HighlightedTextBuilder::default();
let visible_range_in_preview_snapshot =
visible_range_in_preview_snapshot.to_offset(&self.applied_edits_snapshot);
let mut offset_in_preview_snapshot = visible_range_in_preview_snapshot.start;
let insertion_highlight_style = HighlightStyle {
@@ -825,7 +855,19 @@ impl EditPreview {
highlighted_text.build()
}
fn compute_visible_range<T>(&self, edits: &[(Range<Anchor>, T)]) -> Option<Range<usize>> {
pub fn build_result_buffer(&self, cx: &mut App) -> Entity<Buffer> {
cx.new(|cx| {
let mut buffer = Buffer::local_normalized(
self.applied_edits_snapshot.as_rope().clone(),
self.applied_edits_snapshot.line_ending(),
cx,
);
buffer.set_language(self.syntax_snapshot.root_language(), cx);
buffer
})
}
pub fn compute_visible_range<T>(&self, edits: &[(Range<Anchor>, T)]) -> Option<Range<Point>> {
let (first, _) = edits.first()?;
let (last, _) = edits.last()?;
@@ -842,7 +884,7 @@ impl EditPreview {
let range = Point::new(start.row, 0)
..Point::new(end.row, self.applied_edits_snapshot.line_len(end.row));
Some(range.to_offset(&self.applied_edits_snapshot))
Some(range)
}
}

View File

@@ -279,6 +279,13 @@ impl SyntaxSnapshot {
self.layers.is_empty()
}
pub fn root_language(&self) -> Option<Arc<Language>> {
match &self.layers.first()?.content {
SyntaxLayerContent::Parsed { language, .. } => Some(language.clone()),
SyntaxLayerContent::Pending { .. } => None,
}
}
pub fn update_count(&self) -> usize {
self.update_count
}

View File

@@ -78,6 +78,7 @@ pub enum EditPredictionProvider {
}
pub const EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME: &str = "sweep";
pub const EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME: &str = "zeta2";
impl<'de> Deserialize<'de> for EditPredictionProvider {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
@@ -101,17 +102,25 @@ impl<'de> Deserialize<'de> for EditPredictionProvider {
Content::Supermaven => EditPredictionProvider::Supermaven,
Content::Zed => EditPredictionProvider::Zed,
Content::Codestral => EditPredictionProvider::Codestral,
Content::Experimental(name)
if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME =>
{
EditPredictionProvider::Experimental(
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
)
}
Content::Experimental(name)
if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME =>
{
EditPredictionProvider::Experimental(
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
)
}
Content::Experimental(name) => {
if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME {
EditPredictionProvider::Experimental(
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
)
} else {
return Err(D::Error::custom(format!(
"Unknown experimental edit prediction provider: {}",
name
)));
}
return Err(D::Error::custom(format!(
"Unknown experimental edit prediction provider: {}",
name
)));
}
})
}

View File

@@ -161,7 +161,6 @@ workspace.workspace = true
zed_actions.workspace = true
zed_env_vars.workspace = true
zeta.workspace = true
zeta2.workspace = true
zlog.workspace = true
zlog_settings.workspace = true
chrono.workspace = true

View File

@@ -7,13 +7,14 @@ use feature_flags::FeatureFlagAppExt;
use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
use language::language_settings::{EditPredictionProvider, all_language_settings};
use language_models::MistralLanguageModelProvider;
use settings::{EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore};
use settings::{
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore,
};
use std::{cell::RefCell, rc::Rc, sync::Arc};
use supermaven::{Supermaven, SupermavenCompletionProvider};
use ui::Window;
use zeta::ZetaEditPredictionProvider;
use zeta2::SweepFeatureFlag;
use zeta2::Zeta2FeatureFlag;
use zeta::{SweepFeatureFlag, Zeta2FeatureFlag, ZetaEditPredictionProvider};
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default();
@@ -100,9 +101,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
}
fn clear_zeta_edit_history(_: &zeta::ClearHistory, cx: &mut App) {
if let Some(zeta) = zeta::Zeta::global(cx) {
zeta.update(cx, |zeta, _| zeta.clear_history());
} else if let Some(zeta) = zeta2::Zeta::try_global(cx) {
if let Some(zeta) = zeta::Zeta::try_global(cx) {
zeta.update(cx, |zeta, _| zeta.clear_history());
}
}
@@ -204,86 +203,41 @@ fn assign_edit_prediction_provider(
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
value @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
let zeta2 = zeta2::Zeta::global(client, &user_store, cx);
let zeta = zeta::Zeta::global(client, &user_store, cx);
if let Some(project) = editor.project() {
let mut worktree = None;
if let Some(buffer) = &singleton_buffer
&& let Some(file) = buffer.read(cx).file()
{
let id = file.worktree_id(cx);
worktree = project.read(cx).worktree_for_id(id, cx);
}
if let EditPredictionProvider::Experimental(name) = value
&& name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
&& cx.has_flag::<SweepFeatureFlag>()
{
let provider = cx.new(|cx| {
zeta2::ZetaEditPredictionProvider::new(
project.clone(),
&client,
&user_store,
cx,
)
});
if let Some(buffer) = &singleton_buffer
&& buffer.read(cx).file().is_some()
{
zeta2.update(cx, |zeta, cx| {
zeta.set_edit_prediction_model(zeta2::ZetaEditPredictionModel::Sweep);
zeta.register_buffer(buffer, project, cx);
});
}
editor.set_edit_prediction_provider(Some(provider), window, cx);
} else if user_store.read(cx).current_user().is_some() {
if cx.has_flag::<Zeta2FeatureFlag>() {
let zeta = zeta2::Zeta::global(client, &user_store, cx);
let provider = cx.new(|cx| {
zeta2::ZetaEditPredictionProvider::new(
project.clone(),
&client,
&user_store,
cx,
)
});
// TODO [zeta2] handle multibuffers
if let Some(buffer) = &singleton_buffer
&& buffer.read(cx).file().is_some()
if let Some(project) = editor.project()
&& let Some(buffer) = &singleton_buffer
&& buffer.read(cx).file().is_some()
{
let has_model = zeta.update(cx, |zeta, cx| {
let model = if let EditPredictionProvider::Experimental(name) = value {
if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
&& cx.has_flag::<SweepFeatureFlag>()
{
zeta.update(cx, |zeta, cx| {
zeta.set_edit_prediction_model(
zeta2::ZetaEditPredictionModel::ZedCloud,
);
zeta.register_buffer(buffer, project, cx);
});
zeta::ZetaEditPredictionModel::Sweep
} else if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME
&& cx.has_flag::<Zeta2FeatureFlag>()
{
zeta::ZetaEditPredictionModel::Zeta2
} else {
return false;
}
editor.set_edit_prediction_provider(Some(provider), window, cx);
} else if user_store.read(cx).current_user().is_some() {
zeta::ZetaEditPredictionModel::Zeta1
} else {
let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);
return false;
};
if let Some(buffer) = &singleton_buffer
&& buffer.read(cx).file().is_some()
{
zeta.update(cx, |zeta, cx| {
zeta.register_buffer(buffer, project, cx);
});
}
zeta.set_edit_prediction_model(model);
zeta.register_buffer(buffer, project, cx);
true
});
let provider = cx.new(|cx| {
zeta::ZetaEditPredictionProvider::new(
zeta,
project.clone(),
singleton_buffer,
cx,
)
});
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
if has_model {
let provider = cx.new(|cx| {
ZetaEditPredictionProvider::new(project.clone(), &client, &user_store, cx)
});
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
}
}

View File

@@ -4,81 +4,80 @@ version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
exclude = ["fixtures"]
[lints]
workspace = true
[lib]
path = "src/zeta.rs"
doctest = false
[features]
test-support = []
eval-support = []
[dependencies]
ai_onboarding.workspace = true
anyhow.workspace = true
arrayvec.workspace = true
brotli.workspace = true
buffer_diff.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
cloud_zeta2_prompt.workspace = true
copilot.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
copilot.workspace = true
db.workspace = true
edit_prediction.workspace = true
edit_prediction_context.workspace = true
editor.workspace = true
feature_flags.workspace = true
fs.workspace = true
futures.workspace = true
gpui.workspace = true
http_client.workspace = true
indoc.workspace = true
itertools.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
lsp.workspace = true
markdown.workspace = true
menu.workspace = true
open_ai.workspace = true
pretty_assertions.workspace = true
postage.workspace = true
project.workspace = true
rand.workspace = true
regex.workspace = true
release_channel.workspace = true
regex.workspace = true
semver.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
strsim.workspace = true
strum.workspace = true
telemetry.workspace = true
telemetry_events.workspace = true
theme.workspace = true
thiserror.workspace = true
ui.workspace = true
util.workspace = true
ui.workspace = true
uuid.workspace = true
workspace.workspace = true
worktree.workspace = true
zed_actions.workspace = true
[dev-dependencies]
call = { workspace = true, features = ["test-support"] }
client = { workspace = true, features = ["test-support"] }
clock = { workspace = true, features = ["test-support"] }
cloud_api_types.workspace = true
collections = { workspace = true, features = ["test-support"] }
cloud_llm_client = { workspace = true, features = ["test-support"] }
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
http_client = { workspace = true, features = ["test-support"] }
indoc.workspace = true
language = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
lsp.workspace = true
parking_lot.workspace = true
reqwest_client = { workspace = true, features = ["test-support"] }
rpc = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
theme = { workspace = true, features = ["test-support"] }
tree-sitter-go.workspace = true
tree-sitter-rust.workspace = true
workspace = { workspace = true, features = ["test-support"] }
worktree = { workspace = true, features = ["test-support"] }
zlog.workspace = true

View File

@@ -1,173 +0,0 @@
use std::cmp;
use crate::EditPrediction;
use gpui::{
AnyElement, App, BorderStyle, Bounds, Corners, Edges, HighlightStyle, Hsla, StyledText,
TextLayout, TextStyle, point, prelude::*, quad, size,
};
use language::OffsetRangeExt;
use settings::Settings;
use theme::ThemeSettings;
use ui::prelude::*;
pub struct CompletionDiffElement {
element: AnyElement,
text_layout: TextLayout,
cursor_offset: usize,
}
impl CompletionDiffElement {
pub fn new(completion: &EditPrediction, cx: &App) -> Self {
let mut diff = completion
.snapshot
.text_for_range(completion.excerpt_range.clone())
.collect::<String>();
let mut cursor_offset_in_diff = None;
let mut delta = 0;
let mut diff_highlights = Vec::new();
for (old_range, new_text) in completion.edits.iter() {
let old_range = old_range.to_offset(&completion.snapshot);
if cursor_offset_in_diff.is_none() && completion.cursor_offset <= old_range.end {
cursor_offset_in_diff =
Some(completion.cursor_offset - completion.excerpt_range.start + delta);
}
let old_start_in_diff = old_range.start - completion.excerpt_range.start + delta;
let old_end_in_diff = old_range.end - completion.excerpt_range.start + delta;
if old_start_in_diff < old_end_in_diff {
diff_highlights.push((
old_start_in_diff..old_end_in_diff,
HighlightStyle {
background_color: Some(cx.theme().status().deleted_background),
strikethrough: Some(gpui::StrikethroughStyle {
thickness: px(1.),
color: Some(cx.theme().colors().text_muted),
}),
..Default::default()
},
));
}
if !new_text.is_empty() {
diff.insert_str(old_end_in_diff, new_text);
diff_highlights.push((
old_end_in_diff..old_end_in_diff + new_text.len(),
HighlightStyle {
background_color: Some(cx.theme().status().created_background),
..Default::default()
},
));
delta += new_text.len();
}
}
let cursor_offset_in_diff = cursor_offset_in_diff
.unwrap_or_else(|| completion.cursor_offset - completion.excerpt_range.start + delta);
let settings = ThemeSettings::get_global(cx).clone();
let text_style = TextStyle {
color: cx.theme().colors().editor_foreground,
font_size: settings.buffer_font_size(cx).into(),
font_family: settings.buffer_font.family,
font_features: settings.buffer_font.features,
font_fallbacks: settings.buffer_font.fallbacks,
line_height: relative(settings.buffer_line_height.value()),
font_weight: settings.buffer_font.weight,
font_style: settings.buffer_font.style,
..Default::default()
};
let element = StyledText::new(diff).with_default_highlights(&text_style, diff_highlights);
let text_layout = element.layout().clone();
CompletionDiffElement {
element: element.into_any_element(),
text_layout,
cursor_offset: cursor_offset_in_diff,
}
}
}
impl IntoElement for CompletionDiffElement {
type Element = Self;
fn into_element(self) -> Self {
self
}
}
impl Element for CompletionDiffElement {
type RequestLayoutState = ();
type PrepaintState = ();
fn id(&self) -> Option<ElementId> {
None
}
fn source_location(&self) -> Option<&'static core::panic::Location<'static>> {
None
}
fn request_layout(
&mut self,
_id: Option<&gpui::GlobalElementId>,
_inspector_id: Option<&gpui::InspectorElementId>,
window: &mut Window,
cx: &mut App,
) -> (gpui::LayoutId, Self::RequestLayoutState) {
(self.element.request_layout(window, cx), ())
}
fn prepaint(
&mut self,
_id: Option<&gpui::GlobalElementId>,
_inspector_id: Option<&gpui::InspectorElementId>,
_bounds: gpui::Bounds<Pixels>,
_request_layout: &mut Self::RequestLayoutState,
window: &mut Window,
cx: &mut App,
) -> Self::PrepaintState {
self.element.prepaint(window, cx);
}
fn paint(
&mut self,
_id: Option<&gpui::GlobalElementId>,
_inspector_id: Option<&gpui::InspectorElementId>,
_bounds: gpui::Bounds<Pixels>,
_request_layout: &mut Self::RequestLayoutState,
_prepaint: &mut Self::PrepaintState,
window: &mut Window,
cx: &mut App,
) {
if let Some(position) = self.text_layout.position_for_index(self.cursor_offset) {
let bounds = self.text_layout.bounds();
let line_height = self.text_layout.line_height();
let line_width = self
.text_layout
.line_layout_for_index(self.cursor_offset)
.map_or(bounds.size.width, |layout| layout.width());
window.paint_quad(quad(
Bounds::new(
point(bounds.origin.x, position.y),
size(cmp::max(bounds.size.width, line_width), line_height),
),
Corners::default(),
cx.theme().colors().editor_active_line_background,
Edges::default(),
Hsla::transparent_black(),
BorderStyle::default(),
));
self.element.paint(window, cx);
window.paint_quad(quad(
Bounds::new(position, size(px(2.), line_height)),
Corners::default(),
cx.theme().players().local().cursor,
Edges::default(),
Hsla::transparent_black(),
BorderStyle::default(),
));
}
}
}

View File

@@ -1,110 +0,0 @@
use std::any::{Any, TypeId};
use command_palette_hooks::CommandPaletteFilter;
use feature_flags::{FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
use gpui::actions;
use language::language_settings::EditPredictionProvider;
use project::DisableAiSettings;
use settings::{Settings, SettingsStore, update_settings_file};
use ui::App;
use workspace::Workspace;
use crate::{RateCompletionModal, onboarding_modal::ZedPredictModal};
actions!(
edit_prediction,
[
/// Resets the edit prediction onboarding state.
ResetOnboarding,
/// Opens the rate completions modal.
RateCompletions
]
);
pub fn init(cx: &mut App) {
feature_gate_predict_edits_actions(cx);
cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
RateCompletionModal::toggle(workspace, window, cx);
}
});
workspace.register_action(
move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
ZedPredictModal::toggle(
workspace,
workspace.user_store().clone(),
workspace.client().clone(),
window,
cx,
)
},
);
workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
settings
.project
.all_languages
.features
.get_or_insert_default()
.edit_prediction_provider = Some(EditPredictionProvider::None)
});
});
})
.detach();
}
fn feature_gate_predict_edits_actions(cx: &mut App) {
let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
let zeta_all_action_types = [
TypeId::of::<RateCompletions>(),
TypeId::of::<ResetOnboarding>(),
zed_actions::OpenZedPredictOnboarding.type_id(),
TypeId::of::<crate::ClearHistory>(),
TypeId::of::<crate::ThumbsUpActiveCompletion>(),
TypeId::of::<crate::ThumbsDownActiveCompletion>(),
TypeId::of::<crate::NextEdit>(),
TypeId::of::<crate::PreviousEdit>(),
];
CommandPaletteFilter::update_global(cx, |filter, _cx| {
filter.hide_action_types(&rate_completion_action_types);
filter.hide_action_types(&reset_onboarding_action_types);
filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
});
cx.observe_global::<SettingsStore>(move |cx| {
let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
CommandPaletteFilter::update_global(cx, |filter, _cx| {
if is_ai_disabled {
filter.hide_action_types(&zeta_all_action_types);
} else if has_feature_flag {
filter.show_action_types(&rate_completion_action_types);
} else {
filter.hide_action_types(&rate_completion_action_types);
}
});
})
.detach();
cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
if !DisableAiSettings::get_global(cx).disable_ai {
if is_enabled {
CommandPaletteFilter::update_global(cx, |filter, _cx| {
filter.show_action_types(&rate_completion_action_types);
});
} else {
CommandPaletteFilter::update_global(cx, |filter, _cx| {
filter.hide_action_types(&rate_completion_action_types);
});
}
}
})
.detach();
}

View File

@@ -1,6 +1,6 @@
use std::sync::Arc;
use crate::{ZedPredictUpsell, onboarding_event};
use crate::ZedPredictUpsell;
use ai_onboarding::EditPredictionOnboarding;
use client::{Client, UserStore};
use db::kvp::Dismissable;
@@ -14,6 +14,16 @@ use settings::update_settings_file;
use ui::{Vector, VectorName, prelude::*};
use workspace::{ModalView, Workspace};
#[macro_export]
macro_rules! onboarding_event {
($name:expr) => {
telemetry::event!($name, source = "Edit Prediction Onboarding");
};
($name:expr, $($key:ident $(= $value:expr)?),+ $(,)?) => {
telemetry::event!($name, source = "Edit Prediction Onboarding", $($key $(= $value)?),+);
};
}
/// Introduces user to Zed's Edit Prediction feature
pub struct ZedPredictModal {
onboarding: Entity<EditPredictionOnboarding>,

View File

@@ -1,9 +0,0 @@
#[macro_export]
macro_rules! onboarding_event {
($name:expr) => {
telemetry::event!($name, source = "Edit Prediction Onboarding");
};
($name:expr, $($key:ident $(= $value:expr)?),+ $(,)?) => {
telemetry::event!($name, source = "Edit Prediction Onboarding", $($key $(= $value)?),+);
};
}

View File

@@ -1,7 +1,13 @@
use std::{ops::Range, sync::Arc};
use std::{
ops::Range,
path::Path,
sync::Arc,
time::{Duration, Instant},
};
use gpui::{AsyncApp, Entity, SharedString};
use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot};
use serde::Serialize;
#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct EditPredictionId(pub SharedString);
@@ -26,6 +32,17 @@ pub struct EditPrediction {
pub edit_preview: EditPreview,
// We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
pub buffer: Entity<Buffer>,
pub buffer_snapshotted_at: Instant,
pub response_received_at: Instant,
pub inputs: EditPredictionInputs,
}
#[derive(Debug, Clone, Serialize)]
pub struct EditPredictionInputs {
pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
pub included_files: Vec<cloud_llm_client::predict_edits_v3::IncludedFile>,
pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
pub cursor_path: Arc<Path>,
}
impl EditPrediction {
@@ -33,14 +50,17 @@ impl EditPrediction {
id: EditPredictionId,
edited_buffer: &Entity<Buffer>,
edited_buffer_snapshot: &BufferSnapshot,
edits: Vec<(Range<Anchor>, Arc<str>)>,
edits: Arc<[(Range<Anchor>, Arc<str>)]>,
buffer_snapshotted_at: Instant,
response_received_at: Instant,
inputs: EditPredictionInputs,
cx: &mut AsyncApp,
) -> Option<Self> {
let (edits, snapshot, edit_preview_task) = edited_buffer
.read_with(cx, |buffer, cx| {
let new_snapshot = buffer.snapshot();
let edits: Arc<[_]> =
interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits.into())?.into();
interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits)?.into();
Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
})
@@ -53,7 +73,10 @@ impl EditPrediction {
edits,
snapshot,
edit_preview,
inputs,
buffer: edited_buffer.clone(),
buffer_snapshotted_at,
response_received_at,
})
}
@@ -67,6 +90,10 @@ impl EditPrediction {
pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
self.snapshot.remote_id() == buffer.remote_id()
}
pub fn latency(&self) -> Duration {
self.response_received_at - self.buffer_snapshotted_at
}
}
impl std::fmt::Debug for EditPrediction {
@@ -147,6 +174,17 @@ mod tests {
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
buffer: buffer.clone(),
edit_preview,
inputs: EditPredictionInputs {
events: vec![],
included_files: vec![],
cursor_point: cloud_llm_client::predict_edits_v3::Point {
line: cloud_llm_client::predict_edits_v3::Line(0),
column: 0,
},
cursor_path: Path::new("path.txt").into(),
},
buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),
};
cx.update(|cx| {

View File

@@ -131,8 +131,14 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
}
fn discard(&mut self, cx: &mut Context<Self>) {
self.zeta.update(cx, |zeta, _cx| {
zeta.discard_current_prediction(&self.project);
self.zeta.update(cx, |zeta, cx| {
zeta.discard_current_prediction(&self.project, cx);
});
}
fn did_show(&mut self, cx: &mut Context<Self>) {
self.zeta.update(cx, |zeta, cx| {
zeta.did_show_current_prediction(&self.project, cx);
});
}
@@ -162,8 +168,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
let snapshot = buffer.snapshot();
let Some(edits) = prediction.interpolate(&snapshot) else {
self.zeta.update(cx, |zeta, _cx| {
zeta.discard_current_prediction(&self.project);
self.zeta.update(cx, |zeta, cx| {
zeta.discard_current_prediction(&self.project, cx);
});
return None;
};

View File

@@ -1,8 +1,18 @@
use crate::{CompletionDiffElement, EditPrediction, EditPredictionRating, Zeta};
use editor::Editor;
use gpui::{App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, actions, prelude::*};
use language::language_settings;
use crate::{EditPrediction, EditPredictionRating, Zeta};
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use cloud_zeta2_prompt::write_codeblock;
use editor::{Editor, ExcerptRange, MultiBuffer};
use gpui::{
App, BorderStyle, DismissEvent, EdgesRefinement, Entity, EventEmitter, FocusHandle, Focusable,
Length, StyleRefinement, TextStyleRefinement, Window, actions, prelude::*,
};
use language::{LanguageRegistry, Point, language_settings};
use markdown::{Markdown, MarkdownStyle};
use settings::Settings as _;
use std::fmt::Write;
use std::sync::Arc;
use std::time::Duration;
use theme::ThemeSettings;
use ui::{KeyBinding, List, ListItem, ListItemSpacing, Tooltip, prelude::*};
use workspace::{ModalView, Workspace};
@@ -10,41 +20,44 @@ actions!(
zeta,
[
/// Rates the active completion with a thumbs up.
ThumbsUpActiveCompletion,
ThumbsUpActivePrediction,
/// Rates the active completion with a thumbs down.
ThumbsDownActiveCompletion,
ThumbsDownActivePrediction,
/// Navigates to the next edit in the completion history.
NextEdit,
/// Navigates to the previous edit in the completion history.
PreviousEdit,
/// Focuses on the completions list.
FocusCompletions,
FocusPredictions,
/// Previews the selected completion.
PreviewCompletion,
PreviewPrediction,
]
);
pub struct RateCompletionModal {
pub struct RatePredictionsModal {
zeta: Entity<Zeta>,
active_completion: Option<ActiveCompletion>,
language_registry: Arc<LanguageRegistry>,
active_prediction: Option<ActivePrediction>,
selected_index: usize,
diff_editor: Entity<Editor>,
focus_handle: FocusHandle,
_subscription: gpui::Subscription,
current_view: RateCompletionView,
current_view: RatePredictionView,
}
struct ActiveCompletion {
completion: EditPrediction,
struct ActivePrediction {
prediction: EditPrediction,
feedback_editor: Entity<Editor>,
formatted_inputs: Entity<Markdown>,
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
enum RateCompletionView {
enum RatePredictionView {
SuggestedEdits,
RawInput,
}
impl RateCompletionView {
impl RatePredictionView {
pub fn name(&self) -> &'static str {
match self {
Self::SuggestedEdits => "Suggested Edits",
@@ -53,25 +66,42 @@ impl RateCompletionView {
}
}
impl RateCompletionModal {
impl RatePredictionsModal {
pub fn toggle(workspace: &mut Workspace, window: &mut Window, cx: &mut Context<Workspace>) {
if let Some(zeta) = Zeta::global(cx) {
workspace.toggle_modal(window, cx, |_window, cx| RateCompletionModal::new(zeta, cx));
if let Some(zeta) = Zeta::try_global(cx) {
let language_registry = workspace.app_state().languages.clone();
workspace.toggle_modal(window, cx, |window, cx| {
RatePredictionsModal::new(zeta, language_registry, window, cx)
});
telemetry::event!("Rate Completion Modal Open", source = "Edit Prediction");
telemetry::event!("Rate Prediction Modal Open", source = "Edit Prediction");
}
}
pub fn new(zeta: Entity<Zeta>, cx: &mut Context<Self>) -> Self {
pub fn new(
zeta: Entity<Zeta>,
language_registry: Arc<LanguageRegistry>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let subscription = cx.observe(&zeta, |_, _, cx| cx.notify());
Self {
zeta,
language_registry,
selected_index: 0,
focus_handle: cx.focus_handle(),
active_completion: None,
active_prediction: None,
_subscription: subscription,
current_view: RateCompletionView::SuggestedEdits,
diff_editor: cx.new(|cx| {
let multibuffer = cx.new(|_| MultiBuffer::new(language::Capability::ReadOnly));
let mut editor = Editor::for_multibuffer(multibuffer, None, window, cx);
editor.disable_inline_diagnostics();
editor.set_expand_all_diff_hunks(cx);
editor.set_show_git_diff_gutter(false, cx);
editor
}),
current_view: RatePredictionView::SuggestedEdits,
}
}
@@ -83,7 +113,7 @@ impl RateCompletionModal {
self.selected_index += 1;
self.selected_index = usize::min(
self.selected_index,
self.zeta.read(cx).shown_completions().count(),
self.zeta.read(cx).shown_predictions().count(),
);
cx.notify();
}
@@ -102,7 +132,7 @@ impl RateCompletionModal {
let next_index = self
.zeta
.read(cx)
.shown_completions()
.shown_predictions()
.skip(self.selected_index)
.enumerate()
.skip(1) // Skip straight to the next item
@@ -122,7 +152,7 @@ impl RateCompletionModal {
let prev_index = self
.zeta
.read(cx)
.shown_completions()
.shown_predictions()
.rev()
.skip((completions_len - 1) - self.selected_index)
.enumerate()
@@ -149,14 +179,14 @@ impl RateCompletionModal {
pub fn thumbs_up_active(
&mut self,
_: &ThumbsUpActiveCompletion,
_: &ThumbsUpActivePrediction,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.zeta.update(cx, |zeta, cx| {
if let Some(active) = &self.active_completion {
zeta.rate_completion(
&active.completion,
if let Some(active) = &self.active_prediction {
zeta.rate_prediction(
&active.prediction,
EditPredictionRating::Positive,
active.feedback_editor.read(cx).text(cx),
cx,
@@ -165,9 +195,9 @@ impl RateCompletionModal {
});
let current_completion = self
.active_completion
.active_prediction
.as_ref()
.map(|completion| completion.completion.clone());
.map(|completion| completion.prediction.clone());
self.select_completion(current_completion, false, window, cx);
self.select_next_edit(&Default::default(), window, cx);
self.confirm(&Default::default(), window, cx);
@@ -177,18 +207,18 @@ impl RateCompletionModal {
pub fn thumbs_down_active(
&mut self,
_: &ThumbsDownActiveCompletion,
_: &ThumbsDownActivePrediction,
window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some(active) = &self.active_completion {
if let Some(active) = &self.active_prediction {
if active.feedback_editor.read(cx).text(cx).is_empty() {
return;
}
self.zeta.update(cx, |zeta, cx| {
zeta.rate_completion(
&active.completion,
zeta.rate_prediction(
&active.prediction,
EditPredictionRating::Negative,
active.feedback_editor.read(cx).text(cx),
cx,
@@ -197,9 +227,9 @@ impl RateCompletionModal {
}
let current_completion = self
.active_completion
.active_prediction
.as_ref()
.map(|completion| completion.completion.clone());
.map(|completion| completion.prediction.clone());
self.select_completion(current_completion, false, window, cx);
self.select_next_edit(&Default::default(), window, cx);
self.confirm(&Default::default(), window, cx);
@@ -209,7 +239,7 @@ impl RateCompletionModal {
fn focus_completions(
&mut self,
_: &FocusCompletions,
_: &FocusPredictions,
window: &mut Window,
cx: &mut Context<Self>,
) {
@@ -219,14 +249,14 @@ impl RateCompletionModal {
fn preview_completion(
&mut self,
_: &PreviewCompletion,
_: &PreviewPrediction,
window: &mut Window,
cx: &mut Context<Self>,
) {
let completion = self
.zeta
.read(cx)
.shown_completions()
.shown_predictions()
.skip(self.selected_index)
.take(1)
.next()
@@ -239,7 +269,7 @@ impl RateCompletionModal {
let completion = self
.zeta
.read(cx)
.shown_completions()
.shown_predictions()
.skip(self.selected_index)
.take(1)
.next()
@@ -250,54 +280,145 @@ impl RateCompletionModal {
pub fn select_completion(
&mut self,
completion: Option<EditPrediction>,
prediction: Option<EditPrediction>,
focus: bool,
window: &mut Window,
cx: &mut Context<Self>,
) {
// Avoid resetting completion rating if it's already selected.
if let Some(completion) = completion.as_ref() {
if let Some(prediction) = prediction {
self.selected_index = self
.zeta
.read(cx)
.shown_completions()
.shown_predictions()
.enumerate()
.find(|(_, completion_b)| completion.id == completion_b.id)
.find(|(_, completion_b)| prediction.id == completion_b.id)
.map(|(ix, _)| ix)
.unwrap_or(self.selected_index);
cx.notify();
if let Some(prev_completion) = self.active_completion.as_ref()
&& completion.id == prev_completion.completion.id
if let Some(prev_prediction) = self.active_prediction.as_ref()
&& prediction.id == prev_prediction.prediction.id
{
if focus {
window.focus(&prev_completion.feedback_editor.focus_handle(cx));
window.focus(&prev_prediction.feedback_editor.focus_handle(cx));
}
return;
}
self.diff_editor.update(cx, |editor, cx| {
let new_buffer = prediction.edit_preview.build_result_buffer(cx);
let new_buffer_snapshot = new_buffer.read(cx).snapshot();
let old_buffer_snapshot = prediction.snapshot.clone();
let new_buffer_id = new_buffer_snapshot.remote_id();
let range = prediction
.edit_preview
.compute_visible_range(&prediction.edits)
.unwrap_or(Point::zero()..Point::zero());
let start = Point::new(range.start.row.saturating_sub(5), 0);
let end = Point::new(range.end.row + 5, 0).min(new_buffer_snapshot.max_point());
let diff = cx.new::<BufferDiff>(|cx| {
let diff_snapshot = BufferDiffSnapshot::new_with_base_buffer(
new_buffer_snapshot.text.clone(),
Some(old_buffer_snapshot.text().into()),
old_buffer_snapshot.clone(),
cx,
);
let diff = BufferDiff::new(&new_buffer_snapshot, cx);
cx.spawn(async move |diff, cx| {
let diff_snapshot = diff_snapshot.await;
diff.update(cx, |diff, cx| {
diff.set_snapshot(diff_snapshot, &new_buffer_snapshot.text, cx);
})
})
.detach();
diff
});
editor.disable_header_for_buffer(new_buffer_id, cx);
editor.buffer().update(cx, |multibuffer, cx| {
multibuffer.clear(cx);
multibuffer.push_excerpts(
new_buffer,
vec![ExcerptRange {
context: start..end,
primary: start..end,
}],
cx,
);
multibuffer.add_diff(diff, cx);
});
});
let mut formatted_inputs = String::new();
write!(&mut formatted_inputs, "## Events\n\n").unwrap();
for event in &prediction.inputs.events {
write!(&mut formatted_inputs, "```diff\n{event}```\n\n").unwrap();
}
write!(&mut formatted_inputs, "## Included files\n\n").unwrap();
for included_file in &prediction.inputs.included_files {
let cursor_insertions = &[(prediction.inputs.cursor_point, "<|CURSOR|>")];
write!(
&mut formatted_inputs,
"### {}\n\n",
included_file.path.display()
)
.unwrap();
write_codeblock(
&included_file.path,
&included_file.excerpts,
if included_file.path == prediction.inputs.cursor_path {
cursor_insertions
} else {
&[]
},
included_file.max_row,
false,
&mut formatted_inputs,
);
}
self.active_prediction = Some(ActivePrediction {
prediction,
feedback_editor: cx.new(|cx| {
let mut editor = Editor::multi_line(window, cx);
editor.disable_scrollbars_and_minimap(window, cx);
editor.set_soft_wrap_mode(language_settings::SoftWrap::EditorWidth, cx);
editor.set_show_line_numbers(false, cx);
editor.set_show_git_diff_gutter(false, cx);
editor.set_show_code_actions(false, cx);
editor.set_show_runnables(false, cx);
editor.set_show_breakpoints(false, cx);
editor.set_show_wrap_guides(false, cx);
editor.set_show_indent_guides(false, cx);
editor.set_show_edit_predictions(Some(false), window, cx);
editor.set_placeholder_text("Add your feedback…", window, cx);
if focus {
cx.focus_self(window);
}
editor
}),
formatted_inputs: cx.new(|cx| {
Markdown::new(
formatted_inputs.into(),
Some(self.language_registry.clone()),
None,
cx,
)
}),
});
} else {
self.active_prediction = None;
}
self.active_completion = completion.map(|completion| ActiveCompletion {
completion,
feedback_editor: cx.new(|cx| {
let mut editor = Editor::multi_line(window, cx);
editor.disable_scrollbars_and_minimap(window, cx);
editor.set_soft_wrap_mode(language_settings::SoftWrap::EditorWidth, cx);
editor.set_show_line_numbers(false, cx);
editor.set_show_git_diff_gutter(false, cx);
editor.set_show_code_actions(false, cx);
editor.set_show_runnables(false, cx);
editor.set_show_breakpoints(false, cx);
editor.set_show_wrap_guides(false, cx);
editor.set_show_indent_guides(false, cx);
editor.set_show_edit_predictions(Some(false), window, cx);
editor.set_placeholder_text("Add your feedback…", window, cx);
if focus {
cx.focus_self(window);
}
editor
}),
});
cx.notify();
}
@@ -312,33 +433,31 @@ impl RateCompletionModal {
.child(
Button::new(
ElementId::Name("suggested-edits".into()),
RateCompletionView::SuggestedEdits.name(),
RatePredictionView::SuggestedEdits.name(),
)
.label_size(LabelSize::Small)
.on_click(cx.listener(move |this, _, _window, cx| {
this.current_view = RateCompletionView::SuggestedEdits;
this.current_view = RatePredictionView::SuggestedEdits;
cx.notify();
}))
.toggle_state(self.current_view == RateCompletionView::SuggestedEdits),
.toggle_state(self.current_view == RatePredictionView::SuggestedEdits),
)
.child(
Button::new(
ElementId::Name("raw-input".into()),
RateCompletionView::RawInput.name(),
RatePredictionView::RawInput.name(),
)
.label_size(LabelSize::Small)
.on_click(cx.listener(move |this, _, _window, cx| {
this.current_view = RateCompletionView::RawInput;
this.current_view = RatePredictionView::RawInput;
cx.notify();
}))
.toggle_state(self.current_view == RateCompletionView::RawInput),
.toggle_state(self.current_view == RatePredictionView::RawInput),
)
}
fn render_suggested_edits(&self, cx: &mut Context<Self>) -> Option<gpui::Stateful<Div>> {
let active_completion = self.active_completion.as_ref()?;
let bg_color = cx.theme().colors().editor_background;
Some(
div()
.id("diff")
@@ -347,14 +466,18 @@ impl RateCompletionModal {
.bg(bg_color)
.overflow_scroll()
.whitespace_nowrap()
.child(CompletionDiffElement::new(
&active_completion.completion,
cx,
)),
.child(self.diff_editor.clone()),
)
}
fn render_raw_input(&self, cx: &mut Context<Self>) -> Option<gpui::Stateful<Div>> {
fn render_raw_input(
&self,
window: &mut Window,
cx: &mut Context<Self>,
) -> Option<gpui::Stateful<Div>> {
let theme_settings = ThemeSettings::get_global(cx);
let buffer_font_size = theme_settings.buffer_font_size(cx);
Some(
v_flex()
.size_full()
@@ -368,30 +491,81 @@ impl RateCompletionModal {
.size_full()
.bg(cx.theme().colors().editor_background)
.overflow_scroll()
.child(if let Some(active_completion) = &self.active_completion {
format!(
"{}\n{}",
active_completion.completion.input_events,
active_completion.completion.input_excerpt
.child(if let Some(active_prediction) = &self.active_prediction {
markdown::MarkdownElement::new(
active_prediction.formatted_inputs.clone(),
MarkdownStyle {
base_text_style: window.text_style(),
syntax: cx.theme().syntax().clone(),
code_block: StyleRefinement {
text: Some(TextStyleRefinement {
font_family: Some(
theme_settings.buffer_font.family.clone(),
),
font_size: Some(buffer_font_size.into()),
..Default::default()
}),
padding: EdgesRefinement {
top: Some(DefiniteLength::Absolute(
AbsoluteLength::Pixels(px(8.)),
)),
left: Some(DefiniteLength::Absolute(
AbsoluteLength::Pixels(px(8.)),
)),
right: Some(DefiniteLength::Absolute(
AbsoluteLength::Pixels(px(8.)),
)),
bottom: Some(DefiniteLength::Absolute(
AbsoluteLength::Pixels(px(8.)),
)),
},
margin: EdgesRefinement {
top: Some(Length::Definite(px(8.).into())),
left: Some(Length::Definite(px(0.).into())),
right: Some(Length::Definite(px(0.).into())),
bottom: Some(Length::Definite(px(12.).into())),
},
border_style: Some(BorderStyle::Solid),
border_widths: EdgesRefinement {
top: Some(AbsoluteLength::Pixels(px(1.))),
left: Some(AbsoluteLength::Pixels(px(1.))),
right: Some(AbsoluteLength::Pixels(px(1.))),
bottom: Some(AbsoluteLength::Pixels(px(1.))),
},
border_color: Some(cx.theme().colors().border_variant),
background: Some(
cx.theme().colors().editor_background.into(),
),
..Default::default()
},
..Default::default()
},
)
.into_any_element()
} else {
"No active completion".to_string()
div()
.child("No active completion".to_string())
.into_any_element()
}),
)
.id("raw-input-view"),
)
}
fn render_active_completion(&mut self, cx: &mut Context<Self>) -> Option<impl IntoElement> {
let active_completion = self.active_completion.as_ref()?;
let completion_id = active_completion.completion.id;
fn render_active_completion(
&mut self,
window: &mut Window,
cx: &mut Context<Self>,
) -> Option<impl IntoElement> {
let active_prediction = self.active_prediction.as_ref()?;
let completion_id = active_prediction.prediction.id.clone();
let focus_handle = &self.focus_handle(cx);
let border_color = cx.theme().colors().border;
let bg_color = cx.theme().colors().editor_background;
let rated = self.zeta.read(cx).is_completion_rated(completion_id);
let feedback_empty = active_completion
let rated = self.zeta.read(cx).is_prediction_rated(&completion_id);
let feedback_empty = active_prediction
.feedback_editor
.read(cx)
.text(cx)
@@ -412,10 +586,10 @@ impl RateCompletionModal {
.child(self.render_view_nav(cx))
.when_some(
match self.current_view {
RateCompletionView::SuggestedEdits => {
RatePredictionView::SuggestedEdits => {
self.render_suggested_edits(cx)
}
RateCompletionView::RawInput => self.render_raw_input(cx),
RatePredictionView::RawInput => self.render_raw_input(window, cx),
},
|this, element| this.child(element),
),
@@ -450,7 +624,7 @@ impl RateCompletionModal {
.h_40()
.pt_1()
.bg(bg_color)
.child(active_completion.feedback_editor.clone()),
.child(active_prediction.feedback_editor.clone()),
)
})
.child(
@@ -472,7 +646,7 @@ impl RateCompletionModal {
)
.child(Label::new("Rated completion.").color(Color::Muted)),
)
} else if active_completion.completion.edits.is_empty() {
} else if active_prediction.prediction.edits.is_empty() {
Some(
label_container
.child(
@@ -489,7 +663,7 @@ impl RateCompletionModal {
h_flex()
.gap_1()
.child(
Button::new("bad", "Bad Completion")
Button::new("bad", "Bad Prediction")
.icon(IconName::ThumbsDown)
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
@@ -500,14 +674,14 @@ impl RateCompletionModal {
))
})
.key_binding(KeyBinding::for_action_in(
&ThumbsDownActiveCompletion,
&ThumbsDownActivePrediction,
focus_handle,
cx,
))
.on_click(cx.listener(move |this, _, window, cx| {
if this.active_completion.is_some() {
if this.active_prediction.is_some() {
this.thumbs_down_active(
&ThumbsDownActiveCompletion,
&ThumbsDownActivePrediction,
window,
cx,
);
@@ -515,20 +689,20 @@ impl RateCompletionModal {
})),
)
.child(
Button::new("good", "Good Completion")
Button::new("good", "Good Prediction")
.icon(IconName::ThumbsUp)
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.disabled(rated)
.key_binding(KeyBinding::for_action_in(
&ThumbsUpActiveCompletion,
&ThumbsUpActivePrediction,
focus_handle,
cx,
))
.on_click(cx.listener(move |this, _, window, cx| {
if this.active_completion.is_some() {
if this.active_prediction.is_some() {
this.thumbs_up_active(
&ThumbsUpActiveCompletion,
&ThumbsUpActivePrediction,
window,
cx,
);
@@ -543,34 +717,32 @@ impl RateCompletionModal {
fn render_shown_completions(&self, cx: &Context<Self>) -> impl Iterator<Item = ListItem> {
self.zeta
.read(cx)
.shown_completions()
.shown_predictions()
.cloned()
.enumerate()
.map(|(index, completion)| {
let selected = self
.active_completion
.active_prediction
.as_ref()
.is_some_and(|selected| selected.completion.id == completion.id);
let rated = self.zeta.read(cx).is_completion_rated(completion.id);
.is_some_and(|selected| selected.prediction.id == completion.id);
let rated = self.zeta.read(cx).is_prediction_rated(&completion.id);
let (icon_name, icon_color, tooltip_text) =
match (rated, completion.edits.is_empty()) {
(true, _) => (IconName::Check, Color::Success, "Rated Completion"),
(true, _) => (IconName::Check, Color::Success, "Rated Prediction"),
(false, true) => (IconName::File, Color::Muted, "No Edits Produced"),
(false, false) => (IconName::FileDiff, Color::Accent, "Edits Available"),
};
let file_name = completion
.path
.file_name()
.map(|f| f.to_string_lossy().into_owned())
.unwrap_or("untitled".to_string());
let file_path = completion
.path
.parent()
.map(|p| p.to_string_lossy().into_owned());
let file = completion.buffer.read(cx).file();
let file_name = file
.as_ref()
.map_or(SharedString::new_static("untitled"), |file| {
file.file_name(cx).to_string().into()
});
let file_path = file.map(|file| file.path().as_unix_str().to_string());
ListItem::new(completion.id)
ListItem::new(completion.id.clone())
.inset(true)
.spacing(ListItemSpacing::Sparse)
.focused(index == self.selected_index)
@@ -615,12 +787,12 @@ impl RateCompletionModal {
}
}
impl Render for RateCompletionModal {
impl Render for RatePredictionsModal {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let border_color = cx.theme().colors().border;
h_flex()
.key_context("RateCompletionModal")
.key_context("RatePredictionModal")
.track_focus(&self.focus_handle)
.on_action(cx.listener(Self::dismiss))
.on_action(cx.listener(Self::confirm))
@@ -688,20 +860,20 @@ impl Render for RateCompletionModal {
),
),
)
.children(self.render_active_completion(cx))
.children(self.render_active_completion(window, cx))
.on_mouse_down_out(cx.listener(|_, _, _, cx| cx.emit(DismissEvent)))
}
}
impl EventEmitter<DismissEvent> for RateCompletionModal {}
impl EventEmitter<DismissEvent> for RatePredictionsModal {}
impl Focusable for RateCompletionModal {
impl Focusable for RatePredictionsModal {
fn focus_handle(&self, _cx: &App) -> FocusHandle {
self.focus_handle.clone()
}
}
impl ModalView for RateCompletionModal {}
impl ModalView for RatePredictionsModal {}
fn format_time_ago(elapsed: Duration) -> String {
let seconds = elapsed.as_secs();

View File

@@ -2,7 +2,6 @@ use std::fmt;
use std::{path::Path, sync::Arc};
use serde::{Deserialize, Serialize};
use util::rel_path::RelPath;
#[derive(Debug, Clone, Serialize)]
pub struct AutocompleteRequest {
@@ -91,34 +90,24 @@ pub struct AdditionalCompletion {
pub finish_reason: Option<String>,
}
pub(crate) fn write_event(event: crate::Event, f: &mut impl fmt::Write) -> fmt::Result {
pub(crate) fn write_event(
event: &cloud_llm_client::predict_edits_v3::Event,
f: &mut impl fmt::Write,
) -> fmt::Result {
match event {
crate::Event::BufferChange {
old_snapshot,
new_snapshot,
cloud_llm_client::predict_edits_v3::Event::BufferChange {
old_path,
path,
diff,
..
} => {
let old_path = old_snapshot
.file()
.map(|f| f.path().as_ref())
.unwrap_or(RelPath::unix("untitled").unwrap());
let new_path = new_snapshot
.file()
.map(|f| f.path().as_ref())
.unwrap_or(RelPath::unix("untitled").unwrap());
if old_path != new_path {
if old_path != path {
// TODO confirm how to do this for sweep
// writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
}
let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
if !diff.is_empty() {
write!(
f,
"File: {}:\n{}\n",
new_path.display(util::paths::PathStyle::Posix),
diff
)?
write!(f, "File: {}:\n{}\n", path.display(), diff)?
}
fmt::Result::Ok(())

File diff suppressed because it is too large Load Diff

500
crates/zeta/src/zeta1.rs Normal file
View File

@@ -0,0 +1,500 @@
mod input_excerpt;
use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
use crate::{
EditPredictionId, ZedUpdateRequiredError, Zeta,
prediction::{EditPrediction, EditPredictionInputs},
};
use anyhow::{Context as _, Result};
use cloud_llm_client::{
PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, predict_edits_v3::Event,
};
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
use input_excerpt::excerpt_for_cursor_position;
use language::{
Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
};
use project::{Project, ProjectPath};
use release_channel::AppVersion;
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
pub(crate) const MAX_EVENT_TOKENS: usize = 500;
pub(crate) fn request_prediction_with_zeta1(
zeta: &mut Zeta,
project: &Entity<Project>,
buffer: &Entity<Buffer>,
position: language::Anchor,
cx: &mut Context<Zeta>,
) -> Task<Result<Option<EditPrediction>>> {
let buffer = buffer.clone();
let buffer_snapshotted_at = Instant::now();
let snapshot = buffer.read(cx).snapshot();
let client = zeta.client.clone();
let llm_token = zeta.llm_token.clone();
let app_version = AppVersion::global(cx);
let zeta_project = zeta.get_or_init_zeta_project(project, cx);
let events = Arc::new(zeta_project.events(cx));
let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
let can_collect_file = zeta.can_collect_file(project, file, cx);
let git_info = if can_collect_file {
git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
} else {
None
};
(git_info, can_collect_file)
} else {
(None, false)
};
let full_path: Arc<Path> = snapshot
.file()
.map(|f| Arc::from(f.full_path(cx).as_path()))
.unwrap_or_else(|| Arc::from(Path::new("untitled")));
let full_path_str = full_path.to_string_lossy().into_owned();
let cursor_point = position.to_point(&snapshot);
let prompt_for_events = {
let events = events.clone();
move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
};
let gather_task = gather_context(
full_path_str,
&snapshot,
cursor_point,
prompt_for_events,
cx,
);
cx.spawn(async move |this, cx| {
let GatherContextOutput {
mut body,
context_range,
editable_range,
included_events_count,
} = gather_task.await?;
let done_gathering_context_at = Instant::now();
let included_events = &events[events.len() - included_events_count..events.len()];
body.can_collect_data = can_collect_file
&& this
.read_with(cx, |this, _| this.can_collect_events(included_events))
.unwrap_or(false);
if body.can_collect_data {
body.git_info = git_info;
}
log::debug!(
"Events:\n{}\nExcerpt:\n{:?}",
body.input_events,
body.input_excerpt
);
let http_client = client.http_client();
let response = Zeta::send_api_request::<PredictEditsResponse>(
|request| {
let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
predict_edits_url
} else {
http_client
.build_zed_llm_url("/predict_edits/v2", &[])?
.as_str()
.into()
};
Ok(request
.uri(uri)
.body(serde_json::to_string(&body)?.into())?)
},
client,
llm_token,
app_version,
)
.await;
let inputs = EditPredictionInputs {
events: included_events.into(),
included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
path: full_path.clone(),
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
text: snapshot
.text_for_range(context_range)
.collect::<String>()
.into(),
}],
}],
cursor_point: cloud_llm_client::predict_edits_v3::Point {
column: cursor_point.column,
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
},
cursor_path: full_path,
};
// let response = perform_predict_edits(PerformPredictEditsParams {
// client,
// llm_token,
// app_version,
// body,
// })
// .await;
let (response, usage) = match response {
Ok(response) => response,
Err(err) => {
if err.is::<ZedUpdateRequiredError>() {
cx.update(|cx| {
this.update(cx, |zeta, _cx| {
zeta.update_required = true;
})
.ok();
let error_message: SharedString = err.to_string().into();
show_app_notification(
NotificationId::unique::<ZedUpdateRequiredError>(),
cx,
move |cx| {
cx.new(|cx| {
ErrorMessagePrompt::new(error_message.clone(), cx)
.with_link_button("Update Zed", "https://zed.dev/releases")
})
},
);
})
.ok();
}
return Err(err);
}
};
let received_response_at = Instant::now();
log::debug!("completion response: {}", &response.output_excerpt);
if let Some(usage) = usage {
this.update(cx, |this, cx| {
this.user_store.update(cx, |user_store, cx| {
user_store.update_edit_prediction_usage(usage, cx);
});
})
.ok();
}
let edit_prediction = process_completion_response(
response,
buffer,
&snapshot,
editable_range,
inputs,
buffer_snapshotted_at,
received_response_at,
cx,
)
.await;
let finished_at = Instant::now();
// record latency for ~1% of requests
if rand::random::<u8>() <= 2 {
telemetry::event!(
"Edit Prediction Request",
context_latency = done_gathering_context_at
.duration_since(buffer_snapshotted_at)
.as_millis(),
request_latency = received_response_at
.duration_since(done_gathering_context_at)
.as_millis(),
process_latency = finished_at.duration_since(received_response_at).as_millis()
);
}
edit_prediction
})
}
fn process_completion_response(
prediction_response: PredictEditsResponse,
buffer: Entity<Buffer>,
snapshot: &BufferSnapshot,
editable_range: Range<usize>,
inputs: EditPredictionInputs,
buffer_snapshotted_at: Instant,
received_response_at: Instant,
cx: &AsyncApp,
) -> Task<Result<Option<EditPrediction>>> {
let snapshot = snapshot.clone();
let request_id = prediction_response.request_id;
let output_excerpt = prediction_response.output_excerpt;
cx.spawn(async move |cx| {
let output_excerpt: Arc<str> = output_excerpt.into();
let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
.background_spawn({
let output_excerpt = output_excerpt.clone();
let editable_range = editable_range.clone();
let snapshot = snapshot.clone();
async move { parse_edits(output_excerpt, editable_range, &snapshot) }
})
.await?
.into();
Ok(EditPrediction::new(
EditPredictionId(request_id.into()),
&buffer,
&snapshot,
edits,
buffer_snapshotted_at,
received_response_at,
inputs,
cx,
)
.await)
})
}
fn parse_edits(
output_excerpt: Arc<str>,
editable_range: Range<usize>,
snapshot: &BufferSnapshot,
) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
let content = output_excerpt.replace(CURSOR_MARKER, "");
let start_markers = content
.match_indices(EDITABLE_REGION_START_MARKER)
.collect::<Vec<_>>();
anyhow::ensure!(
start_markers.len() == 1,
"expected exactly one start marker, found {}",
start_markers.len()
);
let end_markers = content
.match_indices(EDITABLE_REGION_END_MARKER)
.collect::<Vec<_>>();
anyhow::ensure!(
end_markers.len() == 1,
"expected exactly one end marker, found {}",
end_markers.len()
);
let sof_markers = content
.match_indices(START_OF_FILE_MARKER)
.collect::<Vec<_>>();
anyhow::ensure!(
sof_markers.len() <= 1,
"expected at most one start-of-file marker, found {}",
sof_markers.len()
);
let codefence_start = start_markers[0].0;
let content = &content[codefence_start..];
let newline_ix = content.find('\n').context("could not find newline")?;
let content = &content[newline_ix + 1..];
let codefence_end = content
.rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
.context("could not find end marker")?;
let new_text = &content[..codefence_end];
let old_text = snapshot
.text_for_range(editable_range.clone())
.collect::<String>();
Ok(compute_edits(
old_text,
new_text,
editable_range.start,
snapshot,
))
}
pub fn compute_edits(
old_text: String,
new_text: &str,
offset: usize,
snapshot: &BufferSnapshot,
) -> Vec<(Range<Anchor>, Arc<str>)> {
text_diff(&old_text, new_text)
.into_iter()
.map(|(mut old_range, new_text)| {
old_range.start += offset;
old_range.end += offset;
let prefix_len = common_prefix(
snapshot.chars_for_range(old_range.clone()),
new_text.chars(),
);
old_range.start += prefix_len;
let suffix_len = common_prefix(
snapshot.reversed_chars_for_range(old_range.clone()),
new_text[prefix_len..].chars().rev(),
);
old_range.end = old_range.end.saturating_sub(suffix_len);
let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
let range = if old_range.is_empty() {
let anchor = snapshot.anchor_after(old_range.start);
anchor..anchor
} else {
snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
};
(range, new_text)
})
.collect()
}
fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
a.zip(b)
.take_while(|(a, b)| a == b)
.map(|(a, _)| a.len_utf8())
.sum()
}
fn git_info_for_file(
project: &Entity<Project>,
project_path: &ProjectPath,
cx: &App,
) -> Option<PredictEditsGitInfo> {
let git_store = project.read(cx).git_store().read(cx);
if let Some((repository, _repo_path)) =
git_store.repository_and_path_for_project_path(project_path, cx)
{
let repository = repository.read(cx);
let head_sha = repository
.head_commit
.as_ref()
.map(|head_commit| head_commit.sha.to_string());
let remote_origin_url = repository.remote_origin_url.clone();
let remote_upstream_url = repository.remote_upstream_url.clone();
if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
return None;
}
Some(PredictEditsGitInfo {
head_sha,
remote_origin_url,
remote_upstream_url,
})
} else {
None
}
}
pub struct GatherContextOutput {
pub body: PredictEditsBody,
pub context_range: Range<Point>,
pub editable_range: Range<usize>,
pub included_events_count: usize,
}
pub fn gather_context(
full_path_str: String,
snapshot: &BufferSnapshot,
cursor_point: language::Point,
prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
cx: &App,
) -> Task<Result<GatherContextOutput>> {
cx.background_spawn({
let snapshot = snapshot.clone();
async move {
let input_excerpt = excerpt_for_cursor_position(
cursor_point,
&full_path_str,
&snapshot,
MAX_REWRITE_TOKENS,
MAX_CONTEXT_TOKENS,
);
let (input_events, included_events_count) = prompt_for_events();
let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
let body = PredictEditsBody {
input_events,
input_excerpt: input_excerpt.prompt,
can_collect_data: false,
diagnostic_groups: None,
git_info: None,
outline: None,
speculated_output: None,
};
Ok(GatherContextOutput {
body,
context_range: input_excerpt.context_range,
editable_range,
included_events_count,
})
}
})
}
fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
let mut result = String::new();
for (ix, event) in events.iter().rev().enumerate() {
let event_string = format_event(event.as_ref());
let event_tokens = guess_token_count(event_string.len());
if event_tokens > remaining_tokens {
return (result, ix);
}
if !result.is_empty() {
result.insert_str(0, "\n\n");
}
result.insert_str(0, &event_string);
remaining_tokens -= event_tokens;
}
return (result, events.len());
}
pub fn format_event(event: &Event) -> String {
match event {
Event::BufferChange {
path,
old_path,
diff,
..
} => {
let mut prompt = String::new();
if old_path != path {
writeln!(
prompt,
"User renamed {} to {}\n",
old_path.display(),
path.display()
)
.unwrap();
}
if !diff.is_empty() {
write!(
prompt,
"User edited {}:\n```diff\n{}\n```",
path.display(),
diff
)
.unwrap();
}
prompt
}
}
}
/// Typical number of string bytes per token for the purposes of limiting model input. This is
/// intentionally low to err on the side of underestimating limits.
pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
fn guess_token_count(bytes: usize) -> usize {
bytes / BYTES_PER_TOKEN_GUESS
}

View File

@@ -1,4 +1,4 @@
use crate::{
use super::{
CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, START_OF_FILE_MARKER,
guess_token_count,
};
@@ -7,6 +7,7 @@ use std::{fmt::Write, ops::Range};
#[derive(Debug)]
pub struct InputExcerpt {
pub context_range: Range<Point>,
pub editable_range: Range<Point>,
pub prompt: String,
}
@@ -63,6 +64,7 @@ pub fn excerpt_for_cursor_position(
write!(prompt, "\n```").unwrap();
InputExcerpt {
context_range,
editable_range,
prompt,
}
@@ -124,7 +126,7 @@ mod tests {
use super::*;
use gpui::{App, AppContext};
use indoc::indoc;
use language::{Buffer, Language, LanguageConfig, LanguageMatcher};
use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
use std::sync::Arc;
#[gpui::test]

View File

@@ -0,0 +1,671 @@
use client::test::FakeServer;
use clock::{FakeSystemClock, ReplicaId};
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
use cloud_llm_client::{PredictEditsBody, PredictEditsResponse};
use gpui::TestAppContext;
use http_client::FakeHttpClient;
use indoc::indoc;
use language::Point;
use parking_lot::Mutex;
use serde_json::json;
use settings::SettingsStore;
use util::{path, rel_path::rel_path};
use crate::zeta1::MAX_EVENT_TOKENS;
use super::*;
const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
#[gpui::test]
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
});
let edit_preview = cx
.read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
.await;
let completion = EditPrediction {
edits,
edit_preview,
buffer: buffer.clone(),
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
id: EditPredictionId("the-id".into()),
inputs: EditPredictionInputs {
events: Default::default(),
included_files: Default::default(),
cursor_point: cloud_llm_client::predict_edits_v3::Point {
line: Line(0),
column: 0,
},
cursor_path: Path::new("").into(),
},
buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),
};
cx.update(|cx| {
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
vec![(2..5, "REM".into()), (9..11, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
vec![(2..2, "REM".into()), (6..8, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.undo(cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
vec![(2..5, "REM".into()), (9..11, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
vec![(3..3, "EM".into()), (7..9, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
vec![(4..4, "M".into()), (8..10, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
vec![(9..11, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
vec![(4..4, "M".into()), (8..10, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
vec![(4..4, "M".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
})
}
#[gpui::test]
async fn test_clean_up_diff(cx: &mut TestAppContext) {
init_test(cx);
assert_eq!(
apply_edit_prediction(
indoc! {"
fn main() {
let word_1 = \"lorem\";
let range = word.len()..word.len();
}
"},
indoc! {"
<|editable_region_start|>
fn main() {
let word_1 = \"lorem\";
let range = word_1.len()..word_1.len();
}
<|editable_region_end|>
"},
cx,
)
.await,
indoc! {"
fn main() {
let word_1 = \"lorem\";
let range = word_1.len()..word_1.len();
}
"},
);
assert_eq!(
apply_edit_prediction(
indoc! {"
fn main() {
let story = \"the quick\"
}
"},
indoc! {"
<|editable_region_start|>
fn main() {
let story = \"the quick brown fox jumps over the lazy dog\";
}
<|editable_region_end|>
"},
cx,
)
.await,
indoc! {"
fn main() {
let story = \"the quick brown fox jumps over the lazy dog\";
}
"},
);
}
#[gpui::test]
async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
init_test(cx);
let buffer_content = "lorem\n";
let completion_response = indoc! {"
```animals.js
<|start_of_file|>
<|editable_region_start|>
lorem
ipsum
<|editable_region_end|>
```"};
assert_eq!(
apply_edit_prediction(buffer_content, completion_response, cx).await,
"lorem\nipsum"
);
}
#[gpui::test]
async fn test_can_collect_data(cx: &mut TestAppContext) {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(path!("/project/src/main.rs"), cx)
})
.await
.unwrap();
let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
zeta.update(cx, |zeta, _cx| {
zeta.data_collection_choice = DataCollectionChoice::Enabled
});
run_edit_prediction(&buffer, &project, &zeta, cx).await;
assert_eq!(
captured_request.lock().clone().unwrap().can_collect_data,
true
);
zeta.update(cx, |zeta, _cx| {
zeta.data_collection_choice = DataCollectionChoice::Disabled
});
run_edit_prediction(&buffer, &project, &zeta, cx).await;
assert_eq!(
captured_request.lock().clone().unwrap().can_collect_data,
false
);
}
#[gpui::test]
async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [], cx).await;
let buffer = cx.new(|_cx| {
Buffer::remote(
language::BufferId::new(1).unwrap(),
ReplicaId::new(1),
language::Capability::ReadWrite,
"fn main() {\n println!(\"Hello\");\n}",
)
});
let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
zeta.update(cx, |zeta, _cx| {
zeta.data_collection_choice = DataCollectionChoice::Enabled
});
run_edit_prediction(&buffer, &project, &zeta, cx).await;
assert_eq!(
captured_request.lock().clone().unwrap().can_collect_data,
false
);
}
#[gpui::test]
async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree(
path!("/project"),
json!({
"LICENSE": BSD_0_TXT,
".env": "SECRET_KEY=secret"
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer("/project/.env", cx)
})
.await
.unwrap();
let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
zeta.update(cx, |zeta, _cx| {
zeta.data_collection_choice = DataCollectionChoice::Enabled
});
run_edit_prediction(&buffer, &project, &zeta, cx).await;
assert_eq!(
captured_request.lock().clone().unwrap().can_collect_data,
false
);
}
#[gpui::test]
async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [], cx).await;
let buffer = cx.new(|cx| Buffer::local("", cx));
let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
zeta.update(cx, |zeta, _cx| {
zeta.data_collection_choice = DataCollectionChoice::Enabled
});
run_edit_prediction(&buffer, &project, &zeta, cx).await;
assert_eq!(
captured_request.lock().clone().unwrap().can_collect_data,
false
);
}
#[gpui::test]
async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer("/project/main.rs", cx)
})
.await
.unwrap();
let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
zeta.update(cx, |zeta, _cx| {
zeta.data_collection_choice = DataCollectionChoice::Enabled
});
run_edit_prediction(&buffer, &project, &zeta, cx).await;
assert_eq!(
captured_request.lock().clone().unwrap().can_collect_data,
false
);
}
#[gpui::test]
async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree(
path!("/open_source_worktree"),
json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
)
.await;
fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
.await;
let project = Project::test(
fs.clone(),
[
path!("/open_source_worktree").as_ref(),
path!("/closed_source_worktree").as_ref(),
],
cx,
)
.await;
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
})
.await
.unwrap();
let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
zeta.update(cx, |zeta, _cx| {
zeta.data_collection_choice = DataCollectionChoice::Enabled
});
run_edit_prediction(&buffer, &project, &zeta, cx).await;
assert_eq!(
captured_request.lock().clone().unwrap().can_collect_data,
true
);
let closed_source_file = project
.update(cx, |project, cx| {
let worktree2 = project
.worktree_for_root_name("closed_source_worktree", cx)
.unwrap();
worktree2.update(cx, |worktree2, cx| {
worktree2.load_file(rel_path("main.rs"), cx)
})
})
.await
.unwrap()
.file;
buffer.update(cx, |buffer, cx| {
buffer.file_updated(closed_source_file, cx);
});
run_edit_prediction(&buffer, &project, &zeta, cx).await;
assert_eq!(
captured_request.lock().clone().unwrap().can_collect_data,
false
);
}
#[gpui::test]
async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree(
path!("/worktree1"),
json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
)
.await;
fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
.await;
let project = Project::test(
fs.clone(),
[path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
cx,
)
.await;
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(path!("/worktree1/main.rs"), cx)
})
.await
.unwrap();
let private_buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(path!("/worktree2/file.rs"), cx)
})
.await
.unwrap();
let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
zeta.update(cx, |zeta, _cx| {
zeta.data_collection_choice = DataCollectionChoice::Enabled
});
run_edit_prediction(&buffer, &project, &zeta, cx).await;
assert_eq!(
captured_request.lock().clone().unwrap().can_collect_data,
true
);
// this has a side effect of registering the buffer to watch for edits
run_edit_prediction(&private_buffer, &project, &zeta, cx).await;
assert_eq!(
captured_request.lock().clone().unwrap().can_collect_data,
false
);
private_buffer.update(cx, |private_buffer, cx| {
private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
});
run_edit_prediction(&buffer, &project, &zeta, cx).await;
assert_eq!(
captured_request.lock().clone().unwrap().can_collect_data,
false
);
// make an edit that uses too many bytes, causing private_buffer edit to not be able to be
// included
buffer.update(cx, |buffer, cx| {
buffer.edit(
[(
0..0,
" ".repeat(MAX_EVENT_TOKENS * zeta1::BYTES_PER_TOKEN_GUESS),
)],
None,
cx,
);
});
run_edit_prediction(&buffer, &project, &zeta, cx).await;
assert_eq!(
captured_request.lock().clone().unwrap().can_collect_data,
true
);
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
});
}
async fn apply_edit_prediction(
buffer_content: &str,
completion_response: &str,
cx: &mut TestAppContext,
) -> String {
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let (zeta, _, response) = make_test_zeta(&project, cx).await;
*response.lock() = completion_response.to_string();
let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await;
buffer.update(cx, |buffer, cx| {
buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
});
buffer.read_with(cx, |buffer, _| buffer.text())
}
async fn run_edit_prediction(
buffer: &Entity<Buffer>,
project: &Entity<Project>,
zeta: &Entity<Zeta>,
cx: &mut TestAppContext,
) -> EditPrediction {
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx));
cx.background_executor.run_until_parked();
let prediction_task = zeta.update(cx, |zeta, cx| {
zeta.request_prediction(&project, buffer, cursor, cx)
});
prediction_task.await.unwrap().unwrap()
}
async fn make_test_zeta(
project: &Entity<Project>,
cx: &mut TestAppContext,
) -> (
Entity<Zeta>,
Arc<Mutex<Option<PredictEditsBody>>>,
Arc<Mutex<String>>,
) {
let default_response = indoc! {"
```main.rs
<|start_of_file|>
<|editable_region_start|>
hello world
<|editable_region_end|>
```"
};
let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
let completion_response: Arc<Mutex<String>> =
Arc::new(Mutex::new(default_response.to_string()));
let http_client = FakeHttpClient::create({
let captured_request = captured_request.clone();
let completion_response = completion_response.clone();
let mut next_request_id = 0;
move |req| {
let captured_request = captured_request.clone();
let completion_response = completion_response.clone();
async move {
match (req.method(), req.uri().path()) {
(&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&CreateLlmTokenResponse {
token: LlmToken("the-llm-token".to_string()),
})
.unwrap()
.into(),
)
.unwrap()),
(&Method::POST, "/predict_edits/v2") => {
let mut request_body = String::new();
req.into_body().read_to_string(&mut request_body).await?;
*captured_request.lock() =
Some(serde_json::from_str(&request_body).unwrap());
next_request_id += 1;
Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&PredictEditsResponse {
request_id: format!("request-{next_request_id}"),
output_excerpt: completion_response.lock().clone(),
})
.unwrap()
.into(),
)
.unwrap())
}
_ => Ok(http_client::Response::builder()
.status(404)
.body("Not Found".into())
.unwrap()),
}
}
}
});
let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
cx.update(|cx| {
RefreshLlmTokenListener::register(client.clone(), cx);
});
let _server = FakeServer::for_client(42, &client, cx).await;
let zeta = cx.new(|cx| {
let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx);
zeta.set_edit_prediction_model(ZetaEditPredictionModel::Zeta1);
let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
for worktree in worktrees {
let worktree_id = worktree.read(cx).id();
zeta.get_or_init_zeta_project(project, cx)
.license_detection_watchers
.entry(worktree_id)
.or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
}
zeta
});
(zeta, captured_request, completion_response)
}
fn to_completion_edits(
iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
buffer: &Entity<Buffer>,
cx: &App,
) -> Vec<(Range<Anchor>, Arc<str>)> {
let buffer = buffer.read(cx);
iterator
.into_iter()
.map(|(range, text)| {
(
buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
text,
)
})
.collect()
}
fn from_completion_edits(
editor_edits: &[(Range<Anchor>, Arc<str>)],
buffer: &Entity<Buffer>,
cx: &App,
) -> Vec<(Range<usize>, Arc<str>)> {
let buffer = buffer.read(cx);
editor_edits
.iter()
.map(|(range, text)| {
(
range.start.to_offset(buffer)..range.end.to_offset(buffer),
text.clone(),
)
})
.collect()
}
#[ctor::ctor]
fn init_logger() {
zlog::init_test();
}

View File

@@ -1,61 +0,0 @@
[package]
name = "zeta2"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/zeta2.rs"
[features]
eval-support = []
[dependencies]
anyhow.workspace = true
arrayvec.workspace = true
brotli.workspace = true
chrono.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
cloud_zeta2_prompt.workspace = true
collections.workspace = true
edit_prediction.workspace = true
edit_prediction_context.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
indoc.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
lsp.workspace = true
open_ai.workspace = true
pretty_assertions.workspace = true
project.workspace = true
release_channel.workspace = true
semver.workspace = true
serde.workspace = true
serde_json.workspace = true
smol.workspace = true
strsim.workspace = true
thiserror.workspace = true
util.workspace = true
uuid.workspace = true
workspace.workspace = true
worktree.workspace = true
[dev-dependencies]
clock = { workspace = true, features = ["test-support"] }
cloud_llm_client = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
lsp.workspace = true
indoc.workspace = true
language = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
zlog.workspace = true

View File

@@ -1 +0,0 @@
../../LICENSE-GPL

File diff suppressed because it is too large Load Diff

View File

@@ -13,7 +13,6 @@ path = "src/zeta2_tools.rs"
[dependencies]
anyhow.workspace = true
chrono.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
cloud_zeta2_prompt.workspace = true
@@ -24,9 +23,7 @@ feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
language.workspace = true
log.workspace = true
multi_buffer.workspace = true
ordered-float.workspace = true
project.workspace = true
serde.workspace = true
serde_json.workspace = true
@@ -36,7 +33,7 @@ ui.workspace = true
ui_input.workspace = true
util.workspace = true
workspace.workspace = true
zeta2.workspace = true
zeta.workspace = true
[dev-dependencies]
clap.workspace = true

View File

@@ -25,7 +25,7 @@ use ui::{
v_flex,
};
use workspace::Item;
use zeta2::{
use zeta::{
Zeta, ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo,
ZetaSearchQueryDebugInfo,
};

View File

@@ -1,30 +1,26 @@
mod zeta2_context_view;
use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc};
use std::{str::FromStr, sync::Arc, time::Duration};
use chrono::TimeDelta;
use client::{Client, UserStore};
use cloud_llm_client::predict_edits_v3::{
DeclarationScoreComponents, PredictEditsRequest, PromptFormat,
};
use cloud_llm_client::predict_edits_v3::PromptFormat;
use collections::HashMap;
use editor::{Editor, EditorEvent, EditorMode, ExcerptRange, MultiBuffer};
use editor::{Editor, EditorEvent, EditorMode, MultiBuffer};
use feature_flags::FeatureFlagAppExt as _;
use futures::{FutureExt, StreamExt as _, channel::oneshot, future::Shared};
use gpui::{
CursorStyle, Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task,
WeakEntity, actions, prelude::*,
Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, actions,
prelude::*,
};
use language::{Buffer, DiskState};
use ordered_float::OrderedFloat;
use project::{Project, WorktreeId, telemetry_snapshot::TelemetrySnapshot};
use language::Buffer;
use project::{Project, telemetry_snapshot::TelemetrySnapshot};
use ui::{ButtonLike, ContextMenu, ContextMenuEntry, DropdownMenu, KeyBinding, prelude::*};
use ui_input::InputField;
use util::{ResultExt, paths::PathStyle, rel_path::RelPath};
use util::ResultExt;
use workspace::{Item, SplitDirection, Workspace};
use zeta2::{
AgenticContextOptions, ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, Zeta, Zeta2FeatureFlag,
ZetaDebugInfo, ZetaEditPredictionDebugInfo, ZetaOptions,
use zeta::{
AgenticContextOptions, ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, EditPredictionInputs, Zeta,
Zeta2FeatureFlag, ZetaDebugInfo, ZetaEditPredictionDebugInfo, ZetaOptions,
};
use edit_prediction_context::{EditPredictionContextOptions, EditPredictionExcerptOptions};
@@ -99,7 +95,6 @@ pub struct Zeta2Inspector {
cursor_context_ratio_input: Entity<InputField>,
max_prompt_bytes_input: Entity<InputField>,
context_mode: ContextModeState,
active_view: ActiveView,
zeta: Entity<Zeta>,
_active_editor_subscription: Option<Subscription>,
_update_state_task: Task<()>,
@@ -113,21 +108,14 @@ pub enum ContextModeState {
},
}
#[derive(PartialEq)]
enum ActiveView {
Context,
Inference,
}
struct LastPrediction {
context_editor: Entity<Editor>,
prompt_editor: Entity<Editor>,
retrieval_time: TimeDelta,
request_time: Option<TimeDelta>,
retrieval_time: Duration,
request_time: Option<Duration>,
buffer: WeakEntity<Buffer>,
position: language::Anchor,
state: LastPredictionState,
request: PredictEditsRequest,
inputs: EditPredictionInputs,
project_snapshot: Shared<Task<Arc<TelemetrySnapshot>>>,
_task: Option<Task<()>>,
}
@@ -175,7 +163,6 @@ impl Zeta2Inspector {
focus_handle: cx.focus_handle(),
project: project.clone(),
last_prediction: None,
active_view: ActiveView::Inference,
max_excerpt_bytes_input: Self::number_input("Max Excerpt Bytes", window, cx),
min_excerpt_bytes_input: Self::number_input("Min Excerpt Bytes", window, cx),
cursor_context_ratio_input: Self::number_input("Cursor Context Ratio", window, cx),
@@ -305,7 +292,7 @@ impl Zeta2Inspector {
ContextMode::Syntax(context_options) => {
let max_retrieved_declarations = match &this.context_mode {
ContextModeState::Llm => {
zeta2::DEFAULT_SYNTAX_CONTEXT_OPTIONS.max_retrieved_declarations
zeta::DEFAULT_SYNTAX_CONTEXT_OPTIONS.max_retrieved_declarations
}
ContextModeState::Syntax {
max_retrieved_declarations,
@@ -340,22 +327,10 @@ impl Zeta2Inspector {
fn update_last_prediction(
&mut self,
prediction: zeta2::ZetaDebugInfo,
prediction: zeta::ZetaDebugInfo,
window: &mut Window,
cx: &mut Context<Self>,
) {
let project = self.project.read(cx);
let path_style = project.path_style(cx);
let Some(worktree_id) = project
.worktrees(cx)
.next()
.map(|worktree| worktree.read(cx).id())
else {
log::error!("Open a worktree to use edit prediction debug view");
self.last_prediction.take();
return;
};
self._update_state_task = cx.spawn_in(window, {
let language_registry = self.project.read(cx).languages().clone();
async move |this, cx| {
@@ -364,11 +339,10 @@ impl Zeta2Inspector {
return;
};
for ext in prediction
.request
.referenced_declarations
.inputs
.included_files
.iter()
.filter_map(|snippet| snippet.path.extension())
.chain(prediction.request.excerpt_path.extension())
.filter_map(|file| file.path.extension())
{
if !languages.contains_key(ext) {
// Most snippets are gonna be the same language,
@@ -391,90 +365,6 @@ impl Zeta2Inspector {
let json_language = language_registry.language_for_name("Json").await.log_err();
this.update_in(cx, |this, window, cx| {
let context_editor = cx.new(|cx| {
let mut excerpt_score_components = HashMap::default();
let multibuffer = cx.new(|cx| {
let mut multibuffer = MultiBuffer::new(language::Capability::ReadOnly);
let excerpt_file = Arc::new(ExcerptMetadataFile {
title: RelPath::unix("Cursor Excerpt").unwrap().into(),
path_style,
worktree_id,
});
let excerpt_buffer = cx.new(|cx| {
let mut buffer =
Buffer::local(prediction.request.excerpt.clone(), cx);
if let Some(language) = prediction
.request
.excerpt_path
.extension()
.and_then(|ext| languages.get(ext))
{
buffer.set_language(language.clone(), cx);
}
buffer.file_updated(excerpt_file, cx);
buffer
});
multibuffer.push_excerpts(
excerpt_buffer,
[ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
cx,
);
let mut declarations =
prediction.request.referenced_declarations.clone();
declarations.sort_unstable_by_key(|declaration| {
Reverse(OrderedFloat(declaration.declaration_score))
});
for snippet in &declarations {
let snippet_file = Arc::new(ExcerptMetadataFile {
title: RelPath::unix(&format!(
"{} (Score: {})",
snippet.path.display(),
snippet.declaration_score
))
.unwrap()
.into(),
path_style,
worktree_id,
});
let excerpt_buffer = cx.new(|cx| {
let mut buffer = Buffer::local(snippet.text.clone(), cx);
buffer.file_updated(snippet_file, cx);
if let Some(ext) = snippet.path.extension()
&& let Some(language) = languages.get(ext)
{
buffer.set_language(language.clone(), cx);
}
buffer
});
let excerpt_ids = multibuffer.push_excerpts(
excerpt_buffer,
[ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
cx,
);
let excerpt_id = excerpt_ids.first().unwrap();
excerpt_score_components
.insert(*excerpt_id, snippet.score_components.clone());
}
multibuffer
});
let mut editor =
Editor::new(EditorMode::full(), multibuffer, None, window, cx);
editor.register_addon(ZetaContextAddon {
excerpt_score_components,
});
editor
});
let ZetaEditPredictionDebugInfo {
response_rx,
position,
@@ -606,7 +496,6 @@ impl Zeta2Inspector {
let project_snapshot_task = TelemetrySnapshot::new(&this.project, cx);
this.last_prediction = Some(LastPrediction {
context_editor,
prompt_editor: cx.new(|cx| {
let buffer = cx.new(|cx| {
let mut buffer =
@@ -632,7 +521,7 @@ impl Zeta2Inspector {
.foreground_executor()
.spawn(async move { Arc::new(project_snapshot_task.await) })
.shared(),
request: prediction.request,
inputs: prediction.inputs,
_task: Some(task),
});
cx.notify();
@@ -664,9 +553,6 @@ impl Zeta2Inspector {
let Some(last_prediction) = self.last_prediction.as_mut() else {
return;
};
if !last_prediction.request.can_collect_data {
return;
}
let project_snapshot_task = last_prediction.project_snapshot.clone();
@@ -718,7 +604,7 @@ impl Zeta2Inspector {
id = request_id,
kind = kind,
text = text,
request = last_prediction.request,
request = last_prediction.inputs,
project_snapshot = project_snapshot,
);
})
@@ -727,17 +613,6 @@ impl Zeta2Inspector {
.detach();
}
fn focus_feedback(&mut self, window: &mut Window, cx: &mut Context<Self>) {
if let Some(last_prediction) = self.last_prediction.as_mut() {
if let LastPredictionState::Success {
feedback_editor, ..
} = &mut last_prediction.state
{
feedback_editor.focus_handle(cx).focus(window);
}
};
}
fn render_options(&self, window: &mut Window, cx: &mut Context<Self>) -> Div {
v_flex()
.gap_2()
@@ -747,11 +622,11 @@ impl Zeta2Inspector {
.justify_between()
.child(
ui::Button::new("reset-options", "Reset")
.disabled(self.zeta.read(cx).options() == &zeta2::DEFAULT_OPTIONS)
.disabled(self.zeta.read(cx).options() == &zeta::DEFAULT_OPTIONS)
.style(ButtonStyle::Outlined)
.size(ButtonSize::Large)
.on_click(cx.listener(|this, _, window, cx| {
this.set_options_state(&zeta2::DEFAULT_OPTIONS, window, cx);
this.set_options_state(&zeta::DEFAULT_OPTIONS, window, cx);
})),
),
)
@@ -915,42 +790,6 @@ impl Zeta2Inspector {
)
}
fn render_tabs(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
if self.last_prediction.is_none() {
return None;
};
Some(
ui::ToggleButtonGroup::single_row(
"prediction",
[
ui::ToggleButtonSimple::new(
"Context",
cx.listener(|this, _, _, cx| {
this.active_view = ActiveView::Context;
cx.notify();
}),
),
ui::ToggleButtonSimple::new(
"Inference",
cx.listener(|this, _, window, cx| {
this.active_view = ActiveView::Inference;
this.focus_feedback(window, cx);
cx.notify();
}),
),
],
)
.style(ui::ToggleButtonGroupStyle::Outlined)
.selected_index(if self.active_view == ActiveView::Context {
0
} else {
1
})
.into_any_element(),
)
}
fn render_stats(&self) -> Option<Div> {
let Some(prediction) = self.last_prediction.as_ref() else {
return None;
@@ -970,15 +809,15 @@ impl Zeta2Inspector {
)
}
fn render_duration(name: &'static str, time: Option<chrono::TimeDelta>) -> Div {
fn render_duration(name: &'static str, time: Option<Duration>) -> Div {
h_flex()
.gap_1()
.child(Label::new(name).color(Color::Muted).size(LabelSize::Small))
.child(match time {
Some(time) => Label::new(if time.num_microseconds().unwrap_or(0) >= 1000 {
format!("{} ms", time.num_milliseconds())
Some(time) => Label::new(if time.as_micros() >= 1000 {
format!("{} ms", time.as_millis())
} else {
format!("{} µs", time.num_microseconds().unwrap_or(0))
format!("{} µs", time.as_micros())
})
.size(LabelSize::Small),
None => Label::new("...").size(LabelSize::Small),
@@ -1006,144 +845,135 @@ impl Zeta2Inspector {
}
fn render_last_prediction(&self, prediction: &LastPrediction, cx: &mut Context<Self>) -> Div {
match &self.active_view {
ActiveView::Context => div().size_full().child(prediction.context_editor.clone()),
ActiveView::Inference => h_flex()
.items_start()
.w_full()
.flex_1()
.border_t_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().editor_background)
.child(
v_flex()
.flex_1()
.gap_2()
.p_4()
.h_full()
.child(
h_flex()
.justify_between()
.child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall))
.child(match prediction.state {
LastPredictionState::Requested
| LastPredictionState::Failed { .. } => ui::Chip::new("Local")
.bg_color(cx.theme().status().warning_background)
.label_color(Color::Success),
LastPredictionState::Success { .. } => ui::Chip::new("Cloud")
.bg_color(cx.theme().status().success_background)
.label_color(Color::Success),
}),
)
.child(prediction.prompt_editor.clone()),
)
.child(ui::vertical_divider())
.child(
v_flex()
.flex_1()
.gap_2()
.h_full()
.child(
v_flex()
.flex_1()
.gap_2()
.p_4()
.child(
ui::Headline::new("Model Response")
.size(ui::HeadlineSize::XSmall),
)
.child(match &prediction.state {
LastPredictionState::Success {
model_response_editor,
..
} => model_response_editor.clone().into_any_element(),
LastPredictionState::Requested => v_flex()
.gap_2()
.child(Label::new("Loading...").buffer_font(cx))
.into_any_element(),
LastPredictionState::Failed { message } => v_flex()
.gap_2()
.max_w_96()
.child(Label::new(message.clone()).buffer_font(cx))
.into_any_element(),
}),
)
.child(ui::divider())
.child(
if prediction.request.can_collect_data
&& let LastPredictionState::Success {
feedback_editor,
feedback: feedback_state,
h_flex()
.items_start()
.w_full()
.flex_1()
.border_t_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().editor_background)
.child(
v_flex()
.flex_1()
.gap_2()
.p_4()
.h_full()
.child(
h_flex()
.justify_between()
.child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall))
.child(match prediction.state {
LastPredictionState::Requested
| LastPredictionState::Failed { .. } => ui::Chip::new("Local")
.bg_color(cx.theme().status().warning_background)
.label_color(Color::Success),
LastPredictionState::Success { .. } => ui::Chip::new("Cloud")
.bg_color(cx.theme().status().success_background)
.label_color(Color::Success),
}),
)
.child(prediction.prompt_editor.clone()),
)
.child(ui::vertical_divider())
.child(
v_flex()
.flex_1()
.gap_2()
.h_full()
.child(
v_flex()
.flex_1()
.gap_2()
.p_4()
.child(
ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall),
)
.child(match &prediction.state {
LastPredictionState::Success {
model_response_editor,
..
} = &prediction.state
{
v_flex()
.key_context("Zeta2Feedback")
.on_action(cx.listener(Self::handle_rate_positive))
.on_action(cx.listener(Self::handle_rate_negative))
} => model_response_editor.clone().into_any_element(),
LastPredictionState::Requested => v_flex()
.gap_2()
.p_2()
.child(feedback_editor.clone())
.child(
h_flex()
.justify_end()
.w_full()
.child(
ButtonLike::new("rate-positive")
.when(
*feedback_state == Some(Feedback::Positive),
|this| this.style(ButtonStyle::Filled),
.child(Label::new("Loading...").buffer_font(cx))
.into_any_element(),
LastPredictionState::Failed { message } => v_flex()
.gap_2()
.max_w_96()
.child(Label::new(message.clone()).buffer_font(cx))
.into_any_element(),
}),
)
.child(ui::divider())
.child(
if let LastPredictionState::Success {
feedback_editor,
feedback: feedback_state,
..
} = &prediction.state
{
v_flex()
.key_context("Zeta2Feedback")
.on_action(cx.listener(Self::handle_rate_positive))
.on_action(cx.listener(Self::handle_rate_negative))
.gap_2()
.p_2()
.child(feedback_editor.clone())
.child(
h_flex()
.justify_end()
.w_full()
.child(
ButtonLike::new("rate-positive")
.when(
*feedback_state == Some(Feedback::Positive),
|this| this.style(ButtonStyle::Filled),
)
.child(
KeyBinding::for_action(
&Zeta2RatePredictionPositive,
cx,
)
.child(
KeyBinding::for_action(
&Zeta2RatePredictionPositive,
cx,
)
.size(TextSize::Small.rems(cx)),
.size(TextSize::Small.rems(cx)),
)
.child(ui::Icon::new(ui::IconName::ThumbsUp))
.on_click(cx.listener(|this, _, window, cx| {
this.handle_rate_positive(
&Zeta2RatePredictionPositive,
window,
cx,
);
})),
)
.child(
ButtonLike::new("rate-negative")
.when(
*feedback_state == Some(Feedback::Negative),
|this| this.style(ButtonStyle::Filled),
)
.child(
KeyBinding::for_action(
&Zeta2RatePredictionNegative,
cx,
)
.child(ui::Icon::new(ui::IconName::ThumbsUp))
.on_click(cx.listener(
|this, _, window, cx| {
this.handle_rate_positive(
&Zeta2RatePredictionPositive,
window,
cx,
);
},
)),
)
.child(
ButtonLike::new("rate-negative")
.when(
*feedback_state == Some(Feedback::Negative),
|this| this.style(ButtonStyle::Filled),
)
.child(
KeyBinding::for_action(
&Zeta2RatePredictionNegative,
cx,
)
.size(TextSize::Small.rems(cx)),
)
.child(ui::Icon::new(ui::IconName::ThumbsDown))
.on_click(cx.listener(
|this, _, window, cx| {
this.handle_rate_negative(
&Zeta2RatePredictionNegative,
window,
cx,
);
},
)),
),
)
.into_any()
} else {
Empty.into_any_element()
},
),
),
}
.size(TextSize::Small.rems(cx)),
)
.child(ui::Icon::new(ui::IconName::ThumbsDown))
.on_click(cx.listener(|this, _, window, cx| {
this.handle_rate_negative(
&Zeta2RatePredictionNegative,
window,
cx,
);
})),
),
)
.into_any()
} else {
Empty.into_any_element()
},
),
)
}
}
@@ -1178,8 +1008,7 @@ impl Render for Zeta2Inspector {
.h_full()
.justify_between()
.child(self.render_options(window, cx))
.gap_4()
.children(self.render_tabs(cx)),
.gap_4(),
)
.child(ui::vertical_divider())
.children(self.render_stats()),
@@ -1187,104 +1016,3 @@ impl Render for Zeta2Inspector {
.child(self.render_content(window, cx))
}
}
// Using same approach as commit view
struct ExcerptMetadataFile {
title: Arc<RelPath>,
worktree_id: WorktreeId,
path_style: PathStyle,
}
impl language::File for ExcerptMetadataFile {
fn as_local(&self) -> Option<&dyn language::LocalFile> {
None
}
fn disk_state(&self) -> DiskState {
DiskState::New
}
fn path(&self) -> &Arc<RelPath> {
&self.title
}
fn full_path(&self, _: &App) -> PathBuf {
self.title.as_std_path().to_path_buf()
}
fn file_name<'a>(&'a self, _: &'a App) -> &'a str {
self.title.file_name().unwrap()
}
fn path_style(&self, _: &App) -> PathStyle {
self.path_style
}
fn worktree_id(&self, _: &App) -> WorktreeId {
self.worktree_id
}
fn to_proto(&self, _: &App) -> language::proto::File {
unimplemented!()
}
fn is_private(&self) -> bool {
false
}
}
struct ZetaContextAddon {
excerpt_score_components: HashMap<editor::ExcerptId, DeclarationScoreComponents>,
}
impl editor::Addon for ZetaContextAddon {
fn to_any(&self) -> &dyn std::any::Any {
self
}
fn render_buffer_header_controls(
&self,
excerpt_info: &multi_buffer::ExcerptInfo,
_window: &Window,
_cx: &App,
) -> Option<AnyElement> {
let score_components = self.excerpt_score_components.get(&excerpt_info.id)?.clone();
Some(
div()
.id(excerpt_info.id.to_proto() as usize)
.child(ui::Icon::new(IconName::Info))
.cursor(CursorStyle::PointingHand)
.tooltip(move |_, cx| {
cx.new(|_| ScoreComponentsTooltip::new(&score_components))
.into()
})
.into_any(),
)
}
}
struct ScoreComponentsTooltip {
text: SharedString,
}
impl ScoreComponentsTooltip {
fn new(components: &DeclarationScoreComponents) -> Self {
Self {
text: format!("{:#?}", components).into(),
}
}
}
impl Render for ScoreComponentsTooltip {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
div().pl_2().pt_2p5().child(
div()
.elevation_2(cx)
.py_1()
.px_2()
.child(ui::Label::new(self.text.clone()).buffer_font(cx)),
)
}
}

View File

@@ -53,8 +53,7 @@ terminal_view.workspace = true
toml.workspace = true
util.workspace = true
watch.workspace = true
zeta.workspace = true
zeta2 = { workspace = true, features = ["eval-support"] }
zeta = { workspace = true, features = ["eval-support"] }
zlog.workspace = true
[dev-dependencies]

View File

@@ -9,7 +9,7 @@ use collections::HashSet;
use gpui::{AsyncApp, Entity};
use project::Project;
use util::ResultExt as _;
use zeta2::{Zeta, udiff::DiffLine};
use zeta::{Zeta, udiff::DiffLine};
use crate::{
EvaluateArguments, PredictionOptions,

View File

@@ -26,7 +26,7 @@ use project::{Project, ProjectPath};
use pulldown_cmark::CowStr;
use serde::{Deserialize, Serialize};
use util::{paths::PathStyle, rel_path::RelPath};
use zeta2::udiff::OpenedBuffers;
use zeta::udiff::OpenedBuffers;
use crate::paths::{REPOS_DIR, WORKTREES_DIR};
@@ -557,7 +557,7 @@ impl NamedExample {
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<OpenedBuffers<'_>> {
zeta2::udiff::apply_diff(&self.example.edit_history, project, cx).await
zeta::udiff::apply_diff(&self.example.edit_history, project, cx).await
}
}

View File

@@ -31,7 +31,7 @@ use serde_json::json;
use std::io::{self};
use std::time::Duration;
use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
use zeta2::ContextMode;
use zeta::ContextMode;
#[derive(Parser, Debug)]
#[command(name = "zeta")]
@@ -193,13 +193,14 @@ pub struct EvaluateArguments {
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
enum PredictionProvider {
Zeta1,
#[default]
Zeta2,
Sweep,
}
fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
zeta2::ZetaOptions {
fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta::ZetaOptions {
zeta::ZetaOptions {
context: ContextMode::Syntax(EditPredictionContextOptions {
max_retrieved_declarations: args.max_retrieved_definitions,
use_imports: !args.disable_imports_gathering,
@@ -397,7 +398,7 @@ async fn zeta2_syntax_context(
let output = cx
.update(|cx| {
let zeta = cx.new(|cx| {
zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
});
let indexing_done_task = zeta.update(cx, |zeta, cx| {
zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true));
@@ -435,7 +436,7 @@ async fn zeta1_context(
args: ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<zeta::GatherContextOutput> {
) -> Result<zeta::zeta1::GatherContextOutput> {
let LoadedContext {
full_path_str,
snapshot,
@@ -450,7 +451,7 @@ async fn zeta1_context(
let prompt_for_events = move || (events, 0);
cx.update(|cx| {
zeta::gather_context(
zeta::zeta1::gather_context(
full_path_str,
&snapshot,
clipped_cursor,

View File

@@ -21,7 +21,7 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::{Duration, Instant};
use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
use zeta::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
pub async fn run_predict(
args: PredictArguments,
@@ -47,12 +47,13 @@ pub fn setup_zeta(
cx: &mut AsyncApp,
) -> Result<Entity<Zeta>> {
let zeta =
cx.new(|cx| zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
cx.new(|cx| zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
zeta.update(cx, |zeta, _cx| {
let model = match provider {
PredictionProvider::Zeta2 => zeta2::ZetaEditPredictionModel::ZedCloud,
PredictionProvider::Sweep => zeta2::ZetaEditPredictionModel::Sweep,
PredictionProvider::Zeta1 => zeta::ZetaEditPredictionModel::Zeta1,
PredictionProvider::Zeta2 => zeta::ZetaEditPredictionModel::Zeta2,
PredictionProvider::Sweep => zeta::ZetaEditPredictionModel::Sweep,
};
zeta.set_edit_prediction_model(model);
})?;
@@ -142,25 +143,25 @@ pub async fn perform_predict(
let mut search_queries_executed_at = None;
while let Some(event) = debug_rx.next().await {
match event {
zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
zeta::ZetaDebugInfo::ContextRetrievalStarted(info) => {
start_time = Some(info.timestamp);
fs::write(
example_run_dir.join("search_prompt.md"),
&info.search_prompt,
)?;
}
zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
zeta::ZetaDebugInfo::SearchQueriesGenerated(info) => {
search_queries_generated_at = Some(info.timestamp);
fs::write(
example_run_dir.join("search_queries.json"),
serde_json::to_string_pretty(&info.search_queries).unwrap(),
)?;
}
zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
zeta::ZetaDebugInfo::SearchQueriesExecuted(info) => {
search_queries_executed_at = Some(info.timestamp);
}
zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
zeta::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
zeta::ZetaDebugInfo::EditPredictionRequested(request) => {
let prediction_started_at = Instant::now();
start_time.get_or_insert(prediction_started_at);
let prompt = request.local_prompt.unwrap_or_default();
@@ -170,9 +171,9 @@ pub async fn perform_predict(
let mut result = result.lock().unwrap();
result.prompt_len = prompt.chars().count();
for included_file in request.request.included_files {
for included_file in request.inputs.included_files {
let insertions =
vec![(request.request.cursor_point, CURSOR_MARKER)];
vec![(request.inputs.cursor_point, CURSOR_MARKER)];
result.excerpts.extend(included_file.excerpts.iter().map(
|excerpt| ActualExcerpt {
path: included_file.path.components().skip(1).collect(),
@@ -182,7 +183,7 @@ pub async fn perform_predict(
write_codeblock(
&included_file.path,
included_file.excerpts.iter(),
if included_file.path == request.request.excerpt_path {
if included_file.path == request.inputs.cursor_path {
&insertions
} else {
&[]
@@ -196,7 +197,7 @@ pub async fn perform_predict(
let response =
request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
let response = zeta2::text_from_response(response).unwrap_or_default();
let response = zeta::text_from_response(response).unwrap_or_default();
let prediction_finished_at = Instant::now();
fs::write(example_run_dir.join("prediction_response.md"), &response)?;
@@ -267,20 +268,7 @@ pub async fn perform_predict(
let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
result.diff = prediction
.map(|prediction| {
let old_text = prediction.snapshot.text();
let new_text = prediction
.buffer
.update(cx, |buffer, cx| {
let branch = buffer.branch(cx);
branch.update(cx, |branch, cx| {
branch.edit(prediction.edits.iter().cloned(), None, cx);
branch.text()
})
})
.unwrap();
language::unified_diff(&old_text, &new_text)
})
.and_then(|prediction| prediction.edit_preview.as_unified_diff(&prediction.edits))
.unwrap_or_default();
anyhow::Ok(result)

View File

@@ -32,7 +32,7 @@ use std::{
time::Duration,
};
use util::paths::PathStyle;
use zeta2::ContextMode;
use zeta::ContextMode;
use crate::headless::ZetaCliAppState;
use crate::source_location::SourceLocation;
@@ -44,7 +44,7 @@ pub async fn retrieval_stats(
only_extension: Option<String>,
file_limit: Option<usize>,
skip_files: Option<usize>,
options: zeta2::ZetaOptions,
options: zeta::ZetaOptions,
cx: &mut AsyncApp,
) -> Result<String> {
let ContextMode::Syntax(context_options) = options.context.clone() else {