Compare commits
6 Commits
migrate-in
...
blink-mana
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2acd48f439 | ||
|
|
1f9d5ef684 | ||
|
|
83f0a3fd13 | ||
|
|
7ecbf8cf60 | ||
|
|
fb0fcd86fd | ||
|
|
36708c910a |
@@ -1,8 +1,8 @@
|
||||
use gpui::Context;
|
||||
use gpui::{Context, FocusHandle};
|
||||
use settings::SettingsStore;
|
||||
use smol::Timer;
|
||||
use std::time::Duration;
|
||||
use ui::App;
|
||||
use ui::{App, Window};
|
||||
|
||||
pub struct BlinkManager {
|
||||
blink_interval: Duration,
|
||||
@@ -11,21 +11,34 @@ pub struct BlinkManager {
|
||||
blinking_paused: bool,
|
||||
/// Whether the cursor should be visibly rendered or not.
|
||||
visible: bool,
|
||||
/// Whether the blinking currently enabled.
|
||||
enabled: bool,
|
||||
/// The focus handle to use to determine if the cursor should be blinking.
|
||||
focus_handle: FocusHandle,
|
||||
/// Whether the blinking is enabled in the settings.
|
||||
blink_enabled_in_settings: fn(&App) -> bool,
|
||||
is_enabled: Box<dyn Fn(&App) -> bool>,
|
||||
}
|
||||
|
||||
impl BlinkManager {
|
||||
pub fn new(
|
||||
blink_interval: Duration,
|
||||
blink_enabled_in_settings: fn(&App) -> bool,
|
||||
focus_handle: FocusHandle,
|
||||
is_enabled: impl Fn(&App) -> bool + 'static,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
// Make sure we blink the cursors if the setting is re-enabled
|
||||
cx.observe_global::<SettingsStore>(move |this, cx| {
|
||||
this.blink_cursors(this.blink_epoch, cx)
|
||||
cx.observe_global_in::<SettingsStore>(window, move |this, window, cx| {
|
||||
this.refresh(window, cx);
|
||||
})
|
||||
.detach();
|
||||
|
||||
cx.on_focus(&focus_handle, window, move |this, window, cx| {
|
||||
this.visible = false;
|
||||
this.refresh(window, cx);
|
||||
})
|
||||
.detach();
|
||||
|
||||
cx.on_blur(&focus_handle, window, move |this, _window, _cx| {
|
||||
this.visible = false;
|
||||
})
|
||||
.detach();
|
||||
|
||||
@@ -34,48 +47,64 @@ impl BlinkManager {
|
||||
blink_epoch: 0,
|
||||
blinking_paused: false,
|
||||
visible: true,
|
||||
enabled: false,
|
||||
blink_enabled_in_settings,
|
||||
focus_handle,
|
||||
is_enabled: Box::new(is_enabled),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn refresh(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.blink_cursors(self.blink_epoch, window, cx)
|
||||
}
|
||||
|
||||
fn next_blink_epoch(&mut self) -> usize {
|
||||
self.blink_epoch += 1;
|
||||
self.blink_epoch
|
||||
}
|
||||
|
||||
pub fn pause_blinking(&mut self, cx: &mut Context<Self>) {
|
||||
pub fn pause_blinking(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.show_cursor(cx);
|
||||
|
||||
let epoch = self.next_blink_epoch();
|
||||
let interval = self.blink_interval;
|
||||
cx.spawn(async move |this, cx| {
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
Timer::after(interval).await;
|
||||
this.update(cx, |this, cx| this.resume_cursor_blinking(epoch, cx))
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.resume_cursor_blinking(epoch, window, cx)
|
||||
})
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn resume_cursor_blinking(&mut self, epoch: usize, cx: &mut Context<Self>) {
|
||||
fn resume_cursor_blinking(
|
||||
&mut self,
|
||||
epoch: usize,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if epoch == self.blink_epoch {
|
||||
self.blinking_paused = false;
|
||||
self.blink_cursors(epoch, cx);
|
||||
self.blink_cursors(epoch, window, cx);
|
||||
}
|
||||
}
|
||||
|
||||
fn blink_cursors(&mut self, epoch: usize, cx: &mut Context<Self>) {
|
||||
if (self.blink_enabled_in_settings)(cx) {
|
||||
if epoch == self.blink_epoch && self.enabled && !self.blinking_paused {
|
||||
fn blink_cursors(&mut self, epoch: usize, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if (self.is_enabled)(cx) {
|
||||
if epoch == self.blink_epoch
|
||||
&& self.focus_handle.is_focused(window)
|
||||
&& !self.blinking_paused
|
||||
{
|
||||
self.visible = !self.visible;
|
||||
cx.notify();
|
||||
|
||||
let epoch = self.next_blink_epoch();
|
||||
let interval = self.blink_interval;
|
||||
cx.spawn(async move |this, cx| {
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
Timer::after(interval).await;
|
||||
if let Some(this) = this.upgrade() {
|
||||
this.update(cx, |this, cx| this.blink_cursors(epoch, cx))
|
||||
.ok();
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.blink_cursors(epoch, window, cx)
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
@@ -92,25 +121,6 @@ impl BlinkManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable the blinking of the cursor.
|
||||
pub fn enable(&mut self, cx: &mut Context<Self>) {
|
||||
if self.enabled {
|
||||
return;
|
||||
}
|
||||
|
||||
self.enabled = true;
|
||||
// Set cursors as invisible and start blinking: this causes cursors
|
||||
// to be visible during the next render.
|
||||
self.visible = false;
|
||||
self.blink_cursors(self.blink_epoch, cx);
|
||||
}
|
||||
|
||||
/// Disable the blinking of the cursor.
|
||||
pub fn disable(&mut self, _cx: &mut Context<Self>) {
|
||||
self.visible = false;
|
||||
self.enabled = false;
|
||||
}
|
||||
|
||||
pub fn visible(&self) -> bool {
|
||||
self.visible
|
||||
}
|
||||
|
||||
@@ -1886,16 +1886,26 @@ impl Editor {
|
||||
|
||||
let selections = SelectionsCollection::new();
|
||||
|
||||
let focus_handle = cx.focus_handle();
|
||||
|
||||
let blink_manager = cx.new(|cx| {
|
||||
let mut blink_manager = BlinkManager::new(
|
||||
CURSOR_BLINK_INTERVAL,
|
||||
|cx| EditorSettings::get_global(cx).cursor_blink,
|
||||
cx,
|
||||
);
|
||||
if is_minimap {
|
||||
blink_manager.disable(cx);
|
||||
BlinkManager::new(
|
||||
CURSOR_BLINK_INTERVAL,
|
||||
focus_handle.clone(),
|
||||
|_cx| false,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
} else {
|
||||
BlinkManager::new(
|
||||
CURSOR_BLINK_INTERVAL,
|
||||
focus_handle.clone(),
|
||||
|cx| EditorSettings::get_global(cx).cursor_blink,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
blink_manager
|
||||
});
|
||||
|
||||
let soft_wrap_mode_override =
|
||||
@@ -2095,7 +2105,6 @@ impl Editor {
|
||||
|
||||
let inlay_hint_settings =
|
||||
inlay_hint_settings(selections.newest_anchor().head(), &buffer_snapshot, cx);
|
||||
let focus_handle = cx.focus_handle();
|
||||
if !is_minimap {
|
||||
cx.on_focus(&focus_handle, window, Self::handle_focus)
|
||||
.detach();
|
||||
@@ -2292,15 +2301,7 @@ impl Editor {
|
||||
cx.observe_global_in::<SettingsStore>(window, Self::settings_changed),
|
||||
observe_buffer_font_size_adjustment(cx, |_, cx| cx.notify()),
|
||||
cx.observe_window_activation(window, |editor, window, cx| {
|
||||
let active = window.is_window_active();
|
||||
editor.blink_manager.update(cx, |blink_manager, cx| {
|
||||
if active {
|
||||
blink_manager.enable(cx);
|
||||
} else {
|
||||
blink_manager.disable(cx);
|
||||
}
|
||||
});
|
||||
if active {
|
||||
if window.is_window_active() {
|
||||
editor.show_mouse_cursor(cx);
|
||||
}
|
||||
}),
|
||||
@@ -3321,7 +3322,9 @@ impl Editor {
|
||||
}
|
||||
}
|
||||
|
||||
self.blink_manager.update(cx, BlinkManager::pause_blinking);
|
||||
self.blink_manager.update(cx, |blink_manager, cx| {
|
||||
blink_manager.pause_blinking(window, cx)
|
||||
});
|
||||
cx.emit(EditorEvent::SelectionsChanged { local });
|
||||
|
||||
let selections = &self.selections.disjoint_anchors_arc();
|
||||
@@ -22161,7 +22164,6 @@ impl Editor {
|
||||
blame.update(cx, GitBlame::focus)
|
||||
}
|
||||
|
||||
self.blink_manager.update(cx, BlinkManager::enable);
|
||||
self.show_cursor_names(window, cx);
|
||||
self.buffer.update(cx, |buffer, cx| {
|
||||
buffer.finalize_last_transaction(cx);
|
||||
@@ -22209,7 +22211,6 @@ impl Editor {
|
||||
}
|
||||
|
||||
pub fn handle_blur(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.blink_manager.update(cx, BlinkManager::disable);
|
||||
self.buffer
|
||||
.update(cx, |buffer, cx| buffer.remove_active_selections(cx));
|
||||
|
||||
|
||||
@@ -340,11 +340,11 @@ impl LspLogView {
|
||||
* Configuration: {CONFIGURATION}",
|
||||
NAME = info.status.name,
|
||||
ID = info.id,
|
||||
BINARY = info.status.binary.as_ref().map_or_else(
|
||||
|| "Unknown".to_string(),
|
||||
|binary| serde_json::to_string_pretty(binary)
|
||||
.unwrap_or_else(|e| format!("Failed to serialize binary info: {e:#}"))
|
||||
),
|
||||
BINARY = info
|
||||
.status
|
||||
.binary
|
||||
.as_ref()
|
||||
.map_or_else(|| "Unknown".to_string(), |binary| format!("{:#?}", binary)),
|
||||
WORKSPACE_FOLDERS = info
|
||||
.status
|
||||
.workspace_folders
|
||||
|
||||
@@ -234,17 +234,25 @@ impl TerminalView {
|
||||
|
||||
let scroll_handle = TerminalScrollHandle::new(terminal.read(cx));
|
||||
|
||||
let blink_manager = cx.new(|cx| {
|
||||
BlinkManager::new(
|
||||
CURSOR_BLINK_INTERVAL,
|
||||
|cx| {
|
||||
!matches!(
|
||||
TerminalSettings::get_global(cx).blinking,
|
||||
TerminalBlink::Off
|
||||
)
|
||||
},
|
||||
cx,
|
||||
)
|
||||
let blink_manager = cx.new({
|
||||
let weak_this = cx.weak_entity();
|
||||
let focus_handle = focus_handle.clone();
|
||||
|
||||
move |cx| {
|
||||
BlinkManager::new(
|
||||
CURSOR_BLINK_INTERVAL,
|
||||
focus_handle,
|
||||
move |cx| match TerminalSettings::get_global(cx).blinking {
|
||||
TerminalBlink::Off => false,
|
||||
TerminalBlink::On => true,
|
||||
TerminalBlink::TerminalControlled => weak_this
|
||||
.read_with(cx, |this, _cx| this.blinking_terminal_enabled)
|
||||
.unwrap_or(false),
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
});
|
||||
|
||||
let _subscriptions = vec![
|
||||
@@ -434,11 +442,6 @@ impl TerminalView {
|
||||
let breadcrumb_visibility_changed = self.show_breadcrumbs != settings.toolbar.breadcrumbs;
|
||||
self.show_breadcrumbs = settings.toolbar.breadcrumbs;
|
||||
|
||||
let should_blink = match settings.blinking {
|
||||
TerminalBlink::Off => false,
|
||||
TerminalBlink::On => true,
|
||||
TerminalBlink::TerminalControlled => self.blinking_terminal_enabled,
|
||||
};
|
||||
let new_cursor_shape = settings.cursor_shape;
|
||||
let old_cursor_shape = self.cursor_shape;
|
||||
if old_cursor_shape != new_cursor_shape {
|
||||
@@ -448,15 +451,6 @@ impl TerminalView {
|
||||
});
|
||||
}
|
||||
|
||||
self.blink_manager.update(
|
||||
cx,
|
||||
if should_blink {
|
||||
BlinkManager::enable
|
||||
} else {
|
||||
BlinkManager::disable
|
||||
},
|
||||
);
|
||||
|
||||
if breadcrumb_visibility_changed {
|
||||
cx.emit(ItemEvent::UpdateBreadcrumbs);
|
||||
}
|
||||
@@ -649,14 +643,17 @@ impl TerminalView {
|
||||
// When focused, check blinking settings and blink manager state
|
||||
match TerminalSettings::get_global(cx).blinking {
|
||||
TerminalBlink::Off => true,
|
||||
TerminalBlink::On | TerminalBlink::TerminalControlled => {
|
||||
self.blink_manager.read(cx).visible()
|
||||
TerminalBlink::TerminalControlled => {
|
||||
!self.blinking_terminal_enabled || self.blink_manager.read(cx).visible()
|
||||
}
|
||||
TerminalBlink::On => self.blink_manager.read(cx).visible(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn pause_cursor_blinking(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.blink_manager.update(cx, BlinkManager::pause_blinking);
|
||||
pub fn pause_cursor_blinking(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.blink_manager.update(cx, |blink_manager, cx| {
|
||||
blink_manager.pause_blinking(window, cx)
|
||||
});
|
||||
}
|
||||
|
||||
pub fn terminal(&self) -> &Entity<Terminal> {
|
||||
@@ -872,21 +869,9 @@ fn subscribe_for_terminal_events(
|
||||
|
||||
Event::BlinkChanged(blinking) => {
|
||||
terminal_view.blinking_terminal_enabled = *blinking;
|
||||
|
||||
// If in terminal-controlled mode and focused, update blink manager
|
||||
if matches!(
|
||||
TerminalSettings::get_global(cx).blinking,
|
||||
TerminalBlink::TerminalControlled
|
||||
) && terminal_view.focus_handle.is_focused(window)
|
||||
{
|
||||
terminal_view.blink_manager.update(cx, |manager, cx| {
|
||||
if *blinking {
|
||||
manager.enable(cx);
|
||||
} else {
|
||||
manager.disable(cx);
|
||||
}
|
||||
});
|
||||
}
|
||||
terminal_view
|
||||
.blink_manager
|
||||
.update(cx, |this, cx| this.refresh(window, cx));
|
||||
}
|
||||
|
||||
Event::TitleChanged => {
|
||||
@@ -1012,22 +997,11 @@ impl TerminalView {
|
||||
terminal.focus_in();
|
||||
});
|
||||
|
||||
let should_blink = match TerminalSettings::get_global(cx).blinking {
|
||||
TerminalBlink::Off => false,
|
||||
TerminalBlink::On => true,
|
||||
TerminalBlink::TerminalControlled => self.blinking_terminal_enabled,
|
||||
};
|
||||
|
||||
if should_blink {
|
||||
self.blink_manager.update(cx, BlinkManager::enable);
|
||||
}
|
||||
|
||||
window.invalidate_character_coordinates();
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn focus_out(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.blink_manager.update(cx, BlinkManager::disable);
|
||||
self.terminal.update(cx, |terminal, _| {
|
||||
terminal.focus_out();
|
||||
terminal.set_cursor_shape(CursorShape::Hollow);
|
||||
|
||||
@@ -77,7 +77,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
) -> bool {
|
||||
let zeta = self.zeta.read(cx);
|
||||
if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
|
||||
zeta.sweep_api_token.is_some()
|
||||
zeta.sweep_ai.api_token.is_some()
|
||||
} else {
|
||||
true
|
||||
}
|
||||
|
||||
@@ -1,10 +1,269 @@
|
||||
use std::fmt;
|
||||
use std::{path::Path, sync::Arc};
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use cloud_llm_client::predict_edits_v3::Event;
|
||||
use futures::AsyncReadExt as _;
|
||||
use gpui::{
|
||||
App, AppContext as _, Entity, Task,
|
||||
http_client::{self, AsyncBody, Method},
|
||||
};
|
||||
use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
|
||||
use lsp::DiagnosticSeverity;
|
||||
use project::{Project, ProjectPath};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
fmt::{self, Write as _},
|
||||
ops::Range,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::Instant,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{EditPrediction, EditPredictionId, EditPredictionInputs};
|
||||
|
||||
const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
|
||||
|
||||
pub struct SweepAi {
|
||||
pub api_token: Option<String>,
|
||||
pub debug_info: Arc<str>,
|
||||
}
|
||||
|
||||
impl SweepAi {
|
||||
pub fn new(cx: &App) -> Self {
|
||||
SweepAi {
|
||||
api_token: std::env::var("SWEEP_AI_TOKEN")
|
||||
.context("No SWEEP_AI_TOKEN environment variable set")
|
||||
.log_err(),
|
||||
debug_info: debug_info(cx),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn request_prediction_with_sweep(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
snapshot: BufferSnapshot,
|
||||
position: language::Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
recent_paths: &VecDeque<ProjectPath>,
|
||||
diagnostic_search_range: Range<Point>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Option<EditPrediction>>> {
|
||||
let debug_info = self.debug_info.clone();
|
||||
let Some(api_token) = self.api_token.clone() else {
|
||||
return Task::ready(Ok(None));
|
||||
};
|
||||
let full_path: Arc<Path> = snapshot
|
||||
.file()
|
||||
.map(|file| file.full_path(cx))
|
||||
.unwrap_or_else(|| "untitled".into())
|
||||
.into();
|
||||
|
||||
let project_file = project::File::from_dyn(snapshot.file());
|
||||
let repo_name = project_file
|
||||
.map(|file| file.worktree.read(cx).root_name_str())
|
||||
.unwrap_or("untitled")
|
||||
.into();
|
||||
let offset = position.to_offset(&snapshot);
|
||||
|
||||
let recent_buffers = recent_paths.iter().cloned();
|
||||
let http_client = cx.http_client();
|
||||
|
||||
let recent_buffer_snapshots = recent_buffers
|
||||
.filter_map(|project_path| {
|
||||
let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
|
||||
if active_buffer == &buffer {
|
||||
None
|
||||
} else {
|
||||
Some(buffer.read(cx).snapshot())
|
||||
}
|
||||
})
|
||||
.take(3)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let cursor_point = position.to_point(&snapshot);
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
|
||||
let result = cx.background_spawn(async move {
|
||||
let text = snapshot.text();
|
||||
|
||||
let mut recent_changes = String::new();
|
||||
for event in &events {
|
||||
write_event(event.as_ref(), &mut recent_changes).unwrap();
|
||||
}
|
||||
|
||||
let mut file_chunks = recent_buffer_snapshots
|
||||
.into_iter()
|
||||
.map(|snapshot| {
|
||||
let end_point = Point::new(30, 0).min(snapshot.max_point());
|
||||
FileChunk {
|
||||
content: snapshot.text_for_range(Point::zero()..end_point).collect(),
|
||||
file_path: snapshot
|
||||
.file()
|
||||
.map(|f| f.path().as_unix_str())
|
||||
.unwrap_or("untitled")
|
||||
.to_string(),
|
||||
start_line: 0,
|
||||
end_line: end_point.row as usize,
|
||||
timestamp: snapshot.file().and_then(|file| {
|
||||
Some(
|
||||
file.disk_state()
|
||||
.mtime()?
|
||||
.to_seconds_and_nanos_for_persistence()?
|
||||
.0,
|
||||
)
|
||||
}),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
|
||||
let mut diagnostic_content = String::new();
|
||||
let mut diagnostic_count = 0;
|
||||
|
||||
for entry in diagnostic_entries {
|
||||
let start_point: Point = entry.range.start;
|
||||
|
||||
let severity = match entry.diagnostic.severity {
|
||||
DiagnosticSeverity::ERROR => "error",
|
||||
DiagnosticSeverity::WARNING => "warning",
|
||||
DiagnosticSeverity::INFORMATION => "info",
|
||||
DiagnosticSeverity::HINT => "hint",
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
diagnostic_count += 1;
|
||||
|
||||
writeln!(
|
||||
&mut diagnostic_content,
|
||||
"{} at line {}: {}",
|
||||
severity,
|
||||
start_point.row + 1,
|
||||
entry.diagnostic.message
|
||||
)?;
|
||||
}
|
||||
|
||||
if !diagnostic_content.is_empty() {
|
||||
file_chunks.push(FileChunk {
|
||||
file_path: format!("Diagnostics for {}", full_path.display()),
|
||||
start_line: 0,
|
||||
end_line: diagnostic_count,
|
||||
content: diagnostic_content,
|
||||
timestamp: None,
|
||||
});
|
||||
}
|
||||
|
||||
let request_body = AutocompleteRequest {
|
||||
debug_info,
|
||||
repo_name,
|
||||
file_path: full_path.clone(),
|
||||
file_contents: text.clone(),
|
||||
original_file_contents: text,
|
||||
cursor_position: offset,
|
||||
recent_changes: recent_changes.clone(),
|
||||
changes_above_cursor: true,
|
||||
multiple_suggestions: false,
|
||||
branch: None,
|
||||
file_chunks,
|
||||
retrieval_chunks: vec![],
|
||||
recent_user_actions: vec![],
|
||||
// TODO
|
||||
privacy_mode_enabled: false,
|
||||
};
|
||||
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
|
||||
serde_json::to_writer(writer, &request_body)?;
|
||||
let body: AsyncBody = buf.into();
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
events,
|
||||
included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
|
||||
path: full_path.clone(),
|
||||
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
|
||||
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
|
||||
start_line: cloud_llm_client::predict_edits_v3::Line(0),
|
||||
text: request_body.file_contents.into(),
|
||||
}],
|
||||
}],
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
column: cursor_point.column,
|
||||
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
|
||||
},
|
||||
cursor_path: full_path.clone(),
|
||||
};
|
||||
|
||||
let request = http_client::Request::builder()
|
||||
.uri(SWEEP_API_URL)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_token))
|
||||
.header("Connection", "keep-alive")
|
||||
.header("Content-Encoding", "br")
|
||||
.method(Method::POST)
|
||||
.body(body)?;
|
||||
|
||||
let mut response = http_client.send(request).await?;
|
||||
|
||||
let mut body: Vec<u8> = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
|
||||
let response_received_at = Instant::now();
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!(
|
||||
"Request failed with status: {:?}\nBody: {}",
|
||||
response.status(),
|
||||
String::from_utf8_lossy(&body),
|
||||
);
|
||||
};
|
||||
|
||||
let response: AutocompleteResponse = serde_json::from_slice(&body)?;
|
||||
|
||||
let old_text = snapshot
|
||||
.text_for_range(response.start_index..response.end_index)
|
||||
.collect::<String>();
|
||||
let edits = language::text_diff(&old_text, &response.completion)
|
||||
.into_iter()
|
||||
.map(|(range, text)| {
|
||||
(
|
||||
snapshot.anchor_after(response.start_index + range.start)
|
||||
..snapshot.anchor_before(response.start_index + range.end),
|
||||
text,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
anyhow::Ok((
|
||||
response.autocomplete_id,
|
||||
edits,
|
||||
snapshot,
|
||||
response_received_at,
|
||||
inputs,
|
||||
))
|
||||
});
|
||||
|
||||
let buffer = active_buffer.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
|
||||
anyhow::Ok(
|
||||
EditPrediction::new(
|
||||
EditPredictionId(id.into()),
|
||||
&buffer,
|
||||
&old_snapshot,
|
||||
edits.into(),
|
||||
buffer_snapshotted_at,
|
||||
response_received_at,
|
||||
inputs,
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct AutocompleteRequest {
|
||||
struct AutocompleteRequest {
|
||||
pub debug_info: Arc<str>,
|
||||
pub repo_name: String,
|
||||
pub branch: Option<String>,
|
||||
@@ -22,7 +281,7 @@ pub struct AutocompleteRequest {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct FileChunk {
|
||||
struct FileChunk {
|
||||
pub file_path: String,
|
||||
pub start_line: usize,
|
||||
pub end_line: usize,
|
||||
@@ -31,7 +290,7 @@ pub struct FileChunk {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct RetrievalChunk {
|
||||
struct RetrievalChunk {
|
||||
pub file_path: String,
|
||||
pub start_line: usize,
|
||||
pub end_line: usize,
|
||||
@@ -40,7 +299,7 @@ pub struct RetrievalChunk {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct UserAction {
|
||||
struct UserAction {
|
||||
pub action_type: ActionType,
|
||||
pub line_number: usize,
|
||||
pub offset: usize,
|
||||
@@ -51,7 +310,7 @@ pub struct UserAction {
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
|
||||
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum ActionType {
|
||||
enum ActionType {
|
||||
CursorMovement,
|
||||
InsertChar,
|
||||
DeleteChar,
|
||||
@@ -60,7 +319,7 @@ pub enum ActionType {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct AutocompleteResponse {
|
||||
struct AutocompleteResponse {
|
||||
pub autocomplete_id: String,
|
||||
pub start_index: usize,
|
||||
pub end_index: usize,
|
||||
@@ -80,7 +339,7 @@ pub struct AutocompleteResponse {
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct AdditionalCompletion {
|
||||
struct AdditionalCompletion {
|
||||
pub start_index: usize,
|
||||
pub end_index: usize,
|
||||
pub completion: String,
|
||||
@@ -90,7 +349,7 @@ pub struct AdditionalCompletion {
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
pub(crate) fn write_event(
|
||||
fn write_event(
|
||||
event: &cloud_llm_client::predict_edits_v3::Event,
|
||||
f: &mut impl fmt::Write,
|
||||
) -> fmt::Result {
|
||||
@@ -115,7 +374,7 @@ pub(crate) fn write_event(
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn debug_info(cx: &gpui::App) -> Arc<str> {
|
||||
fn debug_info(cx: &gpui::App) -> Arc<str> {
|
||||
format!(
|
||||
"Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
|
||||
version = release_channel::AppVersion::global(cx),
|
||||
|
||||
@@ -30,7 +30,6 @@ use language::{
|
||||
};
|
||||
use language::{BufferSnapshot, OffsetRangeExt};
|
||||
use language_model::{LlmApiToken, RefreshLlmTokenListener};
|
||||
use lsp::DiagnosticSeverity;
|
||||
use open_ai::FunctionDefinition;
|
||||
use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
|
||||
use release_channel::AppVersion;
|
||||
@@ -42,7 +41,6 @@ use std::collections::{VecDeque, hash_map};
|
||||
use telemetry_events::EditPredictionRating;
|
||||
use workspace::Workspace;
|
||||
|
||||
use std::fmt::Write as _;
|
||||
use std::ops::Range;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
@@ -80,6 +78,7 @@ use crate::rate_prediction_modal::{
|
||||
NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
|
||||
ThumbsUpActivePrediction,
|
||||
};
|
||||
use crate::sweep_ai::SweepAi;
|
||||
use crate::zeta1::request_prediction_with_zeta1;
|
||||
pub use provider::ZetaEditPredictionProvider;
|
||||
|
||||
@@ -171,7 +170,7 @@ impl FeatureFlag for Zeta2FeatureFlag {
|
||||
const NAME: &'static str = "zeta2";
|
||||
|
||||
fn enabled_for_staff() -> bool {
|
||||
false
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,8 +191,7 @@ pub struct Zeta {
|
||||
#[cfg(feature = "eval-support")]
|
||||
eval_cache: Option<Arc<dyn EvalCache>>,
|
||||
edit_prediction_model: ZetaEditPredictionModel,
|
||||
sweep_api_token: Option<String>,
|
||||
sweep_ai_debug_info: Arc<str>,
|
||||
sweep_ai: SweepAi,
|
||||
data_collection_choice: DataCollectionChoice,
|
||||
rejected_predictions: Vec<EditPredictionRejection>,
|
||||
reject_predictions_tx: mpsc::UnboundedSender<()>,
|
||||
@@ -202,7 +200,7 @@ pub struct Zeta {
|
||||
rated_predictions: HashSet<EditPredictionId>,
|
||||
}
|
||||
|
||||
#[derive(Default, PartialEq, Eq)]
|
||||
#[derive(Copy, Clone, Default, PartialEq, Eq)]
|
||||
pub enum ZetaEditPredictionModel {
|
||||
#[default]
|
||||
Zeta1,
|
||||
@@ -499,11 +497,8 @@ impl Zeta {
|
||||
#[cfg(feature = "eval-support")]
|
||||
eval_cache: None,
|
||||
edit_prediction_model: ZetaEditPredictionModel::Zeta2,
|
||||
sweep_api_token: std::env::var("SWEEP_AI_TOKEN")
|
||||
.context("No SWEEP_AI_TOKEN environment variable set")
|
||||
.log_err(),
|
||||
sweep_ai: SweepAi::new(cx),
|
||||
data_collection_choice,
|
||||
sweep_ai_debug_info: sweep_ai::debug_info(cx),
|
||||
rejected_predictions: Vec::new(),
|
||||
reject_predictions_debounce_task: None,
|
||||
reject_predictions_tx: reject_tx,
|
||||
@@ -517,7 +512,7 @@ impl Zeta {
|
||||
}
|
||||
|
||||
pub fn has_sweep_api_token(&self) -> bool {
|
||||
self.sweep_api_token.is_some()
|
||||
self.sweep_ai.api_token.is_some()
|
||||
}
|
||||
|
||||
#[cfg(feature = "eval-support")]
|
||||
@@ -643,7 +638,9 @@ impl Zeta {
|
||||
}
|
||||
}
|
||||
project::Event::DiagnosticsUpdated { .. } => {
|
||||
self.refresh_prediction_from_diagnostics(project, cx);
|
||||
if cx.has_flag::<Zeta2FeatureFlag>() {
|
||||
self.refresh_prediction_from_diagnostics(project, cx);
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
@@ -1126,7 +1123,6 @@ impl Zeta {
|
||||
zeta_project.next_pending_prediction_id += 1;
|
||||
let last_request = zeta_project.last_prediction_refresh;
|
||||
|
||||
// TODO report cancelled requests like in zeta1
|
||||
let task = cx.spawn(async move |this, cx| {
|
||||
if let Some((last_entity, last_timestamp)) = last_request
|
||||
&& throttle_entity == last_entity
|
||||
@@ -1136,6 +1132,12 @@ impl Zeta {
|
||||
cx.background_executor().timer(timeout).await;
|
||||
}
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.get_or_init_zeta_project(&project, cx)
|
||||
.last_prediction_refresh = Some((throttle_entity, Instant::now()));
|
||||
})
|
||||
.ok();
|
||||
|
||||
let edit_prediction_id = do_refresh(this.clone(), cx).await.log_err().flatten();
|
||||
|
||||
// When a prediction completes, remove it from the pending list, and cancel
|
||||
@@ -1183,249 +1185,77 @@ impl Zeta {
|
||||
position: language::Anchor,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Option<EditPrediction>>> {
|
||||
match self.edit_prediction_model {
|
||||
ZetaEditPredictionModel::Zeta1 => {
|
||||
request_prediction_with_zeta1(self, project, active_buffer, position, cx)
|
||||
}
|
||||
ZetaEditPredictionModel::Zeta2 => {
|
||||
self.request_prediction_with_zeta2(project, active_buffer, position, cx)
|
||||
}
|
||||
ZetaEditPredictionModel::Sweep => {
|
||||
self.request_prediction_with_sweep(project, active_buffer, position, true, cx)
|
||||
}
|
||||
}
|
||||
self.request_prediction_internal(
|
||||
project.clone(),
|
||||
active_buffer.clone(),
|
||||
position,
|
||||
cx.has_flag::<Zeta2FeatureFlag>(),
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn request_prediction_with_sweep(
|
||||
fn request_prediction_internal(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
project: Entity<Project>,
|
||||
active_buffer: Entity<Buffer>,
|
||||
position: language::Anchor,
|
||||
allow_jump: bool,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Option<EditPrediction>>> {
|
||||
let snapshot = active_buffer.read(cx).snapshot();
|
||||
let debug_info = self.sweep_ai_debug_info.clone();
|
||||
let Some(api_token) = self.sweep_api_token.clone() else {
|
||||
return Task::ready(Ok(None));
|
||||
};
|
||||
let full_path: Arc<Path> = snapshot
|
||||
.file()
|
||||
.map(|file| file.full_path(cx))
|
||||
.unwrap_or_else(|| "untitled".into())
|
||||
.into();
|
||||
|
||||
let project_file = project::File::from_dyn(snapshot.file());
|
||||
let repo_name = project_file
|
||||
.map(|file| file.worktree.read(cx).root_name_str())
|
||||
.unwrap_or("untitled")
|
||||
.into();
|
||||
let offset = position.to_offset(&snapshot);
|
||||
|
||||
let project_state = self.get_or_init_zeta_project(project, cx);
|
||||
let events = project_state.events(cx);
|
||||
let has_events = !events.is_empty();
|
||||
let recent_buffers = project_state.recent_paths.iter().cloned();
|
||||
let http_client = cx.http_client();
|
||||
|
||||
let recent_buffer_snapshots = recent_buffers
|
||||
.filter_map(|project_path| {
|
||||
let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
|
||||
if active_buffer == &buffer {
|
||||
None
|
||||
} else {
|
||||
Some(buffer.read(cx).snapshot())
|
||||
}
|
||||
})
|
||||
.take(3)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
const DIAGNOSTIC_LINES_RANGE: u32 = 20;
|
||||
|
||||
self.get_or_init_zeta_project(&project, cx);
|
||||
let zeta_project = self.projects.get(&project.entity_id()).unwrap();
|
||||
let events = zeta_project.events(cx);
|
||||
let has_events = !events.is_empty();
|
||||
|
||||
let snapshot = active_buffer.read(cx).snapshot();
|
||||
let cursor_point = position.to_point(&snapshot);
|
||||
let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
|
||||
let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
|
||||
let diagnostic_search_range =
|
||||
Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
|
||||
let result = cx.background_spawn({
|
||||
let snapshot = snapshot.clone();
|
||||
let diagnostic_search_range = diagnostic_search_range.clone();
|
||||
async move {
|
||||
let text = snapshot.text();
|
||||
|
||||
let mut recent_changes = String::new();
|
||||
for event in &events {
|
||||
sweep_ai::write_event(event.as_ref(), &mut recent_changes).unwrap();
|
||||
}
|
||||
|
||||
let mut file_chunks = recent_buffer_snapshots
|
||||
.into_iter()
|
||||
.map(|snapshot| {
|
||||
let end_point = Point::new(30, 0).min(snapshot.max_point());
|
||||
sweep_ai::FileChunk {
|
||||
content: snapshot.text_for_range(Point::zero()..end_point).collect(),
|
||||
file_path: snapshot
|
||||
.file()
|
||||
.map(|f| f.path().as_unix_str())
|
||||
.unwrap_or("untitled")
|
||||
.to_string(),
|
||||
start_line: 0,
|
||||
end_line: end_point.row as usize,
|
||||
timestamp: snapshot.file().and_then(|file| {
|
||||
Some(
|
||||
file.disk_state()
|
||||
.mtime()?
|
||||
.to_seconds_and_nanos_for_persistence()?
|
||||
.0,
|
||||
)
|
||||
}),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let diagnostic_entries =
|
||||
snapshot.diagnostics_in_range(diagnostic_search_range, false);
|
||||
let mut diagnostic_content = String::new();
|
||||
let mut diagnostic_count = 0;
|
||||
|
||||
for entry in diagnostic_entries {
|
||||
let start_point: Point = entry.range.start;
|
||||
|
||||
let severity = match entry.diagnostic.severity {
|
||||
DiagnosticSeverity::ERROR => "error",
|
||||
DiagnosticSeverity::WARNING => "warning",
|
||||
DiagnosticSeverity::INFORMATION => "info",
|
||||
DiagnosticSeverity::HINT => "hint",
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
diagnostic_count += 1;
|
||||
|
||||
writeln!(
|
||||
&mut diagnostic_content,
|
||||
"{} at line {}: {}",
|
||||
severity,
|
||||
start_point.row + 1,
|
||||
entry.diagnostic.message
|
||||
)?;
|
||||
}
|
||||
|
||||
if !diagnostic_content.is_empty() {
|
||||
file_chunks.push(sweep_ai::FileChunk {
|
||||
file_path: format!("Diagnostics for {}", full_path.display()),
|
||||
start_line: 0,
|
||||
end_line: diagnostic_count,
|
||||
content: diagnostic_content,
|
||||
timestamp: None,
|
||||
});
|
||||
}
|
||||
|
||||
let request_body = sweep_ai::AutocompleteRequest {
|
||||
debug_info,
|
||||
repo_name,
|
||||
file_path: full_path.clone(),
|
||||
file_contents: text.clone(),
|
||||
original_file_contents: text,
|
||||
cursor_position: offset,
|
||||
recent_changes: recent_changes.clone(),
|
||||
changes_above_cursor: true,
|
||||
multiple_suggestions: false,
|
||||
branch: None,
|
||||
file_chunks,
|
||||
retrieval_chunks: vec![],
|
||||
recent_user_actions: vec![],
|
||||
// TODO
|
||||
privacy_mode_enabled: false,
|
||||
};
|
||||
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
|
||||
serde_json::to_writer(writer, &request_body)?;
|
||||
let body: AsyncBody = buf.into();
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
events,
|
||||
included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
|
||||
path: full_path.clone(),
|
||||
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
|
||||
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
|
||||
start_line: cloud_llm_client::predict_edits_v3::Line(0),
|
||||
text: request_body.file_contents.into(),
|
||||
}],
|
||||
}],
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
column: cursor_point.column,
|
||||
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
|
||||
},
|
||||
cursor_path: full_path.clone(),
|
||||
};
|
||||
|
||||
const SWEEP_API_URL: &str =
|
||||
"https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
|
||||
|
||||
let request = http_client::Request::builder()
|
||||
.uri(SWEEP_API_URL)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_token))
|
||||
.header("Connection", "keep-alive")
|
||||
.header("Content-Encoding", "br")
|
||||
.method(Method::POST)
|
||||
.body(body)?;
|
||||
|
||||
let mut response = http_client.send(request).await?;
|
||||
|
||||
let mut body: Vec<u8> = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
|
||||
let response_received_at = Instant::now();
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!(
|
||||
"Request failed with status: {:?}\nBody: {}",
|
||||
response.status(),
|
||||
String::from_utf8_lossy(&body),
|
||||
);
|
||||
};
|
||||
|
||||
let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
|
||||
|
||||
let old_text = snapshot
|
||||
.text_for_range(response.start_index..response.end_index)
|
||||
.collect::<String>();
|
||||
let edits = language::text_diff(&old_text, &response.completion)
|
||||
.into_iter()
|
||||
.map(|(range, text)| {
|
||||
(
|
||||
snapshot.anchor_after(response.start_index + range.start)
|
||||
..snapshot.anchor_before(response.start_index + range.end),
|
||||
text,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
anyhow::Ok((
|
||||
response.autocomplete_id,
|
||||
edits,
|
||||
snapshot,
|
||||
response_received_at,
|
||||
inputs,
|
||||
))
|
||||
}
|
||||
});
|
||||
|
||||
let buffer = active_buffer.clone();
|
||||
let project = project.clone();
|
||||
let active_buffer = active_buffer.clone();
|
||||
let task = match self.edit_prediction_model {
|
||||
ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
|
||||
self,
|
||||
&project,
|
||||
&active_buffer,
|
||||
snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
cx,
|
||||
),
|
||||
ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
|
||||
&project,
|
||||
&active_buffer,
|
||||
snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
cx,
|
||||
),
|
||||
ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
|
||||
&project,
|
||||
&active_buffer,
|
||||
snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
&zeta_project.recent_paths,
|
||||
diagnostic_search_range.clone(),
|
||||
cx,
|
||||
),
|
||||
};
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
|
||||
let prediction = task
|
||||
.await?
|
||||
.filter(|prediction| !prediction.edits.is_empty());
|
||||
|
||||
if edits.is_empty() {
|
||||
if prediction.is_none() && allow_jump {
|
||||
let cursor_point = position.to_point(&snapshot);
|
||||
if has_events
|
||||
&& allow_jump
|
||||
&& let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
|
||||
active_buffer,
|
||||
active_buffer.clone(),
|
||||
&snapshot,
|
||||
diagnostic_search_range,
|
||||
cursor_point,
|
||||
@@ -1436,9 +1266,9 @@ impl Zeta {
|
||||
{
|
||||
return this
|
||||
.update(cx, |this, cx| {
|
||||
this.request_prediction_with_sweep(
|
||||
&project,
|
||||
&jump_buffer,
|
||||
this.request_prediction_internal(
|
||||
project,
|
||||
jump_buffer,
|
||||
jump_position,
|
||||
false,
|
||||
cx,
|
||||
@@ -1450,19 +1280,7 @@ impl Zeta {
|
||||
return anyhow::Ok(None);
|
||||
}
|
||||
|
||||
anyhow::Ok(
|
||||
EditPrediction::new(
|
||||
EditPredictionId(id.into()),
|
||||
&buffer,
|
||||
&old_snapshot,
|
||||
edits.into(),
|
||||
buffer_snapshotted_at,
|
||||
response_received_at,
|
||||
inputs,
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
)
|
||||
Ok(prediction)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1549,7 +1367,9 @@ impl Zeta {
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
active_snapshot: BufferSnapshot,
|
||||
position: language::Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Option<EditPrediction>>> {
|
||||
let project_state = self.projects.get(&project.entity_id());
|
||||
@@ -1561,7 +1381,6 @@ impl Zeta {
|
||||
.map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
|
||||
});
|
||||
let options = self.options.clone();
|
||||
let active_snapshot = active_buffer.read(cx).snapshot();
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
let Some(excerpt_path) = active_snapshot
|
||||
.file()
|
||||
@@ -1579,10 +1398,6 @@ impl Zeta {
|
||||
.collect::<Vec<_>>();
|
||||
let debug_tx = self.debug_tx.clone();
|
||||
|
||||
let events = project_state
|
||||
.map(|state| state.events(cx))
|
||||
.unwrap_or_default();
|
||||
|
||||
let diagnostics = active_snapshot.diagnostic_sets().clone();
|
||||
|
||||
let file = active_buffer.read(cx).file();
|
||||
|
||||
@@ -32,19 +32,17 @@ pub(crate) fn request_prediction_with_zeta1(
|
||||
zeta: &mut Zeta,
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
snapshot: BufferSnapshot,
|
||||
position: language::Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
cx: &mut Context<Zeta>,
|
||||
) -> Task<Result<Option<EditPrediction>>> {
|
||||
let buffer = buffer.clone();
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let client = zeta.client.clone();
|
||||
let llm_token = zeta.llm_token.clone();
|
||||
let app_version = AppVersion::global(cx);
|
||||
|
||||
let zeta_project = zeta.get_or_init_zeta_project(project, cx);
|
||||
let events = Arc::new(zeta_project.events(cx));
|
||||
|
||||
let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
|
||||
let can_collect_file = zeta.can_collect_file(project, file, cx);
|
||||
let git_info = if can_collect_file {
|
||||
|
||||
@@ -42,43 +42,48 @@ actions!(
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
|
||||
workspace.register_action(move |workspace, _: &OpenZeta2Inspector, window, cx| {
|
||||
let project = workspace.project();
|
||||
workspace.split_item(
|
||||
SplitDirection::Right,
|
||||
Box::new(cx.new(|cx| {
|
||||
Zeta2Inspector::new(
|
||||
&project,
|
||||
workspace.client(),
|
||||
workspace.user_store(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
});
|
||||
})
|
||||
.detach();
|
||||
|
||||
cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
|
||||
workspace.register_action(move |workspace, _: &OpenZeta2ContextView, window, cx| {
|
||||
let project = workspace.project();
|
||||
workspace.split_item(
|
||||
SplitDirection::Right,
|
||||
Box::new(cx.new(|cx| {
|
||||
Zeta2ContextView::new(
|
||||
project.clone(),
|
||||
workspace.client(),
|
||||
workspace.user_store(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
workspace.register_action_renderer(|div, _, _, cx| {
|
||||
let has_flag = cx.has_flag::<Zeta2FeatureFlag>();
|
||||
div.when(has_flag, |div| {
|
||||
div.on_action(
|
||||
cx.listener(move |workspace, _: &OpenZeta2Inspector, window, cx| {
|
||||
let project = workspace.project();
|
||||
workspace.split_item(
|
||||
SplitDirection::Right,
|
||||
Box::new(cx.new(|cx| {
|
||||
Zeta2Inspector::new(
|
||||
&project,
|
||||
workspace.client(),
|
||||
workspace.user_store(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}),
|
||||
)
|
||||
.on_action(cx.listener(
|
||||
move |workspace, _: &OpenZeta2ContextView, window, cx| {
|
||||
let project = workspace.project();
|
||||
workspace.split_item(
|
||||
SplitDirection::Right,
|
||||
Box::new(cx.new(|cx| {
|
||||
Zeta2ContextView::new(
|
||||
project.clone(),
|
||||
workspace.client(),
|
||||
workspace.user_store(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
},
|
||||
))
|
||||
})
|
||||
});
|
||||
})
|
||||
.detach();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::{
|
||||
collections::{BTreeSet, HashMap},
|
||||
collections::HashMap,
|
||||
io::{IsTerminal, Write},
|
||||
sync::Arc,
|
||||
};
|
||||
@@ -125,21 +125,10 @@ fn write_aggregated_scores(
|
||||
.peekable();
|
||||
let has_edit_predictions = edit_predictions.peek().is_some();
|
||||
let aggregated_result = EvaluationResult {
|
||||
context: Scores::aggregate(successful.iter().map(|r| &r.context)),
|
||||
edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)),
|
||||
prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
|
||||
generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
|
||||
/ successful.len(),
|
||||
context_lines_found_in_context: successful
|
||||
.iter()
|
||||
.map(|r| r.context_lines_found_in_context)
|
||||
.sum::<usize>()
|
||||
/ successful.len(),
|
||||
context_lines_in_expected_patch: successful
|
||||
.iter()
|
||||
.map(|r| r.context_lines_in_expected_patch)
|
||||
.sum::<usize>()
|
||||
/ successful.len(),
|
||||
};
|
||||
|
||||
writeln!(w, "\n{}", "-".repeat(80))?;
|
||||
@@ -261,11 +250,8 @@ fn write_eval_result(
|
||||
#[derive(Debug, Default)]
|
||||
pub struct EvaluationResult {
|
||||
pub edit_prediction: Option<Scores>,
|
||||
pub context: Scores,
|
||||
pub prompt_len: usize,
|
||||
pub generated_len: usize,
|
||||
pub context_lines_in_expected_patch: usize,
|
||||
pub context_lines_found_in_context: usize,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
@@ -363,14 +349,6 @@ impl std::fmt::Display for EvaluationResult {
|
||||
|
||||
impl EvaluationResult {
|
||||
fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
r#"
|
||||
### Context Scores
|
||||
{}
|
||||
"#,
|
||||
self.context.to_markdown(),
|
||||
)?;
|
||||
if let Some(prediction) = &self.edit_prediction {
|
||||
write!(
|
||||
f,
|
||||
@@ -387,34 +365,18 @@ impl EvaluationResult {
|
||||
writeln!(f, "### Scores\n")?;
|
||||
writeln!(
|
||||
f,
|
||||
" Prompt Generated RetrievedContext PatchContext TP FP FN Precision Recall F1"
|
||||
" Prompt Generated TP FP FN Precision Recall F1"
|
||||
)?;
|
||||
writeln!(
|
||||
f,
|
||||
"─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────"
|
||||
)?;
|
||||
writeln!(
|
||||
f,
|
||||
"Context Retrieval {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
self.context.true_positives,
|
||||
self.context.false_positives,
|
||||
self.context.false_negatives,
|
||||
self.context.precision() * 100.0,
|
||||
self.context.recall() * 100.0,
|
||||
self.context.f1_score() * 100.0
|
||||
"───────────────────────────────────────────────────────────────────────────────────────────────"
|
||||
)?;
|
||||
if let Some(edit_prediction) = &self.edit_prediction {
|
||||
writeln!(
|
||||
f,
|
||||
"Edit Prediction {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
|
||||
"Edit Prediction {:<7} {:<9} {:<6} {:<6} {:<6} {:>9.2} {:>8.2} {:>7.2}",
|
||||
self.prompt_len,
|
||||
self.generated_len,
|
||||
self.context_lines_found_in_context,
|
||||
self.context_lines_in_expected_patch,
|
||||
edit_prediction.true_positives,
|
||||
edit_prediction.false_positives,
|
||||
edit_prediction.false_negatives,
|
||||
@@ -434,53 +396,6 @@ fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> Eval
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let actual_context_lines: HashSet<_> = preds
|
||||
.excerpts
|
||||
.iter()
|
||||
.flat_map(|excerpt| {
|
||||
excerpt
|
||||
.text
|
||||
.lines()
|
||||
.map(|line| format!("{}: {line}", excerpt.path.display()))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut false_positive_lines = actual_context_lines.clone();
|
||||
|
||||
for entry in &example.expected_context {
|
||||
let mut best_alternative_score: Option<Scores> = None;
|
||||
|
||||
for alternative in &entry.alternatives {
|
||||
let expected: HashSet<_> = alternative
|
||||
.excerpts
|
||||
.iter()
|
||||
.flat_map(|excerpt| {
|
||||
excerpt
|
||||
.text
|
||||
.lines()
|
||||
.map(|line| format!("{}: {line}", excerpt.path.display()))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let scores = Scores::new(&expected, &actual_context_lines);
|
||||
|
||||
false_positive_lines.retain(|line| !expected.contains(line));
|
||||
|
||||
if best_alternative_score
|
||||
.as_ref()
|
||||
.is_none_or(|best| scores.recall() > best.recall())
|
||||
{
|
||||
best_alternative_score = Some(scores);
|
||||
}
|
||||
}
|
||||
|
||||
let best_alternative = best_alternative_score.unwrap_or_default();
|
||||
eval_result.context.false_negatives += best_alternative.false_negatives;
|
||||
eval_result.context.true_positives += best_alternative.true_positives;
|
||||
}
|
||||
|
||||
eval_result.context.false_positives = false_positive_lines.len();
|
||||
|
||||
if predict {
|
||||
// todo: alternatives for patches
|
||||
let expected_patch = example
|
||||
@@ -493,25 +408,6 @@ fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> Eval
|
||||
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
|
||||
.map(|line| line.to_string())
|
||||
.collect();
|
||||
let expected_context_lines = expected_patch
|
||||
.iter()
|
||||
.filter_map(|line| {
|
||||
if let DiffLine::Context(str) = line {
|
||||
Some(String::from(*str))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<BTreeSet<_>>();
|
||||
let actual_context_lines = preds
|
||||
.excerpts
|
||||
.iter()
|
||||
.flat_map(|excerpt| excerpt.text.lines().map(ToOwned::to_owned))
|
||||
.collect::<BTreeSet<_>>();
|
||||
|
||||
let matched = expected_context_lines
|
||||
.intersection(&actual_context_lines)
|
||||
.count();
|
||||
|
||||
let actual_patch_lines = preds
|
||||
.diff
|
||||
@@ -522,8 +418,6 @@ fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> Eval
|
||||
.collect();
|
||||
|
||||
eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
|
||||
eval_result.context_lines_in_expected_patch = expected_context_lines.len();
|
||||
eval_result.context_lines_found_in_context = matched;
|
||||
}
|
||||
|
||||
eval_result
|
||||
|
||||
@@ -14,7 +14,6 @@ use anyhow::{Context as _, Result, anyhow};
|
||||
use clap::ValueEnum;
|
||||
use cloud_zeta2_prompt::CURSOR_MARKER;
|
||||
use collections::HashMap;
|
||||
use edit_prediction_context::Line;
|
||||
use futures::{
|
||||
AsyncWriteExt as _,
|
||||
lock::{Mutex, OwnedMutexGuard},
|
||||
@@ -53,7 +52,6 @@ pub struct Example {
|
||||
pub cursor_position: String,
|
||||
pub edit_history: String,
|
||||
pub expected_patch: String,
|
||||
pub expected_context: Vec<ExpectedContextEntry>,
|
||||
}
|
||||
|
||||
pub type ActualExcerpt = Excerpt;
|
||||
@@ -64,25 +62,6 @@ pub struct Excerpt {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExpectedContextEntry {
|
||||
pub heading: String,
|
||||
pub alternatives: Vec<ExpectedExcerptSet>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExpectedExcerptSet {
|
||||
pub heading: String,
|
||||
pub excerpts: Vec<ExpectedExcerpt>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExpectedExcerpt {
|
||||
pub path: PathBuf,
|
||||
pub text: String,
|
||||
pub required_lines: Vec<Line>,
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Debug, Clone)]
|
||||
pub enum ExampleFormat {
|
||||
Json,
|
||||
@@ -132,7 +111,6 @@ impl NamedExample {
|
||||
cursor_position: String::new(),
|
||||
edit_history: String::new(),
|
||||
expected_patch: String::new(),
|
||||
expected_context: Vec::new(),
|
||||
},
|
||||
};
|
||||
|
||||
@@ -197,30 +175,10 @@ impl NamedExample {
|
||||
};
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
|
||||
let heading = mem::take(&mut text);
|
||||
match current_section {
|
||||
Section::ExpectedExcerpts => {
|
||||
named.example.expected_context.push(ExpectedContextEntry {
|
||||
heading,
|
||||
alternatives: Vec::new(),
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
|
||||
let heading = mem::take(&mut text);
|
||||
match current_section {
|
||||
Section::ExpectedExcerpts => {
|
||||
let expected_context = &mut named.example.expected_context;
|
||||
let last_entry = expected_context.last_mut().unwrap();
|
||||
last_entry.alternatives.push(ExpectedExcerptSet {
|
||||
heading,
|
||||
excerpts: Vec::new(),
|
||||
})
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(level)) => {
|
||||
anyhow::bail!("Unexpected heading level: {level}");
|
||||
@@ -253,41 +211,7 @@ impl NamedExample {
|
||||
named.example.cursor_position = mem::take(&mut text);
|
||||
}
|
||||
Section::ExpectedExcerpts => {
|
||||
let text = mem::take(&mut text);
|
||||
for excerpt in text.split("\n…\n") {
|
||||
let (mut text, required_lines) = extract_required_lines(&excerpt);
|
||||
if !text.ends_with('\n') {
|
||||
text.push('\n');
|
||||
}
|
||||
|
||||
if named.example.expected_context.is_empty() {
|
||||
named.example.expected_context.push(Default::default());
|
||||
}
|
||||
|
||||
let alternatives = &mut named
|
||||
.example
|
||||
.expected_context
|
||||
.last_mut()
|
||||
.unwrap()
|
||||
.alternatives;
|
||||
|
||||
if alternatives.is_empty() {
|
||||
alternatives.push(ExpectedExcerptSet {
|
||||
heading: String::new(),
|
||||
excerpts: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
alternatives
|
||||
.last_mut()
|
||||
.unwrap()
|
||||
.excerpts
|
||||
.push(ExpectedExcerpt {
|
||||
path: block_info.into(),
|
||||
text,
|
||||
required_lines,
|
||||
});
|
||||
}
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Section::ExpectedPatch => {
|
||||
named.example.expected_patch = mem::take(&mut text);
|
||||
@@ -561,47 +485,6 @@ impl NamedExample {
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_required_lines(text: &str) -> (String, Vec<Line>) {
|
||||
const MARKER: &str = "[ZETA]";
|
||||
let mut new_text = String::new();
|
||||
let mut required_lines = Vec::new();
|
||||
let mut skipped_lines = 0_u32;
|
||||
|
||||
for (row, mut line) in text.split('\n').enumerate() {
|
||||
if let Some(marker_column) = line.find(MARKER) {
|
||||
let mut strip_column = marker_column;
|
||||
|
||||
while strip_column > 0 {
|
||||
let prev_char = line[strip_column - 1..].chars().next().unwrap();
|
||||
if prev_char.is_whitespace() || ['/', '#'].contains(&prev_char) {
|
||||
strip_column -= 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let metadata = &line[marker_column + MARKER.len()..];
|
||||
if metadata.contains("required") {
|
||||
required_lines.push(Line(row as u32 - skipped_lines));
|
||||
}
|
||||
|
||||
if strip_column == 0 {
|
||||
skipped_lines += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
line = &line[..strip_column];
|
||||
}
|
||||
|
||||
new_text.push_str(line);
|
||||
new_text.push('\n');
|
||||
}
|
||||
|
||||
new_text.pop();
|
||||
|
||||
(new_text, required_lines)
|
||||
}
|
||||
|
||||
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
|
||||
let output = smol::process::Command::new("git")
|
||||
.current_dir(repo_path)
|
||||
@@ -656,37 +539,6 @@ impl Display for NamedExample {
|
||||
)?;
|
||||
}
|
||||
|
||||
if !self.example.expected_context.is_empty() {
|
||||
write!(f, "\n## {EXPECTED_CONTEXT_HEADING}\n\n")?;
|
||||
|
||||
for entry in &self.example.expected_context {
|
||||
write!(f, "\n### {}\n\n", entry.heading)?;
|
||||
|
||||
let skip_h4 =
|
||||
entry.alternatives.len() == 1 && entry.alternatives[0].heading.is_empty();
|
||||
|
||||
for excerpt_set in &entry.alternatives {
|
||||
if !skip_h4 {
|
||||
write!(f, "\n#### {}\n\n", excerpt_set.heading)?;
|
||||
}
|
||||
|
||||
for excerpt in &excerpt_set.excerpts {
|
||||
write!(
|
||||
f,
|
||||
"`````{}{}\n{}`````\n\n",
|
||||
excerpt
|
||||
.path
|
||||
.extension()
|
||||
.map(|ext| format!("{} ", ext.to_string_lossy()))
|
||||
.unwrap_or_default(),
|
||||
excerpt.path.display(),
|
||||
excerpt.text
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -707,38 +559,3 @@ pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
|
||||
.lock_owned()
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use indoc::indoc;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn test_extract_required_lines() {
|
||||
let input = indoc! {"
|
||||
zero
|
||||
one // [ZETA] required
|
||||
two
|
||||
// [ZETA] something
|
||||
three
|
||||
four # [ZETA] required
|
||||
five
|
||||
"};
|
||||
|
||||
let expected_updated_input = indoc! {"
|
||||
zero
|
||||
one
|
||||
two
|
||||
three
|
||||
four
|
||||
five
|
||||
"};
|
||||
|
||||
let expected_required_lines = vec![Line(1), Line(4)];
|
||||
|
||||
let (updated_input, required_lines) = extract_required_lines(input);
|
||||
assert_eq!(updated_input, expected_updated_input);
|
||||
assert_eq!(required_lines, expected_required_lines);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,8 +128,6 @@ pub struct PredictArguments {
|
||||
|
||||
#[derive(Clone, Debug, Args)]
|
||||
pub struct PredictionOptions {
|
||||
#[arg(long)]
|
||||
use_expected_context: bool,
|
||||
#[clap(flatten)]
|
||||
zeta2: Zeta2Args,
|
||||
#[clap(long)]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
|
||||
use crate::example::{ActualExcerpt, NamedExample};
|
||||
use crate::headless::ZetaCliAppState;
|
||||
use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
|
||||
use crate::{
|
||||
@@ -7,16 +7,13 @@ use crate::{
|
||||
use ::serde::Serialize;
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
|
||||
use collections::HashMap;
|
||||
use futures::StreamExt as _;
|
||||
use gpui::{AppContext, AsyncApp, Entity};
|
||||
use language::{Anchor, Buffer, Point};
|
||||
use project::Project;
|
||||
use project::buffer_store::BufferStoreEvent;
|
||||
use serde::Deserialize;
|
||||
use std::fs;
|
||||
use std::io::{IsTerminal, Write};
|
||||
use std::ops::Range;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
@@ -204,15 +201,12 @@ pub async fn perform_predict(
|
||||
let mut result = result.lock().unwrap();
|
||||
result.generated_len = response.chars().count();
|
||||
|
||||
if !options.use_expected_context {
|
||||
result.planning_search_time = Some(
|
||||
search_queries_generated_at.unwrap() - start_time.unwrap(),
|
||||
);
|
||||
result.running_search_time = Some(
|
||||
search_queries_executed_at.unwrap()
|
||||
- search_queries_generated_at.unwrap(),
|
||||
);
|
||||
}
|
||||
result.planning_search_time =
|
||||
Some(search_queries_generated_at.unwrap() - start_time.unwrap());
|
||||
result.running_search_time = Some(
|
||||
search_queries_executed_at.unwrap()
|
||||
- search_queries_generated_at.unwrap(),
|
||||
);
|
||||
result.prediction_time = prediction_finished_at - prediction_started_at;
|
||||
result.total_time = prediction_finished_at - start_time.unwrap();
|
||||
|
||||
@@ -224,37 +218,10 @@ pub async fn perform_predict(
|
||||
}
|
||||
});
|
||||
|
||||
if options.use_expected_context {
|
||||
let context_excerpts_tasks = example
|
||||
.example
|
||||
.expected_context
|
||||
.iter()
|
||||
.flat_map(|section| {
|
||||
section.alternatives[0].excerpts.iter().map(|excerpt| {
|
||||
resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let context_excerpts_vec =
|
||||
futures::future::try_join_all(context_excerpts_tasks).await?;
|
||||
|
||||
let mut context_excerpts = HashMap::default();
|
||||
for (buffer, mut excerpts) in context_excerpts_vec {
|
||||
context_excerpts
|
||||
.entry(buffer)
|
||||
.or_insert(Vec::new())
|
||||
.append(&mut excerpts);
|
||||
}
|
||||
|
||||
zeta.update(cx, |zeta, _cx| {
|
||||
zeta.set_context(project.clone(), context_excerpts)
|
||||
})?;
|
||||
} else {
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
|
||||
})?
|
||||
.await?;
|
||||
}
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
|
||||
})?
|
||||
.await?;
|
||||
}
|
||||
|
||||
let prediction = zeta
|
||||
@@ -274,38 +241,6 @@ pub async fn perform_predict(
|
||||
anyhow::Ok(result)
|
||||
}
|
||||
|
||||
async fn resolve_context_entry(
|
||||
project: Entity<Project>,
|
||||
excerpt: ExpectedExcerpt,
|
||||
mut cx: AsyncApp,
|
||||
) -> Result<(Entity<Buffer>, Vec<Range<Anchor>>)> {
|
||||
let buffer = project
|
||||
.update(&mut cx, |project, cx| {
|
||||
let project_path = project.find_project_path(&excerpt.path, cx).unwrap();
|
||||
project.open_buffer(project_path, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let ranges = buffer.read_with(&mut cx, |buffer, _| {
|
||||
let full_text = buffer.text();
|
||||
let offset = full_text
|
||||
.find(&excerpt.text)
|
||||
.expect("Expected context not found");
|
||||
let point = buffer.offset_to_point(offset);
|
||||
excerpt
|
||||
.required_lines
|
||||
.iter()
|
||||
.map(|line| {
|
||||
let row = point.row + line.0;
|
||||
let range = Point::new(row, 0)..Point::new(row + 1, 0);
|
||||
buffer.anchor_after(range.start)..buffer.anchor_before(range.end)
|
||||
})
|
||||
.collect()
|
||||
})?;
|
||||
|
||||
Ok((buffer, ranges))
|
||||
}
|
||||
|
||||
struct RunCache {
|
||||
cache_mode: CacheMode,
|
||||
example_run_dir: PathBuf,
|
||||
|
||||
Reference in New Issue
Block a user