Compare commits

...

6 Commits

Author SHA1 Message Date
Agus Zubiaga
2acd48f439 Checkpoint: Refactoring blink manager 2025-11-25 17:24:27 -03:00
Agus Zubiaga
1f9d5ef684 Always display terminal cursor when blinking is disabled (#43487)
Fixes an issue where the terminal cursor wouldn't always be displayed in
the default `blink: "terminal_controlled"` mode unless the terminal
requested cursor blinking.

Release Notes:

- N/A
2025-11-25 19:49:16 +00:00
Peter Tripp
83f0a3fd13 Redact sensitive environment variables in LSP Logs: Server Info (#43480)
Follow-up to: 
- https://github.com/zed-industries/zed/pull/43436
- https://github.com/zed-industries/zed/pull/42831

The changes in #42831 resulted in a regression where environment
variables in the Server Info view were no longer redact. The changes in
#43436 were insufficient as I was still seeing sensitive values in
Nightly e6fe95b4f2 (which includes
#43436).

CC: @SomeoneToIgnore (Hi! 👋 Thanks for keeping this redaction
functionality alive)

Release Notes:

- N/A
2025-11-25 21:00:31 +02:00
Ben Kunkle
7ecbf8cf60 zeta2: Remove expected context from evals (#43430)
Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...
2025-11-25 13:44:04 -05:00
Max Brunsfeld
fb0fcd86fd Add missing update of last_prediction_refresh (#43483)
Fixes a regression introduced in
https://github.com/zed-industries/zed/pull/43284 where edit predictions
stopped being throttled at all 😬

Release Notes:

- N/A

Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-11-25 10:43:46 -08:00
Max Brunsfeld
36708c910a Separate experimental edit prediction jumps feature from the Sweep AI prediction provider (#43481)
Release Notes:

- N/A

---------

Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-11-25 10:36:45 -08:00
13 changed files with 511 additions and 805 deletions

View File

@@ -1,8 +1,8 @@
use gpui::Context;
use gpui::{Context, FocusHandle};
use settings::SettingsStore;
use smol::Timer;
use std::time::Duration;
use ui::App;
use ui::{App, Window};
pub struct BlinkManager {
blink_interval: Duration,
@@ -11,21 +11,34 @@ pub struct BlinkManager {
blinking_paused: bool,
/// Whether the cursor should be visibly rendered or not.
visible: bool,
/// Whether the blinking currently enabled.
enabled: bool,
/// The focus handle to use to determine if the cursor should be blinking.
focus_handle: FocusHandle,
/// Whether the blinking is enabled in the settings.
blink_enabled_in_settings: fn(&App) -> bool,
is_enabled: Box<dyn Fn(&App) -> bool>,
}
impl BlinkManager {
pub fn new(
blink_interval: Duration,
blink_enabled_in_settings: fn(&App) -> bool,
focus_handle: FocusHandle,
is_enabled: impl Fn(&App) -> bool + 'static,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
// Make sure we blink the cursors if the setting is re-enabled
cx.observe_global::<SettingsStore>(move |this, cx| {
this.blink_cursors(this.blink_epoch, cx)
cx.observe_global_in::<SettingsStore>(window, move |this, window, cx| {
this.refresh(window, cx);
})
.detach();
cx.on_focus(&focus_handle, window, move |this, window, cx| {
this.visible = false;
this.refresh(window, cx);
})
.detach();
cx.on_blur(&focus_handle, window, move |this, _window, _cx| {
this.visible = false;
})
.detach();
@@ -34,48 +47,64 @@ impl BlinkManager {
blink_epoch: 0,
blinking_paused: false,
visible: true,
enabled: false,
blink_enabled_in_settings,
focus_handle,
is_enabled: Box::new(is_enabled),
}
}
pub fn refresh(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.blink_cursors(self.blink_epoch, window, cx)
}
fn next_blink_epoch(&mut self) -> usize {
self.blink_epoch += 1;
self.blink_epoch
}
pub fn pause_blinking(&mut self, cx: &mut Context<Self>) {
pub fn pause_blinking(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.show_cursor(cx);
let epoch = self.next_blink_epoch();
let interval = self.blink_interval;
cx.spawn(async move |this, cx| {
cx.spawn_in(window, async move |this, cx| {
Timer::after(interval).await;
this.update(cx, |this, cx| this.resume_cursor_blinking(epoch, cx))
this.update_in(cx, |this, window, cx| {
this.resume_cursor_blinking(epoch, window, cx)
})
})
.detach();
}
fn resume_cursor_blinking(&mut self, epoch: usize, cx: &mut Context<Self>) {
fn resume_cursor_blinking(
&mut self,
epoch: usize,
window: &mut Window,
cx: &mut Context<Self>,
) {
if epoch == self.blink_epoch {
self.blinking_paused = false;
self.blink_cursors(epoch, cx);
self.blink_cursors(epoch, window, cx);
}
}
fn blink_cursors(&mut self, epoch: usize, cx: &mut Context<Self>) {
if (self.blink_enabled_in_settings)(cx) {
if epoch == self.blink_epoch && self.enabled && !self.blinking_paused {
fn blink_cursors(&mut self, epoch: usize, window: &mut Window, cx: &mut Context<Self>) {
if (self.is_enabled)(cx) {
if epoch == self.blink_epoch
&& self.focus_handle.is_focused(window)
&& !self.blinking_paused
{
self.visible = !self.visible;
cx.notify();
let epoch = self.next_blink_epoch();
let interval = self.blink_interval;
cx.spawn(async move |this, cx| {
cx.spawn_in(window, async move |this, cx| {
Timer::after(interval).await;
if let Some(this) = this.upgrade() {
this.update(cx, |this, cx| this.blink_cursors(epoch, cx))
.ok();
this.update_in(cx, |this, window, cx| {
this.blink_cursors(epoch, window, cx)
})
.ok();
}
})
.detach();
@@ -92,25 +121,6 @@ impl BlinkManager {
}
}
/// Enable the blinking of the cursor.
pub fn enable(&mut self, cx: &mut Context<Self>) {
if self.enabled {
return;
}
self.enabled = true;
// Set cursors as invisible and start blinking: this causes cursors
// to be visible during the next render.
self.visible = false;
self.blink_cursors(self.blink_epoch, cx);
}
/// Disable the blinking of the cursor.
pub fn disable(&mut self, _cx: &mut Context<Self>) {
self.visible = false;
self.enabled = false;
}
pub fn visible(&self) -> bool {
self.visible
}

View File

@@ -1886,16 +1886,26 @@ impl Editor {
let selections = SelectionsCollection::new();
let focus_handle = cx.focus_handle();
let blink_manager = cx.new(|cx| {
let mut blink_manager = BlinkManager::new(
CURSOR_BLINK_INTERVAL,
|cx| EditorSettings::get_global(cx).cursor_blink,
cx,
);
if is_minimap {
blink_manager.disable(cx);
BlinkManager::new(
CURSOR_BLINK_INTERVAL,
focus_handle.clone(),
|_cx| false,
window,
cx,
)
} else {
BlinkManager::new(
CURSOR_BLINK_INTERVAL,
focus_handle.clone(),
|cx| EditorSettings::get_global(cx).cursor_blink,
window,
cx,
)
}
blink_manager
});
let soft_wrap_mode_override =
@@ -2095,7 +2105,6 @@ impl Editor {
let inlay_hint_settings =
inlay_hint_settings(selections.newest_anchor().head(), &buffer_snapshot, cx);
let focus_handle = cx.focus_handle();
if !is_minimap {
cx.on_focus(&focus_handle, window, Self::handle_focus)
.detach();
@@ -2292,15 +2301,7 @@ impl Editor {
cx.observe_global_in::<SettingsStore>(window, Self::settings_changed),
observe_buffer_font_size_adjustment(cx, |_, cx| cx.notify()),
cx.observe_window_activation(window, |editor, window, cx| {
let active = window.is_window_active();
editor.blink_manager.update(cx, |blink_manager, cx| {
if active {
blink_manager.enable(cx);
} else {
blink_manager.disable(cx);
}
});
if active {
if window.is_window_active() {
editor.show_mouse_cursor(cx);
}
}),
@@ -3321,7 +3322,9 @@ impl Editor {
}
}
self.blink_manager.update(cx, BlinkManager::pause_blinking);
self.blink_manager.update(cx, |blink_manager, cx| {
blink_manager.pause_blinking(window, cx)
});
cx.emit(EditorEvent::SelectionsChanged { local });
let selections = &self.selections.disjoint_anchors_arc();
@@ -22161,7 +22164,6 @@ impl Editor {
blame.update(cx, GitBlame::focus)
}
self.blink_manager.update(cx, BlinkManager::enable);
self.show_cursor_names(window, cx);
self.buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction(cx);
@@ -22209,7 +22211,6 @@ impl Editor {
}
pub fn handle_blur(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.blink_manager.update(cx, BlinkManager::disable);
self.buffer
.update(cx, |buffer, cx| buffer.remove_active_selections(cx));

View File

@@ -340,11 +340,11 @@ impl LspLogView {
* Configuration: {CONFIGURATION}",
NAME = info.status.name,
ID = info.id,
BINARY = info.status.binary.as_ref().map_or_else(
|| "Unknown".to_string(),
|binary| serde_json::to_string_pretty(binary)
.unwrap_or_else(|e| format!("Failed to serialize binary info: {e:#}"))
),
BINARY = info
.status
.binary
.as_ref()
.map_or_else(|| "Unknown".to_string(), |binary| format!("{:#?}", binary)),
WORKSPACE_FOLDERS = info
.status
.workspace_folders

View File

@@ -234,17 +234,25 @@ impl TerminalView {
let scroll_handle = TerminalScrollHandle::new(terminal.read(cx));
let blink_manager = cx.new(|cx| {
BlinkManager::new(
CURSOR_BLINK_INTERVAL,
|cx| {
!matches!(
TerminalSettings::get_global(cx).blinking,
TerminalBlink::Off
)
},
cx,
)
let blink_manager = cx.new({
let weak_this = cx.weak_entity();
let focus_handle = focus_handle.clone();
move |cx| {
BlinkManager::new(
CURSOR_BLINK_INTERVAL,
focus_handle,
move |cx| match TerminalSettings::get_global(cx).blinking {
TerminalBlink::Off => false,
TerminalBlink::On => true,
TerminalBlink::TerminalControlled => weak_this
.read_with(cx, |this, _cx| this.blinking_terminal_enabled)
.unwrap_or(false),
},
window,
cx,
)
}
});
let _subscriptions = vec![
@@ -434,11 +442,6 @@ impl TerminalView {
let breadcrumb_visibility_changed = self.show_breadcrumbs != settings.toolbar.breadcrumbs;
self.show_breadcrumbs = settings.toolbar.breadcrumbs;
let should_blink = match settings.blinking {
TerminalBlink::Off => false,
TerminalBlink::On => true,
TerminalBlink::TerminalControlled => self.blinking_terminal_enabled,
};
let new_cursor_shape = settings.cursor_shape;
let old_cursor_shape = self.cursor_shape;
if old_cursor_shape != new_cursor_shape {
@@ -448,15 +451,6 @@ impl TerminalView {
});
}
self.blink_manager.update(
cx,
if should_blink {
BlinkManager::enable
} else {
BlinkManager::disable
},
);
if breadcrumb_visibility_changed {
cx.emit(ItemEvent::UpdateBreadcrumbs);
}
@@ -649,14 +643,17 @@ impl TerminalView {
// When focused, check blinking settings and blink manager state
match TerminalSettings::get_global(cx).blinking {
TerminalBlink::Off => true,
TerminalBlink::On | TerminalBlink::TerminalControlled => {
self.blink_manager.read(cx).visible()
TerminalBlink::TerminalControlled => {
!self.blinking_terminal_enabled || self.blink_manager.read(cx).visible()
}
TerminalBlink::On => self.blink_manager.read(cx).visible(),
}
}
pub fn pause_cursor_blinking(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
self.blink_manager.update(cx, BlinkManager::pause_blinking);
pub fn pause_cursor_blinking(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.blink_manager.update(cx, |blink_manager, cx| {
blink_manager.pause_blinking(window, cx)
});
}
pub fn terminal(&self) -> &Entity<Terminal> {
@@ -872,21 +869,9 @@ fn subscribe_for_terminal_events(
Event::BlinkChanged(blinking) => {
terminal_view.blinking_terminal_enabled = *blinking;
// If in terminal-controlled mode and focused, update blink manager
if matches!(
TerminalSettings::get_global(cx).blinking,
TerminalBlink::TerminalControlled
) && terminal_view.focus_handle.is_focused(window)
{
terminal_view.blink_manager.update(cx, |manager, cx| {
if *blinking {
manager.enable(cx);
} else {
manager.disable(cx);
}
});
}
terminal_view
.blink_manager
.update(cx, |this, cx| this.refresh(window, cx));
}
Event::TitleChanged => {
@@ -1012,22 +997,11 @@ impl TerminalView {
terminal.focus_in();
});
let should_blink = match TerminalSettings::get_global(cx).blinking {
TerminalBlink::Off => false,
TerminalBlink::On => true,
TerminalBlink::TerminalControlled => self.blinking_terminal_enabled,
};
if should_blink {
self.blink_manager.update(cx, BlinkManager::enable);
}
window.invalidate_character_coordinates();
cx.notify();
}
fn focus_out(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
self.blink_manager.update(cx, BlinkManager::disable);
self.terminal.update(cx, |terminal, _| {
terminal.focus_out();
terminal.set_cursor_shape(CursorShape::Hollow);

View File

@@ -77,7 +77,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
) -> bool {
let zeta = self.zeta.read(cx);
if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
zeta.sweep_api_token.is_some()
zeta.sweep_ai.api_token.is_some()
} else {
true
}

View File

@@ -1,10 +1,269 @@
use std::fmt;
use std::{path::Path, sync::Arc};
use anyhow::{Context as _, Result};
use cloud_llm_client::predict_edits_v3::Event;
use futures::AsyncReadExt as _;
use gpui::{
App, AppContext as _, Entity, Task,
http_client::{self, AsyncBody, Method},
};
use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
use lsp::DiagnosticSeverity;
use project::{Project, ProjectPath};
use serde::{Deserialize, Serialize};
use std::{
collections::VecDeque,
fmt::{self, Write as _},
ops::Range,
path::Path,
sync::Arc,
time::Instant,
};
use util::ResultExt as _;
use crate::{EditPrediction, EditPredictionId, EditPredictionInputs};
const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
pub struct SweepAi {
pub api_token: Option<String>,
pub debug_info: Arc<str>,
}
impl SweepAi {
pub fn new(cx: &App) -> Self {
SweepAi {
api_token: std::env::var("SWEEP_AI_TOKEN")
.context("No SWEEP_AI_TOKEN environment variable set")
.log_err(),
debug_info: debug_info(cx),
}
}
pub fn request_prediction_with_sweep(
&self,
project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
snapshot: BufferSnapshot,
position: language::Anchor,
events: Vec<Arc<Event>>,
recent_paths: &VecDeque<ProjectPath>,
diagnostic_search_range: Range<Point>,
cx: &mut App,
) -> Task<Result<Option<EditPrediction>>> {
let debug_info = self.debug_info.clone();
let Some(api_token) = self.api_token.clone() else {
return Task::ready(Ok(None));
};
let full_path: Arc<Path> = snapshot
.file()
.map(|file| file.full_path(cx))
.unwrap_or_else(|| "untitled".into())
.into();
let project_file = project::File::from_dyn(snapshot.file());
let repo_name = project_file
.map(|file| file.worktree.read(cx).root_name_str())
.unwrap_or("untitled")
.into();
let offset = position.to_offset(&snapshot);
let recent_buffers = recent_paths.iter().cloned();
let http_client = cx.http_client();
let recent_buffer_snapshots = recent_buffers
.filter_map(|project_path| {
let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
if active_buffer == &buffer {
None
} else {
Some(buffer.read(cx).snapshot())
}
})
.take(3)
.collect::<Vec<_>>();
let cursor_point = position.to_point(&snapshot);
let buffer_snapshotted_at = Instant::now();
let result = cx.background_spawn(async move {
let text = snapshot.text();
let mut recent_changes = String::new();
for event in &events {
write_event(event.as_ref(), &mut recent_changes).unwrap();
}
let mut file_chunks = recent_buffer_snapshots
.into_iter()
.map(|snapshot| {
let end_point = Point::new(30, 0).min(snapshot.max_point());
FileChunk {
content: snapshot.text_for_range(Point::zero()..end_point).collect(),
file_path: snapshot
.file()
.map(|f| f.path().as_unix_str())
.unwrap_or("untitled")
.to_string(),
start_line: 0,
end_line: end_point.row as usize,
timestamp: snapshot.file().and_then(|file| {
Some(
file.disk_state()
.mtime()?
.to_seconds_and_nanos_for_persistence()?
.0,
)
}),
}
})
.collect::<Vec<_>>();
let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
let mut diagnostic_content = String::new();
let mut diagnostic_count = 0;
for entry in diagnostic_entries {
let start_point: Point = entry.range.start;
let severity = match entry.diagnostic.severity {
DiagnosticSeverity::ERROR => "error",
DiagnosticSeverity::WARNING => "warning",
DiagnosticSeverity::INFORMATION => "info",
DiagnosticSeverity::HINT => "hint",
_ => continue,
};
diagnostic_count += 1;
writeln!(
&mut diagnostic_content,
"{} at line {}: {}",
severity,
start_point.row + 1,
entry.diagnostic.message
)?;
}
if !diagnostic_content.is_empty() {
file_chunks.push(FileChunk {
file_path: format!("Diagnostics for {}", full_path.display()),
start_line: 0,
end_line: diagnostic_count,
content: diagnostic_content,
timestamp: None,
});
}
let request_body = AutocompleteRequest {
debug_info,
repo_name,
file_path: full_path.clone(),
file_contents: text.clone(),
original_file_contents: text,
cursor_position: offset,
recent_changes: recent_changes.clone(),
changes_above_cursor: true,
multiple_suggestions: false,
branch: None,
file_chunks,
retrieval_chunks: vec![],
recent_user_actions: vec![],
// TODO
privacy_mode_enabled: false,
};
let mut buf: Vec<u8> = Vec::new();
let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
serde_json::to_writer(writer, &request_body)?;
let body: AsyncBody = buf.into();
let inputs = EditPredictionInputs {
events,
included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
path: full_path.clone(),
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
start_line: cloud_llm_client::predict_edits_v3::Line(0),
text: request_body.file_contents.into(),
}],
}],
cursor_point: cloud_llm_client::predict_edits_v3::Point {
column: cursor_point.column,
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
},
cursor_path: full_path.clone(),
};
let request = http_client::Request::builder()
.uri(SWEEP_API_URL)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_token))
.header("Connection", "keep-alive")
.header("Content-Encoding", "br")
.method(Method::POST)
.body(body)?;
let mut response = http_client.send(request).await?;
let mut body: Vec<u8> = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let response_received_at = Instant::now();
if !response.status().is_success() {
anyhow::bail!(
"Request failed with status: {:?}\nBody: {}",
response.status(),
String::from_utf8_lossy(&body),
);
};
let response: AutocompleteResponse = serde_json::from_slice(&body)?;
let old_text = snapshot
.text_for_range(response.start_index..response.end_index)
.collect::<String>();
let edits = language::text_diff(&old_text, &response.completion)
.into_iter()
.map(|(range, text)| {
(
snapshot.anchor_after(response.start_index + range.start)
..snapshot.anchor_before(response.start_index + range.end),
text,
)
})
.collect::<Vec<_>>();
anyhow::Ok((
response.autocomplete_id,
edits,
snapshot,
response_received_at,
inputs,
))
});
let buffer = active_buffer.clone();
cx.spawn(async move |cx| {
let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
anyhow::Ok(
EditPrediction::new(
EditPredictionId(id.into()),
&buffer,
&old_snapshot,
edits.into(),
buffer_snapshotted_at,
response_received_at,
inputs,
cx,
)
.await,
)
})
}
}
#[derive(Debug, Clone, Serialize)]
pub struct AutocompleteRequest {
struct AutocompleteRequest {
pub debug_info: Arc<str>,
pub repo_name: String,
pub branch: Option<String>,
@@ -22,7 +281,7 @@ pub struct AutocompleteRequest {
}
#[derive(Debug, Clone, Serialize)]
pub struct FileChunk {
struct FileChunk {
pub file_path: String,
pub start_line: usize,
pub end_line: usize,
@@ -31,7 +290,7 @@ pub struct FileChunk {
}
#[derive(Debug, Clone, Serialize)]
pub struct RetrievalChunk {
struct RetrievalChunk {
pub file_path: String,
pub start_line: usize,
pub end_line: usize,
@@ -40,7 +299,7 @@ pub struct RetrievalChunk {
}
#[derive(Debug, Clone, Serialize)]
pub struct UserAction {
struct UserAction {
pub action_type: ActionType,
pub line_number: usize,
pub offset: usize,
@@ -51,7 +310,7 @@ pub struct UserAction {
#[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ActionType {
enum ActionType {
CursorMovement,
InsertChar,
DeleteChar,
@@ -60,7 +319,7 @@ pub enum ActionType {
}
#[derive(Debug, Clone, Deserialize)]
pub struct AutocompleteResponse {
struct AutocompleteResponse {
pub autocomplete_id: String,
pub start_index: usize,
pub end_index: usize,
@@ -80,7 +339,7 @@ pub struct AutocompleteResponse {
#[allow(dead_code)]
#[derive(Debug, Clone, Deserialize)]
pub struct AdditionalCompletion {
struct AdditionalCompletion {
pub start_index: usize,
pub end_index: usize,
pub completion: String,
@@ -90,7 +349,7 @@ pub struct AdditionalCompletion {
pub finish_reason: Option<String>,
}
pub(crate) fn write_event(
fn write_event(
event: &cloud_llm_client::predict_edits_v3::Event,
f: &mut impl fmt::Write,
) -> fmt::Result {
@@ -115,7 +374,7 @@ pub(crate) fn write_event(
}
}
pub(crate) fn debug_info(cx: &gpui::App) -> Arc<str> {
fn debug_info(cx: &gpui::App) -> Arc<str> {
format!(
"Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
version = release_channel::AppVersion::global(cx),

View File

@@ -30,7 +30,6 @@ use language::{
};
use language::{BufferSnapshot, OffsetRangeExt};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use lsp::DiagnosticSeverity;
use open_ai::FunctionDefinition;
use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
use release_channel::AppVersion;
@@ -42,7 +41,6 @@ use std::collections::{VecDeque, hash_map};
use telemetry_events::EditPredictionRating;
use workspace::Workspace;
use std::fmt::Write as _;
use std::ops::Range;
use std::path::Path;
use std::rc::Rc;
@@ -80,6 +78,7 @@ use crate::rate_prediction_modal::{
NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
ThumbsUpActivePrediction,
};
use crate::sweep_ai::SweepAi;
use crate::zeta1::request_prediction_with_zeta1;
pub use provider::ZetaEditPredictionProvider;
@@ -171,7 +170,7 @@ impl FeatureFlag for Zeta2FeatureFlag {
const NAME: &'static str = "zeta2";
fn enabled_for_staff() -> bool {
false
true
}
}
@@ -192,8 +191,7 @@ pub struct Zeta {
#[cfg(feature = "eval-support")]
eval_cache: Option<Arc<dyn EvalCache>>,
edit_prediction_model: ZetaEditPredictionModel,
sweep_api_token: Option<String>,
sweep_ai_debug_info: Arc<str>,
sweep_ai: SweepAi,
data_collection_choice: DataCollectionChoice,
rejected_predictions: Vec<EditPredictionRejection>,
reject_predictions_tx: mpsc::UnboundedSender<()>,
@@ -202,7 +200,7 @@ pub struct Zeta {
rated_predictions: HashSet<EditPredictionId>,
}
#[derive(Default, PartialEq, Eq)]
#[derive(Copy, Clone, Default, PartialEq, Eq)]
pub enum ZetaEditPredictionModel {
#[default]
Zeta1,
@@ -499,11 +497,8 @@ impl Zeta {
#[cfg(feature = "eval-support")]
eval_cache: None,
edit_prediction_model: ZetaEditPredictionModel::Zeta2,
sweep_api_token: std::env::var("SWEEP_AI_TOKEN")
.context("No SWEEP_AI_TOKEN environment variable set")
.log_err(),
sweep_ai: SweepAi::new(cx),
data_collection_choice,
sweep_ai_debug_info: sweep_ai::debug_info(cx),
rejected_predictions: Vec::new(),
reject_predictions_debounce_task: None,
reject_predictions_tx: reject_tx,
@@ -517,7 +512,7 @@ impl Zeta {
}
pub fn has_sweep_api_token(&self) -> bool {
self.sweep_api_token.is_some()
self.sweep_ai.api_token.is_some()
}
#[cfg(feature = "eval-support")]
@@ -643,7 +638,9 @@ impl Zeta {
}
}
project::Event::DiagnosticsUpdated { .. } => {
self.refresh_prediction_from_diagnostics(project, cx);
if cx.has_flag::<Zeta2FeatureFlag>() {
self.refresh_prediction_from_diagnostics(project, cx);
}
}
_ => (),
}
@@ -1126,7 +1123,6 @@ impl Zeta {
zeta_project.next_pending_prediction_id += 1;
let last_request = zeta_project.last_prediction_refresh;
// TODO report cancelled requests like in zeta1
let task = cx.spawn(async move |this, cx| {
if let Some((last_entity, last_timestamp)) = last_request
&& throttle_entity == last_entity
@@ -1136,6 +1132,12 @@ impl Zeta {
cx.background_executor().timer(timeout).await;
}
this.update(cx, |this, cx| {
this.get_or_init_zeta_project(&project, cx)
.last_prediction_refresh = Some((throttle_entity, Instant::now()));
})
.ok();
let edit_prediction_id = do_refresh(this.clone(), cx).await.log_err().flatten();
// When a prediction completes, remove it from the pending list, and cancel
@@ -1183,249 +1185,77 @@ impl Zeta {
position: language::Anchor,
cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> {
match self.edit_prediction_model {
ZetaEditPredictionModel::Zeta1 => {
request_prediction_with_zeta1(self, project, active_buffer, position, cx)
}
ZetaEditPredictionModel::Zeta2 => {
self.request_prediction_with_zeta2(project, active_buffer, position, cx)
}
ZetaEditPredictionModel::Sweep => {
self.request_prediction_with_sweep(project, active_buffer, position, true, cx)
}
}
self.request_prediction_internal(
project.clone(),
active_buffer.clone(),
position,
cx.has_flag::<Zeta2FeatureFlag>(),
cx,
)
}
fn request_prediction_with_sweep(
fn request_prediction_internal(
&mut self,
project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
project: Entity<Project>,
active_buffer: Entity<Buffer>,
position: language::Anchor,
allow_jump: bool,
cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> {
let snapshot = active_buffer.read(cx).snapshot();
let debug_info = self.sweep_ai_debug_info.clone();
let Some(api_token) = self.sweep_api_token.clone() else {
return Task::ready(Ok(None));
};
let full_path: Arc<Path> = snapshot
.file()
.map(|file| file.full_path(cx))
.unwrap_or_else(|| "untitled".into())
.into();
let project_file = project::File::from_dyn(snapshot.file());
let repo_name = project_file
.map(|file| file.worktree.read(cx).root_name_str())
.unwrap_or("untitled")
.into();
let offset = position.to_offset(&snapshot);
let project_state = self.get_or_init_zeta_project(project, cx);
let events = project_state.events(cx);
let has_events = !events.is_empty();
let recent_buffers = project_state.recent_paths.iter().cloned();
let http_client = cx.http_client();
let recent_buffer_snapshots = recent_buffers
.filter_map(|project_path| {
let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
if active_buffer == &buffer {
None
} else {
Some(buffer.read(cx).snapshot())
}
})
.take(3)
.collect::<Vec<_>>();
const DIAGNOSTIC_LINES_RANGE: u32 = 20;
self.get_or_init_zeta_project(&project, cx);
let zeta_project = self.projects.get(&project.entity_id()).unwrap();
let events = zeta_project.events(cx);
let has_events = !events.is_empty();
let snapshot = active_buffer.read(cx).snapshot();
let cursor_point = position.to_point(&snapshot);
let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
let diagnostic_search_range =
Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
let buffer_snapshotted_at = Instant::now();
let result = cx.background_spawn({
let snapshot = snapshot.clone();
let diagnostic_search_range = diagnostic_search_range.clone();
async move {
let text = snapshot.text();
let mut recent_changes = String::new();
for event in &events {
sweep_ai::write_event(event.as_ref(), &mut recent_changes).unwrap();
}
let mut file_chunks = recent_buffer_snapshots
.into_iter()
.map(|snapshot| {
let end_point = Point::new(30, 0).min(snapshot.max_point());
sweep_ai::FileChunk {
content: snapshot.text_for_range(Point::zero()..end_point).collect(),
file_path: snapshot
.file()
.map(|f| f.path().as_unix_str())
.unwrap_or("untitled")
.to_string(),
start_line: 0,
end_line: end_point.row as usize,
timestamp: snapshot.file().and_then(|file| {
Some(
file.disk_state()
.mtime()?
.to_seconds_and_nanos_for_persistence()?
.0,
)
}),
}
})
.collect::<Vec<_>>();
let diagnostic_entries =
snapshot.diagnostics_in_range(diagnostic_search_range, false);
let mut diagnostic_content = String::new();
let mut diagnostic_count = 0;
for entry in diagnostic_entries {
let start_point: Point = entry.range.start;
let severity = match entry.diagnostic.severity {
DiagnosticSeverity::ERROR => "error",
DiagnosticSeverity::WARNING => "warning",
DiagnosticSeverity::INFORMATION => "info",
DiagnosticSeverity::HINT => "hint",
_ => continue,
};
diagnostic_count += 1;
writeln!(
&mut diagnostic_content,
"{} at line {}: {}",
severity,
start_point.row + 1,
entry.diagnostic.message
)?;
}
if !diagnostic_content.is_empty() {
file_chunks.push(sweep_ai::FileChunk {
file_path: format!("Diagnostics for {}", full_path.display()),
start_line: 0,
end_line: diagnostic_count,
content: diagnostic_content,
timestamp: None,
});
}
let request_body = sweep_ai::AutocompleteRequest {
debug_info,
repo_name,
file_path: full_path.clone(),
file_contents: text.clone(),
original_file_contents: text,
cursor_position: offset,
recent_changes: recent_changes.clone(),
changes_above_cursor: true,
multiple_suggestions: false,
branch: None,
file_chunks,
retrieval_chunks: vec![],
recent_user_actions: vec![],
// TODO
privacy_mode_enabled: false,
};
let mut buf: Vec<u8> = Vec::new();
let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
serde_json::to_writer(writer, &request_body)?;
let body: AsyncBody = buf.into();
let inputs = EditPredictionInputs {
events,
included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
path: full_path.clone(),
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
start_line: cloud_llm_client::predict_edits_v3::Line(0),
text: request_body.file_contents.into(),
}],
}],
cursor_point: cloud_llm_client::predict_edits_v3::Point {
column: cursor_point.column,
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
},
cursor_path: full_path.clone(),
};
const SWEEP_API_URL: &str =
"https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
let request = http_client::Request::builder()
.uri(SWEEP_API_URL)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_token))
.header("Connection", "keep-alive")
.header("Content-Encoding", "br")
.method(Method::POST)
.body(body)?;
let mut response = http_client.send(request).await?;
let mut body: Vec<u8> = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let response_received_at = Instant::now();
if !response.status().is_success() {
anyhow::bail!(
"Request failed with status: {:?}\nBody: {}",
response.status(),
String::from_utf8_lossy(&body),
);
};
let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
let old_text = snapshot
.text_for_range(response.start_index..response.end_index)
.collect::<String>();
let edits = language::text_diff(&old_text, &response.completion)
.into_iter()
.map(|(range, text)| {
(
snapshot.anchor_after(response.start_index + range.start)
..snapshot.anchor_before(response.start_index + range.end),
text,
)
})
.collect::<Vec<_>>();
anyhow::Ok((
response.autocomplete_id,
edits,
snapshot,
response_received_at,
inputs,
))
}
});
let buffer = active_buffer.clone();
let project = project.clone();
let active_buffer = active_buffer.clone();
let task = match self.edit_prediction_model {
ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
self,
&project,
&active_buffer,
snapshot.clone(),
position,
events,
cx,
),
ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
&project,
&active_buffer,
snapshot.clone(),
position,
events,
cx,
),
ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
&project,
&active_buffer,
snapshot.clone(),
position,
events,
&zeta_project.recent_paths,
diagnostic_search_range.clone(),
cx,
),
};
cx.spawn(async move |this, cx| {
let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
let prediction = task
.await?
.filter(|prediction| !prediction.edits.is_empty());
if edits.is_empty() {
if prediction.is_none() && allow_jump {
let cursor_point = position.to_point(&snapshot);
if has_events
&& allow_jump
&& let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
active_buffer,
active_buffer.clone(),
&snapshot,
diagnostic_search_range,
cursor_point,
@@ -1436,9 +1266,9 @@ impl Zeta {
{
return this
.update(cx, |this, cx| {
this.request_prediction_with_sweep(
&project,
&jump_buffer,
this.request_prediction_internal(
project,
jump_buffer,
jump_position,
false,
cx,
@@ -1450,19 +1280,7 @@ impl Zeta {
return anyhow::Ok(None);
}
anyhow::Ok(
EditPrediction::new(
EditPredictionId(id.into()),
&buffer,
&old_snapshot,
edits.into(),
buffer_snapshotted_at,
response_received_at,
inputs,
cx,
)
.await,
)
Ok(prediction)
})
}
@@ -1549,7 +1367,9 @@ impl Zeta {
&mut self,
project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
active_snapshot: BufferSnapshot,
position: language::Anchor,
events: Vec<Arc<Event>>,
cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> {
let project_state = self.projects.get(&project.entity_id());
@@ -1561,7 +1381,6 @@ impl Zeta {
.map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
});
let options = self.options.clone();
let active_snapshot = active_buffer.read(cx).snapshot();
let buffer_snapshotted_at = Instant::now();
let Some(excerpt_path) = active_snapshot
.file()
@@ -1579,10 +1398,6 @@ impl Zeta {
.collect::<Vec<_>>();
let debug_tx = self.debug_tx.clone();
let events = project_state
.map(|state| state.events(cx))
.unwrap_or_default();
let diagnostics = active_snapshot.diagnostic_sets().clone();
let file = active_buffer.read(cx).file();

View File

@@ -32,19 +32,17 @@ pub(crate) fn request_prediction_with_zeta1(
zeta: &mut Zeta,
project: &Entity<Project>,
buffer: &Entity<Buffer>,
snapshot: BufferSnapshot,
position: language::Anchor,
events: Vec<Arc<Event>>,
cx: &mut Context<Zeta>,
) -> Task<Result<Option<EditPrediction>>> {
let buffer = buffer.clone();
let buffer_snapshotted_at = Instant::now();
let snapshot = buffer.read(cx).snapshot();
let client = zeta.client.clone();
let llm_token = zeta.llm_token.clone();
let app_version = AppVersion::global(cx);
let zeta_project = zeta.get_or_init_zeta_project(project, cx);
let events = Arc::new(zeta_project.events(cx));
let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
let can_collect_file = zeta.can_collect_file(project, file, cx);
let git_info = if can_collect_file {

View File

@@ -42,43 +42,48 @@ actions!(
pub fn init(cx: &mut App) {
cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
workspace.register_action(move |workspace, _: &OpenZeta2Inspector, window, cx| {
let project = workspace.project();
workspace.split_item(
SplitDirection::Right,
Box::new(cx.new(|cx| {
Zeta2Inspector::new(
&project,
workspace.client(),
workspace.user_store(),
window,
cx,
)
})),
window,
cx,
);
});
})
.detach();
cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
workspace.register_action(move |workspace, _: &OpenZeta2ContextView, window, cx| {
let project = workspace.project();
workspace.split_item(
SplitDirection::Right,
Box::new(cx.new(|cx| {
Zeta2ContextView::new(
project.clone(),
workspace.client(),
workspace.user_store(),
window,
cx,
)
})),
window,
cx,
);
workspace.register_action_renderer(|div, _, _, cx| {
let has_flag = cx.has_flag::<Zeta2FeatureFlag>();
div.when(has_flag, |div| {
div.on_action(
cx.listener(move |workspace, _: &OpenZeta2Inspector, window, cx| {
let project = workspace.project();
workspace.split_item(
SplitDirection::Right,
Box::new(cx.new(|cx| {
Zeta2Inspector::new(
&project,
workspace.client(),
workspace.user_store(),
window,
cx,
)
})),
window,
cx,
)
}),
)
.on_action(cx.listener(
move |workspace, _: &OpenZeta2ContextView, window, cx| {
let project = workspace.project();
workspace.split_item(
SplitDirection::Right,
Box::new(cx.new(|cx| {
Zeta2ContextView::new(
project.clone(),
workspace.client(),
workspace.user_store(),
window,
cx,
)
})),
window,
cx,
);
},
))
})
});
})
.detach();

View File

@@ -1,5 +1,5 @@
use std::{
collections::{BTreeSet, HashMap},
collections::HashMap,
io::{IsTerminal, Write},
sync::Arc,
};
@@ -125,21 +125,10 @@ fn write_aggregated_scores(
.peekable();
let has_edit_predictions = edit_predictions.peek().is_some();
let aggregated_result = EvaluationResult {
context: Scores::aggregate(successful.iter().map(|r| &r.context)),
edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)),
prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
/ successful.len(),
context_lines_found_in_context: successful
.iter()
.map(|r| r.context_lines_found_in_context)
.sum::<usize>()
/ successful.len(),
context_lines_in_expected_patch: successful
.iter()
.map(|r| r.context_lines_in_expected_patch)
.sum::<usize>()
/ successful.len(),
};
writeln!(w, "\n{}", "-".repeat(80))?;
@@ -261,11 +250,8 @@ fn write_eval_result(
#[derive(Debug, Default)]
pub struct EvaluationResult {
pub edit_prediction: Option<Scores>,
pub context: Scores,
pub prompt_len: usize,
pub generated_len: usize,
pub context_lines_in_expected_patch: usize,
pub context_lines_found_in_context: usize,
}
#[derive(Default, Debug)]
@@ -363,14 +349,6 @@ impl std::fmt::Display for EvaluationResult {
impl EvaluationResult {
fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
r#"
### Context Scores
{}
"#,
self.context.to_markdown(),
)?;
if let Some(prediction) = &self.edit_prediction {
write!(
f,
@@ -387,34 +365,18 @@ impl EvaluationResult {
writeln!(f, "### Scores\n")?;
writeln!(
f,
" Prompt Generated RetrievedContext PatchContext TP FP FN Precision Recall F1"
" Prompt Generated TP FP FN Precision Recall F1"
)?;
writeln!(
f,
"─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────"
)?;
writeln!(
f,
"Context Retrieval {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
"",
"",
"",
"",
self.context.true_positives,
self.context.false_positives,
self.context.false_negatives,
self.context.precision() * 100.0,
self.context.recall() * 100.0,
self.context.f1_score() * 100.0
"───────────────────────────────────────────────────────────────────────────────────────────────"
)?;
if let Some(edit_prediction) = &self.edit_prediction {
writeln!(
f,
"Edit Prediction {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
"Edit Prediction {:<7} {:<9} {:<6} {:<6} {:<6} {:>9.2} {:>8.2} {:>7.2}",
self.prompt_len,
self.generated_len,
self.context_lines_found_in_context,
self.context_lines_in_expected_patch,
edit_prediction.true_positives,
edit_prediction.false_positives,
edit_prediction.false_negatives,
@@ -434,53 +396,6 @@ fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> Eval
..Default::default()
};
let actual_context_lines: HashSet<_> = preds
.excerpts
.iter()
.flat_map(|excerpt| {
excerpt
.text
.lines()
.map(|line| format!("{}: {line}", excerpt.path.display()))
})
.collect();
let mut false_positive_lines = actual_context_lines.clone();
for entry in &example.expected_context {
let mut best_alternative_score: Option<Scores> = None;
for alternative in &entry.alternatives {
let expected: HashSet<_> = alternative
.excerpts
.iter()
.flat_map(|excerpt| {
excerpt
.text
.lines()
.map(|line| format!("{}: {line}", excerpt.path.display()))
})
.collect();
let scores = Scores::new(&expected, &actual_context_lines);
false_positive_lines.retain(|line| !expected.contains(line));
if best_alternative_score
.as_ref()
.is_none_or(|best| scores.recall() > best.recall())
{
best_alternative_score = Some(scores);
}
}
let best_alternative = best_alternative_score.unwrap_or_default();
eval_result.context.false_negatives += best_alternative.false_negatives;
eval_result.context.true_positives += best_alternative.true_positives;
}
eval_result.context.false_positives = false_positive_lines.len();
if predict {
// todo: alternatives for patches
let expected_patch = example
@@ -493,25 +408,6 @@ fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> Eval
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
.map(|line| line.to_string())
.collect();
let expected_context_lines = expected_patch
.iter()
.filter_map(|line| {
if let DiffLine::Context(str) = line {
Some(String::from(*str))
} else {
None
}
})
.collect::<BTreeSet<_>>();
let actual_context_lines = preds
.excerpts
.iter()
.flat_map(|excerpt| excerpt.text.lines().map(ToOwned::to_owned))
.collect::<BTreeSet<_>>();
let matched = expected_context_lines
.intersection(&actual_context_lines)
.count();
let actual_patch_lines = preds
.diff
@@ -522,8 +418,6 @@ fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> Eval
.collect();
eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
eval_result.context_lines_in_expected_patch = expected_context_lines.len();
eval_result.context_lines_found_in_context = matched;
}
eval_result

View File

@@ -14,7 +14,6 @@ use anyhow::{Context as _, Result, anyhow};
use clap::ValueEnum;
use cloud_zeta2_prompt::CURSOR_MARKER;
use collections::HashMap;
use edit_prediction_context::Line;
use futures::{
AsyncWriteExt as _,
lock::{Mutex, OwnedMutexGuard},
@@ -53,7 +52,6 @@ pub struct Example {
pub cursor_position: String,
pub edit_history: String,
pub expected_patch: String,
pub expected_context: Vec<ExpectedContextEntry>,
}
pub type ActualExcerpt = Excerpt;
@@ -64,25 +62,6 @@ pub struct Excerpt {
pub text: String,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize)]
pub struct ExpectedContextEntry {
pub heading: String,
pub alternatives: Vec<ExpectedExcerptSet>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize)]
pub struct ExpectedExcerptSet {
pub heading: String,
pub excerpts: Vec<ExpectedExcerpt>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExpectedExcerpt {
pub path: PathBuf,
pub text: String,
pub required_lines: Vec<Line>,
}
#[derive(ValueEnum, Debug, Clone)]
pub enum ExampleFormat {
Json,
@@ -132,7 +111,6 @@ impl NamedExample {
cursor_position: String::new(),
edit_history: String::new(),
expected_patch: String::new(),
expected_context: Vec::new(),
},
};
@@ -197,30 +175,10 @@ impl NamedExample {
};
}
Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
let heading = mem::take(&mut text);
match current_section {
Section::ExpectedExcerpts => {
named.example.expected_context.push(ExpectedContextEntry {
heading,
alternatives: Vec::new(),
});
}
_ => {}
}
mem::take(&mut text);
}
Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
let heading = mem::take(&mut text);
match current_section {
Section::ExpectedExcerpts => {
let expected_context = &mut named.example.expected_context;
let last_entry = expected_context.last_mut().unwrap();
last_entry.alternatives.push(ExpectedExcerptSet {
heading,
excerpts: Vec::new(),
})
}
_ => {}
}
mem::take(&mut text);
}
Event::End(TagEnd::Heading(level)) => {
anyhow::bail!("Unexpected heading level: {level}");
@@ -253,41 +211,7 @@ impl NamedExample {
named.example.cursor_position = mem::take(&mut text);
}
Section::ExpectedExcerpts => {
let text = mem::take(&mut text);
for excerpt in text.split("\n\n") {
let (mut text, required_lines) = extract_required_lines(&excerpt);
if !text.ends_with('\n') {
text.push('\n');
}
if named.example.expected_context.is_empty() {
named.example.expected_context.push(Default::default());
}
let alternatives = &mut named
.example
.expected_context
.last_mut()
.unwrap()
.alternatives;
if alternatives.is_empty() {
alternatives.push(ExpectedExcerptSet {
heading: String::new(),
excerpts: vec![],
});
}
alternatives
.last_mut()
.unwrap()
.excerpts
.push(ExpectedExcerpt {
path: block_info.into(),
text,
required_lines,
});
}
mem::take(&mut text);
}
Section::ExpectedPatch => {
named.example.expected_patch = mem::take(&mut text);
@@ -561,47 +485,6 @@ impl NamedExample {
}
}
fn extract_required_lines(text: &str) -> (String, Vec<Line>) {
const MARKER: &str = "[ZETA]";
let mut new_text = String::new();
let mut required_lines = Vec::new();
let mut skipped_lines = 0_u32;
for (row, mut line) in text.split('\n').enumerate() {
if let Some(marker_column) = line.find(MARKER) {
let mut strip_column = marker_column;
while strip_column > 0 {
let prev_char = line[strip_column - 1..].chars().next().unwrap();
if prev_char.is_whitespace() || ['/', '#'].contains(&prev_char) {
strip_column -= 1;
} else {
break;
}
}
let metadata = &line[marker_column + MARKER.len()..];
if metadata.contains("required") {
required_lines.push(Line(row as u32 - skipped_lines));
}
if strip_column == 0 {
skipped_lines += 1;
continue;
}
line = &line[..strip_column];
}
new_text.push_str(line);
new_text.push('\n');
}
new_text.pop();
(new_text, required_lines)
}
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
let output = smol::process::Command::new("git")
.current_dir(repo_path)
@@ -656,37 +539,6 @@ impl Display for NamedExample {
)?;
}
if !self.example.expected_context.is_empty() {
write!(f, "\n## {EXPECTED_CONTEXT_HEADING}\n\n")?;
for entry in &self.example.expected_context {
write!(f, "\n### {}\n\n", entry.heading)?;
let skip_h4 =
entry.alternatives.len() == 1 && entry.alternatives[0].heading.is_empty();
for excerpt_set in &entry.alternatives {
if !skip_h4 {
write!(f, "\n#### {}\n\n", excerpt_set.heading)?;
}
for excerpt in &excerpt_set.excerpts {
write!(
f,
"`````{}{}\n{}`````\n\n",
excerpt
.path
.extension()
.map(|ext| format!("{} ", ext.to_string_lossy()))
.unwrap_or_default(),
excerpt.path.display(),
excerpt.text
)?;
}
}
}
}
Ok(())
}
}
@@ -707,38 +559,3 @@ pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
.lock_owned()
.await
}
#[cfg(test)]
mod tests {
use super::*;
use indoc::indoc;
use pretty_assertions::assert_eq;
#[test]
fn test_extract_required_lines() {
let input = indoc! {"
zero
one // [ZETA] required
two
// [ZETA] something
three
four # [ZETA] required
five
"};
let expected_updated_input = indoc! {"
zero
one
two
three
four
five
"};
let expected_required_lines = vec![Line(1), Line(4)];
let (updated_input, required_lines) = extract_required_lines(input);
assert_eq!(updated_input, expected_updated_input);
assert_eq!(required_lines, expected_required_lines);
}
}

View File

@@ -128,8 +128,6 @@ pub struct PredictArguments {
#[derive(Clone, Debug, Args)]
pub struct PredictionOptions {
#[arg(long)]
use_expected_context: bool,
#[clap(flatten)]
zeta2: Zeta2Args,
#[clap(long)]

View File

@@ -1,4 +1,4 @@
use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
use crate::example::{ActualExcerpt, NamedExample};
use crate::headless::ZetaCliAppState;
use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
use crate::{
@@ -7,16 +7,13 @@ use crate::{
use ::serde::Serialize;
use anyhow::{Context, Result, anyhow};
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
use collections::HashMap;
use futures::StreamExt as _;
use gpui::{AppContext, AsyncApp, Entity};
use language::{Anchor, Buffer, Point};
use project::Project;
use project::buffer_store::BufferStoreEvent;
use serde::Deserialize;
use std::fs;
use std::io::{IsTerminal, Write};
use std::ops::Range;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
@@ -204,15 +201,12 @@ pub async fn perform_predict(
let mut result = result.lock().unwrap();
result.generated_len = response.chars().count();
if !options.use_expected_context {
result.planning_search_time = Some(
search_queries_generated_at.unwrap() - start_time.unwrap(),
);
result.running_search_time = Some(
search_queries_executed_at.unwrap()
- search_queries_generated_at.unwrap(),
);
}
result.planning_search_time =
Some(search_queries_generated_at.unwrap() - start_time.unwrap());
result.running_search_time = Some(
search_queries_executed_at.unwrap()
- search_queries_generated_at.unwrap(),
);
result.prediction_time = prediction_finished_at - prediction_started_at;
result.total_time = prediction_finished_at - start_time.unwrap();
@@ -224,37 +218,10 @@ pub async fn perform_predict(
}
});
if options.use_expected_context {
let context_excerpts_tasks = example
.example
.expected_context
.iter()
.flat_map(|section| {
section.alternatives[0].excerpts.iter().map(|excerpt| {
resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
})
})
.collect::<Vec<_>>();
let context_excerpts_vec =
futures::future::try_join_all(context_excerpts_tasks).await?;
let mut context_excerpts = HashMap::default();
for (buffer, mut excerpts) in context_excerpts_vec {
context_excerpts
.entry(buffer)
.or_insert(Vec::new())
.append(&mut excerpts);
}
zeta.update(cx, |zeta, _cx| {
zeta.set_context(project.clone(), context_excerpts)
})?;
} else {
zeta.update(cx, |zeta, cx| {
zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
})?
.await?;
}
zeta.update(cx, |zeta, cx| {
zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
})?
.await?;
}
let prediction = zeta
@@ -274,38 +241,6 @@ pub async fn perform_predict(
anyhow::Ok(result)
}
async fn resolve_context_entry(
project: Entity<Project>,
excerpt: ExpectedExcerpt,
mut cx: AsyncApp,
) -> Result<(Entity<Buffer>, Vec<Range<Anchor>>)> {
let buffer = project
.update(&mut cx, |project, cx| {
let project_path = project.find_project_path(&excerpt.path, cx).unwrap();
project.open_buffer(project_path, cx)
})?
.await?;
let ranges = buffer.read_with(&mut cx, |buffer, _| {
let full_text = buffer.text();
let offset = full_text
.find(&excerpt.text)
.expect("Expected context not found");
let point = buffer.offset_to_point(offset);
excerpt
.required_lines
.iter()
.map(|line| {
let row = point.row + line.0;
let range = Point::new(row, 0)..Point::new(row + 1, 0);
buffer.anchor_after(range.start)..buffer.anchor_before(range.end)
})
.collect()
})?;
Ok((buffer, ranges))
}
struct RunCache {
cache_mode: CacheMode,
example_run_dir: PathBuf,