Improve edit prediction example capture (#45536)

This PR improves the `edit prediction: Capture Example` in several ways:
* fixed bugs in how the uncommitted diff was calculated
* added a `edit_predictions.examples_dir` setting that can be set in
order to have the action automatically save examples into the given
folder
* moved the action into the `edit_predictions` crate, in preparation for
collecting this data passively from end users, when they have opted in
to data sharing, similar to what we did for Zeta 1

Release Notes:

- N/A
This commit is contained in:
Max Brunsfeld
2025-12-22 12:40:02 -08:00
committed by GitHub
parent dd521a96fb
commit 07ada58466
19 changed files with 771 additions and 216 deletions

12
Cargo.lock generated
View File

@@ -5212,6 +5212,7 @@ dependencies = [
"anyhow",
"arrayvec",
"brotli",
"buffer_diff",
"client",
"clock",
"cloud_api_types",
@@ -5249,7 +5250,9 @@ dependencies = [
"strum 0.27.2",
"telemetry",
"telemetry_events",
"text",
"thiserror 2.0.17",
"time",
"ui",
"util",
"uuid",
@@ -5354,8 +5357,10 @@ dependencies = [
"anyhow",
"buffer_diff",
"client",
"clock",
"cloud_llm_client",
"codestral",
"collections",
"command_palette_hooks",
"copilot",
"edit_prediction",
@@ -5364,18 +5369,20 @@ dependencies = [
"feature_flags",
"fs",
"futures 0.3.31",
"git",
"gpui",
"indoc",
"language",
"log",
"language_model",
"lsp",
"markdown",
"menu",
"multi_buffer",
"paths",
"pretty_assertions",
"project",
"regex",
"release_channel",
"semver",
"serde_json",
"settings",
"supermaven",
@@ -5388,6 +5395,7 @@ dependencies = [
"workspace",
"zed_actions",
"zeta_prompt",
"zlog",
]
[[package]]

View File

@@ -314,6 +314,12 @@ impl BufferDiffSnapshot {
self.inner.hunks.is_empty()
}
pub fn base_text_string(&self) -> Option<String> {
self.inner
.base_text_exists
.then(|| self.inner.base_text.text())
}
pub fn secondary_diff(&self) -> Option<&BufferDiffSnapshot> {
self.secondary_diff.as_deref()
}

View File

@@ -19,6 +19,7 @@ ai_onboarding.workspace = true
anyhow.workspace = true
arrayvec.workspace = true
brotli.workspace = true
buffer_diff.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
@@ -52,7 +53,9 @@ settings.workspace = true
strum.workspace = true
telemetry.workspace = true
telemetry_events.workspace = true
text.workspace = true
thiserror.workspace = true
time.workspace = true
ui.workspace = true
util.workspace = true
uuid.workspace = true

View File

@@ -0,0 +1,375 @@
use crate::{
EditPredictionStore, StoredEvent,
cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
};
use anyhow::Result;
use buffer_diff::BufferDiffSnapshot;
use collections::HashMap;
use gpui::{App, Entity, Task};
use language::{Buffer, ToPoint as _};
use project::Project;
use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
use text::{BufferSnapshot as TextBufferSnapshot, ToOffset as _};
pub fn capture_example(
project: Entity<Project>,
buffer: Entity<Buffer>,
cursor_anchor: language::Anchor,
last_event_is_expected_patch: bool,
cx: &mut App,
) -> Option<Task<Result<ExampleSpec>>> {
let ep_store = EditPredictionStore::try_global(cx)?;
let snapshot = buffer.read(cx).snapshot();
let file = snapshot.file()?;
let worktree_id = file.worktree_id(cx);
let repository = project.read(cx).active_repository(cx)?;
let repository_snapshot = repository.read(cx).snapshot();
let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
let cursor_path = worktree.read(cx).root_name().join(file.path());
if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
return None;
}
let repository_url = repository_snapshot
.remote_origin_url
.clone()
.or_else(|| repository_snapshot.remote_upstream_url.clone())?;
let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
let mut events = ep_store.update(cx, |store, cx| {
store.edit_history_for_project_with_pause_split_last_event(&project, cx)
});
let git_store = project.read(cx).git_store().clone();
Some(cx.spawn(async move |mut cx| {
let snapshots_by_path = collect_snapshots(&project, &git_store, &events, &mut cx).await?;
let cursor_excerpt = cx
.background_executor()
.spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
.await;
let uncommitted_diff = cx
.background_executor()
.spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
.await;
let mut edit_history = String::new();
let mut expected_patch = String::new();
if last_event_is_expected_patch {
if let Some(stored_event) = events.pop() {
zeta_prompt::write_event(&mut expected_patch, &stored_event.event);
}
}
for stored_event in &events {
zeta_prompt::write_event(&mut edit_history, &stored_event.event);
if !edit_history.ends_with('\n') {
edit_history.push('\n');
}
}
let name = generate_timestamp_name();
Ok(ExampleSpec {
name,
repository_url,
revision,
uncommitted_diff,
cursor_path: cursor_path.as_std_path().into(),
cursor_position: cursor_excerpt,
edit_history,
expected_patch,
})
}))
}
fn compute_cursor_excerpt(
snapshot: &language::BufferSnapshot,
cursor_anchor: language::Anchor,
) -> String {
let cursor_point = cursor_anchor.to_point(snapshot);
let (_editable_range, context_range) =
editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
let context_start_offset = context_range.start.to_offset(snapshot);
let cursor_offset = cursor_anchor.to_offset(snapshot);
let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
let mut excerpt = snapshot.text_for_range(context_range).collect::<String>();
if cursor_offset_in_excerpt <= excerpt.len() {
excerpt.insert_str(cursor_offset_in_excerpt, zeta_prompt::CURSOR_MARKER);
}
excerpt
}
async fn collect_snapshots(
project: &Entity<Project>,
git_store: &Entity<project::git_store::GitStore>,
events: &[StoredEvent],
cx: &mut gpui::AsyncApp,
) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
let mut snapshots_by_path = HashMap::default();
for stored_event in events {
let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
let project_path = project.find_project_path(path, cx)?;
let full_path = project
.worktree_for_id(project_path.worktree_id, cx)?
.read(cx)
.root_name()
.join(&project_path.path)
.as_std_path()
.into();
Some((project_path, full_path))
})? {
if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
let buffer = project
.update(cx, |project, cx| {
project.open_buffer(project_path.clone(), cx)
})?
.await?;
let diff = git_store
.update(cx, |git_store, cx| {
git_store.open_uncommitted_diff(buffer.clone(), cx)
})?
.await?;
let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx))?;
entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
}
}
}
Ok(snapshots_by_path)
}
fn compute_uncommitted_diff(
snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
) -> String {
let mut uncommitted_diff = String::new();
for (full_path, (before_text, diff_snapshot)) in snapshots_by_path {
if let Some(head_text) = &diff_snapshot.base_text_string() {
let file_diff = language::unified_diff(head_text, &before_text.text());
if !file_diff.is_empty() {
let path_str = full_path.to_string_lossy();
writeln!(uncommitted_diff, "--- a/{path_str}").ok();
writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
uncommitted_diff.push_str(&file_diff);
if !uncommitted_diff.ends_with('\n') {
uncommitted_diff.push('\n');
}
}
}
}
uncommitted_diff
}
fn generate_timestamp_name() -> String {
let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
match format {
Ok(format) => {
let now = time::OffsetDateTime::now_local()
.unwrap_or_else(|_| time::OffsetDateTime::now_utc());
now.format(&format)
.unwrap_or_else(|_| "unknown-time".to_string())
}
Err(_) => "unknown-time".to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use client::{Client, UserStore};
use clock::FakeSystemClock;
use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
use indoc::indoc;
use language::{Anchor, Point};
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use std::path::Path;
#[gpui::test]
async fn test_capture_example(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
let committed_contents = indoc! {"
fn main() {
one();
two();
three();
four();
five();
six();
seven();
eight();
nine();
}
"};
let disk_contents = indoc! {"
fn main() {
// comment 1
one();
two();
three();
four();
five();
six();
seven();
eight();
// comment 2
nine();
}
"};
fs.insert_tree(
"/project",
json!({
".git": {},
"src": {
"main.rs": disk_contents,
}
}),
)
.await;
fs.set_head_for_repo(
Path::new("/project/.git"),
&[("src/main.rs", committed_contents.to_string())],
"abc123def456",
);
fs.set_remote_for_repo(
Path::new("/project/.git"),
"origin",
"https://github.com/test/repo.git",
);
let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer("/project/src/main.rs", cx)
})
.await
.unwrap();
let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
ep_store.update(cx, |ep_store, cx| {
ep_store.register_buffer(&buffer, &project, cx)
});
cx.run_until_parked();
buffer.update(cx, |buffer, cx| {
let point = Point::new(6, 0);
buffer.edit([(point..point, " // comment 3\n")], None, cx);
let point = Point::new(4, 0);
buffer.edit([(point..point, " // comment 4\n")], None, cx);
pretty_assertions::assert_eq!(
buffer.text(),
indoc! {"
fn main() {
// comment 1
one();
two();
// comment 4
three();
four();
// comment 3
five();
six();
seven();
eight();
// comment 2
nine();
}
"}
);
});
cx.run_until_parked();
let mut example = cx
.update(|cx| {
capture_example(project.clone(), buffer.clone(), Anchor::MIN, false, cx).unwrap()
})
.await
.unwrap();
example.name = "test".to_string();
pretty_assertions::assert_eq!(
example,
ExampleSpec {
name: "test".to_string(),
repository_url: "https://github.com/test/repo.git".to_string(),
revision: "abc123def456".to_string(),
uncommitted_diff: indoc! {"
--- a/project/src/main.rs
+++ b/project/src/main.rs
@@ -1,4 +1,5 @@
fn main() {
+ // comment 1
one();
two();
three();
@@ -7,5 +8,6 @@
six();
seven();
eight();
+ // comment 2
nine();
}
"}
.to_string(),
cursor_path: Path::new("project/src/main.rs").into(),
cursor_position: indoc! {"
<|user_cursor|>fn main() {
// comment 1
one();
two();
// comment 4
three();
four();
// comment 3
five();
six();
seven();
eight();
// comment 2
nine();
}
"}
.to_string(),
edit_history: indoc! {"
--- a/project/src/main.rs
+++ b/project/src/main.rs
@@ -2,8 +2,10 @@
// comment 1
one();
two();
+ // comment 4
three();
four();
+ // comment 3
five();
six();
seven();
"}
.to_string(),
expected_patch: "".to_string(),
}
);
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
zlog::init_test();
let http_client = FakeHttpClient::with_404_response();
let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
language_model::init(client.clone(), cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
EditPredictionStore::global(&client, &user_store, cx);
})
}
}

View File

@@ -35,6 +35,7 @@ use semver::Version;
use serde::de::DeserializeOwned;
use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
use std::collections::{VecDeque, hash_map};
use text::Edit;
use workspace::Workspace;
use std::ops::Range;
@@ -57,9 +58,9 @@ pub mod open_ai_response;
mod prediction;
pub mod sweep_ai;
#[cfg(any(test, feature = "test-support", feature = "cli-support"))]
pub mod udiff;
mod capture_example;
mod zed_edit_prediction_delegate;
pub mod zeta1;
pub mod zeta2;
@@ -74,6 +75,7 @@ pub use crate::prediction::EditPrediction;
pub use crate::prediction::EditPredictionId;
use crate::prediction::EditPredictionResult;
pub use crate::sweep_ai::SweepAi;
pub use capture_example::capture_example;
pub use language_model::ApiKeyState;
pub use telemetry_events::EditPredictionRating;
pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
@@ -231,8 +233,15 @@ pub struct EditPredictionFinishedDebugEvent {
pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
/// An event with associated metadata for reconstructing buffer state.
#[derive(Clone)]
pub struct StoredEvent {
pub event: Arc<zeta_prompt::Event>,
pub old_snapshot: TextBufferSnapshot,
}
struct ProjectState {
events: VecDeque<Arc<zeta_prompt::Event>>,
events: VecDeque<StoredEvent>,
last_event: Option<LastEvent>,
recent_paths: VecDeque<ProjectPath>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
@@ -248,7 +257,7 @@ struct ProjectState {
}
impl ProjectState {
pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
self.events
.iter()
.cloned()
@@ -260,7 +269,7 @@ impl ProjectState {
.collect()
}
pub fn events_split_by_pause(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
self.events
.iter()
.cloned()
@@ -415,7 +424,7 @@ impl LastEvent {
&self,
license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
cx: &App,
) -> Option<Arc<zeta_prompt::Event>> {
) -> Option<StoredEvent> {
let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
@@ -430,19 +439,22 @@ impl LastEvent {
})
});
let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
if path == old_path && diff.is_empty() {
None
} else {
Some(Arc::new(zeta_prompt::Event::BufferChange {
old_path,
path,
diff,
in_open_source_repo,
// TODO: Actually detect if this edit was predicted or not
predicted: false,
}))
Some(StoredEvent {
event: Arc::new(zeta_prompt::Event::BufferChange {
old_path,
path,
diff,
in_open_source_repo,
// TODO: Actually detect if this edit was predicted or not
predicted: false,
}),
old_snapshot: self.old_snapshot.clone(),
})
}
}
@@ -475,6 +487,52 @@ impl LastEvent {
}
}
pub(crate) fn compute_diff_between_snapshots(
old_snapshot: &TextBufferSnapshot,
new_snapshot: &TextBufferSnapshot,
) -> Option<String> {
let edits: Vec<Edit<usize>> = new_snapshot
.edits_since::<usize>(&old_snapshot.version)
.collect();
let (first_edit, last_edit) = edits.first().zip(edits.last())?;
let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
const CONTEXT_LINES: u32 = 3;
let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
let old_context_end_row =
(old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
let new_context_end_row =
(new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
let old_end_line_offset = old_snapshot
.point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
let new_end_line_offset = new_snapshot
.point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
let old_edit_range = old_start_line_offset..old_end_line_offset;
let new_edit_range = new_start_line_offset..new_end_line_offset;
let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
let diff = language::unified_diff_with_offsets(
&old_region_text,
&new_region_text,
old_context_start_row,
new_context_start_row,
);
Some(diff)
}
fn buffer_path_with_id_fallback(
file: Option<&Arc<dyn File>>,
snapshot: &TextBufferSnapshot,
@@ -643,7 +701,7 @@ impl EditPredictionStore {
&self,
project: &Entity<Project>,
cx: &App,
) -> Vec<Arc<zeta_prompt::Event>> {
) -> Vec<StoredEvent> {
self.projects
.get(&project.entity_id())
.map(|project_state| project_state.events(cx))
@@ -654,7 +712,7 @@ impl EditPredictionStore {
&self,
project: &Entity<Project>,
cx: &App,
) -> Vec<Arc<zeta_prompt::Event>> {
) -> Vec<StoredEvent> {
self.projects
.get(&project.entity_id())
.map(|project_state| project_state.events_split_by_pause(cx))
@@ -1536,8 +1594,10 @@ impl EditPredictionStore {
self.get_or_init_project(&project, cx);
let project_state = self.projects.get(&project.entity_id()).unwrap();
let events = project_state.events(cx);
let has_events = !events.is_empty();
let stored_events = project_state.events(cx);
let has_events = !stored_events.is_empty();
let events: Vec<Arc<zeta_prompt::Event>> =
stored_events.into_iter().map(|e| e.event).collect();
let debug_tx = project_state.debug_tx.clone();
let snapshot = active_buffer.read(cx).snapshot();

View File

@@ -1,5 +1,5 @@
use super::*;
use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
use client::{UserStore, test::FakeServer};
use clock::{FakeSystemClock, ReplicaId};
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
@@ -360,7 +360,7 @@ async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContex
ep_store.edit_history_for_project(&project, cx)
});
assert_eq!(events.len(), 1);
let zeta_prompt::Event::BufferChange { diff, .. } = events[0].as_ref();
let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
assert_eq!(
diff.as_str(),
indoc! {"
@@ -377,7 +377,7 @@ async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContex
ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
});
assert_eq!(events.len(), 2);
let zeta_prompt::Event::BufferChange { diff, .. } = events[0].as_ref();
let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
assert_eq!(
diff.as_str(),
indoc! {"
@@ -389,7 +389,7 @@ async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContex
"}
);
let zeta_prompt::Event::BufferChange { diff, .. } = events[1].as_ref();
let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
assert_eq!(
diff.as_str(),
indoc! {"
@@ -2082,6 +2082,74 @@ async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut Te
);
}
#[gpui::test]
fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| {
Buffer::local(
indoc! {"
zero
one
two
three
four
five
six
seven
eight
nine
ten
eleven
twelve
thirteen
fourteen
fifteen
sixteen
seventeen
eighteen
nineteen
twenty
twenty-one
twenty-two
twenty-three
twenty-four
"},
cx,
)
});
let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
buffer.update(cx, |buffer, cx| {
let point = Point::new(12, 0);
buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
let point = Point::new(8, 0);
buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
});
let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
assert_eq!(
diff,
indoc! {"
@@ -6,10 +6,12 @@
five
six
seven
+FIRST INSERTION
eight
nine
ten
eleven
+SECOND INSERTION
twelve
thirteen
fourteen
"}
);
}
#[ctor::ctor]
fn init_logger() {
zlog::init_test();

View File

@@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize};
use std::{fmt::Write as _, mem, path::Path, sync::Arc};
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ExampleSpec {
#[serde(default)]
pub name: String,

View File

@@ -45,6 +45,11 @@ pub async fn run_format_prompt(
let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
let project = state.project.clone();
let (_, input) = ep_store.update(&mut cx, |ep_store, cx| {
let events = ep_store
.edit_history_for_project(&project, cx)
.into_iter()
.map(|e| e.event)
.collect();
anyhow::Ok(zeta2_prompt_input(
&snapshot,
example
@@ -53,7 +58,7 @@ pub async fn run_format_prompt(
.context("context must be set")?
.files
.clone(),
ep_store.edit_history_for_project(&project, cx),
events,
example.spec.cursor_path.clone(),
example
.buffer

View File

@@ -15,8 +15,7 @@ doctest = false
[dependencies]
anyhow.workspace = true
buffer_diff.workspace = true
git.workspace = true
log.workspace = true
collections.workspace = true
time.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
@@ -50,11 +49,18 @@ zed_actions.workspace = true
zeta_prompt.workspace = true
[dev-dependencies]
clock.workspace = true
copilot = { workspace = true, features = ["test-support"] }
editor = { workspace = true, features = ["test-support"] }
futures.workspace = true
indoc.workspace = true
language_model.workspace = true
lsp = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
release_channel.workspace = true
semver.workspace = true
serde_json.workspace = true
theme = { workspace = true, features = ["test-support"] }
workspace = { workspace = true, features = ["test-support"] }
zlog.workspace = true

View File

@@ -915,11 +915,8 @@ impl EditPredictionButton {
.when(
cx.has_flag::<PredictEditsRatePredictionsFeatureFlag>(),
|this| {
this.action(
"Capture Edit Prediction Example",
CaptureExample.boxed_clone(),
)
.action("Rate Predictions", RatePredictions.boxed_clone())
this.action("Capture Prediction Example", CaptureExample.boxed_clone())
.action("Rate Predictions", RatePredictions.boxed_clone())
},
);
}

View File

@@ -2,25 +2,17 @@ mod edit_prediction_button;
mod edit_prediction_context_view;
mod rate_prediction_modal;
use std::any::{Any as _, TypeId};
use std::path::Path;
use std::sync::Arc;
use command_palette_hooks::CommandPaletteFilter;
use edit_prediction::{
EditPredictionStore, ResetOnboarding, Zeta2FeatureFlag, example_spec::ExampleSpec,
};
use edit_prediction::{ResetOnboarding, Zeta2FeatureFlag, capture_example};
use edit_prediction_context_view::EditPredictionContextView;
use editor::Editor;
use feature_flags::FeatureFlagAppExt as _;
use git::repository::DiffType;
use gpui::{Window, actions};
use language::ToPoint as _;
use log;
use gpui::actions;
use language::language_settings::AllLanguageSettings;
use project::DisableAiSettings;
use rate_prediction_modal::RatePredictionsModal;
use settings::{Settings as _, SettingsStore};
use text::ToOffset as _;
use std::any::{Any as _, TypeId};
use ui::{App, prelude::*};
use workspace::{SplitDirection, Workspace};
@@ -56,7 +48,9 @@ pub fn init(cx: &mut App) {
}
});
workspace.register_action(capture_edit_prediction_example);
workspace.register_action(|workspace, _: &CaptureExample, window, cx| {
capture_example_as_markdown(workspace, window, cx);
});
workspace.register_action_renderer(|div, _, _, cx| {
let has_flag = cx.has_flag::<Zeta2FeatureFlag>();
div.when(has_flag, |div| {
@@ -138,182 +132,48 @@ fn feature_gate_predict_edits_actions(cx: &mut App) {
.detach();
}
fn capture_edit_prediction_example(
fn capture_example_as_markdown(
workspace: &mut Workspace,
_: &CaptureExample,
window: &mut Window,
cx: &mut Context<Workspace>,
) {
let Some(ep_store) = EditPredictionStore::try_global(cx) else {
return;
};
let project = workspace.project().clone();
let (worktree_root, repository) = {
let project_ref = project.read(cx);
let worktree_root = project_ref
.visible_worktrees(cx)
.next()
.map(|worktree| worktree.read(cx).abs_path());
let repository = project_ref.active_repository(cx);
(worktree_root, repository)
};
let (Some(worktree_root), Some(repository)) = (worktree_root, repository) else {
log::error!("CaptureExampleSpec: missing worktree or active repository");
return;
};
let repository_snapshot = repository.read(cx).snapshot();
if worktree_root.as_ref() != repository_snapshot.work_directory_abs_path.as_ref() {
log::error!(
"repository is not at worktree root (repo={:?}, worktree={:?})",
repository_snapshot.work_directory_abs_path,
worktree_root
);
return;
}
let Some(repository_url) = repository_snapshot
.remote_origin_url
.clone()
.or_else(|| repository_snapshot.remote_upstream_url.clone())
else {
log::error!("active repository has no origin/upstream remote url");
return;
};
let Some(revision) = repository_snapshot
.head_commit
.as_ref()
.map(|commit| commit.sha.to_string())
else {
log::error!("active repository has no head commit");
return;
};
let mut events = ep_store.update(cx, |store, cx| {
store.edit_history_for_project_with_pause_split_last_event(&project, cx)
});
let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
log::error!("no active editor");
return;
};
let Some(project_path) = editor.read(cx).project_path(cx) else {
log::error!("active editor has no project path");
return;
};
let Some((buffer, cursor_anchor)) = editor
.read(cx)
.buffer()
.read(cx)
.text_anchor_for_position(editor.read(cx).selections.newest_anchor().head(), cx)
else {
log::error!("failed to resolve cursor buffer/anchor");
return;
};
let snapshot = buffer.read(cx).snapshot();
let cursor_point = cursor_anchor.to_point(&snapshot);
let (_editable_range, context_range) =
edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
cursor_point,
&snapshot,
100,
50,
);
let cursor_path: Arc<Path> = repository
.read(cx)
.project_path_to_repo_path(&project_path, cx)
.map(|repo_path| Path::new(repo_path.as_unix_str()).into())
.unwrap_or_else(|| Path::new(project_path.path.as_unix_str()).into());
let cursor_position = {
let context_start_offset = context_range.start.to_offset(&snapshot);
let cursor_offset = cursor_anchor.to_offset(&snapshot);
let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
let mut excerpt = snapshot.text_for_range(context_range).collect::<String>();
if cursor_offset_in_excerpt <= excerpt.len() {
excerpt.insert_str(cursor_offset_in_excerpt, zeta_prompt::CURSOR_MARKER);
}
excerpt
};
) -> Option<()> {
let markdown_language = workspace
.app_state()
.languages
.language_for_name("Markdown");
let fs = workspace.app_state().fs.clone();
let project = workspace.project().clone();
let editor = workspace.active_item_as::<Editor>(cx)?;
let editor = editor.read(cx);
let (buffer, cursor_anchor) = editor
.buffer()
.read(cx)
.text_anchor_for_position(editor.selections.newest_anchor().head(), cx)?;
let example = capture_example(project.clone(), buffer, cursor_anchor, true, cx)?;
let examples_dir = AllLanguageSettings::get_global(cx)
.edit_predictions
.examples_dir
.clone();
cx.spawn_in(window, async move |workspace_entity, cx| {
let markdown_language = markdown_language.await?;
let example_spec = example.await?;
let buffer = if let Some(dir) = examples_dir {
fs.create_dir(&dir).await.ok();
let mut path = dir.join(&example_spec.name.replace(' ', "--").replace(':', "-"));
path.set_extension("md");
project.update(cx, |project, cx| project.open_local_buffer(&path, cx))
} else {
project.update(cx, |project, cx| project.create_buffer(false, cx))
}?
.await?;
let uncommitted_diff_rx = repository.update(cx, |repository, cx| {
repository.diff(DiffType::HeadToWorktree, cx)
})?;
let uncommitted_diff = match uncommitted_diff_rx.await {
Ok(Ok(diff)) => diff,
Ok(Err(error)) => {
log::error!("failed to compute uncommitted diff: {error:#}");
return Ok(());
}
Err(error) => {
log::error!("uncommitted diff channel dropped: {error:#}");
return Ok(());
}
};
let mut edit_history = String::new();
let mut expected_patch = String::new();
if let Some(last_event) = events.pop() {
for event in &events {
zeta_prompt::write_event(&mut edit_history, event);
if !edit_history.ends_with('\n') {
edit_history.push('\n');
}
edit_history.push('\n');
}
zeta_prompt::write_event(&mut expected_patch, &last_event);
}
let format =
time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
let name = match format {
Ok(format) => {
let now = time::OffsetDateTime::now_local()
.unwrap_or_else(|_| time::OffsetDateTime::now_utc());
now.format(&format)
.unwrap_or_else(|_| "unknown-time".to_string())
}
Err(_) => "unknown-time".to_string(),
};
let markdown = ExampleSpec {
name,
repository_url,
revision,
uncommitted_diff,
cursor_path,
cursor_position,
edit_history,
expected_patch,
}
.to_markdown();
let buffer = project
.update(cx, |project, cx| project.create_buffer(false, cx))?
.await?;
buffer.update(cx, |buffer, cx| {
buffer.set_text(markdown, cx);
buffer.set_text(example_spec.to_markdown(), cx);
buffer.set_language(Some(markdown_language), cx);
})?;
workspace_entity.update_in(cx, |workspace, window, cx| {
workspace.add_item_to_active_pane(
Box::new(
@@ -327,4 +187,5 @@ fn capture_edit_prediction_example(
})
})
.detach_and_log_err(cx);
None
}

View File

@@ -156,8 +156,16 @@ impl GitRepository for FakeGitRepository {
})
}
fn remote_url(&self, _name: &str) -> BoxFuture<'_, Option<String>> {
async move { None }.boxed()
fn remote_url(&self, name: &str) -> BoxFuture<'_, Option<String>> {
let name = name.to_string();
let fut = self.with_state_async(false, move |state| {
state
.remotes
.get(&name)
.context("remote not found")
.cloned()
});
async move { fut.await.ok() }.boxed()
}
fn diff_tree(&self, _request: DiffTreeType) -> BoxFuture<'_, Result<TreeDiff>> {

View File

@@ -1857,6 +1857,18 @@ impl FakeFs {
.unwrap();
}
pub fn set_remote_for_repo(
&self,
dot_git: &Path,
name: impl Into<String>,
url: impl Into<String>,
) {
self.with_git_state(dot_git, true, |state| {
state.remotes.insert(name.into(), url.into());
})
.unwrap();
}
pub fn insert_branches(&self, dot_git: &Path, branches: &[&str]) {
self.with_git_state(dot_git, true, |state| {
if let Some(first) = branches.first()

View File

@@ -67,7 +67,7 @@ use task::RunnableTag;
pub use task_context::{ContextLocation, ContextProvider, RunnableRange};
pub use text_diff::{
DiffOptions, apply_diff_patch, line_diff, text_diff, text_diff_with_options, unified_diff,
word_diff_ranges,
unified_diff_with_offsets, word_diff_ranges,
};
use theme::SyntaxTheme;
pub use toolchain::{

View File

@@ -392,6 +392,7 @@ pub struct EditPredictionSettings {
/// Whether edit predictions are enabled in the assistant panel.
/// This setting has no effect if globally disabled.
pub enabled_in_text_threads: bool,
pub examples_dir: Option<Arc<Path>>,
}
impl EditPredictionSettings {
@@ -699,6 +700,7 @@ impl settings::Settings for AllLanguageSettings {
copilot: copilot_settings,
codestral: codestral_settings,
enabled_in_text_threads,
examples_dir: edit_predictions.examples_dir,
},
defaults: default_language_settings,
languages,

View File

@@ -1,25 +1,139 @@
use crate::{CharClassifier, CharKind, CharScopeContext, LanguageScope};
use anyhow::{Context, anyhow};
use imara_diff::{
Algorithm, UnifiedDiffBuilder, diff,
intern::{InternedInput, Token},
Algorithm, Sink, diff,
intern::{InternedInput, Interner, Token},
sources::lines_with_terminator,
};
use std::{iter, ops::Range, sync::Arc};
use std::{fmt::Write, iter, ops::Range, sync::Arc};
const MAX_WORD_DIFF_LEN: usize = 512;
const MAX_WORD_DIFF_LINE_COUNT: usize = 8;
/// Computes a diff between two strings, returning a unified diff string.
pub fn unified_diff(old_text: &str, new_text: &str) -> String {
unified_diff_with_offsets(old_text, new_text, 0, 0)
}
/// Computes a diff between two strings, returning a unified diff string with
/// hunk headers adjusted to reflect the given starting line numbers (1-indexed).
pub fn unified_diff_with_offsets(
old_text: &str,
new_text: &str,
old_start_line: u32,
new_start_line: u32,
) -> String {
let input = InternedInput::new(old_text, new_text);
diff(
Algorithm::Histogram,
&input,
UnifiedDiffBuilder::new(&input),
OffsetUnifiedDiffBuilder::new(&input, old_start_line, new_start_line),
)
}
/// A unified diff builder that applies line number offsets to hunk headers.
struct OffsetUnifiedDiffBuilder<'a> {
before: &'a [Token],
after: &'a [Token],
interner: &'a Interner<&'a str>,
pos: u32,
before_hunk_start: u32,
after_hunk_start: u32,
before_hunk_len: u32,
after_hunk_len: u32,
old_line_offset: u32,
new_line_offset: u32,
buffer: String,
dst: String,
}
impl<'a> OffsetUnifiedDiffBuilder<'a> {
fn new(input: &'a InternedInput<&'a str>, old_line_offset: u32, new_line_offset: u32) -> Self {
Self {
before_hunk_start: 0,
after_hunk_start: 0,
before_hunk_len: 0,
after_hunk_len: 0,
old_line_offset,
new_line_offset,
buffer: String::with_capacity(8),
dst: String::new(),
interner: &input.interner,
before: &input.before,
after: &input.after,
pos: 0,
}
}
fn print_tokens(&mut self, tokens: &[Token], prefix: char) {
for &token in tokens {
writeln!(&mut self.buffer, "{prefix}{}", self.interner[token]).unwrap();
}
}
fn flush(&mut self) {
if self.before_hunk_len == 0 && self.after_hunk_len == 0 {
return;
}
let end = (self.pos + 3).min(self.before.len() as u32);
self.update_pos(end, end);
writeln!(
&mut self.dst,
"@@ -{},{} +{},{} @@",
self.before_hunk_start + 1 + self.old_line_offset,
self.before_hunk_len,
self.after_hunk_start + 1 + self.new_line_offset,
self.after_hunk_len,
)
.unwrap();
write!(&mut self.dst, "{}", &self.buffer).unwrap();
self.buffer.clear();
self.before_hunk_len = 0;
self.after_hunk_len = 0;
}
fn update_pos(&mut self, print_to: u32, move_to: u32) {
self.print_tokens(&self.before[self.pos as usize..print_to as usize], ' ');
let len = print_to - self.pos;
self.pos = move_to;
self.before_hunk_len += len;
self.after_hunk_len += len;
}
}
impl Sink for OffsetUnifiedDiffBuilder<'_> {
type Out = String;
fn process_change(&mut self, before: Range<u32>, after: Range<u32>) {
if before.start - self.pos > 6 {
self.flush();
}
if self.before_hunk_len == 0 && self.after_hunk_len == 0 {
self.pos = before.start.saturating_sub(3);
self.before_hunk_start = self.pos;
self.after_hunk_start = after.start.saturating_sub(3);
}
self.update_pos(before.start, before.end);
self.before_hunk_len += before.end - before.start;
self.after_hunk_len += after.end - after.start;
self.print_tokens(
&self.before[before.start as usize..before.end as usize],
'-',
);
self.print_tokens(&self.after[after.start as usize..after.end as usize], '+');
}
fn finish(mut self) -> Self::Out {
self.flush();
self.dst
}
}
/// Computes a diff between two strings, returning a vector of old and new row
/// ranges.
pub fn line_diff(old_text: &str, new_text: &str) -> Vec<(Range<u32>, Range<u32>)> {
@@ -327,4 +441,30 @@ mod tests {
let patch = unified_diff(old_text, new_text);
assert_eq!(apply_diff_patch(old_text, &patch).unwrap(), new_text);
}
#[test]
fn test_unified_diff_with_offsets() {
let old_text = "foo\nbar\nbaz\n";
let new_text = "foo\nBAR\nbaz\n";
let expected_diff_body = " foo\n-bar\n+BAR\n baz\n";
let diff_no_offset = unified_diff(old_text, new_text);
assert_eq!(
diff_no_offset,
format!("@@ -1,3 +1,3 @@\n{}", expected_diff_body)
);
let diff_with_offset = unified_diff_with_offsets(old_text, new_text, 9, 11);
assert_eq!(
diff_with_offset,
format!("@@ -10,3 +12,3 @@\n{}", expected_diff_body)
);
let diff_with_offset = unified_diff_with_offsets(old_text, new_text, 99, 104);
assert_eq!(
diff_with_offset,
format!("@@ -100,3 +105,3 @@\n{}", expected_diff_body)
);
}
}

View File

@@ -5756,6 +5756,7 @@ impl Repository {
cx.spawn(|_: &mut AsyncApp| async move { rx.await? })
}
fn load_blob_content(&mut self, oid: Oid, cx: &App) -> Task<Result<String>> {
let repository_id = self.snapshot.id;
let rx = self.send_job(None, move |state, _| async move {

View File

@@ -56,6 +56,7 @@ merge_from_overwrites!(
std::sync::Arc<str>,
gpui::SharedString,
std::path::PathBuf,
std::sync::Arc<std::path::Path>,
gpui::Modifiers,
gpui::FontFeatures,
gpui::FontWeight

View File

@@ -1,4 +1,4 @@
use std::num::NonZeroU32;
use std::{num::NonZeroU32, path::Path};
use collections::{HashMap, HashSet};
use gpui::{Modifiers, SharedString};
@@ -167,6 +167,8 @@ pub struct EditPredictionSettingsContent {
/// Whether edit predictions are enabled in the assistant prompt editor.
/// This has no effect if globally disabled.
pub enabled_in_text_threads: Option<bool>,
/// The directory where manually captured edit prediction examples are stored.
pub examples_dir: Option<Arc<Path>>,
}
#[with_fallible_options]