Compare commits
6 Commits
ex-test-in
...
sweep-jump
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed86e5c406 | ||
|
|
9088862322 | ||
|
|
972604bd0e | ||
|
|
1aa3362729 | ||
|
|
f2dad964a1 | ||
|
|
98a3ecbb05 |
@@ -182,7 +182,7 @@ impl EditPredictionProvider for CodestralCompletionProvider {
|
||||
Self::api_key(cx).is_some()
|
||||
}
|
||||
|
||||
fn is_refreshing(&self) -> bool {
|
||||
fn is_refreshing(&self, _cx: &App) -> bool {
|
||||
self.pending_request.is_some()
|
||||
}
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ impl EditPredictionProvider for CopilotCompletionProvider {
|
||||
false
|
||||
}
|
||||
|
||||
fn is_refreshing(&self) -> bool {
|
||||
fn is_refreshing(&self, _cx: &App) -> bool {
|
||||
self.pending_refresh.is_some() && self.completions.is_empty()
|
||||
}
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ pub trait EditPredictionProvider: 'static + Sized {
|
||||
cursor_position: language::Anchor,
|
||||
cx: &App,
|
||||
) -> bool;
|
||||
fn is_refreshing(&self) -> bool;
|
||||
fn is_refreshing(&self, cx: &App) -> bool;
|
||||
fn refresh(
|
||||
&mut self,
|
||||
buffer: Entity<Buffer>,
|
||||
@@ -200,7 +200,7 @@ where
|
||||
}
|
||||
|
||||
fn is_refreshing(&self, cx: &App) -> bool {
|
||||
self.read(cx).is_refreshing()
|
||||
self.read(cx).is_refreshing(cx)
|
||||
}
|
||||
|
||||
fn refresh(
|
||||
|
||||
@@ -469,7 +469,7 @@ impl EditPredictionProvider for FakeEditPredictionProvider {
|
||||
true
|
||||
}
|
||||
|
||||
fn is_refreshing(&self) -> bool {
|
||||
fn is_refreshing(&self, _cx: &gpui::App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -542,7 +542,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
|
||||
true
|
||||
}
|
||||
|
||||
fn is_refreshing(&self) -> bool {
|
||||
fn is_refreshing(&self, _cx: &gpui::App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
@@ -129,7 +129,7 @@ impl EditPredictionProvider for SupermavenCompletionProvider {
|
||||
self.supermaven.read(cx).is_enabled()
|
||||
}
|
||||
|
||||
fn is_refreshing(&self) -> bool {
|
||||
fn is_refreshing(&self, _cx: &App) -> bool {
|
||||
self.pending_refresh.is_some() && self.completion_id.is_none()
|
||||
}
|
||||
|
||||
|
||||
@@ -374,6 +374,7 @@ impl PartialEq<str> for RelPath {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct RelPathComponents<'a>(&'a str);
|
||||
|
||||
pub struct RelPathAncestors<'a>(Option<&'a str>);
|
||||
|
||||
@@ -1486,7 +1486,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
) -> bool {
|
||||
true
|
||||
}
|
||||
fn is_refreshing(&self) -> bool {
|
||||
fn is_refreshing(&self, _cx: &App) -> bool {
|
||||
!self.pending_completions.is_empty()
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,9 @@ indoc.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
log.workspace = true
|
||||
lsp.workspace = true
|
||||
open_ai.workspace = true
|
||||
pretty_assertions.workspace = true
|
||||
project.workspace = true
|
||||
release_channel.workspace = true
|
||||
serde.workspace = true
|
||||
@@ -44,7 +46,6 @@ util.workspace = true
|
||||
uuid.workspace = true
|
||||
workspace.workspace = true
|
||||
worktree.workspace = true
|
||||
pretty_assertions.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
clock = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -1,24 +1,15 @@
|
||||
use std::{
|
||||
cmp,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use std::{cmp, sync::Arc, time::Duration};
|
||||
|
||||
use arrayvec::ArrayVec;
|
||||
use client::{Client, UserStore};
|
||||
use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
|
||||
use gpui::{App, Entity, Task, prelude::*};
|
||||
use gpui::{App, Entity, prelude::*};
|
||||
use language::ToPoint as _;
|
||||
use project::Project;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
|
||||
|
||||
pub struct ZetaEditPredictionProvider {
|
||||
zeta: Entity<Zeta>,
|
||||
next_pending_prediction_id: usize,
|
||||
pending_predictions: ArrayVec<PendingPrediction, 2>,
|
||||
last_request_timestamp: Instant,
|
||||
project: Entity<Project>,
|
||||
}
|
||||
|
||||
@@ -29,28 +20,25 @@ impl ZetaEditPredictionProvider {
|
||||
project: Entity<Project>,
|
||||
client: &Arc<Client>,
|
||||
user_store: &Entity<UserStore>,
|
||||
cx: &mut App,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let zeta = Zeta::global(client, user_store, cx);
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.register_project(&project, cx);
|
||||
});
|
||||
|
||||
cx.observe(&zeta, |_this, _zeta, cx| {
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
|
||||
Self {
|
||||
zeta,
|
||||
next_pending_prediction_id: 0,
|
||||
pending_predictions: ArrayVec::new(),
|
||||
last_request_timestamp: Instant::now(),
|
||||
project: project,
|
||||
zeta,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PendingPrediction {
|
||||
id: usize,
|
||||
_task: Task<()>,
|
||||
}
|
||||
|
||||
impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
fn name() -> &'static str {
|
||||
"zed-predict2"
|
||||
@@ -95,8 +83,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
fn is_refreshing(&self) -> bool {
|
||||
!self.pending_predictions.is_empty()
|
||||
fn is_refreshing(&self, cx: &App) -> bool {
|
||||
self.zeta.read(cx).is_refreshing(&self.project)
|
||||
}
|
||||
|
||||
fn refresh(
|
||||
@@ -123,59 +111,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
|
||||
self.zeta.update(cx, |zeta, cx| {
|
||||
zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx);
|
||||
zeta.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
|
||||
});
|
||||
|
||||
let pending_prediction_id = self.next_pending_prediction_id;
|
||||
self.next_pending_prediction_id += 1;
|
||||
let last_request_timestamp = self.last_request_timestamp;
|
||||
|
||||
let project = self.project.clone();
|
||||
let task = cx.spawn(async move |this, cx| {
|
||||
if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
|
||||
.checked_duration_since(Instant::now())
|
||||
{
|
||||
cx.background_executor().timer(timeout).await;
|
||||
}
|
||||
|
||||
let refresh_task = this.update(cx, |this, cx| {
|
||||
this.last_request_timestamp = Instant::now();
|
||||
this.zeta.update(cx, |zeta, cx| {
|
||||
zeta.refresh_prediction(&project, &buffer, cursor_position, cx)
|
||||
})
|
||||
});
|
||||
|
||||
if let Some(refresh_task) = refresh_task.ok() {
|
||||
refresh_task.await.log_err();
|
||||
}
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
if this.pending_predictions[0].id == pending_prediction_id {
|
||||
this.pending_predictions.remove(0);
|
||||
} else {
|
||||
this.pending_predictions.clear();
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
});
|
||||
|
||||
// We always maintain at most two pending predictions. When we already
|
||||
// have two, we replace the newest one.
|
||||
if self.pending_predictions.len() <= 1 {
|
||||
self.pending_predictions.push(PendingPrediction {
|
||||
id: pending_prediction_id,
|
||||
_task: task,
|
||||
});
|
||||
} else if self.pending_predictions.len() == 2 {
|
||||
self.pending_predictions.pop();
|
||||
self.pending_predictions.push(PendingPrediction {
|
||||
id: pending_prediction_id,
|
||||
_task: task,
|
||||
});
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn cycle(
|
||||
@@ -191,14 +128,12 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
self.zeta.update(cx, |zeta, cx| {
|
||||
zeta.accept_current_prediction(&self.project, cx);
|
||||
});
|
||||
self.pending_predictions.clear();
|
||||
}
|
||||
|
||||
fn discard(&mut self, cx: &mut Context<Self>) {
|
||||
self.zeta.update(cx, |zeta, _cx| {
|
||||
zeta.discard_current_prediction(&self.project);
|
||||
});
|
||||
self.pending_predictions.clear();
|
||||
}
|
||||
|
||||
fn suggest(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use anyhow::{Context as _, Result, anyhow, bail};
|
||||
use arrayvec::ArrayVec;
|
||||
use chrono::TimeDelta;
|
||||
use client::{Client, EditPredictionUsage, UserStore};
|
||||
use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
|
||||
@@ -19,18 +20,20 @@ use futures::AsyncReadExt as _;
|
||||
use futures::channel::{mpsc, oneshot};
|
||||
use gpui::http_client::{AsyncBody, Method};
|
||||
use gpui::{
|
||||
App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
|
||||
http_client, prelude::*,
|
||||
App, AsyncApp, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task,
|
||||
WeakEntity, http_client, prelude::*,
|
||||
};
|
||||
use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint};
|
||||
use language::{BufferSnapshot, OffsetRangeExt};
|
||||
use language_model::{LlmApiToken, RefreshLlmTokenListener};
|
||||
use lsp::DiagnosticSeverity;
|
||||
use open_ai::FunctionDefinition;
|
||||
use project::{Project, ProjectPath};
|
||||
use release_channel::AppVersion;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::collections::{VecDeque, hash_map};
|
||||
|
||||
use std::fmt::Write;
|
||||
use std::ops::Range;
|
||||
use std::path::Path;
|
||||
use std::str::FromStr as _;
|
||||
@@ -39,7 +42,7 @@ use std::time::{Duration, Instant};
|
||||
use std::{env, mem};
|
||||
use thiserror::Error;
|
||||
use util::rel_path::RelPathBuf;
|
||||
use util::{LogErrorFuture, ResultExt as _, TryFutureExt};
|
||||
use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
|
||||
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
|
||||
|
||||
pub mod assemble_excerpts;
|
||||
@@ -239,6 +242,9 @@ struct ZetaProject {
|
||||
recent_paths: VecDeque<ProjectPath>,
|
||||
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
|
||||
current_prediction: Option<CurrentEditPrediction>,
|
||||
next_pending_prediction_id: usize,
|
||||
pending_predictions: ArrayVec<PendingPrediction, 2>,
|
||||
last_prediction_refresh: Option<(EntityId, Instant)>,
|
||||
context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
|
||||
refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
|
||||
refresh_context_debounce_task: Option<Task<Option<()>>>,
|
||||
@@ -248,7 +254,7 @@ struct ZetaProject {
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CurrentEditPrediction {
|
||||
pub requested_by_buffer_id: EntityId,
|
||||
pub requested_by: PredictionRequestedBy,
|
||||
pub prediction: EditPrediction,
|
||||
}
|
||||
|
||||
@@ -272,11 +278,13 @@ impl CurrentEditPrediction {
|
||||
return true;
|
||||
};
|
||||
|
||||
let requested_by_buffer_id = self.requested_by.buffer_id();
|
||||
|
||||
// This reduces the occurrence of UI thrash from replacing edits
|
||||
//
|
||||
// TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
|
||||
if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
|
||||
&& self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
|
||||
if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
|
||||
&& requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
|
||||
&& old_edits.len() == 1
|
||||
&& new_edits.len() == 1
|
||||
{
|
||||
@@ -289,6 +297,26 @@ impl CurrentEditPrediction {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum PredictionRequestedBy {
|
||||
DiagnosticsUpdate,
|
||||
Buffer(EntityId),
|
||||
}
|
||||
|
||||
impl PredictionRequestedBy {
|
||||
pub fn buffer_id(&self) -> Option<EntityId> {
|
||||
match self {
|
||||
PredictionRequestedBy::DiagnosticsUpdate => None,
|
||||
PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PendingPrediction {
|
||||
id: usize,
|
||||
_task: Task<()>,
|
||||
}
|
||||
|
||||
/// A prediction from the perspective of a buffer.
|
||||
#[derive(Debug)]
|
||||
enum BufferEditPrediction<'a> {
|
||||
@@ -508,31 +536,48 @@ impl Zeta {
|
||||
recent_paths: VecDeque::new(),
|
||||
registered_buffers: HashMap::default(),
|
||||
current_prediction: None,
|
||||
pending_predictions: ArrayVec::new(),
|
||||
next_pending_prediction_id: 0,
|
||||
last_prediction_refresh: None,
|
||||
context: None,
|
||||
refresh_context_task: None,
|
||||
refresh_context_debounce_task: None,
|
||||
refresh_context_timestamp: None,
|
||||
_subscription: cx.subscribe(&project, |this, project, event, cx| {
|
||||
// TODO [zeta2] init with recent paths
|
||||
if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
|
||||
if let project::Event::ActiveEntryChanged(Some(active_entry_id)) = event {
|
||||
let path = project.read(cx).path_for_entry(*active_entry_id, cx);
|
||||
if let Some(path) = path {
|
||||
if let Some(ix) = zeta_project
|
||||
.recent_paths
|
||||
.iter()
|
||||
.position(|probe| probe == &path)
|
||||
{
|
||||
zeta_project.recent_paths.remove(ix);
|
||||
}
|
||||
zeta_project.recent_paths.push_front(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
_subscription: cx.subscribe(&project, Self::handle_project_event),
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_project_event(
|
||||
&mut self,
|
||||
project: Entity<Project>,
|
||||
event: &project::Event,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
// TODO [zeta2] init with recent paths
|
||||
match event {
|
||||
project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
|
||||
let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
|
||||
return;
|
||||
};
|
||||
let path = project.read(cx).path_for_entry(*active_entry_id, cx);
|
||||
if let Some(path) = path {
|
||||
if let Some(ix) = zeta_project
|
||||
.recent_paths
|
||||
.iter()
|
||||
.position(|probe| probe == &path)
|
||||
{
|
||||
zeta_project.recent_paths.remove(ix);
|
||||
}
|
||||
zeta_project.recent_paths.push_front(path);
|
||||
}
|
||||
}
|
||||
project::Event::DiagnosticsUpdated { .. } => {
|
||||
self.refresh_prediction_from_diagnostics(project, cx);
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
|
||||
fn register_buffer_impl<'a>(
|
||||
zeta_project: &'a mut ZetaProject,
|
||||
buffer: &Entity<Buffer>,
|
||||
@@ -645,16 +690,25 @@ impl Zeta {
|
||||
let project_state = self.projects.get(&project.entity_id())?;
|
||||
|
||||
let CurrentEditPrediction {
|
||||
requested_by_buffer_id,
|
||||
requested_by,
|
||||
prediction,
|
||||
} = project_state.current_prediction.as_ref()?;
|
||||
|
||||
if prediction.targets_buffer(buffer.read(cx)) {
|
||||
Some(BufferEditPrediction::Local { prediction })
|
||||
} else if *requested_by_buffer_id == buffer.entity_id() {
|
||||
Some(BufferEditPrediction::Jump { prediction })
|
||||
} else {
|
||||
None
|
||||
let show_jump = match requested_by {
|
||||
PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
|
||||
requested_by_buffer_id == &buffer.entity_id()
|
||||
}
|
||||
PredictionRequestedBy::DiagnosticsUpdate => true,
|
||||
};
|
||||
|
||||
if show_jump {
|
||||
Some(BufferEditPrediction::Jump { prediction })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -667,6 +721,7 @@ impl Zeta {
|
||||
return;
|
||||
};
|
||||
let request_id = prediction.prediction.id.to_string();
|
||||
project_state.pending_predictions.clear();
|
||||
|
||||
let client = self.client.clone();
|
||||
let llm_token = self.llm_token.clone();
|
||||
@@ -706,49 +761,193 @@ impl Zeta {
|
||||
fn discard_current_prediction(&mut self, project: &Entity<Project>) {
|
||||
if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
|
||||
project_state.current_prediction.take();
|
||||
project_state.pending_predictions.clear();
|
||||
};
|
||||
}
|
||||
|
||||
pub fn refresh_prediction(
|
||||
fn is_refreshing(&self, project: &Entity<Project>) -> bool {
|
||||
self.projects
|
||||
.get(&project.entity_id())
|
||||
.is_some_and(|project_state| !project_state.pending_predictions.is_empty())
|
||||
}
|
||||
|
||||
pub fn refresh_prediction_from_buffer(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
project: Entity<Project>,
|
||||
buffer: Entity<Buffer>,
|
||||
position: language::Anchor,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
let request_task = self.request_prediction(project, buffer, position, cx);
|
||||
let buffer = buffer.clone();
|
||||
let project = project.clone();
|
||||
) {
|
||||
self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
|
||||
let Some(request_task) = this
|
||||
.update(cx, |this, cx| {
|
||||
this.request_prediction(&project, &buffer, position, cx)
|
||||
})
|
||||
.log_err()
|
||||
else {
|
||||
return Task::ready(anyhow::Ok(()));
|
||||
};
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
if let Some(prediction) = request_task.await? {
|
||||
this.update(cx, |this, cx| {
|
||||
let project_state = this
|
||||
.projects
|
||||
.get_mut(&project.entity_id())
|
||||
.context("Project not found")?;
|
||||
let project = project.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
if let Some(prediction) = request_task.await? {
|
||||
this.update(cx, |this, cx| {
|
||||
let project_state = this
|
||||
.projects
|
||||
.get_mut(&project.entity_id())
|
||||
.context("Project not found")?;
|
||||
|
||||
let new_prediction = CurrentEditPrediction {
|
||||
requested_by_buffer_id: buffer.entity_id(),
|
||||
prediction: prediction,
|
||||
};
|
||||
let new_prediction = CurrentEditPrediction {
|
||||
requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
|
||||
prediction: prediction,
|
||||
};
|
||||
|
||||
if project_state
|
||||
.current_prediction
|
||||
.as_ref()
|
||||
.is_none_or(|old_prediction| {
|
||||
new_prediction.should_replace_prediction(&old_prediction, cx)
|
||||
})
|
||||
{
|
||||
project_state.current_prediction = Some(new_prediction);
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})??;
|
||||
}
|
||||
Ok(())
|
||||
if project_state
|
||||
.current_prediction
|
||||
.as_ref()
|
||||
.is_none_or(|old_prediction| {
|
||||
new_prediction.should_replace_prediction(&old_prediction, cx)
|
||||
})
|
||||
{
|
||||
project_state.current_prediction = Some(new_prediction);
|
||||
cx.notify();
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})??;
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn refresh_prediction_from_diagnostics(
|
||||
&mut self,
|
||||
project: Entity<Project>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
|
||||
return;
|
||||
};
|
||||
|
||||
// Prefer predictions from buffer
|
||||
if zeta_project.current_prediction.is_some() {
|
||||
return;
|
||||
};
|
||||
|
||||
self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
|
||||
let Some(open_buffer_task) = project
|
||||
.update(cx, |project, cx| {
|
||||
project
|
||||
.active_entry()
|
||||
.and_then(|entry| project.path_for_entry(entry, cx))
|
||||
.map(|path| project.open_buffer(path, cx))
|
||||
})
|
||||
.log_err()
|
||||
.flatten()
|
||||
else {
|
||||
return Task::ready(anyhow::Ok(()));
|
||||
};
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let active_buffer = open_buffer_task.await?;
|
||||
let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
|
||||
let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
|
||||
active_buffer,
|
||||
&snapshot,
|
||||
Default::default(),
|
||||
Default::default(),
|
||||
&project,
|
||||
cx,
|
||||
)
|
||||
.await?
|
||||
else {
|
||||
return anyhow::Ok(());
|
||||
};
|
||||
|
||||
let Some(prediction) = this
|
||||
.update(cx, |this, cx| {
|
||||
this.request_prediction(&project, &jump_buffer, jump_position, cx)
|
||||
})?
|
||||
.await?
|
||||
else {
|
||||
return anyhow::Ok(());
|
||||
};
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
|
||||
zeta_project.current_prediction.get_or_insert_with(|| {
|
||||
cx.notify();
|
||||
CurrentEditPrediction {
|
||||
requested_by: PredictionRequestedBy::DiagnosticsUpdate,
|
||||
prediction,
|
||||
}
|
||||
});
|
||||
}
|
||||
})?;
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
|
||||
#[cfg(test)]
|
||||
pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
|
||||
|
||||
fn queue_prediction_refresh(
|
||||
&mut self,
|
||||
project: Entity<Project>,
|
||||
throttle_entity: EntityId,
|
||||
cx: &mut Context<Self>,
|
||||
do_refresh: impl FnOnce(WeakEntity<Self>, &mut AsyncApp) -> Task<Result<()>> + 'static,
|
||||
) {
|
||||
let zeta_project = self.get_or_init_zeta_project(&project, cx);
|
||||
let pending_prediction_id = zeta_project.next_pending_prediction_id;
|
||||
zeta_project.next_pending_prediction_id += 1;
|
||||
let last_request = zeta_project.last_prediction_refresh;
|
||||
|
||||
// TODO report cancelled requests like in zeta1
|
||||
let task = cx.spawn(async move |this, cx| {
|
||||
if let Some((last_entity, last_timestamp)) = last_request
|
||||
&& throttle_entity == last_entity
|
||||
&& let Some(timeout) =
|
||||
(last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
|
||||
{
|
||||
cx.background_executor().timer(timeout).await;
|
||||
}
|
||||
|
||||
do_refresh(this.clone(), cx).await.log_err();
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
let zeta_project = this.get_or_init_zeta_project(&project, cx);
|
||||
|
||||
if zeta_project.pending_predictions[0].id == pending_prediction_id {
|
||||
zeta_project.pending_predictions.remove(0);
|
||||
} else {
|
||||
zeta_project.pending_predictions.clear();
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
});
|
||||
|
||||
if zeta_project.pending_predictions.len() <= 1 {
|
||||
zeta_project.pending_predictions.push(PendingPrediction {
|
||||
id: pending_prediction_id,
|
||||
_task: task,
|
||||
});
|
||||
} else if zeta_project.pending_predictions.len() == 2 {
|
||||
zeta_project.pending_predictions.pop();
|
||||
zeta_project.pending_predictions.push(PendingPrediction {
|
||||
id: pending_prediction_id,
|
||||
_task: task,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn request_prediction(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
@@ -761,7 +960,7 @@ impl Zeta {
|
||||
self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
|
||||
}
|
||||
ZetaEditPredictionModel::Sweep => {
|
||||
self.request_prediction_with_sweep(project, active_buffer, position, cx)
|
||||
self.request_prediction_with_sweep(project, active_buffer, position, true, cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -771,6 +970,7 @@ impl Zeta {
|
||||
project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
position: language::Anchor,
|
||||
allow_jump: bool,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Option<EditPrediction>>> {
|
||||
let snapshot = active_buffer.read(cx).snapshot();
|
||||
@@ -793,6 +993,7 @@ impl Zeta {
|
||||
|
||||
let project_state = self.get_or_init_zeta_project(project, cx);
|
||||
let events = project_state.events.clone();
|
||||
let has_events = !events.is_empty();
|
||||
let recent_buffers = project_state.recent_paths.iter().cloned();
|
||||
let http_client = cx.http_client();
|
||||
|
||||
@@ -808,114 +1009,188 @@ impl Zeta {
|
||||
.take(3)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let result = cx.background_spawn(async move {
|
||||
let text = snapshot.text();
|
||||
const DIAGNOSTIC_LINES_RANGE: u32 = 20;
|
||||
|
||||
let mut recent_changes = String::new();
|
||||
for event in events {
|
||||
sweep_ai::write_event(event, &mut recent_changes).unwrap();
|
||||
}
|
||||
let cursor_point = position.to_point(&snapshot);
|
||||
let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
|
||||
let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
|
||||
let diagnostic_search_range =
|
||||
Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
|
||||
|
||||
let file_chunks = recent_buffer_snapshots
|
||||
.into_iter()
|
||||
.map(|snapshot| {
|
||||
let end_point = language::Point::new(30, 0).min(snapshot.max_point());
|
||||
sweep_ai::FileChunk {
|
||||
content: snapshot
|
||||
.text_for_range(language::Point::zero()..end_point)
|
||||
.collect(),
|
||||
file_path: snapshot
|
||||
.file()
|
||||
.map(|f| f.path().as_unix_str())
|
||||
.unwrap_or("untitled")
|
||||
.to_string(),
|
||||
let result = cx.background_spawn({
|
||||
let snapshot = snapshot.clone();
|
||||
let diagnostic_search_range = diagnostic_search_range.clone();
|
||||
async move {
|
||||
let text = snapshot.text();
|
||||
|
||||
let mut recent_changes = String::new();
|
||||
for event in events {
|
||||
sweep_ai::write_event(event, &mut recent_changes).unwrap();
|
||||
}
|
||||
|
||||
let mut file_chunks = recent_buffer_snapshots
|
||||
.into_iter()
|
||||
.map(|snapshot| {
|
||||
let end_point = Point::new(30, 0).min(snapshot.max_point());
|
||||
sweep_ai::FileChunk {
|
||||
content: snapshot.text_for_range(Point::zero()..end_point).collect(),
|
||||
file_path: snapshot
|
||||
.file()
|
||||
.map(|f| f.path().as_unix_str())
|
||||
.unwrap_or("untitled")
|
||||
.to_string(),
|
||||
start_line: 0,
|
||||
end_line: end_point.row as usize,
|
||||
timestamp: snapshot.file().and_then(|file| {
|
||||
Some(
|
||||
file.disk_state()
|
||||
.mtime()?
|
||||
.to_seconds_and_nanos_for_persistence()?
|
||||
.0,
|
||||
)
|
||||
}),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let diagnostic_entries =
|
||||
snapshot.diagnostics_in_range(diagnostic_search_range, false);
|
||||
let mut diagnostic_content = String::new();
|
||||
let mut diagnostic_count = 0;
|
||||
|
||||
for entry in diagnostic_entries {
|
||||
let start_point: Point = entry.range.start;
|
||||
|
||||
let severity = match entry.diagnostic.severity {
|
||||
DiagnosticSeverity::ERROR => "error",
|
||||
DiagnosticSeverity::WARNING => "warning",
|
||||
DiagnosticSeverity::INFORMATION => "info",
|
||||
DiagnosticSeverity::HINT => "hint",
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
diagnostic_count += 1;
|
||||
|
||||
writeln!(
|
||||
&mut diagnostic_content,
|
||||
"{} at line {}: {}",
|
||||
severity,
|
||||
start_point.row + 1,
|
||||
entry.diagnostic.message
|
||||
)?;
|
||||
}
|
||||
|
||||
if !diagnostic_content.is_empty() {
|
||||
file_chunks.push(sweep_ai::FileChunk {
|
||||
file_path: format!("Diagnostics for {}", full_path.display()),
|
||||
start_line: 0,
|
||||
end_line: end_point.row as usize,
|
||||
timestamp: snapshot.file().and_then(|file| {
|
||||
Some(
|
||||
file.disk_state()
|
||||
.mtime()?
|
||||
.to_seconds_and_nanos_for_persistence()?
|
||||
.0,
|
||||
)
|
||||
}),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
end_line: diagnostic_count,
|
||||
content: diagnostic_content,
|
||||
timestamp: None,
|
||||
});
|
||||
}
|
||||
|
||||
let request_body = sweep_ai::AutocompleteRequest {
|
||||
debug_info,
|
||||
repo_name,
|
||||
file_path: full_path.clone(),
|
||||
file_contents: text.clone(),
|
||||
original_file_contents: text,
|
||||
cursor_position: offset,
|
||||
recent_changes: recent_changes.clone(),
|
||||
changes_above_cursor: true,
|
||||
multiple_suggestions: false,
|
||||
branch: None,
|
||||
file_chunks,
|
||||
retrieval_chunks: vec![],
|
||||
recent_user_actions: vec![],
|
||||
// TODO
|
||||
privacy_mode_enabled: false,
|
||||
};
|
||||
let request_body = sweep_ai::AutocompleteRequest {
|
||||
debug_info,
|
||||
repo_name,
|
||||
file_path: full_path.clone(),
|
||||
file_contents: text.clone(),
|
||||
original_file_contents: text,
|
||||
cursor_position: offset,
|
||||
recent_changes: recent_changes.clone(),
|
||||
changes_above_cursor: true,
|
||||
multiple_suggestions: false,
|
||||
branch: None,
|
||||
file_chunks,
|
||||
retrieval_chunks: vec![],
|
||||
recent_user_actions: vec![],
|
||||
// TODO
|
||||
privacy_mode_enabled: false,
|
||||
};
|
||||
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
|
||||
serde_json::to_writer(writer, &request_body)?;
|
||||
let body: AsyncBody = buf.into();
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
|
||||
serde_json::to_writer(writer, &request_body)?;
|
||||
let body: AsyncBody = buf.into();
|
||||
|
||||
const SWEEP_API_URL: &str =
|
||||
"https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
|
||||
const SWEEP_API_URL: &str =
|
||||
"https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
|
||||
|
||||
let request = http_client::Request::builder()
|
||||
.uri(SWEEP_API_URL)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_token))
|
||||
.header("Connection", "keep-alive")
|
||||
.header("Content-Encoding", "br")
|
||||
.method(Method::POST)
|
||||
.body(body)?;
|
||||
let request = http_client::Request::builder()
|
||||
.uri(SWEEP_API_URL)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_token))
|
||||
.header("Connection", "keep-alive")
|
||||
.header("Content-Encoding", "br")
|
||||
.method(Method::POST)
|
||||
.body(body)?;
|
||||
|
||||
let mut response = http_client.send(request).await?;
|
||||
let mut response = http_client.send(request).await?;
|
||||
|
||||
let mut body: Vec<u8> = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
let mut body: Vec<u8> = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!(
|
||||
"Request failed with status: {:?}\nBody: {}",
|
||||
response.status(),
|
||||
String::from_utf8_lossy(&body),
|
||||
);
|
||||
};
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!(
|
||||
"Request failed with status: {:?}\nBody: {}",
|
||||
response.status(),
|
||||
String::from_utf8_lossy(&body),
|
||||
);
|
||||
};
|
||||
|
||||
let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
|
||||
let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
|
||||
|
||||
let old_text = snapshot
|
||||
.text_for_range(response.start_index..response.end_index)
|
||||
.collect::<String>();
|
||||
let edits = language::text_diff(&old_text, &response.completion)
|
||||
.into_iter()
|
||||
.map(|(range, text)| {
|
||||
(
|
||||
snapshot.anchor_after(response.start_index + range.start)
|
||||
..snapshot.anchor_before(response.start_index + range.end),
|
||||
text,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let old_text = snapshot
|
||||
.text_for_range(response.start_index..response.end_index)
|
||||
.collect::<String>();
|
||||
let edits = language::text_diff(&old_text, &response.completion)
|
||||
.into_iter()
|
||||
.map(|(range, text)| {
|
||||
(
|
||||
snapshot.anchor_after(response.start_index + range.start)
|
||||
..snapshot.anchor_before(response.start_index + range.end),
|
||||
text,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
anyhow::Ok((response.autocomplete_id, edits, snapshot))
|
||||
anyhow::Ok((response.autocomplete_id, edits, snapshot))
|
||||
}
|
||||
});
|
||||
|
||||
let buffer = active_buffer.clone();
|
||||
let project = project.clone();
|
||||
let active_buffer = active_buffer.clone();
|
||||
|
||||
cx.spawn(async move |_, cx| {
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (id, edits, old_snapshot) = result.await?;
|
||||
|
||||
if edits.is_empty() {
|
||||
if has_events
|
||||
&& allow_jump
|
||||
&& let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
|
||||
active_buffer,
|
||||
&snapshot,
|
||||
diagnostic_search_range,
|
||||
cursor_point,
|
||||
&project,
|
||||
cx,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
return this
|
||||
.update(cx, |this, cx| {
|
||||
this.request_prediction_with_sweep(
|
||||
&project,
|
||||
&jump_buffer,
|
||||
jump_position,
|
||||
false,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await;
|
||||
}
|
||||
|
||||
return anyhow::Ok(None);
|
||||
}
|
||||
|
||||
@@ -946,6 +1221,85 @@ impl Zeta {
|
||||
})
|
||||
}
|
||||
|
||||
async fn next_diagnostic_location(
|
||||
active_buffer: Entity<Buffer>,
|
||||
active_buffer_snapshot: &BufferSnapshot,
|
||||
active_buffer_diagnostic_search_range: Range<Point>,
|
||||
active_buffer_cursor_point: Point,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
|
||||
// find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
|
||||
let mut jump_location = active_buffer_snapshot
|
||||
.diagnostic_groups(None)
|
||||
.into_iter()
|
||||
.filter_map(|(_, group)| {
|
||||
let range = &group.entries[group.primary_ix]
|
||||
.range
|
||||
.to_point(&active_buffer_snapshot);
|
||||
if range.overlaps(&active_buffer_diagnostic_search_range) {
|
||||
None
|
||||
} else {
|
||||
Some(range.start)
|
||||
}
|
||||
})
|
||||
.min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
|
||||
.map(|position| {
|
||||
(
|
||||
active_buffer.clone(),
|
||||
active_buffer_snapshot.anchor_before(position),
|
||||
)
|
||||
});
|
||||
|
||||
if jump_location.is_none() {
|
||||
let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
|
||||
let file = buffer.file()?;
|
||||
|
||||
Some(ProjectPath {
|
||||
worktree_id: file.worktree_id(cx),
|
||||
path: file.path().clone(),
|
||||
})
|
||||
})?;
|
||||
|
||||
let buffer_task = project.update(cx, |project, cx| {
|
||||
let (path, _, _) = project
|
||||
.diagnostic_summaries(false, cx)
|
||||
.filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
|
||||
.max_by_key(|(path, _, _)| {
|
||||
// find the buffer with errors that shares most parent directories
|
||||
path.path
|
||||
.components()
|
||||
.zip(
|
||||
active_buffer_path
|
||||
.as_ref()
|
||||
.map(|p| p.path.components())
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
.take_while(|(a, b)| a == b)
|
||||
.count()
|
||||
})?;
|
||||
|
||||
Some(project.open_buffer(path, cx))
|
||||
})?;
|
||||
|
||||
if let Some(buffer_task) = buffer_task {
|
||||
let closest_buffer = buffer_task.await?;
|
||||
|
||||
jump_location = closest_buffer
|
||||
.read_with(cx, |buffer, _cx| {
|
||||
buffer
|
||||
.buffer_diagnostics(None)
|
||||
.into_iter()
|
||||
.min_by_key(|entry| entry.diagnostic.severity)
|
||||
.map(|entry| entry.range.start)
|
||||
})?
|
||||
.map(|position| (closest_buffer, position));
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(jump_location)
|
||||
}
|
||||
|
||||
fn request_prediction_with_zed_cloud(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
@@ -2159,8 +2513,8 @@ mod tests {
|
||||
|
||||
// Prediction for current file
|
||||
|
||||
let prediction_task = zeta.update(cx, |zeta, cx| {
|
||||
zeta.refresh_prediction(&project, &buffer1, position, cx)
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
|
||||
});
|
||||
let (_request, respond_tx) = req_rx.next().await.unwrap();
|
||||
|
||||
@@ -2175,7 +2529,8 @@ mod tests {
|
||||
Bye
|
||||
"}))
|
||||
.unwrap();
|
||||
prediction_task.await.unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
zeta.read_with(cx, |zeta, cx| {
|
||||
let prediction = zeta
|
||||
@@ -2233,8 +2588,8 @@ mod tests {
|
||||
});
|
||||
|
||||
// Prediction for another file
|
||||
let prediction_task = zeta.update(cx, |zeta, cx| {
|
||||
zeta.refresh_prediction(&project, &buffer1, position, cx)
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
|
||||
});
|
||||
let (_request, respond_tx) = req_rx.next().await.unwrap();
|
||||
respond_tx
|
||||
@@ -2247,7 +2602,8 @@ mod tests {
|
||||
Adios
|
||||
"#}))
|
||||
.unwrap();
|
||||
prediction_task.await.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
zeta.read_with(cx, |zeta, cx| {
|
||||
let prediction = zeta
|
||||
.current_prediction_for_buffer(&buffer1, &project, cx)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
mod zeta2_context_view;
|
||||
|
||||
use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc, time::Duration};
|
||||
use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc};
|
||||
|
||||
use chrono::TimeDelta;
|
||||
use client::{Client, UserStore};
|
||||
@@ -237,24 +237,13 @@ impl Zeta2Inspector {
|
||||
fn set_zeta_options(&mut self, options: ZetaOptions, cx: &mut Context<Self>) {
|
||||
self.zeta.update(cx, |this, _cx| this.set_options(options));
|
||||
|
||||
const DEBOUNCE_TIME: Duration = Duration::from_millis(100);
|
||||
|
||||
if let Some(prediction) = self.last_prediction.as_mut() {
|
||||
if let Some(buffer) = prediction.buffer.upgrade() {
|
||||
let position = prediction.position;
|
||||
let zeta = self.zeta.clone();
|
||||
let project = self.project.clone();
|
||||
prediction._task = Some(cx.spawn(async move |_this, cx| {
|
||||
cx.background_executor().timer(DEBOUNCE_TIME).await;
|
||||
if let Some(task) = zeta
|
||||
.update(cx, |zeta, cx| {
|
||||
zeta.refresh_prediction(&project, &buffer, position, cx)
|
||||
})
|
||||
.ok()
|
||||
{
|
||||
task.await.log_err();
|
||||
}
|
||||
}));
|
||||
self.zeta.update(cx, |zeta, cx| {
|
||||
zeta.refresh_prediction_from_buffer(project, buffer, position, cx)
|
||||
});
|
||||
prediction.state = LastPredictionState::Requested;
|
||||
} else {
|
||||
self.last_prediction.take();
|
||||
|
||||
Reference in New Issue
Block a user