Compare commits

...

6 Commits

Author SHA1 Message Date
Agus Zubiaga
ed86e5c406 Move all refresh state to zeta global 2025-11-21 13:10:17 -03:00
Agus Zubiaga
9088862322 Refresh prediction on diagnostics update 2025-11-21 11:59:56 -03:00
Agus Zubiaga
972604bd0e Refresh prediction on project diagnostic update 2025-11-20 20:09:30 -03:00
Agus Zubiaga
1aa3362729 Search in current file first 2025-11-20 17:58:45 -03:00
Agus Zubiaga
f2dad964a1 Try to predict at first diagnostic location 2025-11-20 16:44:13 -03:00
Agus Zubiaga
98a3ecbb05 Include diagnostics near the cursor in sweep request 2025-11-20 16:43:15 -03:00
11 changed files with 537 additions and 255 deletions

View File

@@ -182,7 +182,7 @@ impl EditPredictionProvider for CodestralCompletionProvider {
Self::api_key(cx).is_some()
}
fn is_refreshing(&self) -> bool {
fn is_refreshing(&self, _cx: &App) -> bool {
self.pending_request.is_some()
}

View File

@@ -68,7 +68,7 @@ impl EditPredictionProvider for CopilotCompletionProvider {
false
}
fn is_refreshing(&self) -> bool {
fn is_refreshing(&self, _cx: &App) -> bool {
self.pending_refresh.is_some() && self.completions.is_empty()
}

View File

@@ -87,7 +87,7 @@ pub trait EditPredictionProvider: 'static + Sized {
cursor_position: language::Anchor,
cx: &App,
) -> bool;
fn is_refreshing(&self) -> bool;
fn is_refreshing(&self, cx: &App) -> bool;
fn refresh(
&mut self,
buffer: Entity<Buffer>,
@@ -200,7 +200,7 @@ where
}
fn is_refreshing(&self, cx: &App) -> bool {
self.read(cx).is_refreshing()
self.read(cx).is_refreshing(cx)
}
fn refresh(

View File

@@ -469,7 +469,7 @@ impl EditPredictionProvider for FakeEditPredictionProvider {
true
}
fn is_refreshing(&self) -> bool {
fn is_refreshing(&self, _cx: &gpui::App) -> bool {
false
}
@@ -542,7 +542,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
true
}
fn is_refreshing(&self) -> bool {
fn is_refreshing(&self, _cx: &gpui::App) -> bool {
false
}

View File

@@ -129,7 +129,7 @@ impl EditPredictionProvider for SupermavenCompletionProvider {
self.supermaven.read(cx).is_enabled()
}
fn is_refreshing(&self) -> bool {
fn is_refreshing(&self, _cx: &App) -> bool {
self.pending_refresh.is_some() && self.completion_id.is_none()
}

View File

@@ -374,6 +374,7 @@ impl PartialEq<str> for RelPath {
}
}
#[derive(Default)]
pub struct RelPathComponents<'a>(&'a str);
pub struct RelPathAncestors<'a>(Option<&'a str>);

View File

@@ -1486,7 +1486,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
) -> bool {
true
}
fn is_refreshing(&self) -> bool {
fn is_refreshing(&self, _cx: &App) -> bool {
!self.pending_completions.is_empty()
}

View File

@@ -32,7 +32,9 @@ indoc.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
lsp.workspace = true
open_ai.workspace = true
pretty_assertions.workspace = true
project.workspace = true
release_channel.workspace = true
serde.workspace = true
@@ -44,7 +46,6 @@ util.workspace = true
uuid.workspace = true
workspace.workspace = true
worktree.workspace = true
pretty_assertions.workspace = true
[dev-dependencies]
clock = { workspace = true, features = ["test-support"] }

View File

@@ -1,24 +1,15 @@
use std::{
cmp,
sync::Arc,
time::{Duration, Instant},
};
use std::{cmp, sync::Arc, time::Duration};
use arrayvec::ArrayVec;
use client::{Client, UserStore};
use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
use gpui::{App, Entity, Task, prelude::*};
use gpui::{App, Entity, prelude::*};
use language::ToPoint as _;
use project::Project;
use util::ResultExt as _;
use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
pub struct ZetaEditPredictionProvider {
zeta: Entity<Zeta>,
next_pending_prediction_id: usize,
pending_predictions: ArrayVec<PendingPrediction, 2>,
last_request_timestamp: Instant,
project: Entity<Project>,
}
@@ -29,28 +20,25 @@ impl ZetaEditPredictionProvider {
project: Entity<Project>,
client: &Arc<Client>,
user_store: &Entity<UserStore>,
cx: &mut App,
cx: &mut Context<Self>,
) -> Self {
let zeta = Zeta::global(client, user_store, cx);
zeta.update(cx, |zeta, cx| {
zeta.register_project(&project, cx);
});
cx.observe(&zeta, |_this, _zeta, cx| {
cx.notify();
})
.detach();
Self {
zeta,
next_pending_prediction_id: 0,
pending_predictions: ArrayVec::new(),
last_request_timestamp: Instant::now(),
project: project,
zeta,
}
}
}
struct PendingPrediction {
id: usize,
_task: Task<()>,
}
impl EditPredictionProvider for ZetaEditPredictionProvider {
fn name() -> &'static str {
"zed-predict2"
@@ -95,8 +83,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
}
}
fn is_refreshing(&self) -> bool {
!self.pending_predictions.is_empty()
fn is_refreshing(&self, cx: &App) -> bool {
self.zeta.read(cx).is_refreshing(&self.project)
}
fn refresh(
@@ -123,59 +111,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
self.zeta.update(cx, |zeta, cx| {
zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx);
zeta.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
});
let pending_prediction_id = self.next_pending_prediction_id;
self.next_pending_prediction_id += 1;
let last_request_timestamp = self.last_request_timestamp;
let project = self.project.clone();
let task = cx.spawn(async move |this, cx| {
if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
.checked_duration_since(Instant::now())
{
cx.background_executor().timer(timeout).await;
}
let refresh_task = this.update(cx, |this, cx| {
this.last_request_timestamp = Instant::now();
this.zeta.update(cx, |zeta, cx| {
zeta.refresh_prediction(&project, &buffer, cursor_position, cx)
})
});
if let Some(refresh_task) = refresh_task.ok() {
refresh_task.await.log_err();
}
this.update(cx, |this, cx| {
if this.pending_predictions[0].id == pending_prediction_id {
this.pending_predictions.remove(0);
} else {
this.pending_predictions.clear();
}
cx.notify();
})
.ok();
});
// We always maintain at most two pending predictions. When we already
// have two, we replace the newest one.
if self.pending_predictions.len() <= 1 {
self.pending_predictions.push(PendingPrediction {
id: pending_prediction_id,
_task: task,
});
} else if self.pending_predictions.len() == 2 {
self.pending_predictions.pop();
self.pending_predictions.push(PendingPrediction {
id: pending_prediction_id,
_task: task,
});
}
cx.notify();
}
fn cycle(
@@ -191,14 +128,12 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
self.zeta.update(cx, |zeta, cx| {
zeta.accept_current_prediction(&self.project, cx);
});
self.pending_predictions.clear();
}
fn discard(&mut self, cx: &mut Context<Self>) {
self.zeta.update(cx, |zeta, _cx| {
zeta.discard_current_prediction(&self.project);
});
self.pending_predictions.clear();
}
fn suggest(

View File

@@ -1,4 +1,5 @@
use anyhow::{Context as _, Result, anyhow, bail};
use arrayvec::ArrayVec;
use chrono::TimeDelta;
use client::{Client, EditPredictionUsage, UserStore};
use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
@@ -19,18 +20,20 @@ use futures::AsyncReadExt as _;
use futures::channel::{mpsc, oneshot};
use gpui::http_client::{AsyncBody, Method};
use gpui::{
App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
http_client, prelude::*,
App, AsyncApp, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task,
WeakEntity, http_client, prelude::*,
};
use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint};
use language::{BufferSnapshot, OffsetRangeExt};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use lsp::DiagnosticSeverity;
use open_ai::FunctionDefinition;
use project::{Project, ProjectPath};
use release_channel::AppVersion;
use serde::de::DeserializeOwned;
use std::collections::{VecDeque, hash_map};
use std::fmt::Write;
use std::ops::Range;
use std::path::Path;
use std::str::FromStr as _;
@@ -39,7 +42,7 @@ use std::time::{Duration, Instant};
use std::{env, mem};
use thiserror::Error;
use util::rel_path::RelPathBuf;
use util::{LogErrorFuture, ResultExt as _, TryFutureExt};
use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
pub mod assemble_excerpts;
@@ -239,6 +242,9 @@ struct ZetaProject {
recent_paths: VecDeque<ProjectPath>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
current_prediction: Option<CurrentEditPrediction>,
next_pending_prediction_id: usize,
pending_predictions: ArrayVec<PendingPrediction, 2>,
last_prediction_refresh: Option<(EntityId, Instant)>,
context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
refresh_context_debounce_task: Option<Task<Option<()>>>,
@@ -248,7 +254,7 @@ struct ZetaProject {
#[derive(Debug, Clone)]
struct CurrentEditPrediction {
pub requested_by_buffer_id: EntityId,
pub requested_by: PredictionRequestedBy,
pub prediction: EditPrediction,
}
@@ -272,11 +278,13 @@ impl CurrentEditPrediction {
return true;
};
let requested_by_buffer_id = self.requested_by.buffer_id();
// This reduces the occurrence of UI thrash from replacing edits
//
// TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
&& self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
&& requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
&& old_edits.len() == 1
&& new_edits.len() == 1
{
@@ -289,6 +297,26 @@ impl CurrentEditPrediction {
}
}
#[derive(Debug, Clone)]
enum PredictionRequestedBy {
DiagnosticsUpdate,
Buffer(EntityId),
}
impl PredictionRequestedBy {
pub fn buffer_id(&self) -> Option<EntityId> {
match self {
PredictionRequestedBy::DiagnosticsUpdate => None,
PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
}
}
}
struct PendingPrediction {
id: usize,
_task: Task<()>,
}
/// A prediction from the perspective of a buffer.
#[derive(Debug)]
enum BufferEditPrediction<'a> {
@@ -508,31 +536,48 @@ impl Zeta {
recent_paths: VecDeque::new(),
registered_buffers: HashMap::default(),
current_prediction: None,
pending_predictions: ArrayVec::new(),
next_pending_prediction_id: 0,
last_prediction_refresh: None,
context: None,
refresh_context_task: None,
refresh_context_debounce_task: None,
refresh_context_timestamp: None,
_subscription: cx.subscribe(&project, |this, project, event, cx| {
// TODO [zeta2] init with recent paths
if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
if let project::Event::ActiveEntryChanged(Some(active_entry_id)) = event {
let path = project.read(cx).path_for_entry(*active_entry_id, cx);
if let Some(path) = path {
if let Some(ix) = zeta_project
.recent_paths
.iter()
.position(|probe| probe == &path)
{
zeta_project.recent_paths.remove(ix);
}
zeta_project.recent_paths.push_front(path);
}
}
}
}),
_subscription: cx.subscribe(&project, Self::handle_project_event),
})
}
fn handle_project_event(
&mut self,
project: Entity<Project>,
event: &project::Event,
cx: &mut Context<Self>,
) {
// TODO [zeta2] init with recent paths
match event {
project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
return;
};
let path = project.read(cx).path_for_entry(*active_entry_id, cx);
if let Some(path) = path {
if let Some(ix) = zeta_project
.recent_paths
.iter()
.position(|probe| probe == &path)
{
zeta_project.recent_paths.remove(ix);
}
zeta_project.recent_paths.push_front(path);
}
}
project::Event::DiagnosticsUpdated { .. } => {
self.refresh_prediction_from_diagnostics(project, cx);
}
_ => (),
}
}
fn register_buffer_impl<'a>(
zeta_project: &'a mut ZetaProject,
buffer: &Entity<Buffer>,
@@ -645,16 +690,25 @@ impl Zeta {
let project_state = self.projects.get(&project.entity_id())?;
let CurrentEditPrediction {
requested_by_buffer_id,
requested_by,
prediction,
} = project_state.current_prediction.as_ref()?;
if prediction.targets_buffer(buffer.read(cx)) {
Some(BufferEditPrediction::Local { prediction })
} else if *requested_by_buffer_id == buffer.entity_id() {
Some(BufferEditPrediction::Jump { prediction })
} else {
None
let show_jump = match requested_by {
PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
requested_by_buffer_id == &buffer.entity_id()
}
PredictionRequestedBy::DiagnosticsUpdate => true,
};
if show_jump {
Some(BufferEditPrediction::Jump { prediction })
} else {
None
}
}
}
@@ -667,6 +721,7 @@ impl Zeta {
return;
};
let request_id = prediction.prediction.id.to_string();
project_state.pending_predictions.clear();
let client = self.client.clone();
let llm_token = self.llm_token.clone();
@@ -706,49 +761,193 @@ impl Zeta {
fn discard_current_prediction(&mut self, project: &Entity<Project>) {
if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
project_state.current_prediction.take();
project_state.pending_predictions.clear();
};
}
pub fn refresh_prediction(
fn is_refreshing(&self, project: &Entity<Project>) -> bool {
self.projects
.get(&project.entity_id())
.is_some_and(|project_state| !project_state.pending_predictions.is_empty())
}
pub fn refresh_prediction_from_buffer(
&mut self,
project: &Entity<Project>,
buffer: &Entity<Buffer>,
project: Entity<Project>,
buffer: Entity<Buffer>,
position: language::Anchor,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let request_task = self.request_prediction(project, buffer, position, cx);
let buffer = buffer.clone();
let project = project.clone();
) {
self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
let Some(request_task) = this
.update(cx, |this, cx| {
this.request_prediction(&project, &buffer, position, cx)
})
.log_err()
else {
return Task::ready(anyhow::Ok(()));
};
cx.spawn(async move |this, cx| {
if let Some(prediction) = request_task.await? {
this.update(cx, |this, cx| {
let project_state = this
.projects
.get_mut(&project.entity_id())
.context("Project not found")?;
let project = project.clone();
cx.spawn(async move |cx| {
if let Some(prediction) = request_task.await? {
this.update(cx, |this, cx| {
let project_state = this
.projects
.get_mut(&project.entity_id())
.context("Project not found")?;
let new_prediction = CurrentEditPrediction {
requested_by_buffer_id: buffer.entity_id(),
prediction: prediction,
};
let new_prediction = CurrentEditPrediction {
requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
prediction: prediction,
};
if project_state
.current_prediction
.as_ref()
.is_none_or(|old_prediction| {
new_prediction.should_replace_prediction(&old_prediction, cx)
})
{
project_state.current_prediction = Some(new_prediction);
}
anyhow::Ok(())
})??;
}
Ok(())
if project_state
.current_prediction
.as_ref()
.is_none_or(|old_prediction| {
new_prediction.should_replace_prediction(&old_prediction, cx)
})
{
project_state.current_prediction = Some(new_prediction);
cx.notify();
}
anyhow::Ok(())
})??;
}
Ok(())
})
})
}
pub fn refresh_prediction_from_diagnostics(
&mut self,
project: Entity<Project>,
cx: &mut Context<Self>,
) {
let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
return;
};
// Prefer predictions from buffer
if zeta_project.current_prediction.is_some() {
return;
};
self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
let Some(open_buffer_task) = project
.update(cx, |project, cx| {
project
.active_entry()
.and_then(|entry| project.path_for_entry(entry, cx))
.map(|path| project.open_buffer(path, cx))
})
.log_err()
.flatten()
else {
return Task::ready(anyhow::Ok(()));
};
cx.spawn(async move |cx| {
let active_buffer = open_buffer_task.await?;
let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
active_buffer,
&snapshot,
Default::default(),
Default::default(),
&project,
cx,
)
.await?
else {
return anyhow::Ok(());
};
let Some(prediction) = this
.update(cx, |this, cx| {
this.request_prediction(&project, &jump_buffer, jump_position, cx)
})?
.await?
else {
return anyhow::Ok(());
};
this.update(cx, |this, cx| {
if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
zeta_project.current_prediction.get_or_insert_with(|| {
cx.notify();
CurrentEditPrediction {
requested_by: PredictionRequestedBy::DiagnosticsUpdate,
prediction,
}
});
}
})?;
anyhow::Ok(())
})
});
}
#[cfg(not(test))]
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
#[cfg(test)]
pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
fn queue_prediction_refresh(
&mut self,
project: Entity<Project>,
throttle_entity: EntityId,
cx: &mut Context<Self>,
do_refresh: impl FnOnce(WeakEntity<Self>, &mut AsyncApp) -> Task<Result<()>> + 'static,
) {
let zeta_project = self.get_or_init_zeta_project(&project, cx);
let pending_prediction_id = zeta_project.next_pending_prediction_id;
zeta_project.next_pending_prediction_id += 1;
let last_request = zeta_project.last_prediction_refresh;
// TODO report cancelled requests like in zeta1
let task = cx.spawn(async move |this, cx| {
if let Some((last_entity, last_timestamp)) = last_request
&& throttle_entity == last_entity
&& let Some(timeout) =
(last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
{
cx.background_executor().timer(timeout).await;
}
do_refresh(this.clone(), cx).await.log_err();
this.update(cx, |this, cx| {
let zeta_project = this.get_or_init_zeta_project(&project, cx);
if zeta_project.pending_predictions[0].id == pending_prediction_id {
zeta_project.pending_predictions.remove(0);
} else {
zeta_project.pending_predictions.clear();
}
cx.notify();
})
.ok();
});
if zeta_project.pending_predictions.len() <= 1 {
zeta_project.pending_predictions.push(PendingPrediction {
id: pending_prediction_id,
_task: task,
});
} else if zeta_project.pending_predictions.len() == 2 {
zeta_project.pending_predictions.pop();
zeta_project.pending_predictions.push(PendingPrediction {
id: pending_prediction_id,
_task: task,
});
}
}
pub fn request_prediction(
&mut self,
project: &Entity<Project>,
@@ -761,7 +960,7 @@ impl Zeta {
self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
}
ZetaEditPredictionModel::Sweep => {
self.request_prediction_with_sweep(project, active_buffer, position, cx)
self.request_prediction_with_sweep(project, active_buffer, position, true, cx)
}
}
}
@@ -771,6 +970,7 @@ impl Zeta {
project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
position: language::Anchor,
allow_jump: bool,
cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> {
let snapshot = active_buffer.read(cx).snapshot();
@@ -793,6 +993,7 @@ impl Zeta {
let project_state = self.get_or_init_zeta_project(project, cx);
let events = project_state.events.clone();
let has_events = !events.is_empty();
let recent_buffers = project_state.recent_paths.iter().cloned();
let http_client = cx.http_client();
@@ -808,114 +1009,188 @@ impl Zeta {
.take(3)
.collect::<Vec<_>>();
let result = cx.background_spawn(async move {
let text = snapshot.text();
const DIAGNOSTIC_LINES_RANGE: u32 = 20;
let mut recent_changes = String::new();
for event in events {
sweep_ai::write_event(event, &mut recent_changes).unwrap();
}
let cursor_point = position.to_point(&snapshot);
let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
let diagnostic_search_range =
Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
let file_chunks = recent_buffer_snapshots
.into_iter()
.map(|snapshot| {
let end_point = language::Point::new(30, 0).min(snapshot.max_point());
sweep_ai::FileChunk {
content: snapshot
.text_for_range(language::Point::zero()..end_point)
.collect(),
file_path: snapshot
.file()
.map(|f| f.path().as_unix_str())
.unwrap_or("untitled")
.to_string(),
let result = cx.background_spawn({
let snapshot = snapshot.clone();
let diagnostic_search_range = diagnostic_search_range.clone();
async move {
let text = snapshot.text();
let mut recent_changes = String::new();
for event in events {
sweep_ai::write_event(event, &mut recent_changes).unwrap();
}
let mut file_chunks = recent_buffer_snapshots
.into_iter()
.map(|snapshot| {
let end_point = Point::new(30, 0).min(snapshot.max_point());
sweep_ai::FileChunk {
content: snapshot.text_for_range(Point::zero()..end_point).collect(),
file_path: snapshot
.file()
.map(|f| f.path().as_unix_str())
.unwrap_or("untitled")
.to_string(),
start_line: 0,
end_line: end_point.row as usize,
timestamp: snapshot.file().and_then(|file| {
Some(
file.disk_state()
.mtime()?
.to_seconds_and_nanos_for_persistence()?
.0,
)
}),
}
})
.collect::<Vec<_>>();
let diagnostic_entries =
snapshot.diagnostics_in_range(diagnostic_search_range, false);
let mut diagnostic_content = String::new();
let mut diagnostic_count = 0;
for entry in diagnostic_entries {
let start_point: Point = entry.range.start;
let severity = match entry.diagnostic.severity {
DiagnosticSeverity::ERROR => "error",
DiagnosticSeverity::WARNING => "warning",
DiagnosticSeverity::INFORMATION => "info",
DiagnosticSeverity::HINT => "hint",
_ => continue,
};
diagnostic_count += 1;
writeln!(
&mut diagnostic_content,
"{} at line {}: {}",
severity,
start_point.row + 1,
entry.diagnostic.message
)?;
}
if !diagnostic_content.is_empty() {
file_chunks.push(sweep_ai::FileChunk {
file_path: format!("Diagnostics for {}", full_path.display()),
start_line: 0,
end_line: end_point.row as usize,
timestamp: snapshot.file().and_then(|file| {
Some(
file.disk_state()
.mtime()?
.to_seconds_and_nanos_for_persistence()?
.0,
)
}),
}
})
.collect();
end_line: diagnostic_count,
content: diagnostic_content,
timestamp: None,
});
}
let request_body = sweep_ai::AutocompleteRequest {
debug_info,
repo_name,
file_path: full_path.clone(),
file_contents: text.clone(),
original_file_contents: text,
cursor_position: offset,
recent_changes: recent_changes.clone(),
changes_above_cursor: true,
multiple_suggestions: false,
branch: None,
file_chunks,
retrieval_chunks: vec![],
recent_user_actions: vec![],
// TODO
privacy_mode_enabled: false,
};
let request_body = sweep_ai::AutocompleteRequest {
debug_info,
repo_name,
file_path: full_path.clone(),
file_contents: text.clone(),
original_file_contents: text,
cursor_position: offset,
recent_changes: recent_changes.clone(),
changes_above_cursor: true,
multiple_suggestions: false,
branch: None,
file_chunks,
retrieval_chunks: vec![],
recent_user_actions: vec![],
// TODO
privacy_mode_enabled: false,
};
let mut buf: Vec<u8> = Vec::new();
let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
serde_json::to_writer(writer, &request_body)?;
let body: AsyncBody = buf.into();
let mut buf: Vec<u8> = Vec::new();
let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
serde_json::to_writer(writer, &request_body)?;
let body: AsyncBody = buf.into();
const SWEEP_API_URL: &str =
"https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
const SWEEP_API_URL: &str =
"https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
let request = http_client::Request::builder()
.uri(SWEEP_API_URL)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_token))
.header("Connection", "keep-alive")
.header("Content-Encoding", "br")
.method(Method::POST)
.body(body)?;
let request = http_client::Request::builder()
.uri(SWEEP_API_URL)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_token))
.header("Connection", "keep-alive")
.header("Content-Encoding", "br")
.method(Method::POST)
.body(body)?;
let mut response = http_client.send(request).await?;
let mut response = http_client.send(request).await?;
let mut body: Vec<u8> = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let mut body: Vec<u8> = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
if !response.status().is_success() {
anyhow::bail!(
"Request failed with status: {:?}\nBody: {}",
response.status(),
String::from_utf8_lossy(&body),
);
};
if !response.status().is_success() {
anyhow::bail!(
"Request failed with status: {:?}\nBody: {}",
response.status(),
String::from_utf8_lossy(&body),
);
};
let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
let old_text = snapshot
.text_for_range(response.start_index..response.end_index)
.collect::<String>();
let edits = language::text_diff(&old_text, &response.completion)
.into_iter()
.map(|(range, text)| {
(
snapshot.anchor_after(response.start_index + range.start)
..snapshot.anchor_before(response.start_index + range.end),
text,
)
})
.collect::<Vec<_>>();
let old_text = snapshot
.text_for_range(response.start_index..response.end_index)
.collect::<String>();
let edits = language::text_diff(&old_text, &response.completion)
.into_iter()
.map(|(range, text)| {
(
snapshot.anchor_after(response.start_index + range.start)
..snapshot.anchor_before(response.start_index + range.end),
text,
)
})
.collect::<Vec<_>>();
anyhow::Ok((response.autocomplete_id, edits, snapshot))
anyhow::Ok((response.autocomplete_id, edits, snapshot))
}
});
let buffer = active_buffer.clone();
let project = project.clone();
let active_buffer = active_buffer.clone();
cx.spawn(async move |_, cx| {
cx.spawn(async move |this, cx| {
let (id, edits, old_snapshot) = result.await?;
if edits.is_empty() {
if has_events
&& allow_jump
&& let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
active_buffer,
&snapshot,
diagnostic_search_range,
cursor_point,
&project,
cx,
)
.await?
{
return this
.update(cx, |this, cx| {
this.request_prediction_with_sweep(
&project,
&jump_buffer,
jump_position,
false,
cx,
)
})?
.await;
}
return anyhow::Ok(None);
}
@@ -946,6 +1221,85 @@ impl Zeta {
})
}
async fn next_diagnostic_location(
active_buffer: Entity<Buffer>,
active_buffer_snapshot: &BufferSnapshot,
active_buffer_diagnostic_search_range: Range<Point>,
active_buffer_cursor_point: Point,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
// find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
let mut jump_location = active_buffer_snapshot
.diagnostic_groups(None)
.into_iter()
.filter_map(|(_, group)| {
let range = &group.entries[group.primary_ix]
.range
.to_point(&active_buffer_snapshot);
if range.overlaps(&active_buffer_diagnostic_search_range) {
None
} else {
Some(range.start)
}
})
.min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
.map(|position| {
(
active_buffer.clone(),
active_buffer_snapshot.anchor_before(position),
)
});
if jump_location.is_none() {
let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
let file = buffer.file()?;
Some(ProjectPath {
worktree_id: file.worktree_id(cx),
path: file.path().clone(),
})
})?;
let buffer_task = project.update(cx, |project, cx| {
let (path, _, _) = project
.diagnostic_summaries(false, cx)
.filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
.max_by_key(|(path, _, _)| {
// find the buffer with errors that shares most parent directories
path.path
.components()
.zip(
active_buffer_path
.as_ref()
.map(|p| p.path.components())
.unwrap_or_default(),
)
.take_while(|(a, b)| a == b)
.count()
})?;
Some(project.open_buffer(path, cx))
})?;
if let Some(buffer_task) = buffer_task {
let closest_buffer = buffer_task.await?;
jump_location = closest_buffer
.read_with(cx, |buffer, _cx| {
buffer
.buffer_diagnostics(None)
.into_iter()
.min_by_key(|entry| entry.diagnostic.severity)
.map(|entry| entry.range.start)
})?
.map(|position| (closest_buffer, position));
}
}
anyhow::Ok(jump_location)
}
fn request_prediction_with_zed_cloud(
&mut self,
project: &Entity<Project>,
@@ -2159,8 +2513,8 @@ mod tests {
// Prediction for current file
let prediction_task = zeta.update(cx, |zeta, cx| {
zeta.refresh_prediction(&project, &buffer1, position, cx)
zeta.update(cx, |zeta, cx| {
zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
});
let (_request, respond_tx) = req_rx.next().await.unwrap();
@@ -2175,7 +2529,8 @@ mod tests {
Bye
"}))
.unwrap();
prediction_task.await.unwrap();
cx.run_until_parked();
zeta.read_with(cx, |zeta, cx| {
let prediction = zeta
@@ -2233,8 +2588,8 @@ mod tests {
});
// Prediction for another file
let prediction_task = zeta.update(cx, |zeta, cx| {
zeta.refresh_prediction(&project, &buffer1, position, cx)
zeta.update(cx, |zeta, cx| {
zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
});
let (_request, respond_tx) = req_rx.next().await.unwrap();
respond_tx
@@ -2247,7 +2602,8 @@ mod tests {
Adios
"#}))
.unwrap();
prediction_task.await.unwrap();
cx.run_until_parked();
zeta.read_with(cx, |zeta, cx| {
let prediction = zeta
.current_prediction_for_buffer(&buffer1, &project, cx)

View File

@@ -1,6 +1,6 @@
mod zeta2_context_view;
use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc, time::Duration};
use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc};
use chrono::TimeDelta;
use client::{Client, UserStore};
@@ -237,24 +237,13 @@ impl Zeta2Inspector {
fn set_zeta_options(&mut self, options: ZetaOptions, cx: &mut Context<Self>) {
self.zeta.update(cx, |this, _cx| this.set_options(options));
const DEBOUNCE_TIME: Duration = Duration::from_millis(100);
if let Some(prediction) = self.last_prediction.as_mut() {
if let Some(buffer) = prediction.buffer.upgrade() {
let position = prediction.position;
let zeta = self.zeta.clone();
let project = self.project.clone();
prediction._task = Some(cx.spawn(async move |_this, cx| {
cx.background_executor().timer(DEBOUNCE_TIME).await;
if let Some(task) = zeta
.update(cx, |zeta, cx| {
zeta.refresh_prediction(&project, &buffer, position, cx)
})
.ok()
{
task.await.log_err();
}
}));
self.zeta.update(cx, |zeta, cx| {
zeta.refresh_prediction_from_buffer(project, buffer, position, cx)
});
prediction.state = LastPredictionState::Requested;
} else {
self.last_prediction.take();