Compare commits
3 Commits
rayon-over
...
send-Windo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
388d893b36 | ||
|
|
1688ad36dd | ||
|
|
891980b52b |
37
Cargo.lock
generated
37
Cargo.lock
generated
@@ -777,6 +777,9 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_repr",
|
||||
"url",
|
||||
"wayland-backend",
|
||||
"wayland-client",
|
||||
"wayland-protocols 0.32.6",
|
||||
"zbus",
|
||||
]
|
||||
|
||||
@@ -1411,7 +1414,6 @@ dependencies = [
|
||||
"log",
|
||||
"parking_lot",
|
||||
"rodio",
|
||||
"rubato",
|
||||
"serde",
|
||||
"settings",
|
||||
"smol",
|
||||
@@ -5126,6 +5128,7 @@ dependencies = [
|
||||
"client",
|
||||
"gpui",
|
||||
"language",
|
||||
"project",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@@ -7120,7 +7123,7 @@ dependencies = [
|
||||
"wayland-backend",
|
||||
"wayland-client",
|
||||
"wayland-cursor",
|
||||
"wayland-protocols",
|
||||
"wayland-protocols 0.31.2",
|
||||
"wayland-protocols-plasma",
|
||||
"windows 0.61.1",
|
||||
"windows-core 0.61.0",
|
||||
@@ -13511,18 +13514,6 @@ version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad8388ea1a9e0ea807e442e8263a699e7edcb320ecbcd21b4fa8ff859acce3ba"
|
||||
|
||||
[[package]]
|
||||
name = "rubato"
|
||||
version = "0.16.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5258099699851cfd0082aeb645feb9c084d9a5e1f1b8d5372086b989fc5e56a1"
|
||||
dependencies = [
|
||||
"num-complex",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"realfft",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rules_library"
|
||||
version = "0.1.0"
|
||||
@@ -18412,6 +18403,18 @@ dependencies = [
|
||||
"wayland-scanner",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wayland-protocols"
|
||||
version = "0.32.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0781cf46869b37e36928f7b432273c0995aa8aed9552c556fb18754420541efc"
|
||||
dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"wayland-backend",
|
||||
"wayland-client",
|
||||
"wayland-scanner",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wayland-protocols-plasma"
|
||||
version = "0.2.0"
|
||||
@@ -18421,7 +18424,7 @@ dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"wayland-backend",
|
||||
"wayland-client",
|
||||
"wayland-protocols",
|
||||
"wayland-protocols 0.31.2",
|
||||
"wayland-scanner",
|
||||
]
|
||||
|
||||
@@ -19670,6 +19673,7 @@ dependencies = [
|
||||
"aho-corasick",
|
||||
"anstream",
|
||||
"arrayvec",
|
||||
"ashpd 0.11.0",
|
||||
"async-compression",
|
||||
"async-std",
|
||||
"async-tungstenite",
|
||||
@@ -19843,6 +19847,8 @@ dependencies = [
|
||||
"wasmtime",
|
||||
"wasmtime-cranelift",
|
||||
"wasmtime-environ",
|
||||
"wayland-backend",
|
||||
"wayland-sys",
|
||||
"winapi",
|
||||
"windows-core 0.61.0",
|
||||
"windows-numerics",
|
||||
@@ -19850,6 +19856,7 @@ dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.61.0",
|
||||
"zbus_macros",
|
||||
"zeroize",
|
||||
"zvariant",
|
||||
]
|
||||
|
||||
@@ -76,7 +76,7 @@ pub enum MessageEditorEvent {
|
||||
|
||||
impl EventEmitter<MessageEditorEvent> for MessageEditor {}
|
||||
|
||||
const COMMAND_HINT_INLAY_ID: u32 = 0;
|
||||
const COMMAND_HINT_INLAY_ID: usize = 0;
|
||||
|
||||
impl MessageEditor {
|
||||
pub fn new(
|
||||
|
||||
@@ -22,7 +22,6 @@ denoise = { path = "../denoise" }
|
||||
log.workspace = true
|
||||
parking_lot.workspace = true
|
||||
rodio = { workspace = true, features = [ "wav", "playback", "wav_output" ] }
|
||||
rubato = "0.16.2"
|
||||
serde.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
|
||||
@@ -1,17 +1,26 @@
|
||||
use std::{num::NonZero, time::Duration};
|
||||
use std::{
|
||||
num::NonZero,
|
||||
sync::{
|
||||
Arc, Mutex,
|
||||
atomic::{AtomicBool, Ordering},
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use crossbeam::queue::ArrayQueue;
|
||||
use denoise::{Denoiser, DenoiserError};
|
||||
use log::warn;
|
||||
use rodio::{ChannelCount, Sample, SampleRate, Source, conversions::ChannelCountConverter, nz};
|
||||
|
||||
use crate::rodio_ext::resample::FixedResampler;
|
||||
pub use replayable::{Replay, ReplayDurationTooShort, Replayable};
|
||||
|
||||
mod replayable;
|
||||
mod resample;
|
||||
use rodio::{
|
||||
ChannelCount, Sample, SampleRate, Source, conversions::SampleRateConverter, nz,
|
||||
source::UniformSourceIterator,
|
||||
};
|
||||
|
||||
const MAX_CHANNELS: usize = 8;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Replay duration is too short must be >= 100ms")]
|
||||
pub struct ReplayDurationTooShort;
|
||||
|
||||
// These all require constant sources (so the span is infinitely long)
|
||||
// this is not guaranteed by rodio however we know it to be true in all our
|
||||
// applications. Rodio desperately needs a constant source concept.
|
||||
@@ -32,8 +41,8 @@ pub trait RodioExt: Source + Sized {
|
||||
self,
|
||||
channel_count: ChannelCount,
|
||||
sample_rate: SampleRate,
|
||||
) -> ConstantChannelCount<FixedResampler<Self>>;
|
||||
fn constant_samplerate(self, sample_rate: SampleRate) -> FixedResampler<Self>;
|
||||
) -> UniformSourceIterator<Self>;
|
||||
fn constant_samplerate(self, sample_rate: SampleRate) -> ConstantSampleRate<Self>;
|
||||
fn possibly_disconnected_channels_to_mono(self) -> ToMono<Self>;
|
||||
}
|
||||
|
||||
@@ -72,7 +81,38 @@ impl<S: Source> RodioExt for S {
|
||||
self,
|
||||
duration: Duration,
|
||||
) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort> {
|
||||
replayable::replayable(self, duration)
|
||||
if duration < Duration::from_millis(100) {
|
||||
return Err(ReplayDurationTooShort);
|
||||
}
|
||||
|
||||
let samples_per_second = self.sample_rate().get() as usize * self.channels().get() as usize;
|
||||
let samples_to_queue = duration.as_secs_f64() * samples_per_second as f64;
|
||||
let samples_to_queue =
|
||||
(samples_to_queue as usize).next_multiple_of(self.channels().get().into());
|
||||
|
||||
let chunk_size =
|
||||
(samples_per_second.div_ceil(10)).next_multiple_of(self.channels().get() as usize);
|
||||
let chunks_to_queue = samples_to_queue.div_ceil(chunk_size);
|
||||
|
||||
let is_active = Arc::new(AtomicBool::new(true));
|
||||
let queue = Arc::new(ReplayQueue::new(chunks_to_queue, chunk_size));
|
||||
Ok((
|
||||
Replay {
|
||||
rx: Arc::clone(&queue),
|
||||
buffer: Vec::new().into_iter(),
|
||||
sleep_duration: duration / 2,
|
||||
sample_rate: self.sample_rate(),
|
||||
channel_count: self.channels(),
|
||||
source_is_active: is_active.clone(),
|
||||
},
|
||||
Replayable {
|
||||
tx: queue,
|
||||
inner: self,
|
||||
buffer: Vec::with_capacity(chunk_size),
|
||||
chunk_size,
|
||||
is_active,
|
||||
},
|
||||
))
|
||||
}
|
||||
fn take_samples(self, n: usize) -> TakeSamples<S> {
|
||||
TakeSamples {
|
||||
@@ -88,37 +128,37 @@ impl<S: Source> RodioExt for S {
|
||||
self,
|
||||
channel_count: ChannelCount,
|
||||
sample_rate: SampleRate,
|
||||
) -> ConstantChannelCount<FixedResampler<Self>> {
|
||||
ConstantChannelCount::new(self.constant_samplerate(sample_rate), channel_count)
|
||||
) -> UniformSourceIterator<Self> {
|
||||
UniformSourceIterator::new(self, channel_count, sample_rate)
|
||||
}
|
||||
fn constant_samplerate(self, sample_rate: SampleRate) -> FixedResampler<Self> {
|
||||
FixedResampler::new(self, sample_rate)
|
||||
fn constant_samplerate(self, sample_rate: SampleRate) -> ConstantSampleRate<Self> {
|
||||
ConstantSampleRate::new(self, sample_rate)
|
||||
}
|
||||
fn possibly_disconnected_channels_to_mono(self) -> ToMono<Self> {
|
||||
ToMono::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConstantChannelCount<S: Source> {
|
||||
inner: ChannelCountConverter<S>,
|
||||
pub struct ConstantSampleRate<S: Source> {
|
||||
inner: SampleRateConverter<S>,
|
||||
channels: ChannelCount,
|
||||
sample_rate: SampleRate,
|
||||
}
|
||||
|
||||
impl<S: Source> ConstantChannelCount<S> {
|
||||
fn new(source: S, target_channels: ChannelCount) -> Self {
|
||||
let input_channels = source.channels();
|
||||
let sample_rate = source.sample_rate();
|
||||
let inner = ChannelCountConverter::new(source, input_channels, target_channels);
|
||||
impl<S: Source> ConstantSampleRate<S> {
|
||||
fn new(source: S, target_rate: SampleRate) -> Self {
|
||||
let input_sample_rate = source.sample_rate();
|
||||
let channels = source.channels();
|
||||
let inner = SampleRateConverter::new(source, input_sample_rate, target_rate, channels);
|
||||
Self {
|
||||
sample_rate,
|
||||
inner,
|
||||
channels: target_channels,
|
||||
channels,
|
||||
sample_rate: target_rate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Source> Iterator for ConstantChannelCount<S> {
|
||||
impl<S: Source> Iterator for ConstantSampleRate<S> {
|
||||
type Item = rodio::Sample;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
@@ -130,7 +170,7 @@ impl<S: Source> Iterator for ConstantChannelCount<S> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Source> Source for ConstantChannelCount<S> {
|
||||
impl<S: Source> Source for ConstantSampleRate<S> {
|
||||
fn current_span_len(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
@@ -267,6 +307,53 @@ impl<S: Source> Source for TakeSamples<S> {
|
||||
}
|
||||
}
|
||||
|
||||
/// constant source, only works on a single span
|
||||
#[derive(Debug)]
|
||||
struct ReplayQueue {
|
||||
inner: ArrayQueue<Vec<Sample>>,
|
||||
normal_chunk_len: usize,
|
||||
/// The last chunk in the queue may be smaller than
|
||||
/// the normal chunk size. This is always equal to the
|
||||
/// size of the last element in the queue.
|
||||
/// (so normally chunk_size)
|
||||
last_chunk: Mutex<Vec<Sample>>,
|
||||
}
|
||||
|
||||
impl ReplayQueue {
|
||||
fn new(queue_len: usize, chunk_size: usize) -> Self {
|
||||
Self {
|
||||
inner: ArrayQueue::new(queue_len),
|
||||
normal_chunk_len: chunk_size,
|
||||
last_chunk: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
/// Returns the length in samples
|
||||
fn len(&self) -> usize {
|
||||
self.inner.len().saturating_sub(1) * self.normal_chunk_len
|
||||
+ self
|
||||
.last_chunk
|
||||
.lock()
|
||||
.expect("Self::push_last can not poison this lock")
|
||||
.len()
|
||||
}
|
||||
|
||||
fn pop(&self) -> Option<Vec<Sample>> {
|
||||
self.inner.pop() // removes element that was inserted first
|
||||
}
|
||||
|
||||
fn push_last(&self, mut samples: Vec<Sample>) {
|
||||
let mut last_chunk = self
|
||||
.last_chunk
|
||||
.lock()
|
||||
.expect("Self::len can not poison this lock");
|
||||
std::mem::swap(&mut *last_chunk, &mut samples);
|
||||
}
|
||||
|
||||
fn push_normal(&self, samples: Vec<Sample>) {
|
||||
let _pushed_out_of_ringbuf = self.inner.force_push(samples);
|
||||
}
|
||||
}
|
||||
|
||||
/// constant source, only works on a single span
|
||||
pub struct ProcessBuffer<const N: usize, S, F>
|
||||
where
|
||||
@@ -400,15 +487,147 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// constant source, only works on a single span
|
||||
#[derive(Debug)]
|
||||
pub struct Replayable<S: Source> {
|
||||
inner: S,
|
||||
buffer: Vec<Sample>,
|
||||
chunk_size: usize,
|
||||
tx: Arc<ReplayQueue>,
|
||||
is_active: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl<S: Source> Iterator for Replayable<S> {
|
||||
type Item = Sample;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if let Some(sample) = self.inner.next() {
|
||||
self.buffer.push(sample);
|
||||
// If the buffer is full send it
|
||||
if self.buffer.len() == self.chunk_size {
|
||||
self.tx.push_normal(std::mem::take(&mut self.buffer));
|
||||
}
|
||||
Some(sample)
|
||||
} else {
|
||||
let last_chunk = std::mem::take(&mut self.buffer);
|
||||
self.tx.push_last(last_chunk);
|
||||
self.is_active.store(false, Ordering::Relaxed);
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
self.inner.size_hint()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Source> Source for Replayable<S> {
|
||||
fn current_span_len(&self) -> Option<usize> {
|
||||
self.inner.current_span_len()
|
||||
}
|
||||
|
||||
fn channels(&self) -> ChannelCount {
|
||||
self.inner.channels()
|
||||
}
|
||||
|
||||
fn sample_rate(&self) -> SampleRate {
|
||||
self.inner.sample_rate()
|
||||
}
|
||||
|
||||
fn total_duration(&self) -> Option<Duration> {
|
||||
self.inner.total_duration()
|
||||
}
|
||||
}
|
||||
|
||||
/// constant source, only works on a single span
|
||||
#[derive(Debug)]
|
||||
pub struct Replay {
|
||||
rx: Arc<ReplayQueue>,
|
||||
buffer: std::vec::IntoIter<Sample>,
|
||||
sleep_duration: Duration,
|
||||
sample_rate: SampleRate,
|
||||
channel_count: ChannelCount,
|
||||
source_is_active: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl Replay {
|
||||
pub fn source_is_active(&self) -> bool {
|
||||
// - source could return None and not drop
|
||||
// - source could be dropped before returning None
|
||||
self.source_is_active.load(Ordering::Relaxed) && Arc::strong_count(&self.rx) < 2
|
||||
}
|
||||
|
||||
/// Duration of what is in the buffer and can be returned without blocking.
|
||||
pub fn duration_ready(&self) -> Duration {
|
||||
let samples_per_second = self.channels().get() as u32 * self.sample_rate().get();
|
||||
|
||||
let seconds_queued = self.samples_ready() as f64 / samples_per_second as f64;
|
||||
Duration::from_secs_f64(seconds_queued)
|
||||
}
|
||||
|
||||
/// Number of samples in the buffer and can be returned without blocking.
|
||||
pub fn samples_ready(&self) -> usize {
|
||||
self.rx.len() + self.buffer.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for Replay {
|
||||
type Item = Sample;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if let Some(sample) = self.buffer.next() {
|
||||
return Some(sample);
|
||||
}
|
||||
|
||||
loop {
|
||||
if let Some(new_buffer) = self.rx.pop() {
|
||||
self.buffer = new_buffer.into_iter();
|
||||
return self.buffer.next();
|
||||
}
|
||||
|
||||
if !self.source_is_active() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// The queue does not support blocking on a next item. We want this queue as it
|
||||
// is quite fast and provides a fixed size. We know how many samples are in a
|
||||
// buffer so if we do not get one now we must be getting one after `sleep_duration`.
|
||||
std::thread::sleep(self.sleep_duration);
|
||||
}
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
((self.rx.len() + self.buffer.len()), None)
|
||||
}
|
||||
}
|
||||
|
||||
impl Source for Replay {
|
||||
fn current_span_len(&self) -> Option<usize> {
|
||||
None // source is not compatible with spans
|
||||
}
|
||||
|
||||
fn channels(&self) -> ChannelCount {
|
||||
self.channel_count
|
||||
}
|
||||
|
||||
fn sample_rate(&self) -> SampleRate {
|
||||
self.sample_rate
|
||||
}
|
||||
|
||||
fn total_duration(&self) -> Option<Duration> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rodio::{nz, static_buffer::StaticSamplesBuffer};
|
||||
|
||||
use super::*;
|
||||
|
||||
pub const SAMPLES: [Sample; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
|
||||
const SAMPLES: [Sample; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
|
||||
|
||||
pub fn test_source() -> StaticSamplesBuffer {
|
||||
fn test_source() -> StaticSamplesBuffer {
|
||||
StaticSamplesBuffer::new(nz!(1), nz!(1), &SAMPLES)
|
||||
}
|
||||
|
||||
@@ -471,4 +690,74 @@ mod tests {
|
||||
assert_eq!(yielded, SAMPLES.len())
|
||||
}
|
||||
}
|
||||
|
||||
mod instant_replay {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn continues_after_history() {
|
||||
let input = test_source();
|
||||
|
||||
let (mut replay, mut source) = input
|
||||
.replayable(Duration::from_secs(3))
|
||||
.expect("longer than 100ms");
|
||||
|
||||
source.by_ref().take(3).count();
|
||||
let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
|
||||
assert_eq!(&yielded, &SAMPLES[0..3],);
|
||||
|
||||
source.count();
|
||||
let yielded: Vec<Sample> = replay.collect();
|
||||
assert_eq!(&yielded, &SAMPLES[3..5],);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keeps_only_latest() {
|
||||
let input = test_source();
|
||||
|
||||
let (mut replay, mut source) = input
|
||||
.replayable(Duration::from_secs(2))
|
||||
.expect("longer than 100ms");
|
||||
|
||||
source.by_ref().take(5).count(); // get all items but do not end the source
|
||||
let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
|
||||
assert_eq!(&yielded, &SAMPLES[3..5]);
|
||||
source.count(); // exhaust source
|
||||
assert_eq!(replay.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keeps_correct_amount_of_seconds() {
|
||||
let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
|
||||
|
||||
let (replay, mut source) = input
|
||||
.replayable(Duration::from_secs(2))
|
||||
.expect("longer than 100ms");
|
||||
|
||||
// exhaust but do not yet end source
|
||||
source.by_ref().take(40_000).count();
|
||||
|
||||
// take all samples we can without blocking
|
||||
let ready = replay.samples_ready();
|
||||
let n_yielded = replay.take_samples(ready).count();
|
||||
|
||||
let max = source.sample_rate().get() * source.channels().get() as u32 * 2;
|
||||
let margin = 16_000 / 10; // 100ms
|
||||
assert!(n_yielded as u32 >= max - margin);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn samples_ready() {
|
||||
let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
|
||||
let (mut replay, source) = input
|
||||
.replayable(Duration::from_secs(2))
|
||||
.expect("longer than 100ms");
|
||||
assert_eq!(replay.by_ref().samples_ready(), 0);
|
||||
|
||||
source.take(8000).count(); // half a second
|
||||
let margin = 16_000 / 10; // 100ms
|
||||
let ready = replay.samples_ready();
|
||||
assert!(ready >= 8000 - margin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,308 +0,0 @@
|
||||
use std::{
|
||||
sync::{
|
||||
Arc, Mutex,
|
||||
atomic::{AtomicBool, Ordering},
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use crossbeam::queue::ArrayQueue;
|
||||
use rodio::{ChannelCount, Sample, SampleRate, Source};
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Replay duration is too short must be >= 100ms")]
|
||||
pub struct ReplayDurationTooShort;
|
||||
|
||||
pub fn replayable<S: Source>(
|
||||
source: S,
|
||||
duration: Duration,
|
||||
) -> Result<(Replay, Replayable<S>), ReplayDurationTooShort> {
|
||||
if duration < Duration::from_millis(100) {
|
||||
return Err(ReplayDurationTooShort);
|
||||
}
|
||||
|
||||
let samples_per_second = source.sample_rate().get() as usize * source.channels().get() as usize;
|
||||
let samples_to_queue = duration.as_secs_f64() * samples_per_second as f64;
|
||||
let samples_to_queue =
|
||||
(samples_to_queue as usize).next_multiple_of(source.channels().get().into());
|
||||
|
||||
let chunk_size =
|
||||
(samples_per_second.div_ceil(10)).next_multiple_of(source.channels().get() as usize);
|
||||
let chunks_to_queue = samples_to_queue.div_ceil(chunk_size);
|
||||
|
||||
let is_active = Arc::new(AtomicBool::new(true));
|
||||
let queue = Arc::new(ReplayQueue::new(chunks_to_queue, chunk_size));
|
||||
Ok((
|
||||
Replay {
|
||||
rx: Arc::clone(&queue),
|
||||
buffer: Vec::new().into_iter(),
|
||||
sleep_duration: duration / 2,
|
||||
sample_rate: source.sample_rate(),
|
||||
channel_count: source.channels(),
|
||||
source_is_active: is_active.clone(),
|
||||
},
|
||||
Replayable {
|
||||
tx: queue,
|
||||
inner: source,
|
||||
buffer: Vec::with_capacity(chunk_size),
|
||||
chunk_size,
|
||||
is_active,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
/// constant source, only works on a single span
|
||||
#[derive(Debug)]
|
||||
struct ReplayQueue {
|
||||
inner: ArrayQueue<Vec<Sample>>,
|
||||
normal_chunk_len: usize,
|
||||
/// The last chunk in the queue may be smaller than
|
||||
/// the normal chunk size. This is always equal to the
|
||||
/// size of the last element in the queue.
|
||||
/// (so normally chunk_size)
|
||||
last_chunk: Mutex<Vec<Sample>>,
|
||||
}
|
||||
|
||||
impl ReplayQueue {
|
||||
fn new(queue_len: usize, chunk_size: usize) -> Self {
|
||||
Self {
|
||||
inner: ArrayQueue::new(queue_len),
|
||||
normal_chunk_len: chunk_size,
|
||||
last_chunk: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
/// Returns the length in samples
|
||||
fn len(&self) -> usize {
|
||||
self.inner.len().saturating_sub(1) * self.normal_chunk_len
|
||||
+ self
|
||||
.last_chunk
|
||||
.lock()
|
||||
.expect("Self::push_last can not poison this lock")
|
||||
.len()
|
||||
}
|
||||
|
||||
fn pop(&self) -> Option<Vec<Sample>> {
|
||||
self.inner.pop() // removes element that was inserted first
|
||||
}
|
||||
|
||||
fn push_last(&self, mut samples: Vec<Sample>) {
|
||||
let mut last_chunk = self
|
||||
.last_chunk
|
||||
.lock()
|
||||
.expect("Self::len can not poison this lock");
|
||||
std::mem::swap(&mut *last_chunk, &mut samples);
|
||||
}
|
||||
|
||||
fn push_normal(&self, samples: Vec<Sample>) {
|
||||
let _pushed_out_of_ringbuf = self.inner.force_push(samples);
|
||||
}
|
||||
}
|
||||
|
||||
/// constant source, only works on a single span
|
||||
#[derive(Debug)]
|
||||
pub struct Replayable<S: Source> {
|
||||
inner: S,
|
||||
buffer: Vec<Sample>,
|
||||
chunk_size: usize,
|
||||
tx: Arc<ReplayQueue>,
|
||||
is_active: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl<S: Source> Iterator for Replayable<S> {
|
||||
type Item = Sample;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if let Some(sample) = self.inner.next() {
|
||||
self.buffer.push(sample);
|
||||
// If the buffer is full send it
|
||||
if self.buffer.len() == self.chunk_size {
|
||||
self.tx.push_normal(std::mem::take(&mut self.buffer));
|
||||
}
|
||||
Some(sample)
|
||||
} else {
|
||||
let last_chunk = std::mem::take(&mut self.buffer);
|
||||
self.tx.push_last(last_chunk);
|
||||
self.is_active.store(false, Ordering::Relaxed);
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
self.inner.size_hint()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Source> Source for Replayable<S> {
|
||||
fn current_span_len(&self) -> Option<usize> {
|
||||
self.inner.current_span_len()
|
||||
}
|
||||
|
||||
fn channels(&self) -> ChannelCount {
|
||||
self.inner.channels()
|
||||
}
|
||||
|
||||
fn sample_rate(&self) -> SampleRate {
|
||||
self.inner.sample_rate()
|
||||
}
|
||||
|
||||
fn total_duration(&self) -> Option<Duration> {
|
||||
self.inner.total_duration()
|
||||
}
|
||||
}
|
||||
|
||||
/// constant source, only works on a single span
|
||||
#[derive(Debug)]
|
||||
pub struct Replay {
|
||||
rx: Arc<ReplayQueue>,
|
||||
buffer: std::vec::IntoIter<Sample>,
|
||||
sleep_duration: Duration,
|
||||
sample_rate: SampleRate,
|
||||
channel_count: ChannelCount,
|
||||
source_is_active: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl Replay {
|
||||
pub fn source_is_active(&self) -> bool {
|
||||
// - source could return None and not drop
|
||||
// - source could be dropped before returning None
|
||||
self.source_is_active.load(Ordering::Relaxed) && Arc::strong_count(&self.rx) < 2
|
||||
}
|
||||
|
||||
/// Duration of what is in the buffer and can be returned without blocking.
|
||||
pub fn duration_ready(&self) -> Duration {
|
||||
let samples_per_second = self.channels().get() as u32 * self.sample_rate().get();
|
||||
|
||||
let seconds_queued = self.samples_ready() as f64 / samples_per_second as f64;
|
||||
Duration::from_secs_f64(seconds_queued)
|
||||
}
|
||||
|
||||
/// Number of samples in the buffer and can be returned without blocking.
|
||||
pub fn samples_ready(&self) -> usize {
|
||||
self.rx.len() + self.buffer.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for Replay {
|
||||
type Item = Sample;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if let Some(sample) = self.buffer.next() {
|
||||
return Some(sample);
|
||||
}
|
||||
|
||||
loop {
|
||||
if let Some(new_buffer) = self.rx.pop() {
|
||||
self.buffer = new_buffer.into_iter();
|
||||
return self.buffer.next();
|
||||
}
|
||||
|
||||
if !self.source_is_active() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// The queue does not support blocking on a next item. We want this queue as it
|
||||
// is quite fast and provides a fixed size. We know how many samples are in a
|
||||
// buffer so if we do not get one now we must be getting one after `sleep_duration`.
|
||||
std::thread::sleep(self.sleep_duration);
|
||||
}
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
((self.rx.len() + self.buffer.len()), None)
|
||||
}
|
||||
}
|
||||
|
||||
impl Source for Replay {
|
||||
fn current_span_len(&self) -> Option<usize> {
|
||||
None // source is not compatible with spans
|
||||
}
|
||||
|
||||
fn channels(&self) -> ChannelCount {
|
||||
self.channel_count
|
||||
}
|
||||
|
||||
fn sample_rate(&self) -> SampleRate {
|
||||
self.sample_rate
|
||||
}
|
||||
|
||||
fn total_duration(&self) -> Option<Duration> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rodio::{nz, static_buffer::StaticSamplesBuffer};
|
||||
|
||||
use super::*;
|
||||
use crate::{
|
||||
RodioExt,
|
||||
rodio_ext::tests::{SAMPLES, test_source},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn continues_after_history() {
|
||||
let input = test_source();
|
||||
|
||||
let (mut replay, mut source) = input
|
||||
.replayable(Duration::from_secs(3))
|
||||
.expect("longer than 100ms");
|
||||
|
||||
source.by_ref().take(3).count();
|
||||
let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
|
||||
assert_eq!(&yielded, &SAMPLES[0..3],);
|
||||
|
||||
source.count();
|
||||
let yielded: Vec<Sample> = replay.collect();
|
||||
assert_eq!(&yielded, &SAMPLES[3..5],);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keeps_only_latest() {
|
||||
let input = test_source();
|
||||
|
||||
let (mut replay, mut source) = input
|
||||
.replayable(Duration::from_secs(2))
|
||||
.expect("longer than 100ms");
|
||||
|
||||
source.by_ref().take(5).count(); // get all items but do not end the source
|
||||
let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
|
||||
assert_eq!(&yielded, &SAMPLES[3..5]);
|
||||
source.count(); // exhaust source
|
||||
assert_eq!(replay.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keeps_correct_amount_of_seconds() {
|
||||
let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
|
||||
|
||||
let (replay, mut source) = input
|
||||
.replayable(Duration::from_secs(2))
|
||||
.expect("longer than 100ms");
|
||||
|
||||
// exhaust but do not yet end source
|
||||
source.by_ref().take(40_000).count();
|
||||
|
||||
// take all samples we can without blocking
|
||||
let ready = replay.samples_ready();
|
||||
let n_yielded = replay.take_samples(ready).count();
|
||||
|
||||
let max = source.sample_rate().get() * source.channels().get() as u32 * 2;
|
||||
let margin = 16_000 / 10; // 100ms
|
||||
assert!(n_yielded as u32 >= max - margin);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn samples_ready() {
|
||||
let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
|
||||
let (mut replay, source) = input
|
||||
.replayable(Duration::from_secs(2))
|
||||
.expect("longer than 100ms");
|
||||
assert_eq!(replay.by_ref().samples_ready(), 0);
|
||||
|
||||
source.take(8000).count(); // half a second
|
||||
let margin = 16_000 / 10; // 100ms
|
||||
let ready = replay.samples_ready();
|
||||
assert!(ready >= 8000 - margin);
|
||||
}
|
||||
}
|
||||
@@ -1,98 +0,0 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use rodio::{Sample, SampleRate, Source};
|
||||
use rubato::{FftFixedInOut, Resampler};
|
||||
|
||||
pub struct FixedResampler<S> {
|
||||
input: S,
|
||||
next_channel: usize,
|
||||
next_frame: usize,
|
||||
output_buffer: Vec<Vec<Sample>>,
|
||||
input_buffer: Vec<Vec<Sample>>,
|
||||
target_sample_rate: SampleRate,
|
||||
resampler: FftFixedInOut<Sample>,
|
||||
}
|
||||
|
||||
impl<S: Source> FixedResampler<S> {
|
||||
pub fn new(input: S, target_sample_rate: SampleRate) -> Self {
|
||||
let chunk_size_in =
|
||||
Duration::from_millis(50).as_secs_f32() * input.sample_rate().get() as f32;
|
||||
let chunk_size_in = chunk_size_in.ceil() as usize;
|
||||
|
||||
let resampler = FftFixedInOut::new(
|
||||
input.sample_rate().get() as usize,
|
||||
target_sample_rate.get() as usize,
|
||||
chunk_size_in,
|
||||
input.channels().get() as usize,
|
||||
)
|
||||
.expect(
|
||||
"sample rates are non zero, and we are not changing it so there is no resample ratio",
|
||||
);
|
||||
|
||||
Self {
|
||||
next_channel: 0,
|
||||
next_frame: 0,
|
||||
output_buffer: resampler.output_buffer_allocate(true),
|
||||
input_buffer: resampler.input_buffer_allocate(false),
|
||||
target_sample_rate,
|
||||
resampler,
|
||||
input,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Source> Source for FixedResampler<S> {
|
||||
fn current_span_len(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
|
||||
fn channels(&self) -> rodio::ChannelCount {
|
||||
self.input.channels()
|
||||
}
|
||||
|
||||
fn sample_rate(&self) -> rodio::SampleRate {
|
||||
self.target_sample_rate
|
||||
}
|
||||
|
||||
fn total_duration(&self) -> Option<std::time::Duration> {
|
||||
self.input.total_duration()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Source> FixedResampler<S> {
|
||||
fn next_sample(&mut self) -> Option<Sample> {
|
||||
let sample = self.output_buffer[self.next_channel]
|
||||
.get(self.next_frame)
|
||||
.copied();
|
||||
self.next_channel = (self.next_channel + 1) % self.input.channels().get() as usize;
|
||||
self.next_frame += 1;
|
||||
|
||||
sample
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Source> Iterator for FixedResampler<S> {
|
||||
type Item = Sample;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if let Some(sample) = self.next_sample() {
|
||||
return Some(sample);
|
||||
}
|
||||
|
||||
for input_channel in &mut self.input_buffer {
|
||||
input_channel.clear();
|
||||
}
|
||||
|
||||
for _ in 0..self.resampler.input_frames_next() {
|
||||
for input_channel in &mut self.input_buffer {
|
||||
input_channel.push(self.input.next()?);
|
||||
}
|
||||
}
|
||||
|
||||
self.resampler
|
||||
.process_into_buffer(&mut self.input_buffer, &mut self.output_buffer, None).expect("Input and output buffer channels are correct as they have been set by the resampler. The buffer for each channel is the same length. The buffer length is what is requested the resampler.");
|
||||
|
||||
self.next_frame = 0;
|
||||
self.next_sample()
|
||||
}
|
||||
}
|
||||
@@ -43,24 +43,15 @@ pub struct PredictEditsRequest {
|
||||
pub prompt_format: PromptFormat,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||
#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||
pub enum PromptFormat {
|
||||
#[default]
|
||||
MarkedExcerpt,
|
||||
LabeledSections,
|
||||
/// Prompt format intended for use via zeta_cli
|
||||
OnlySnippets,
|
||||
}
|
||||
|
||||
impl PromptFormat {
|
||||
pub const DEFAULT: PromptFormat = PromptFormat::LabeledSections;
|
||||
}
|
||||
|
||||
impl Default for PromptFormat {
|
||||
fn default() -> Self {
|
||||
Self::DEFAULT
|
||||
}
|
||||
}
|
||||
|
||||
impl PromptFormat {
|
||||
pub fn iter() -> impl Iterator<Item = Self> {
|
||||
<Self as strum::IntoEnumIterator>::iter()
|
||||
|
||||
@@ -3,6 +3,7 @@ use anyhow::Result;
|
||||
use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
|
||||
use gpui::{App, Context, Entity, EntityId, Task};
|
||||
use language::{Buffer, OffsetRangeExt, ToOffset, language_settings::AllLanguageSettings};
|
||||
use project::Project;
|
||||
use settings::Settings;
|
||||
use std::{path::Path, time::Duration};
|
||||
|
||||
@@ -83,6 +84,7 @@ impl EditPredictionProvider for CopilotCompletionProvider {
|
||||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
_project: Option<Entity<Project>>,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
debounce: bool,
|
||||
@@ -247,7 +249,7 @@ impl EditPredictionProvider for CopilotCompletionProvider {
|
||||
None
|
||||
} else {
|
||||
let position = cursor_position.bias_right(buffer);
|
||||
Some(EditPrediction::Local {
|
||||
Some(EditPrediction {
|
||||
id: None,
|
||||
edits: vec![(position..position, completion_text.into())],
|
||||
edit_preview: None,
|
||||
|
||||
@@ -154,7 +154,6 @@ pub struct CrashInfo {
|
||||
pub struct InitCrashHandler {
|
||||
pub session_id: String,
|
||||
pub zed_version: String,
|
||||
pub binary: String,
|
||||
pub release_channel: String,
|
||||
pub commit_sha: String,
|
||||
}
|
||||
|
||||
@@ -473,7 +473,7 @@ fn generate_big_table_of_actions() -> String {
|
||||
output.push_str(action.name);
|
||||
output.push_str("</code><br>\n");
|
||||
if !action.deprecated_aliases.is_empty() {
|
||||
output.push_str("Deprecated Alias(es): ");
|
||||
output.push_str("Deprecated Aliases:");
|
||||
for alias in action.deprecated_aliases.iter() {
|
||||
output.push_str("<code>");
|
||||
output.push_str(alias);
|
||||
|
||||
@@ -15,4 +15,5 @@ path = "src/edit_prediction.rs"
|
||||
client.workspace = true
|
||||
gpui.workspace = true
|
||||
language.workspace = true
|
||||
project.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::ops::Range;
|
||||
use client::EditPredictionUsage;
|
||||
use gpui::{App, Context, Entity, SharedString};
|
||||
use language::Buffer;
|
||||
use project::Project;
|
||||
|
||||
// TODO: Find a better home for `Direction`.
|
||||
//
|
||||
@@ -15,19 +16,11 @@ pub enum Direction {
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum EditPrediction {
|
||||
/// Edits within the buffer that requested the prediction
|
||||
Local {
|
||||
id: Option<SharedString>,
|
||||
edits: Vec<(Range<language::Anchor>, String)>,
|
||||
edit_preview: Option<language::EditPreview>,
|
||||
},
|
||||
/// Jump to a different file from the one that requested the prediction
|
||||
Jump {
|
||||
id: Option<SharedString>,
|
||||
snapshot: language::BufferSnapshot,
|
||||
target: language::Anchor,
|
||||
},
|
||||
pub struct EditPrediction {
|
||||
/// The ID of the completion, if it has one.
|
||||
pub id: Option<SharedString>,
|
||||
pub edits: Vec<(Range<language::Anchor>, String)>,
|
||||
pub edit_preview: Option<language::EditPreview>,
|
||||
}
|
||||
|
||||
pub enum DataCollectionState {
|
||||
@@ -90,6 +83,7 @@ pub trait EditPredictionProvider: 'static + Sized {
|
||||
fn is_refreshing(&self) -> bool;
|
||||
fn refresh(
|
||||
&mut self,
|
||||
project: Option<Entity<Project>>,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
debounce: bool,
|
||||
@@ -130,6 +124,7 @@ pub trait EditPredictionProviderHandle {
|
||||
fn is_refreshing(&self, cx: &App) -> bool;
|
||||
fn refresh(
|
||||
&self,
|
||||
project: Option<Entity<Project>>,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
debounce: bool,
|
||||
@@ -203,13 +198,14 @@ where
|
||||
|
||||
fn refresh(
|
||||
&self,
|
||||
project: Option<Entity<Project>>,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
debounce: bool,
|
||||
cx: &mut App,
|
||||
) {
|
||||
self.update(cx, |this, cx| {
|
||||
this.refresh(buffer, cursor_position, debounce, cx)
|
||||
this.refresh(project, buffer, cursor_position, debounce, cx)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1443,7 +1443,11 @@ impl<'a> Iterator for FoldChunks<'a> {
|
||||
[(self.inlay_offset - buffer_chunk_start).0..(chunk_end - buffer_chunk_start).0];
|
||||
|
||||
let bit_end = (chunk_end - buffer_chunk_start).0;
|
||||
let mask = 1u128.unbounded_shl(bit_end as u32).wrapping_sub(1);
|
||||
let mask = if bit_end >= 128 {
|
||||
u128::MAX
|
||||
} else {
|
||||
(1u128 << bit_end) - 1
|
||||
};
|
||||
|
||||
chunk.tabs = (chunk.tabs >> (self.inlay_offset - buffer_chunk_start).0) & mask;
|
||||
chunk.chars = (chunk.chars >> (self.inlay_offset - buffer_chunk_start).0) & mask;
|
||||
|
||||
@@ -8,7 +8,7 @@ use multi_buffer::{
|
||||
use std::{
|
||||
cmp,
|
||||
ops::{Add, AddAssign, Range, Sub, SubAssign},
|
||||
sync::{Arc, OnceLock},
|
||||
sync::Arc,
|
||||
};
|
||||
use sum_tree::{Bias, Cursor, Dimensions, SumTree};
|
||||
use text::{ChunkBitmaps, Patch, Rope};
|
||||
@@ -41,17 +41,12 @@ enum Transform {
|
||||
pub struct Inlay {
|
||||
pub id: InlayId,
|
||||
pub position: Anchor,
|
||||
pub content: InlayContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum InlayContent {
|
||||
Text(text::Rope),
|
||||
Color(Hsla),
|
||||
pub text: text::Rope,
|
||||
color: Option<Hsla>,
|
||||
}
|
||||
|
||||
impl Inlay {
|
||||
pub fn hint(id: u32, position: Anchor, hint: &project::InlayHint) -> Self {
|
||||
pub fn hint(id: usize, position: Anchor, hint: &project::InlayHint) -> Self {
|
||||
let mut text = hint.text();
|
||||
if hint.padding_right && text.reversed_chars_at(text.len()).next() != Some(' ') {
|
||||
text.push(" ");
|
||||
@@ -62,57 +57,51 @@ impl Inlay {
|
||||
Self {
|
||||
id: InlayId::Hint(id),
|
||||
position,
|
||||
content: InlayContent::Text(text),
|
||||
text,
|
||||
color: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn mock_hint(id: u32, position: Anchor, text: impl Into<Rope>) -> Self {
|
||||
pub fn mock_hint(id: usize, position: Anchor, text: impl Into<Rope>) -> Self {
|
||||
Self {
|
||||
id: InlayId::Hint(id),
|
||||
position,
|
||||
content: InlayContent::Text(text.into()),
|
||||
text: text.into(),
|
||||
color: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn color(id: u32, position: Anchor, color: Rgba) -> Self {
|
||||
pub fn color(id: usize, position: Anchor, color: Rgba) -> Self {
|
||||
Self {
|
||||
id: InlayId::Color(id),
|
||||
position,
|
||||
content: InlayContent::Color(color.into()),
|
||||
text: Rope::from("◼"),
|
||||
color: Some(Hsla::from(color)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn edit_prediction<T: Into<Rope>>(id: u32, position: Anchor, text: T) -> Self {
|
||||
pub fn edit_prediction<T: Into<Rope>>(id: usize, position: Anchor, text: T) -> Self {
|
||||
Self {
|
||||
id: InlayId::EditPrediction(id),
|
||||
position,
|
||||
content: InlayContent::Text(text.into()),
|
||||
text: text.into(),
|
||||
color: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn debugger<T: Into<Rope>>(id: u32, position: Anchor, text: T) -> Self {
|
||||
pub fn debugger<T: Into<Rope>>(id: usize, position: Anchor, text: T) -> Self {
|
||||
Self {
|
||||
id: InlayId::DebuggerValue(id),
|
||||
position,
|
||||
content: InlayContent::Text(text.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn text(&self) -> &Rope {
|
||||
static COLOR_TEXT: OnceLock<Rope> = OnceLock::new();
|
||||
match &self.content {
|
||||
InlayContent::Text(text) => text,
|
||||
InlayContent::Color(_) => COLOR_TEXT.get_or_init(|| Rope::from("◼")),
|
||||
text: text.into(),
|
||||
color: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn get_color(&self) -> Option<Hsla> {
|
||||
match self.content {
|
||||
InlayContent::Color(color) => Some(color),
|
||||
_ => None,
|
||||
}
|
||||
self.color
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,7 +116,7 @@ impl sum_tree::Item for Transform {
|
||||
},
|
||||
Transform::Inlay(inlay) => TransformSummary {
|
||||
input: TextSummary::default(),
|
||||
output: inlay.text().summary(),
|
||||
output: inlay.text.summary(),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -365,7 +354,7 @@ impl<'a> Iterator for InlayChunks<'a> {
|
||||
let mut renderer = None;
|
||||
let mut highlight_style = match inlay.id {
|
||||
InlayId::EditPrediction(_) => self.highlight_styles.edit_prediction.map(|s| {
|
||||
if inlay.text().chars().all(|c| c.is_whitespace()) {
|
||||
if inlay.text.chars().all(|c| c.is_whitespace()) {
|
||||
s.whitespace
|
||||
} else {
|
||||
s.insertion
|
||||
@@ -374,7 +363,7 @@ impl<'a> Iterator for InlayChunks<'a> {
|
||||
InlayId::Hint(_) => self.highlight_styles.inlay_hint,
|
||||
InlayId::DebuggerValue(_) => self.highlight_styles.inlay_hint,
|
||||
InlayId::Color(_) => {
|
||||
if let InlayContent::Color(color) = inlay.content {
|
||||
if let Some(color) = inlay.color {
|
||||
renderer = Some(ChunkRenderer {
|
||||
id: ChunkRendererId::Inlay(inlay.id),
|
||||
render: Arc::new(move |cx| {
|
||||
@@ -421,7 +410,7 @@ impl<'a> Iterator for InlayChunks<'a> {
|
||||
let start = offset_in_inlay;
|
||||
let end = cmp::min(self.max_output_offset, self.transforms.end().0)
|
||||
- self.transforms.start().0;
|
||||
let chunks = inlay.text().chunks_in_range(start.0..end.0);
|
||||
let chunks = inlay.text.chunks_in_range(start.0..end.0);
|
||||
text::ChunkWithBitmaps(chunks)
|
||||
});
|
||||
let ChunkBitmaps {
|
||||
@@ -717,7 +706,7 @@ impl InlayMap {
|
||||
|
||||
for inlay_to_insert in to_insert {
|
||||
// Avoid inserting empty inlays.
|
||||
if inlay_to_insert.text().is_empty() {
|
||||
if inlay_to_insert.text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -755,7 +744,7 @@ impl InlayMap {
|
||||
#[cfg(test)]
|
||||
pub(crate) fn randomly_mutate(
|
||||
&mut self,
|
||||
next_inlay_id: &mut u32,
|
||||
next_inlay_id: &mut usize,
|
||||
rng: &mut rand::rngs::StdRng,
|
||||
) -> (InlaySnapshot, Vec<InlayEdit>) {
|
||||
use rand::prelude::*;
|
||||
@@ -833,7 +822,7 @@ impl InlaySnapshot {
|
||||
InlayPoint(cursor.start().1.0 + (buffer_end - buffer_start))
|
||||
}
|
||||
Some(Transform::Inlay(inlay)) => {
|
||||
let overshoot = inlay.text().offset_to_point(overshoot);
|
||||
let overshoot = inlay.text.offset_to_point(overshoot);
|
||||
InlayPoint(cursor.start().1.0 + overshoot)
|
||||
}
|
||||
None => self.max_point(),
|
||||
@@ -863,7 +852,7 @@ impl InlaySnapshot {
|
||||
InlayOffset(cursor.start().1.0 + (buffer_offset_end - buffer_offset_start))
|
||||
}
|
||||
Some(Transform::Inlay(inlay)) => {
|
||||
let overshoot = inlay.text().point_to_offset(overshoot);
|
||||
let overshoot = inlay.text.point_to_offset(overshoot);
|
||||
InlayOffset(cursor.start().1.0 + overshoot)
|
||||
}
|
||||
None => self.len(),
|
||||
@@ -1075,7 +1064,7 @@ impl InlaySnapshot {
|
||||
Some(Transform::Inlay(inlay)) => {
|
||||
let suffix_start = overshoot;
|
||||
let suffix_end = cmp::min(cursor.end().0, range.end).0 - cursor.start().0.0;
|
||||
summary = inlay.text().cursor(suffix_start).summary(suffix_end);
|
||||
summary = inlay.text.cursor(suffix_start).summary(suffix_end);
|
||||
cursor.next();
|
||||
}
|
||||
None => {}
|
||||
@@ -1097,7 +1086,7 @@ impl InlaySnapshot {
|
||||
}
|
||||
Some(Transform::Inlay(inlay)) => {
|
||||
let prefix_end = overshoot;
|
||||
summary += inlay.text().cursor(0).summary::<TextSummary>(prefix_end);
|
||||
summary += inlay.text.cursor(0).summary::<TextSummary>(prefix_end);
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
@@ -1280,7 +1269,7 @@ mod tests {
|
||||
resolve_state: ResolveState::Resolved,
|
||||
},
|
||||
)
|
||||
.text()
|
||||
.text
|
||||
.to_string(),
|
||||
"a",
|
||||
"Should not pad label if not requested"
|
||||
@@ -1300,7 +1289,7 @@ mod tests {
|
||||
resolve_state: ResolveState::Resolved,
|
||||
},
|
||||
)
|
||||
.text()
|
||||
.text
|
||||
.to_string(),
|
||||
" a ",
|
||||
"Should pad label for every side requested"
|
||||
@@ -1320,7 +1309,7 @@ mod tests {
|
||||
resolve_state: ResolveState::Resolved,
|
||||
},
|
||||
)
|
||||
.text()
|
||||
.text
|
||||
.to_string(),
|
||||
" a ",
|
||||
"Should not change already padded label"
|
||||
@@ -1340,7 +1329,7 @@ mod tests {
|
||||
resolve_state: ResolveState::Resolved,
|
||||
},
|
||||
)
|
||||
.text()
|
||||
.text
|
||||
.to_string(),
|
||||
" a ",
|
||||
"Should not change already padded label"
|
||||
@@ -1363,7 +1352,7 @@ mod tests {
|
||||
resolve_state: ResolveState::Resolved,
|
||||
},
|
||||
)
|
||||
.text()
|
||||
.text
|
||||
.to_string(),
|
||||
" 🎨 ",
|
||||
"Should pad single emoji correctly"
|
||||
@@ -1761,7 +1750,7 @@ mod tests {
|
||||
.collect::<Vec<_>>();
|
||||
let mut expected_text = Rope::from(&buffer_snapshot.text());
|
||||
for (offset, inlay) in inlays.iter().rev() {
|
||||
expected_text.replace(*offset..*offset, &inlay.text().to_string());
|
||||
expected_text.replace(*offset..*offset, &inlay.text.to_string());
|
||||
}
|
||||
assert_eq!(inlay_snapshot.text(), expected_text.to_string());
|
||||
|
||||
@@ -1814,7 +1803,7 @@ mod tests {
|
||||
.into_iter()
|
||||
.filter_map(|i| {
|
||||
let (_, inlay) = &inlays[i];
|
||||
let inlay_text_len = inlay.text().len();
|
||||
let inlay_text_len = inlay.text.len();
|
||||
match inlay_text_len {
|
||||
0 => None,
|
||||
1 => Some(InlayHighlight {
|
||||
@@ -1823,7 +1812,7 @@ mod tests {
|
||||
range: 0..1,
|
||||
}),
|
||||
n => {
|
||||
let inlay_text = inlay.text().to_string();
|
||||
let inlay_text = inlay.text.to_string();
|
||||
let mut highlight_end = rng.random_range(1..n);
|
||||
let mut highlight_start = rng.random_range(0..highlight_end);
|
||||
while !inlay_text.is_char_boundary(highlight_end) {
|
||||
@@ -2149,7 +2138,8 @@ mod tests {
|
||||
let inlay = Inlay {
|
||||
id: InlayId::Hint(0),
|
||||
position,
|
||||
content: InlayContent::Text(text::Rope::from(inlay_text)),
|
||||
text: text::Rope::from(inlay_text),
|
||||
color: None,
|
||||
};
|
||||
|
||||
let (inlay_snapshot, _) = inlay_map.splice(&[], vec![inlay]);
|
||||
@@ -2263,7 +2253,8 @@ mod tests {
|
||||
let inlay = Inlay {
|
||||
id: InlayId::Hint(0),
|
||||
position,
|
||||
content: InlayContent::Text(text::Rope::from(test_case.inlay_text)),
|
||||
text: text::Rope::from(test_case.inlay_text),
|
||||
color: None,
|
||||
};
|
||||
|
||||
let (inlay_snapshot, _) = inlay_map.splice(&[], vec![inlay]);
|
||||
|
||||
@@ -53,12 +53,9 @@ pub fn replacement(c: char) -> Option<&'static str> {
|
||||
} else if contains(c, PRESERVE) {
|
||||
None
|
||||
} else {
|
||||
Some(FIXED_WIDTH_SPACE)
|
||||
Some("\u{2007}") // fixed width space
|
||||
}
|
||||
}
|
||||
|
||||
const FIXED_WIDTH_SPACE: &str = "\u{2007}";
|
||||
|
||||
// IDEOGRAPHIC SPACE is common alongside Chinese and other wide character sets.
|
||||
// We don't highlight this for now (as it already shows up wide in the editor),
|
||||
// but could if we tracked state in the classifier.
|
||||
@@ -120,11 +117,11 @@ const PRESERVE: &[(char, char)] = &[
|
||||
];
|
||||
|
||||
fn contains(c: char, list: &[(char, char)]) -> bool {
|
||||
for &(start, end) in list {
|
||||
if c < start {
|
||||
for (start, end) in list {
|
||||
if c < *start {
|
||||
return false;
|
||||
}
|
||||
if c <= end {
|
||||
if c <= *end {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,7 +54,9 @@ impl TabMap {
|
||||
new_snapshot.version += 1;
|
||||
}
|
||||
|
||||
let tab_edits = if old_snapshot.tab_size == new_snapshot.tab_size {
|
||||
let mut tab_edits = Vec::with_capacity(fold_edits.len());
|
||||
|
||||
if old_snapshot.tab_size == new_snapshot.tab_size {
|
||||
// Expand each edit to include the next tab on the same line as the edit,
|
||||
// and any subsequent tabs on that line that moved across the tab expansion
|
||||
// boundary.
|
||||
@@ -110,7 +112,7 @@ impl TabMap {
|
||||
let _old_alloc_ptr = fold_edits.as_ptr();
|
||||
// Combine any edits that overlap due to the expansion.
|
||||
let mut fold_edits = fold_edits.into_iter();
|
||||
if let Some(mut first_edit) = fold_edits.next() {
|
||||
let fold_edits = if let Some(mut first_edit) = fold_edits.next() {
|
||||
// This code relies on reusing allocations from the Vec<_> - at the time of writing .flatten() prevents them.
|
||||
#[allow(clippy::filter_map_identity)]
|
||||
let mut v: Vec<_> = fold_edits
|
||||
@@ -130,30 +132,29 @@ impl TabMap {
|
||||
.collect();
|
||||
v.push(first_edit);
|
||||
debug_assert_eq!(v.as_ptr(), _old_alloc_ptr, "Fold edits were reallocated");
|
||||
v.into_iter()
|
||||
.map(|fold_edit| {
|
||||
let old_start = fold_edit.old.start.to_point(&old_snapshot.fold_snapshot);
|
||||
let old_end = fold_edit.old.end.to_point(&old_snapshot.fold_snapshot);
|
||||
let new_start = fold_edit.new.start.to_point(&new_snapshot.fold_snapshot);
|
||||
let new_end = fold_edit.new.end.to_point(&new_snapshot.fold_snapshot);
|
||||
TabEdit {
|
||||
old: old_snapshot.to_tab_point(old_start)
|
||||
..old_snapshot.to_tab_point(old_end),
|
||||
new: new_snapshot.to_tab_point(new_start)
|
||||
..new_snapshot.to_tab_point(new_end),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
v
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
for fold_edit in fold_edits {
|
||||
let old_start = fold_edit.old.start.to_point(&old_snapshot.fold_snapshot);
|
||||
let old_end = fold_edit.old.end.to_point(&old_snapshot.fold_snapshot);
|
||||
let new_start = fold_edit.new.start.to_point(&new_snapshot.fold_snapshot);
|
||||
let new_end = fold_edit.new.end.to_point(&new_snapshot.fold_snapshot);
|
||||
tab_edits.push(TabEdit {
|
||||
old: old_snapshot.to_tab_point(old_start)..old_snapshot.to_tab_point(old_end),
|
||||
new: new_snapshot.to_tab_point(new_start)..new_snapshot.to_tab_point(new_end),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
new_snapshot.version += 1;
|
||||
vec![TabEdit {
|
||||
tab_edits.push(TabEdit {
|
||||
old: TabPoint::zero()..old_snapshot.max_point(),
|
||||
new: TabPoint::zero()..new_snapshot.max_point(),
|
||||
}]
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
*old_snapshot = new_snapshot;
|
||||
(old_snapshot.clone(), tab_edits)
|
||||
}
|
||||
@@ -194,28 +195,37 @@ impl TabSnapshot {
|
||||
.fold_snapshot
|
||||
.text_summary_for_range(input_start..input_end);
|
||||
|
||||
let mut first_line_chars = 0;
|
||||
let line_end = if range.start.row() == range.end.row() {
|
||||
range.end
|
||||
} else {
|
||||
self.max_point()
|
||||
};
|
||||
let first_line_chars = self
|
||||
for c in self
|
||||
.chunks(range.start..line_end, false, Highlights::default())
|
||||
.flat_map(|chunk| chunk.text.chars())
|
||||
.take_while(|&c| c != '\n')
|
||||
.count() as u32;
|
||||
{
|
||||
if c == '\n' {
|
||||
break;
|
||||
}
|
||||
first_line_chars += 1;
|
||||
}
|
||||
|
||||
let last_line_chars = if range.start.row() == range.end.row() {
|
||||
first_line_chars
|
||||
let mut last_line_chars = 0;
|
||||
if range.start.row() == range.end.row() {
|
||||
last_line_chars = first_line_chars;
|
||||
} else {
|
||||
self.chunks(
|
||||
TabPoint::new(range.end.row(), 0)..range.end,
|
||||
false,
|
||||
Highlights::default(),
|
||||
)
|
||||
.flat_map(|chunk| chunk.text.chars())
|
||||
.count() as u32
|
||||
};
|
||||
for _ in self
|
||||
.chunks(
|
||||
TabPoint::new(range.end.row(), 0)..range.end,
|
||||
false,
|
||||
Highlights::default(),
|
||||
)
|
||||
.flat_map(|chunk| chunk.text.chars())
|
||||
{
|
||||
last_line_chars += 1;
|
||||
}
|
||||
}
|
||||
|
||||
TextSummary {
|
||||
lines: range.end.0 - range.start.0,
|
||||
@@ -504,17 +514,15 @@ impl<'a> std::ops::AddAssign<&'a Self> for TextSummary {
|
||||
|
||||
pub struct TabChunks<'a> {
|
||||
snapshot: &'a TabSnapshot,
|
||||
max_expansion_column: u32,
|
||||
max_output_position: Point,
|
||||
tab_size: NonZeroU32,
|
||||
// region: iteration state
|
||||
fold_chunks: FoldChunks<'a>,
|
||||
chunk: Chunk<'a>,
|
||||
column: u32,
|
||||
max_expansion_column: u32,
|
||||
output_position: Point,
|
||||
input_column: u32,
|
||||
max_output_position: Point,
|
||||
tab_size: NonZeroU32,
|
||||
inside_leading_tab: bool,
|
||||
// endregion: iteration state
|
||||
}
|
||||
|
||||
impl TabChunks<'_> {
|
||||
|
||||
@@ -2,6 +2,7 @@ use edit_prediction::EditPredictionProvider;
|
||||
use gpui::{Entity, prelude::*};
|
||||
use indoc::indoc;
|
||||
use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint};
|
||||
use project::Project;
|
||||
use std::ops::Range;
|
||||
use text::{Point, ToOffset};
|
||||
|
||||
@@ -260,7 +261,7 @@ async fn test_edit_prediction_jump_disabled_for_non_zed_providers(cx: &mut gpui:
|
||||
EditPrediction::Edit { .. } => {
|
||||
// This is expected for non-Zed providers
|
||||
}
|
||||
EditPrediction::MoveWithin { .. } | EditPrediction::MoveOutside { .. } => {
|
||||
EditPrediction::Move { .. } => {
|
||||
panic!(
|
||||
"Non-Zed providers should not show Move predictions (jump functionality)"
|
||||
);
|
||||
@@ -298,7 +299,7 @@ fn assert_editor_active_move_completion(
|
||||
.as_ref()
|
||||
.expect("editor has no active completion");
|
||||
|
||||
if let EditPrediction::MoveWithin { target, .. } = &completion_state.completion {
|
||||
if let EditPrediction::Move { target, .. } = &completion_state.completion {
|
||||
assert(editor.buffer().read(cx).snapshot(cx), *target);
|
||||
} else {
|
||||
panic!("expected move completion");
|
||||
@@ -325,7 +326,7 @@ fn propose_edits<T: ToOffset>(
|
||||
|
||||
cx.update(|_, cx| {
|
||||
provider.update(cx, |provider, _| {
|
||||
provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
|
||||
provider.set_edit_prediction(Some(edit_prediction::EditPrediction {
|
||||
id: None,
|
||||
edits: edits.collect(),
|
||||
edit_preview: None,
|
||||
@@ -356,7 +357,7 @@ fn propose_edits_non_zed<T: ToOffset>(
|
||||
|
||||
cx.update(|_, cx| {
|
||||
provider.update(cx, |provider, _| {
|
||||
provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
|
||||
provider.set_edit_prediction(Some(edit_prediction::EditPrediction {
|
||||
id: None,
|
||||
edits: edits.collect(),
|
||||
edit_preview: None,
|
||||
@@ -417,6 +418,7 @@ impl EditPredictionProvider for FakeEditPredictionProvider {
|
||||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
_project: Option<Entity<Project>>,
|
||||
_buffer: gpui::Entity<language::Buffer>,
|
||||
_cursor_position: language::Anchor,
|
||||
_debounce: bool,
|
||||
@@ -490,6 +492,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
|
||||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
_project: Option<Entity<Project>>,
|
||||
_buffer: gpui::Entity<language::Buffer>,
|
||||
_cursor_position: language::Anchor,
|
||||
_debounce: bool,
|
||||
|
||||
@@ -279,15 +279,15 @@ impl InlineValueCache {
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub enum InlayId {
|
||||
EditPrediction(u32),
|
||||
DebuggerValue(u32),
|
||||
EditPrediction(usize),
|
||||
DebuggerValue(usize),
|
||||
// LSP
|
||||
Hint(u32),
|
||||
Color(u32),
|
||||
Hint(usize),
|
||||
Color(usize),
|
||||
}
|
||||
|
||||
impl InlayId {
|
||||
fn id(&self) -> u32 {
|
||||
fn id(&self) -> usize {
|
||||
match self {
|
||||
Self::EditPrediction(id) => *id,
|
||||
Self::DebuggerValue(id) => *id,
|
||||
@@ -638,23 +638,17 @@ enum EditPrediction {
|
||||
display_mode: EditDisplayMode,
|
||||
snapshot: BufferSnapshot,
|
||||
},
|
||||
/// Move to a specific location in the active editor
|
||||
MoveWithin {
|
||||
Move {
|
||||
target: Anchor,
|
||||
snapshot: BufferSnapshot,
|
||||
},
|
||||
/// Move to a specific location in a different editor (not the active one)
|
||||
MoveOutside {
|
||||
target: language::Anchor,
|
||||
snapshot: BufferSnapshot,
|
||||
},
|
||||
}
|
||||
|
||||
struct EditPredictionState {
|
||||
inlay_ids: Vec<InlayId>,
|
||||
completion: EditPrediction,
|
||||
completion_id: Option<SharedString>,
|
||||
invalidation_range: Option<Range<Anchor>>,
|
||||
invalidation_range: Range<Anchor>,
|
||||
}
|
||||
|
||||
enum EditPredictionSettings {
|
||||
@@ -1124,8 +1118,7 @@ pub struct Editor {
|
||||
edit_prediction_indent_conflict: bool,
|
||||
edit_prediction_requires_modifier_in_indent_conflict: bool,
|
||||
inlay_hint_cache: InlayHintCache,
|
||||
next_inlay_id: u32,
|
||||
next_color_inlay_id: u32,
|
||||
next_inlay_id: usize,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
pixel_position_of_newest_cursor: Option<gpui::Point<Pixels>>,
|
||||
gutter_dimensions: GutterDimensions,
|
||||
@@ -1189,6 +1182,7 @@ pub struct Editor {
|
||||
pub change_list: ChangeList,
|
||||
inline_value_cache: InlineValueCache,
|
||||
selection_drag_state: SelectionDragState,
|
||||
next_color_inlay_id: usize,
|
||||
colors: Option<LspColorData>,
|
||||
folding_newlines: Task<()>,
|
||||
pub lookup_key: Option<Box<dyn Any + Send + Sync>>,
|
||||
@@ -7181,7 +7175,13 @@ impl Editor {
|
||||
return None;
|
||||
}
|
||||
|
||||
provider.refresh(buffer, cursor_buffer_position, debounce, cx);
|
||||
provider.refresh(
|
||||
self.project.clone(),
|
||||
buffer,
|
||||
cursor_buffer_position,
|
||||
debounce,
|
||||
cx,
|
||||
);
|
||||
Some(())
|
||||
}
|
||||
|
||||
@@ -7424,8 +7424,10 @@ impl Editor {
|
||||
return;
|
||||
};
|
||||
|
||||
self.report_edit_prediction_event(active_edit_prediction.completion_id.clone(), true, cx);
|
||||
|
||||
match &active_edit_prediction.completion {
|
||||
EditPrediction::MoveWithin { target, .. } => {
|
||||
EditPrediction::Move { target, .. } => {
|
||||
let target = *target;
|
||||
|
||||
if let Some(position_map) = &self.last_position_map {
|
||||
@@ -7467,19 +7469,7 @@ impl Editor {
|
||||
}
|
||||
}
|
||||
}
|
||||
EditPrediction::MoveOutside { snapshot, target } => {
|
||||
if let Some(workspace) = self.workspace() {
|
||||
Self::open_editor_at_anchor(snapshot, *target, &workspace, window, cx)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
}
|
||||
EditPrediction::Edit { edits, .. } => {
|
||||
self.report_edit_prediction_event(
|
||||
active_edit_prediction.completion_id.clone(),
|
||||
true,
|
||||
cx,
|
||||
);
|
||||
|
||||
if let Some(provider) = self.edit_prediction_provider() {
|
||||
provider.accept(cx);
|
||||
}
|
||||
@@ -7532,8 +7522,10 @@ impl Editor {
|
||||
return;
|
||||
}
|
||||
|
||||
self.report_edit_prediction_event(active_edit_prediction.completion_id.clone(), true, cx);
|
||||
|
||||
match &active_edit_prediction.completion {
|
||||
EditPrediction::MoveWithin { target, .. } => {
|
||||
EditPrediction::Move { target, .. } => {
|
||||
let target = *target;
|
||||
self.change_selections(
|
||||
SelectionEffects::scroll(Autoscroll::newest()),
|
||||
@@ -7544,19 +7536,7 @@ impl Editor {
|
||||
},
|
||||
);
|
||||
}
|
||||
EditPrediction::MoveOutside { snapshot, target } => {
|
||||
if let Some(workspace) = self.workspace() {
|
||||
Self::open_editor_at_anchor(snapshot, *target, &workspace, window, cx)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
}
|
||||
EditPrediction::Edit { edits, .. } => {
|
||||
self.report_edit_prediction_event(
|
||||
active_edit_prediction.completion_id.clone(),
|
||||
true,
|
||||
cx,
|
||||
);
|
||||
|
||||
// Find an insertion that starts at the cursor position.
|
||||
let snapshot = self.buffer.read(cx).snapshot(cx);
|
||||
let cursor_offset = self.selections.newest::<usize>(cx).head();
|
||||
@@ -7651,36 +7631,6 @@ impl Editor {
|
||||
);
|
||||
}
|
||||
|
||||
fn open_editor_at_anchor(
|
||||
snapshot: &language::BufferSnapshot,
|
||||
target: language::Anchor,
|
||||
workspace: &Entity<Workspace>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<()>> {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
let path = snapshot.file().map(|file| file.full_path(cx));
|
||||
let Some(path) =
|
||||
path.and_then(|path| workspace.project().read(cx).find_project_path(path, cx))
|
||||
else {
|
||||
return Task::ready(Err(anyhow::anyhow!("Project path not found")));
|
||||
};
|
||||
let target = text::ToPoint::to_point(&target, snapshot);
|
||||
let item = workspace.open_path(path, None, true, window, cx);
|
||||
window.spawn(cx, async move |cx| {
|
||||
let Some(editor) = item.await?.downcast::<Editor>() else {
|
||||
return Ok(());
|
||||
};
|
||||
editor
|
||||
.update_in(cx, |editor, window, cx| {
|
||||
editor.go_to_singleton_buffer_point(target, window, cx);
|
||||
})
|
||||
.ok();
|
||||
anyhow::Ok(())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn has_active_edit_prediction(&self) -> bool {
|
||||
self.active_edit_prediction.is_some()
|
||||
}
|
||||
@@ -7896,10 +7846,7 @@ impl Editor {
|
||||
.active_edit_prediction
|
||||
.as_ref()
|
||||
.is_some_and(|completion| {
|
||||
let Some(invalidation_range) = completion.invalidation_range.as_ref() else {
|
||||
return false;
|
||||
};
|
||||
let invalidation_range = invalidation_range.to_offset(&multibuffer);
|
||||
let invalidation_range = completion.invalidation_range.to_offset(&multibuffer);
|
||||
let invalidation_range = invalidation_range.start..=invalidation_range.end;
|
||||
!invalidation_range.contains(&offset_selection.head())
|
||||
})
|
||||
@@ -7935,31 +7882,8 @@ impl Editor {
|
||||
}
|
||||
|
||||
let edit_prediction = provider.suggest(&buffer, cursor_buffer_position, cx)?;
|
||||
|
||||
let (completion_id, edits, edit_preview) = match edit_prediction {
|
||||
edit_prediction::EditPrediction::Local {
|
||||
id,
|
||||
edits,
|
||||
edit_preview,
|
||||
} => (id, edits, edit_preview),
|
||||
edit_prediction::EditPrediction::Jump {
|
||||
id,
|
||||
snapshot,
|
||||
target,
|
||||
} => {
|
||||
self.stale_edit_prediction_in_menu = None;
|
||||
self.active_edit_prediction = Some(EditPredictionState {
|
||||
inlay_ids: vec![],
|
||||
completion: EditPrediction::MoveOutside { snapshot, target },
|
||||
completion_id: id,
|
||||
invalidation_range: None,
|
||||
});
|
||||
cx.notify();
|
||||
return Some(());
|
||||
}
|
||||
};
|
||||
|
||||
let edits = edits
|
||||
let edits = edit_prediction
|
||||
.edits
|
||||
.into_iter()
|
||||
.flat_map(|(range, new_text)| {
|
||||
let start = multibuffer.anchor_in_excerpt(excerpt_id, range.start)?;
|
||||
@@ -8004,7 +7928,7 @@ impl Editor {
|
||||
invalidation_row_range =
|
||||
move_invalidation_row_range.unwrap_or(edit_start_row..edit_end_row);
|
||||
let target = first_edit_start;
|
||||
EditPrediction::MoveWithin { target, snapshot }
|
||||
EditPrediction::Move { target, snapshot }
|
||||
} else {
|
||||
let show_completions_in_buffer = !self.edit_prediction_visible_in_cursor_popover(true)
|
||||
&& !self.edit_predictions_hidden_for_vim_mode;
|
||||
@@ -8053,7 +7977,7 @@ impl Editor {
|
||||
|
||||
EditPrediction::Edit {
|
||||
edits,
|
||||
edit_preview,
|
||||
edit_preview: edit_prediction.edit_preview,
|
||||
display_mode,
|
||||
snapshot,
|
||||
}
|
||||
@@ -8070,8 +7994,8 @@ impl Editor {
|
||||
self.active_edit_prediction = Some(EditPredictionState {
|
||||
inlay_ids,
|
||||
completion,
|
||||
completion_id,
|
||||
invalidation_range: Some(invalidation_range),
|
||||
completion_id: edit_prediction.id,
|
||||
invalidation_range,
|
||||
});
|
||||
|
||||
cx.notify();
|
||||
@@ -8657,7 +8581,7 @@ impl Editor {
|
||||
}
|
||||
|
||||
match &active_edit_prediction.completion {
|
||||
EditPrediction::MoveWithin { target, .. } => {
|
||||
EditPrediction::Move { target, .. } => {
|
||||
let target_display_point = target.to_display_point(editor_snapshot);
|
||||
|
||||
if self.edit_prediction_requires_modifier() {
|
||||
@@ -8742,28 +8666,6 @@ impl Editor {
|
||||
window,
|
||||
cx,
|
||||
),
|
||||
EditPrediction::MoveOutside { snapshot, .. } => {
|
||||
let file_name = snapshot
|
||||
.file()
|
||||
.map(|file| file.file_name(cx))
|
||||
.unwrap_or("untitled");
|
||||
let mut element = self
|
||||
.render_edit_prediction_line_popover(
|
||||
format!("Jump to {file_name}"),
|
||||
Some(IconName::ZedPredict),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.into_any();
|
||||
|
||||
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
|
||||
let origin_x = text_bounds.size.width / 2. - size.width / 2.;
|
||||
let origin_y = text_bounds.size.height - size.height - px(30.);
|
||||
let origin = text_bounds.origin + gpui::Point::new(origin_x, origin_y);
|
||||
element.prepaint_at(origin, window, cx);
|
||||
|
||||
Some((element, origin))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8828,13 +8730,13 @@ impl Editor {
|
||||
.items_end()
|
||||
.when(flag_on_right, |el| el.items_start())
|
||||
.child(if flag_on_right {
|
||||
self.render_edit_prediction_line_popover("Jump", None, window, cx)
|
||||
self.render_edit_prediction_line_popover("Jump", None, window, cx)?
|
||||
.rounded_bl(px(0.))
|
||||
.rounded_tl(px(0.))
|
||||
.border_l_2()
|
||||
.border_color(border_color)
|
||||
} else {
|
||||
self.render_edit_prediction_line_popover("Jump", None, window, cx)
|
||||
self.render_edit_prediction_line_popover("Jump", None, window, cx)?
|
||||
.rounded_br(px(0.))
|
||||
.rounded_tr(px(0.))
|
||||
.border_r_2()
|
||||
@@ -8874,7 +8776,7 @@ impl Editor {
|
||||
cx: &mut App,
|
||||
) -> Option<(AnyElement, gpui::Point<Pixels>)> {
|
||||
let mut element = self
|
||||
.render_edit_prediction_line_popover("Scroll", Some(scroll_icon), window, cx)
|
||||
.render_edit_prediction_line_popover("Scroll", Some(scroll_icon), window, cx)?
|
||||
.into_any();
|
||||
|
||||
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
|
||||
@@ -8914,7 +8816,7 @@ impl Editor {
|
||||
Some(IconName::ArrowUp),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
)?
|
||||
.into_any();
|
||||
|
||||
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
|
||||
@@ -8933,7 +8835,7 @@ impl Editor {
|
||||
Some(IconName::ArrowDown),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
)?
|
||||
.into_any();
|
||||
|
||||
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
|
||||
@@ -8980,7 +8882,7 @@ impl Editor {
|
||||
);
|
||||
|
||||
let mut element = self
|
||||
.render_edit_prediction_line_popover(label, None, window, cx)
|
||||
.render_edit_prediction_line_popover(label, None, window, cx)?
|
||||
.into_any();
|
||||
|
||||
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
|
||||
@@ -9007,7 +8909,7 @@ impl Editor {
|
||||
};
|
||||
|
||||
element = self
|
||||
.render_edit_prediction_line_popover(label, Some(icon), window, cx)
|
||||
.render_edit_prediction_line_popover(label, Some(icon), window, cx)?
|
||||
.into_any();
|
||||
|
||||
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
|
||||
@@ -9261,13 +9163,13 @@ impl Editor {
|
||||
icon: Option<IconName>,
|
||||
window: &mut Window,
|
||||
cx: &App,
|
||||
) -> Stateful<Div> {
|
||||
) -> Option<Stateful<Div>> {
|
||||
let padding_right = if icon.is_some() { px(4.) } else { px(8.) };
|
||||
|
||||
let keybind = self.render_edit_prediction_accept_keybind(window, cx);
|
||||
let has_keybind = keybind.is_some();
|
||||
|
||||
h_flex()
|
||||
let result = h_flex()
|
||||
.id("ep-line-popover")
|
||||
.py_0p5()
|
||||
.pl_1()
|
||||
@@ -9313,7 +9215,9 @@ impl Editor {
|
||||
.mt(px(1.5))
|
||||
.child(Icon::new(icon).size(IconSize::Small)),
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
Some(result)
|
||||
}
|
||||
|
||||
fn edit_prediction_line_popover_bg_color(cx: &App) -> Hsla {
|
||||
@@ -9377,7 +9281,7 @@ impl Editor {
|
||||
.rounded_tl(px(0.))
|
||||
.overflow_hidden()
|
||||
.child(div().px_1p5().child(match &prediction.completion {
|
||||
EditPrediction::MoveWithin { target, snapshot } => {
|
||||
EditPrediction::Move { target, snapshot } => {
|
||||
use text::ToPoint as _;
|
||||
if target.text_anchor.to_point(snapshot).row > cursor_point.row
|
||||
{
|
||||
@@ -9386,10 +9290,6 @@ impl Editor {
|
||||
Icon::new(IconName::ZedPredictUp)
|
||||
}
|
||||
}
|
||||
EditPrediction::MoveOutside { .. } => {
|
||||
// TODO [zeta2] custom icon for external jump?
|
||||
Icon::new(provider_icon)
|
||||
}
|
||||
EditPrediction::Edit { .. } => Icon::new(provider_icon),
|
||||
}))
|
||||
.child(
|
||||
@@ -9572,7 +9472,7 @@ impl Editor {
|
||||
.unwrap_or(true);
|
||||
|
||||
match &completion.completion {
|
||||
EditPrediction::MoveWithin {
|
||||
EditPrediction::Move {
|
||||
target, snapshot, ..
|
||||
} => {
|
||||
if !supports_jump {
|
||||
@@ -9594,20 +9494,7 @@ impl Editor {
|
||||
.child(Label::new("Jump to Edit")),
|
||||
)
|
||||
}
|
||||
EditPrediction::MoveOutside { snapshot, .. } => {
|
||||
let file_name = snapshot
|
||||
.file()
|
||||
.map(|file| file.file_name(cx))
|
||||
.unwrap_or("untitled");
|
||||
Some(
|
||||
h_flex()
|
||||
.px_2()
|
||||
.gap_2()
|
||||
.flex_1()
|
||||
.child(Icon::new(IconName::ZedPredict))
|
||||
.child(Label::new(format!("Jump to {file_name}"))),
|
||||
)
|
||||
}
|
||||
|
||||
EditPrediction::Edit {
|
||||
edits,
|
||||
edit_preview,
|
||||
@@ -12451,7 +12338,7 @@ impl Editor {
|
||||
}
|
||||
});
|
||||
});
|
||||
let item = self.cut_common(false, window, cx);
|
||||
let item = self.cut_common(true, window, cx);
|
||||
cx.set_global(KillRing(item))
|
||||
}
|
||||
|
||||
@@ -20644,7 +20531,7 @@ impl Editor {
|
||||
Anchor::in_buffer(excerpt_id, buffer_id, hint.position),
|
||||
hint.text(),
|
||||
);
|
||||
if !inlay.text().chars().contains(&'\n') {
|
||||
if !inlay.text.chars().contains(&'\n') {
|
||||
new_inlays.push(inlay);
|
||||
}
|
||||
});
|
||||
@@ -21531,7 +21418,7 @@ impl Editor {
|
||||
{
|
||||
self.hide_context_menu(window, cx);
|
||||
}
|
||||
self.take_active_edit_prediction(cx);
|
||||
self.discard_edit_prediction(false, cx);
|
||||
cx.emit(EditorEvent::Blurred);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
@@ -6742,14 +6742,6 @@ async fn test_cut_line_ends(cx: &mut TestAppContext) {
|
||||
|
||||
let mut cx = EditorTestContext::new(cx).await;
|
||||
|
||||
cx.set_state(indoc! {"The quick brownˇ"});
|
||||
cx.update_editor(|e, window, cx| e.cut_to_end_of_line(&CutToEndOfLine::default(), window, cx));
|
||||
cx.assert_editor_state(indoc! {"The quick brownˇ"});
|
||||
|
||||
cx.set_state(indoc! {"The emacs foxˇ"});
|
||||
cx.update_editor(|e, window, cx| e.kill_ring_cut(&KillRingCut, window, cx));
|
||||
cx.assert_editor_state(indoc! {"The emacs foxˇ"});
|
||||
|
||||
cx.set_state(indoc! {"
|
||||
The quick« brownˇ»
|
||||
fox jumps overˇ
|
||||
@@ -8272,7 +8264,7 @@ async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext)
|
||||
|
||||
cx.update(|_, cx| {
|
||||
provider.update(cx, |provider, _| {
|
||||
provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
|
||||
provider.set_edit_prediction(Some(edit_prediction::EditPrediction {
|
||||
id: None,
|
||||
edits: vec![(edit_position..edit_position, "X".into())],
|
||||
edit_preview: None,
|
||||
|
||||
@@ -370,7 +370,7 @@ pub fn update_inlay_link_and_hover_points(
|
||||
inlay: hovered_hint.id,
|
||||
inlay_position: hovered_hint.position,
|
||||
range: extra_shift_left
|
||||
..hovered_hint.text().len() + extra_shift_right,
|
||||
..hovered_hint.text.len() + extra_shift_right,
|
||||
},
|
||||
},
|
||||
window,
|
||||
|
||||
@@ -3571,7 +3571,7 @@ pub mod tests {
|
||||
editor
|
||||
.visible_inlay_hints(cx)
|
||||
.into_iter()
|
||||
.map(|hint| hint.text().to_string())
|
||||
.map(|hint| hint.text.to_string())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,53 +8,34 @@ pub async fn get_messages(working_directory: &Path, shas: &[Oid]) -> Result<Hash
|
||||
return Ok(HashMap::default());
|
||||
}
|
||||
|
||||
let output = if cfg!(windows) {
|
||||
// Windows has a maximum invocable command length, so we chunk the input.
|
||||
// Actual max is 32767, but we leave some room for the rest of the command as we aren't in precise control of what std might do here
|
||||
const MAX_CMD_LENGTH: usize = 30000;
|
||||
// 40 bytes of hash, 2 quotes and a separating space
|
||||
const SHA_LENGTH: usize = 40 + 2 + 1;
|
||||
const MAX_ENTRIES_PER_INVOCATION: usize = MAX_CMD_LENGTH / SHA_LENGTH;
|
||||
|
||||
let mut result = vec![];
|
||||
for shas in shas.chunks(MAX_ENTRIES_PER_INVOCATION) {
|
||||
let partial = get_messages_impl(working_directory, shas).await?;
|
||||
result.extend(partial);
|
||||
}
|
||||
result
|
||||
} else {
|
||||
get_messages_impl(working_directory, shas).await?
|
||||
};
|
||||
|
||||
Ok(shas
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(output)
|
||||
.collect::<HashMap<Oid, String>>())
|
||||
}
|
||||
|
||||
async fn get_messages_impl(working_directory: &Path, shas: &[Oid]) -> Result<Vec<String>> {
|
||||
const MARKER: &str = "<MARKER>";
|
||||
let mut cmd = util::command::new_smol_command("git");
|
||||
cmd.current_dir(working_directory)
|
||||
|
||||
let output = util::command::new_smol_command("git")
|
||||
.current_dir(working_directory)
|
||||
.arg("show")
|
||||
.arg("-s")
|
||||
.arg(format!("--format=%B{}", MARKER))
|
||||
.args(shas.iter().map(ToString::to_string));
|
||||
let output = cmd
|
||||
.args(shas.iter().map(ToString::to_string))
|
||||
.output()
|
||||
.await
|
||||
.with_context(|| format!("starting git blame process: {:?}", cmd))?;
|
||||
.context("starting git blame process")?;
|
||||
|
||||
anyhow::ensure!(
|
||||
output.status.success(),
|
||||
"'git show' failed with error {:?}",
|
||||
output.status
|
||||
);
|
||||
Ok(String::from_utf8_lossy(&output.stdout)
|
||||
.trim()
|
||||
.split_terminator(MARKER)
|
||||
.map(|str| str.trim().replace("<", "<").replace(">", ">"))
|
||||
.collect::<Vec<_>>())
|
||||
|
||||
Ok(shas
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(
|
||||
String::from_utf8_lossy(&output.stdout)
|
||||
.trim()
|
||||
.split_terminator(MARKER)
|
||||
.map(|str| str.trim().replace("<", "<").replace(">", ">")),
|
||||
)
|
||||
.collect::<HashMap<Oid, String>>())
|
||||
}
|
||||
|
||||
/// Parse the output of `git diff --name-status -z`
|
||||
|
||||
@@ -455,11 +455,7 @@ impl CommitModal {
|
||||
if can_commit {
|
||||
Tooltip::with_meta_in(
|
||||
tooltip,
|
||||
Some(if is_amend_pending {
|
||||
&git::Amend
|
||||
} else {
|
||||
&git::Commit
|
||||
}),
|
||||
Some(&git::Commit),
|
||||
format!(
|
||||
"git commit{}{}",
|
||||
if is_amend_pending { " --amend" } else { "" },
|
||||
|
||||
@@ -3458,7 +3458,7 @@ impl GitPanel {
|
||||
if can_commit {
|
||||
Tooltip::with_meta_in(
|
||||
tooltip,
|
||||
Some(if amend { &git::Amend } else { &git::Commit }),
|
||||
Some(&git::Commit),
|
||||
format!(
|
||||
"git commit{}{}",
|
||||
if amend { " --amend" } else { "" },
|
||||
|
||||
@@ -38,7 +38,7 @@ wayland = [
|
||||
"blade-macros",
|
||||
"blade-util",
|
||||
"bytemuck",
|
||||
"ashpd",
|
||||
"ashpd/wayland",
|
||||
"cosmic-text",
|
||||
"font-kit",
|
||||
"calloop-wayland-source",
|
||||
|
||||
@@ -73,6 +73,13 @@ pub trait LinuxClient {
|
||||
fn active_window(&self) -> Option<AnyWindowHandle>;
|
||||
fn window_stack(&self) -> Option<Vec<AnyWindowHandle>>;
|
||||
fn run(&self);
|
||||
|
||||
#[cfg(any(feature = "wayland", feature = "x11"))]
|
||||
fn window_identifier(
|
||||
&self,
|
||||
) -> impl Future<Output = Option<ashpd::WindowIdentifier>> + Send + 'static {
|
||||
std::future::ready::<Option<ashpd::WindowIdentifier>>(None)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -290,6 +297,9 @@ impl<P: LinuxClient + 'static> Platform for P {
|
||||
#[cfg(not(any(feature = "wayland", feature = "x11")))]
|
||||
let _ = (done_tx.send(Ok(None)), options);
|
||||
|
||||
#[cfg(any(feature = "wayland", feature = "x11"))]
|
||||
let identifier = self.window_identifier();
|
||||
|
||||
#[cfg(any(feature = "wayland", feature = "x11"))]
|
||||
self.foreground_executor()
|
||||
.spawn(async move {
|
||||
@@ -300,6 +310,7 @@ impl<P: LinuxClient + 'static> Platform for P {
|
||||
};
|
||||
|
||||
let request = match ashpd::desktop::file_chooser::OpenFileRequest::default()
|
||||
.identifier(identifier.await)
|
||||
.modal(true)
|
||||
.title(title)
|
||||
.accept_label(options.prompt.as_ref().map(crate::SharedString::as_str))
|
||||
@@ -346,6 +357,9 @@ impl<P: LinuxClient + 'static> Platform for P {
|
||||
#[cfg(not(any(feature = "wayland", feature = "x11")))]
|
||||
let _ = (done_tx.send(Ok(None)), directory, suggested_name);
|
||||
|
||||
#[cfg(any(feature = "wayland", feature = "x11"))]
|
||||
let identifier = self.window_identifier();
|
||||
|
||||
#[cfg(any(feature = "wayland", feature = "x11"))]
|
||||
self.foreground_executor()
|
||||
.spawn({
|
||||
@@ -355,6 +369,7 @@ impl<P: LinuxClient + 'static> Platform for P {
|
||||
async move {
|
||||
let mut request_builder =
|
||||
ashpd::desktop::file_chooser::SaveFileRequest::default()
|
||||
.identifier(identifier.await)
|
||||
.modal(true)
|
||||
.title("Save File")
|
||||
.current_folder(directory)
|
||||
|
||||
@@ -7,6 +7,7 @@ use std::{
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use ashpd::WindowIdentifier;
|
||||
use calloop::{
|
||||
EventLoop, LoopHandle,
|
||||
timer::{TimeoutAction, Timer},
|
||||
@@ -858,6 +859,20 @@ impl LinuxClient for WaylandClient {
|
||||
fn compositor_name(&self) -> &'static str {
|
||||
"Wayland"
|
||||
}
|
||||
|
||||
fn window_identifier(&self) -> impl Future<Output = Option<WindowIdentifier>> + Send + 'static {
|
||||
async fn inner(surface: Option<wl_surface::WlSurface>) -> Option<WindowIdentifier> {
|
||||
if let Some(surface) = surface {
|
||||
ashpd::WindowIdentifier::from_wayland(&surface).await
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
let client_state = self.0.borrow();
|
||||
let active_window = client_state.keyboard_focused_window.as_ref();
|
||||
inner(active_window.map(|aw| aw.surface()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Dispatch<wl_registry::WlRegistry, GlobalListContents> for WaylandClientStatePtr {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::{Capslock, xcb_flush};
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use ashpd::WindowIdentifier;
|
||||
use calloop::{
|
||||
EventLoop, LoopHandle, RegistrationToken,
|
||||
generic::{FdWrapper, Generic},
|
||||
@@ -1660,6 +1661,16 @@ impl LinuxClient for X11Client {
|
||||
|
||||
Some(handles)
|
||||
}
|
||||
|
||||
fn window_identifier(&self) -> impl Future<Output = Option<WindowIdentifier>> + Send + 'static {
|
||||
let state = self.0.borrow();
|
||||
state
|
||||
.keyboard_focused_window
|
||||
.and_then(|focused_window| state.windows.get(&focused_window))
|
||||
.map(|window| window.window.x_window as u64)
|
||||
.map(|x_window| std::future::ready(Some(WindowIdentifier::from_xid(x_window))))
|
||||
.unwrap_or(std::future::ready(None))
|
||||
}
|
||||
}
|
||||
|
||||
impl X11ClientState {
|
||||
|
||||
@@ -284,7 +284,7 @@ pub(crate) struct X11WindowStatePtr {
|
||||
pub state: Rc<RefCell<X11WindowState>>,
|
||||
pub(crate) callbacks: Rc<RefCell<Callbacks>>,
|
||||
xcb: Rc<XCBConnection>,
|
||||
x_window: xproto::Window,
|
||||
pub(crate) x_window: xproto::Window,
|
||||
}
|
||||
|
||||
impl rwh::HasWindowHandle for RawWindow {
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
use std::{str::FromStr, sync::Arc};
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use gpui::{App, AsyncApp, BorrowAppContext as _, Entity, WeakEntity};
|
||||
use gpui::{App, AsyncApp, BorrowAppContext as _, Entity, SharedString, WeakEntity};
|
||||
use language::LanguageRegistry;
|
||||
use project::LspStore;
|
||||
|
||||
@@ -103,21 +103,12 @@ pub fn resolve_schema_request_inner(
|
||||
.into_iter()
|
||||
.map(|name| name.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut icon_theme_names = vec![];
|
||||
let mut theme_names = vec![];
|
||||
if let Some(registry) = theme::ThemeRegistry::try_global(cx) {
|
||||
icon_theme_names.extend(
|
||||
registry
|
||||
.list_icon_themes()
|
||||
.into_iter()
|
||||
.map(|icon_theme| icon_theme.name),
|
||||
);
|
||||
theme_names.extend(registry.list_names());
|
||||
}
|
||||
let icon_theme_names = icon_theme_names.as_slice();
|
||||
let theme_names = theme_names.as_slice();
|
||||
|
||||
let icon_theme_names = &theme::ThemeRegistry::global(cx)
|
||||
.list_icon_themes()
|
||||
.into_iter()
|
||||
.map(|icon_theme| icon_theme.name)
|
||||
.collect::<Vec<SharedString>>();
|
||||
let theme_names = &theme::ThemeRegistry::global(cx).list_names();
|
||||
cx.global::<settings::SettingsStore>().json_schema(
|
||||
&settings::SettingsJsonSchemaParams {
|
||||
language_names,
|
||||
|
||||
@@ -1008,16 +1008,6 @@ fn get_venv_parent_dir(env: &PythonEnvironment) -> Option<PathBuf> {
|
||||
venv.parent().map(|parent| parent.to_path_buf())
|
||||
}
|
||||
|
||||
fn wr_distance(wr: &PathBuf, venv: Option<&PathBuf>) -> usize {
|
||||
if let Some(venv) = venv
|
||||
&& let Ok(p) = venv.strip_prefix(wr)
|
||||
{
|
||||
p.components().count()
|
||||
} else {
|
||||
usize::MAX
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolchainLister for PythonToolchainProvider {
|
||||
async fn list(
|
||||
@@ -1079,7 +1069,12 @@ impl ToolchainLister for PythonToolchainProvider {
|
||||
let proj_ordering = || {
|
||||
let lhs_project = lhs.project.clone().or_else(|| get_venv_parent_dir(lhs));
|
||||
let rhs_project = rhs.project.clone().or_else(|| get_venv_parent_dir(rhs));
|
||||
wr_distance(&wr, lhs_project.as_ref()).cmp(&wr_distance(&wr, rhs_project.as_ref()))
|
||||
match (&lhs_project, &rhs_project) {
|
||||
(Some(l), Some(r)) => (r == &wr).cmp(&(l == &wr)),
|
||||
(Some(l), None) if l == &wr => Ordering::Less,
|
||||
(None, Some(r)) if r == &wr => Ordering::Greater,
|
||||
_ => Ordering::Equal,
|
||||
}
|
||||
};
|
||||
|
||||
// Compare environment priorities
|
||||
|
||||
@@ -54,7 +54,7 @@ use util::post_inc;
|
||||
const NEWLINES: &[u8] = &[b'\n'; u8::MAX as usize];
|
||||
|
||||
#[derive(Debug, Default, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct ExcerptId(u32);
|
||||
pub struct ExcerptId(usize);
|
||||
|
||||
/// One or more [`Buffers`](Buffer) being edited in a single view.
|
||||
///
|
||||
@@ -7202,7 +7202,7 @@ impl ExcerptId {
|
||||
}
|
||||
|
||||
pub fn max() -> Self {
|
||||
Self(u32::MAX)
|
||||
Self(usize::MAX)
|
||||
}
|
||||
|
||||
pub fn to_proto(self) -> u64 {
|
||||
@@ -7222,7 +7222,7 @@ impl ExcerptId {
|
||||
|
||||
impl From<ExcerptId> for usize {
|
||||
fn from(val: ExcerptId) -> Self {
|
||||
val.0 as usize
|
||||
val.0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -328,12 +328,7 @@ impl PickerDelegate for RecentProjectsDelegate {
|
||||
&Default::default(),
|
||||
cx.background_executor().clone(),
|
||||
));
|
||||
self.matches.sort_unstable_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score) // Descending score
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
.then_with(|| a.candidate_id.cmp(&b.candidate_id)) // Ascending candidate_id for ties
|
||||
});
|
||||
self.matches.sort_unstable_by_key(|m| m.candidate_id);
|
||||
|
||||
if self.reset_selected_match_index {
|
||||
self.selected_match_index = self
|
||||
|
||||
@@ -349,7 +349,6 @@ pub fn execute_run(
|
||||
.spawn(crashes::init(crashes::InitCrashHandler {
|
||||
session_id: id,
|
||||
zed_version: VERSION.to_owned(),
|
||||
binary: "zed-remote-server".to_string(),
|
||||
release_channel: release_channel::RELEASE_CHANNEL_NAME.clone(),
|
||||
commit_sha: option_env!("ZED_COMMIT_SHA").unwrap_or("no_sha").to_owned(),
|
||||
}))
|
||||
@@ -544,7 +543,6 @@ pub(crate) fn execute_proxy(
|
||||
smol::spawn(crashes::init(crashes::InitCrashHandler {
|
||||
session_id: id,
|
||||
zed_version: VERSION.to_owned(),
|
||||
binary: "zed-remote-server".to_string(),
|
||||
release_channel: release_channel::RELEASE_CHANNEL_NAME.clone(),
|
||||
commit_sha: option_env!("ZED_COMMIT_SHA").unwrap_or("no_sha").to_owned(),
|
||||
}))
|
||||
|
||||
@@ -2216,27 +2216,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rayon_stack_overflow() {
|
||||
let sz = 2usize.pow(30);
|
||||
let layout = std::alloc::Layout::from_size_align(sz, 16).unwrap();
|
||||
// SAFETY: Size is nonzero.
|
||||
let massive_alloc = unsafe { std::alloc::alloc(layout) };
|
||||
// SAFETY: `sz` is the same as used for the allocation and the pointer is unaliased.
|
||||
unsafe { massive_alloc.write_bytes(b'A', sz) };
|
||||
// SAFETY: `massive_alloc` is initialised by the write above for `sz * sizeof(u8)`.
|
||||
let text =
|
||||
unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts(massive_alloc, sz)) };
|
||||
let pool: rayon::ThreadPool = rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(1)
|
||||
.build()
|
||||
.unwrap();
|
||||
pool.install(|| {
|
||||
let mut rope = Rope::new();
|
||||
rope.push_large(text);
|
||||
});
|
||||
}
|
||||
|
||||
fn clip_offset(text: &str, mut offset: usize, bias: Bias) -> usize {
|
||||
while !text.is_char_boundary(offset) {
|
||||
match bias {
|
||||
|
||||
@@ -22,6 +22,7 @@ gpui.workspace = true
|
||||
language.workspace = true
|
||||
log.workspace = true
|
||||
postage.workspace = true
|
||||
project.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
|
||||
@@ -4,6 +4,7 @@ use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
|
||||
use futures::StreamExt as _;
|
||||
use gpui::{App, Context, Entity, EntityId, Task};
|
||||
use language::{Anchor, Buffer, BufferSnapshot};
|
||||
use project::Project;
|
||||
use std::{
|
||||
ops::{AddAssign, Range},
|
||||
path::Path,
|
||||
@@ -93,7 +94,7 @@ fn completion_from_diff(
|
||||
edits.push((edit_range, edit_text));
|
||||
}
|
||||
|
||||
EditPrediction::Local {
|
||||
EditPrediction {
|
||||
id: None,
|
||||
edits,
|
||||
edit_preview: None,
|
||||
@@ -131,6 +132,7 @@ impl EditPredictionProvider for SupermavenCompletionProvider {
|
||||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
_project: Option<Entity<Project>>,
|
||||
buffer_handle: Entity<Buffer>,
|
||||
cursor_position: Anchor,
|
||||
debounce: bool,
|
||||
|
||||
@@ -62,23 +62,26 @@ impl Anchor {
|
||||
}
|
||||
|
||||
pub fn bias(&self, bias: Bias, buffer: &BufferSnapshot) -> Anchor {
|
||||
match bias {
|
||||
Bias::Left => self.bias_left(buffer),
|
||||
Bias::Right => self.bias_right(buffer),
|
||||
if bias == Bias::Left {
|
||||
self.bias_left(buffer)
|
||||
} else {
|
||||
self.bias_right(buffer)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bias_left(&self, buffer: &BufferSnapshot) -> Anchor {
|
||||
match self.bias {
|
||||
Bias::Left => *self,
|
||||
Bias::Right => buffer.anchor_before(self),
|
||||
if self.bias == Bias::Left {
|
||||
*self
|
||||
} else {
|
||||
buffer.anchor_before(self)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bias_right(&self, buffer: &BufferSnapshot) -> Anchor {
|
||||
match self.bias {
|
||||
Bias::Left => buffer.anchor_after(self),
|
||||
Bias::Right => *self,
|
||||
if self.bias == Bias::Right {
|
||||
*self
|
||||
} else {
|
||||
buffer.anchor_after(self)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,7 +96,7 @@ impl Anchor {
|
||||
pub fn is_valid(&self, buffer: &BufferSnapshot) -> bool {
|
||||
if *self == Anchor::MIN || *self == Anchor::MAX {
|
||||
true
|
||||
} else if self.buffer_id.is_none_or(|id| id != buffer.remote_id) {
|
||||
} else if self.buffer_id != Some(buffer.remote_id) {
|
||||
false
|
||||
} else {
|
||||
let Some(fragment_id) = buffer.try_fragment_id_for_anchor(self) else {
|
||||
|
||||
@@ -1047,8 +1047,8 @@ impl PickerDelegate for ToolchainSelectorDelegate {
|
||||
let toolchain = toolchain.clone();
|
||||
let scope = scope.clone();
|
||||
|
||||
this.end_slot(IconButton::new(id, IconName::Trash).on_click(cx.listener(
|
||||
move |this, _, _, cx| {
|
||||
this.end_slot(IconButton::new(id, IconName::Trash))
|
||||
.on_click(cx.listener(move |this, _, _, cx| {
|
||||
this.delegate.project.update(cx, |this, cx| {
|
||||
this.remove_toolchain(toolchain.clone(), scope.clone(), cx)
|
||||
});
|
||||
@@ -1076,8 +1076,7 @@ impl PickerDelegate for ToolchainSelectorDelegate {
|
||||
}
|
||||
cx.stop_propagation();
|
||||
cx.notify();
|
||||
},
|
||||
)))
|
||||
}))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -997,13 +997,6 @@ impl WorkspaceDb {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
conn.exec_bound(
|
||||
sql!(
|
||||
DELETE FROM user_toolchains WHERE workspace_id = ?1;
|
||||
)
|
||||
)?(workspace.id).context("Clearing old user toolchains")?;
|
||||
|
||||
for (scope, toolchains) in workspace.user_toolchains {
|
||||
for toolchain in toolchains {
|
||||
let query = sql!(INSERT OR REPLACE INTO user_toolchains(remote_connection_id, workspace_id, worktree_id, relative_worktree_path, language_name, name, path, raw_json) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8));
|
||||
|
||||
@@ -270,7 +270,6 @@ pub fn main() {
|
||||
.spawn(crashes::init(InitCrashHandler {
|
||||
session_id: session_id.clone(),
|
||||
zed_version: app_version.to_string(),
|
||||
binary: "zed".to_string(),
|
||||
release_channel: release_channel::RELEASE_CHANNEL_NAME.clone(),
|
||||
commit_sha: app_commit_sha
|
||||
.as_ref()
|
||||
|
||||
@@ -330,7 +330,6 @@ async fn upload_minidump(
|
||||
metadata.init.release_channel.clone(),
|
||||
)
|
||||
.text("sentry[tags][version]", metadata.init.zed_version.clone())
|
||||
.text("sentry[tags][binary]", metadata.init.binary.clone())
|
||||
.text("sentry[release]", metadata.init.commit_sha.clone())
|
||||
.text("platform", "rust");
|
||||
let mut panic_message = "".to_owned();
|
||||
|
||||
@@ -205,48 +205,42 @@ fn assign_edit_prediction_provider(
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(project) = editor.project() {
|
||||
if std::env::var("ZED_ZETA2").is_ok() {
|
||||
let zeta = zeta2::Zeta::global(client, &user_store, cx);
|
||||
let provider = cx.new(|cx| {
|
||||
zeta2::ZetaEditPredictionProvider::new(
|
||||
project.clone(),
|
||||
&client,
|
||||
&user_store,
|
||||
cx,
|
||||
)
|
||||
if std::env::var("ZED_ZETA2").is_ok() {
|
||||
let zeta = zeta2::Zeta::global(client, &user_store, cx);
|
||||
let provider = cx.new(|cx| {
|
||||
zeta2::ZetaEditPredictionProvider::new(
|
||||
editor.project(),
|
||||
&client,
|
||||
&user_store,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
if let Some(buffer) = &singleton_buffer
|
||||
&& buffer.read(cx).file().is_some()
|
||||
&& let Some(project) = editor.project()
|
||||
{
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.register_buffer(buffer, project, cx);
|
||||
});
|
||||
|
||||
// TODO [zeta2] handle multibuffers
|
||||
if let Some(buffer) = &singleton_buffer
|
||||
&& buffer.read(cx).file().is_some()
|
||||
{
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.register_buffer(buffer, project, cx);
|
||||
});
|
||||
}
|
||||
|
||||
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
||||
} else {
|
||||
let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);
|
||||
|
||||
if let Some(buffer) = &singleton_buffer
|
||||
&& buffer.read(cx).file().is_some()
|
||||
{
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.register_buffer(buffer, project, cx);
|
||||
});
|
||||
}
|
||||
|
||||
let provider = cx.new(|_| {
|
||||
zeta::ZetaEditPredictionProvider::new(
|
||||
zeta,
|
||||
project.clone(),
|
||||
singleton_buffer,
|
||||
)
|
||||
});
|
||||
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
||||
}
|
||||
|
||||
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
||||
} else {
|
||||
let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);
|
||||
|
||||
if let Some(buffer) = &singleton_buffer
|
||||
&& buffer.read(cx).file().is_some()
|
||||
&& let Some(project) = editor.project()
|
||||
{
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.register_buffer(buffer, project, cx);
|
||||
});
|
||||
}
|
||||
|
||||
let provider =
|
||||
cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, singleton_buffer));
|
||||
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1316,17 +1316,12 @@ pub struct ZetaEditPredictionProvider {
|
||||
next_pending_completion_id: usize,
|
||||
current_completion: Option<CurrentEditPrediction>,
|
||||
last_request_timestamp: Instant,
|
||||
project: Entity<Project>,
|
||||
}
|
||||
|
||||
impl ZetaEditPredictionProvider {
|
||||
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
|
||||
|
||||
pub fn new(
|
||||
zeta: Entity<Zeta>,
|
||||
project: Entity<Project>,
|
||||
singleton_buffer: Option<Entity<Buffer>>,
|
||||
) -> Self {
|
||||
pub fn new(zeta: Entity<Zeta>, singleton_buffer: Option<Entity<Buffer>>) -> Self {
|
||||
Self {
|
||||
zeta,
|
||||
singleton_buffer,
|
||||
@@ -1334,7 +1329,6 @@ impl ZetaEditPredictionProvider {
|
||||
next_pending_completion_id: 0,
|
||||
current_completion: None,
|
||||
last_request_timestamp: Instant::now(),
|
||||
project,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1400,6 +1394,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
project: Option<Entity<Project>>,
|
||||
buffer: Entity<Buffer>,
|
||||
position: language::Anchor,
|
||||
_debounce: bool,
|
||||
@@ -1408,6 +1403,9 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
if self.zeta.read(cx).update_required {
|
||||
return;
|
||||
}
|
||||
let Some(project) = project else {
|
||||
return;
|
||||
};
|
||||
|
||||
if self
|
||||
.zeta
|
||||
@@ -1435,7 +1433,6 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
self.next_pending_completion_id += 1;
|
||||
let last_request_timestamp = self.last_request_timestamp;
|
||||
|
||||
let project = self.project.clone();
|
||||
let task = cx.spawn(async move |this, cx| {
|
||||
if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
|
||||
.checked_duration_since(Instant::now())
|
||||
@@ -1607,7 +1604,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
Some(edit_prediction::EditPrediction::Local {
|
||||
Some(edit_prediction::EditPrediction {
|
||||
id: Some(completion.id.to_string().into()),
|
||||
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
|
||||
edit_preview: Some(completion.edit_preview.clone()),
|
||||
|
||||
@@ -1,18 +1,35 @@
|
||||
use std::{borrow::Cow, ops::Range, path::Path, sync::Arc};
|
||||
use std::{borrow::Cow, ops::Range, sync::Arc};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use cloud_llm_client::predict_edits_v3;
|
||||
use gpui::{App, AsyncApp, Entity};
|
||||
use language::{
|
||||
Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot, text_diff,
|
||||
};
|
||||
use project::Project;
|
||||
use util::ResultExt;
|
||||
use language::{Anchor, BufferSnapshot, EditPreview, OffsetRangeExt, text_diff};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EditPrediction {
|
||||
pub id: EditPredictionId,
|
||||
pub edits: Arc<[(Range<Anchor>, String)]>,
|
||||
pub snapshot: BufferSnapshot,
|
||||
pub edit_preview: EditPreview,
|
||||
}
|
||||
|
||||
impl EditPrediction {
|
||||
pub fn interpolate(
|
||||
&self,
|
||||
new_snapshot: &BufferSnapshot,
|
||||
) -> Option<Vec<(Range<Anchor>, String)>> {
|
||||
interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct EditPredictionId(Uuid);
|
||||
|
||||
impl From<Uuid> for EditPredictionId {
|
||||
fn from(value: Uuid) -> Self {
|
||||
EditPredictionId(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EditPredictionId> for gpui::ElementId {
|
||||
fn from(value: EditPredictionId) -> Self {
|
||||
gpui::ElementId::Uuid(value.0)
|
||||
@@ -25,122 +42,9 @@ impl std::fmt::Display for EditPredictionId {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EditPrediction {
|
||||
pub id: EditPredictionId,
|
||||
pub path: Arc<Path>,
|
||||
pub edits: Arc<[(Range<Anchor>, String)]>,
|
||||
pub snapshot: BufferSnapshot,
|
||||
pub edit_preview: EditPreview,
|
||||
// We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
|
||||
_buffer: Entity<Buffer>,
|
||||
}
|
||||
|
||||
impl EditPrediction {
|
||||
pub async fn from_response(
|
||||
response: predict_edits_v3::PredictEditsResponse,
|
||||
active_buffer_old_snapshot: &TextBufferSnapshot,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Option<Self> {
|
||||
// TODO only allow cloud to return one path
|
||||
let Some(path) = response.edits.first().map(|e| e.path.clone()) else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let is_same_path = active_buffer
|
||||
.read_with(cx, |buffer, cx| buffer_path_eq(buffer, &path, cx))
|
||||
.ok()?;
|
||||
|
||||
let (buffer, edits, snapshot, edit_preview_task) = if is_same_path {
|
||||
active_buffer
|
||||
.read_with(cx, |buffer, cx| {
|
||||
let new_snapshot = buffer.snapshot();
|
||||
let edits = edits_from_response(&response.edits, &active_buffer_old_snapshot);
|
||||
let edits: Arc<[_]> =
|
||||
interpolate_edits(active_buffer_old_snapshot, &new_snapshot, edits)?.into();
|
||||
|
||||
Some((
|
||||
active_buffer.clone(),
|
||||
edits.clone(),
|
||||
new_snapshot,
|
||||
buffer.preview_edits(edits, cx),
|
||||
))
|
||||
})
|
||||
.ok()??
|
||||
} else {
|
||||
let buffer_handle = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = project
|
||||
.find_project_path(&path, cx)
|
||||
.context("Failed to find project path for zeta edit")?;
|
||||
anyhow::Ok(project.open_buffer(project_path, cx))
|
||||
})
|
||||
.ok()?
|
||||
.log_err()?
|
||||
.await
|
||||
.context("Failed to open buffer for zeta edit")
|
||||
.log_err()?;
|
||||
|
||||
buffer_handle
|
||||
.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot();
|
||||
let edits = edits_from_response(&response.edits, &snapshot);
|
||||
if edits.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some((
|
||||
buffer_handle.clone(),
|
||||
edits.clone(),
|
||||
snapshot,
|
||||
buffer.preview_edits(edits, cx),
|
||||
))
|
||||
})
|
||||
.ok()??
|
||||
};
|
||||
|
||||
let edit_preview = edit_preview_task.await;
|
||||
|
||||
Some(EditPrediction {
|
||||
id: EditPredictionId(response.request_id),
|
||||
path,
|
||||
edits,
|
||||
snapshot,
|
||||
edit_preview,
|
||||
_buffer: buffer,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn interpolate(
|
||||
&self,
|
||||
new_snapshot: &TextBufferSnapshot,
|
||||
) -> Option<Vec<(Range<Anchor>, String)>> {
|
||||
interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
|
||||
}
|
||||
|
||||
pub fn targets_buffer(&self, buffer: &Buffer, cx: &App) -> bool {
|
||||
buffer_path_eq(buffer, &self.path, cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for EditPrediction {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("EditPrediction")
|
||||
.field("id", &self.id)
|
||||
.field("path", &self.path)
|
||||
.field("edits", &self.edits)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn buffer_path_eq(buffer: &Buffer, path: &Path, cx: &App) -> bool {
|
||||
buffer.file().map(|p| p.full_path(cx)).as_deref() == Some(path)
|
||||
}
|
||||
|
||||
pub fn interpolate_edits(
|
||||
old_snapshot: &TextBufferSnapshot,
|
||||
new_snapshot: &TextBufferSnapshot,
|
||||
old_snapshot: &BufferSnapshot,
|
||||
new_snapshot: &BufferSnapshot,
|
||||
current_edits: Arc<[(Range<Anchor>, String)]>,
|
||||
) -> Option<Vec<(Range<Anchor>, String)>> {
|
||||
let mut edits = Vec::new();
|
||||
@@ -184,13 +88,14 @@ pub fn interpolate_edits(
|
||||
if edits.is_empty() { None } else { Some(edits) }
|
||||
}
|
||||
|
||||
fn edits_from_response(
|
||||
pub fn edits_from_response(
|
||||
edits: &[predict_edits_v3::Edit],
|
||||
snapshot: &TextBufferSnapshot,
|
||||
snapshot: &BufferSnapshot,
|
||||
) -> Arc<[(Range<Anchor>, String)]> {
|
||||
edits
|
||||
.iter()
|
||||
.flat_map(|edit| {
|
||||
// TODO multi-file edits
|
||||
let old_text = snapshot.text_for_range(edit.range.clone());
|
||||
|
||||
excerpt_edits_from_response(
|
||||
@@ -208,7 +113,7 @@ fn excerpt_edits_from_response(
|
||||
old_text: Cow<str>,
|
||||
new_text: &str,
|
||||
offset: usize,
|
||||
snapshot: &TextBufferSnapshot,
|
||||
snapshot: &BufferSnapshot,
|
||||
) -> impl Iterator<Item = (Range<Anchor>, String)> {
|
||||
text_diff(&old_text, new_text)
|
||||
.into_iter()
|
||||
@@ -316,8 +221,6 @@ mod tests {
|
||||
id: EditPredictionId(Uuid::new_v4()),
|
||||
edits,
|
||||
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
|
||||
path: Path::new("test.txt").into(),
|
||||
_buffer: buffer.clone(),
|
||||
edit_preview,
|
||||
};
|
||||
|
||||
|
||||
@@ -4,44 +4,76 @@ use std::{
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use arrayvec::ArrayVec;
|
||||
use client::{Client, UserStore};
|
||||
use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
|
||||
use gpui::{App, Entity, Task, prelude::*};
|
||||
use language::ToPoint as _;
|
||||
use gpui::{App, Entity, EntityId, Task, prelude::*};
|
||||
use language::{BufferSnapshot, ToPoint as _};
|
||||
use project::Project;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{BufferEditPrediction, Zeta};
|
||||
use crate::{Zeta, prediction::EditPrediction};
|
||||
|
||||
pub struct ZetaEditPredictionProvider {
|
||||
zeta: Entity<Zeta>,
|
||||
current_prediction: Option<CurrentEditPrediction>,
|
||||
next_pending_prediction_id: usize,
|
||||
pending_predictions: ArrayVec<PendingPrediction, 2>,
|
||||
last_request_timestamp: Instant,
|
||||
project: Entity<Project>,
|
||||
}
|
||||
|
||||
impl ZetaEditPredictionProvider {
|
||||
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
|
||||
|
||||
pub fn new(
|
||||
project: Entity<Project>,
|
||||
project: Option<&Entity<Project>>,
|
||||
client: &Arc<Client>,
|
||||
user_store: &Entity<UserStore>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let zeta = Zeta::global(client, user_store, cx);
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.register_project(&project, cx);
|
||||
});
|
||||
if let Some(project) = project {
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.register_project(project, cx);
|
||||
});
|
||||
}
|
||||
|
||||
Self {
|
||||
zeta,
|
||||
current_prediction: None,
|
||||
next_pending_prediction_id: 0,
|
||||
pending_predictions: ArrayVec::new(),
|
||||
last_request_timestamp: Instant::now(),
|
||||
project: project,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CurrentEditPrediction {
|
||||
buffer_id: EntityId,
|
||||
prediction: EditPrediction,
|
||||
}
|
||||
|
||||
impl CurrentEditPrediction {
|
||||
fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
|
||||
if self.buffer_id != old_prediction.buffer_id {
|
||||
return true;
|
||||
}
|
||||
|
||||
let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
|
||||
return true;
|
||||
};
|
||||
let Some(new_edits) = self.prediction.interpolate(snapshot) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
if old_edits.len() == 1 && new_edits.len() == 1 {
|
||||
let (old_range, old_text) = &old_edits[0];
|
||||
let (new_range, new_text) = &new_edits[0];
|
||||
new_range == old_range && new_text.starts_with(old_text)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -96,31 +128,42 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
project: Option<Entity<project::Project>>,
|
||||
buffer: Entity<language::Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
_debounce: bool,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let zeta = self.zeta.read(cx);
|
||||
let Some(project) = project else {
|
||||
return;
|
||||
};
|
||||
|
||||
if zeta.user_store.read_with(cx, |user_store, _cx| {
|
||||
user_store.account_too_young() || user_store.has_overdue_invoices()
|
||||
}) {
|
||||
if self
|
||||
.zeta
|
||||
.read(cx)
|
||||
.user_store
|
||||
.read_with(cx, |user_store, _cx| {
|
||||
user_store.account_too_young() || user_store.has_overdue_invoices()
|
||||
})
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx)
|
||||
&& let BufferEditPrediction::Local { prediction } = current
|
||||
&& prediction.interpolate(buffer.read(cx)).is_some()
|
||||
{
|
||||
return;
|
||||
if let Some(current_prediction) = self.current_prediction.as_ref() {
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
if current_prediction
|
||||
.prediction
|
||||
.interpolate(&snapshot)
|
||||
.is_some()
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let pending_prediction_id = self.next_pending_prediction_id;
|
||||
self.next_pending_prediction_id += 1;
|
||||
let last_request_timestamp = self.last_request_timestamp;
|
||||
|
||||
let project = self.project.clone();
|
||||
let task = cx.spawn(async move |this, cx| {
|
||||
if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
|
||||
.checked_duration_since(Instant::now())
|
||||
@@ -128,16 +171,25 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
cx.background_executor().timer(timeout).await;
|
||||
}
|
||||
|
||||
let refresh_task = this.update(cx, |this, cx| {
|
||||
let prediction_request = this.update(cx, |this, cx| {
|
||||
this.last_request_timestamp = Instant::now();
|
||||
this.zeta.update(cx, |zeta, cx| {
|
||||
zeta.refresh_prediction(&project, &buffer, cursor_position, cx)
|
||||
zeta.request_prediction(&project, &buffer, cursor_position, cx)
|
||||
})
|
||||
});
|
||||
|
||||
if let Some(refresh_task) = refresh_task.ok() {
|
||||
refresh_task.await.log_err();
|
||||
}
|
||||
let prediction = match prediction_request {
|
||||
Ok(prediction_request) => {
|
||||
let prediction_request = prediction_request.await;
|
||||
prediction_request.map(|c| {
|
||||
c.map(|prediction| CurrentEditPrediction {
|
||||
buffer_id: buffer.entity_id(),
|
||||
prediction,
|
||||
})
|
||||
})
|
||||
}
|
||||
Err(error) => Err(error),
|
||||
};
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
if this.pending_predictions[0].id == pending_prediction_id {
|
||||
@@ -146,6 +198,24 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
this.pending_predictions.clear();
|
||||
}
|
||||
|
||||
let Some(new_prediction) = prediction
|
||||
.context("edit prediction failed")
|
||||
.log_err()
|
||||
.flatten()
|
||||
else {
|
||||
cx.notify();
|
||||
return;
|
||||
};
|
||||
|
||||
if let Some(old_prediction) = this.current_prediction.as_ref() {
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
|
||||
this.current_prediction = Some(new_prediction);
|
||||
}
|
||||
} else {
|
||||
this.current_prediction = Some(new_prediction);
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
@@ -178,18 +248,15 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
) {
|
||||
}
|
||||
|
||||
fn accept(&mut self, cx: &mut Context<Self>) {
|
||||
self.zeta.update(cx, |zeta, _cx| {
|
||||
zeta.accept_current_prediction(&self.project);
|
||||
});
|
||||
fn accept(&mut self, _cx: &mut Context<Self>) {
|
||||
// TODO [zeta2] report accept
|
||||
self.current_prediction.take();
|
||||
self.pending_predictions.clear();
|
||||
}
|
||||
|
||||
fn discard(&mut self, cx: &mut Context<Self>) {
|
||||
self.zeta.update(cx, |zeta, _cx| {
|
||||
zeta.discard_current_prediction(&self.project);
|
||||
});
|
||||
fn discard(&mut self, _cx: &mut Context<Self>) {
|
||||
self.pending_predictions.clear();
|
||||
self.current_prediction.take();
|
||||
}
|
||||
|
||||
fn suggest(
|
||||
@@ -198,44 +265,36 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
cursor_position: language::Anchor,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<edit_prediction::EditPrediction> {
|
||||
let prediction =
|
||||
self.zeta
|
||||
.read(cx)
|
||||
.current_prediction_for_buffer(buffer, &self.project, cx)?;
|
||||
let CurrentEditPrediction {
|
||||
buffer_id,
|
||||
prediction,
|
||||
..
|
||||
} = self.current_prediction.as_mut()?;
|
||||
|
||||
let prediction = match prediction {
|
||||
BufferEditPrediction::Local { prediction } => prediction,
|
||||
BufferEditPrediction::Jump { prediction } => {
|
||||
return Some(edit_prediction::EditPrediction::Jump {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
snapshot: prediction.snapshot.clone(),
|
||||
target: prediction.edits.first().unwrap().0.start,
|
||||
});
|
||||
}
|
||||
};
|
||||
// Invalidate previous prediction if it was generated for a different buffer.
|
||||
if *buffer_id != buffer.entity_id() {
|
||||
self.current_prediction.take();
|
||||
return None;
|
||||
}
|
||||
|
||||
let buffer = buffer.read(cx);
|
||||
let snapshot = buffer.snapshot();
|
||||
|
||||
let Some(edits) = prediction.interpolate(&snapshot) else {
|
||||
self.zeta.update(cx, |zeta, _cx| {
|
||||
zeta.discard_current_prediction(&self.project);
|
||||
});
|
||||
let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
|
||||
self.current_prediction.take();
|
||||
return None;
|
||||
};
|
||||
|
||||
let cursor_row = cursor_position.to_point(&snapshot).row;
|
||||
let cursor_row = cursor_position.to_point(buffer).row;
|
||||
let (closest_edit_ix, (closest_edit_range, _)) =
|
||||
edits.iter().enumerate().min_by_key(|(_, (range, _))| {
|
||||
let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
|
||||
let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
|
||||
let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
|
||||
let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
|
||||
cmp::min(distance_from_start, distance_from_end)
|
||||
})?;
|
||||
|
||||
let mut edit_start_ix = closest_edit_ix;
|
||||
for (range, _) in edits[..edit_start_ix].iter().rev() {
|
||||
let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
|
||||
- range.end.to_point(&snapshot).row;
|
||||
let distance_from_closest_edit =
|
||||
closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
|
||||
if distance_from_closest_edit <= 1 {
|
||||
edit_start_ix -= 1;
|
||||
} else {
|
||||
@@ -246,7 +305,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
let mut edit_end_ix = closest_edit_ix + 1;
|
||||
for (range, _) in &edits[edit_end_ix..] {
|
||||
let distance_from_closest_edit =
|
||||
range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
|
||||
range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
|
||||
if distance_from_closest_edit <= 1 {
|
||||
edit_end_ix += 1;
|
||||
} else {
|
||||
@@ -254,7 +313,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
Some(edit_prediction::EditPrediction::Local {
|
||||
Some(edit_prediction::EditPrediction {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
|
||||
edit_preview: Some(prediction.edit_preview.clone()),
|
||||
|
||||
@@ -17,8 +17,8 @@ use gpui::{
|
||||
App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
|
||||
http_client, prelude::*,
|
||||
};
|
||||
use language::BufferSnapshot;
|
||||
use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
|
||||
use language::{BufferSnapshot, TextBufferSnapshot};
|
||||
use language_model::{LlmApiToken, RefreshLlmTokenListener};
|
||||
use project::Project;
|
||||
use release_channel::AppVersion;
|
||||
@@ -35,7 +35,7 @@ use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_noti
|
||||
mod prediction;
|
||||
mod provider;
|
||||
|
||||
use crate::prediction::EditPrediction;
|
||||
use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits};
|
||||
pub use provider::ZetaEditPredictionProvider;
|
||||
|
||||
const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
|
||||
@@ -53,7 +53,7 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
|
||||
excerpt: DEFAULT_EXCERPT_OPTIONS,
|
||||
max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
|
||||
max_diagnostic_bytes: 2048,
|
||||
prompt_format: PromptFormat::DEFAULT,
|
||||
prompt_format: PromptFormat::MarkedExcerpt,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -94,47 +94,6 @@ struct ZetaProject {
|
||||
syntax_index: Entity<SyntaxIndex>,
|
||||
events: VecDeque<Event>,
|
||||
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
|
||||
current_prediction: Option<CurrentEditPrediction>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CurrentEditPrediction {
|
||||
pub requested_by_buffer_id: EntityId,
|
||||
pub prediction: EditPrediction,
|
||||
}
|
||||
|
||||
impl CurrentEditPrediction {
|
||||
fn should_replace_prediction(
|
||||
&self,
|
||||
old_prediction: &Self,
|
||||
snapshot: &TextBufferSnapshot,
|
||||
) -> bool {
|
||||
if self.requested_by_buffer_id != old_prediction.requested_by_buffer_id {
|
||||
return true;
|
||||
}
|
||||
|
||||
let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
|
||||
return true;
|
||||
};
|
||||
|
||||
let Some(new_edits) = self.prediction.interpolate(snapshot) else {
|
||||
return false;
|
||||
};
|
||||
if old_edits.len() == 1 && new_edits.len() == 1 {
|
||||
let (old_range, old_text) = &old_edits[0];
|
||||
let (new_range, new_text) = &new_edits[0];
|
||||
new_range == old_range && new_text.starts_with(old_text)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A prediction from the perspective of a buffer.
|
||||
#[derive(Debug)]
|
||||
enum BufferEditPrediction<'a> {
|
||||
Local { prediction: &'a EditPrediction },
|
||||
Jump { prediction: &'a EditPrediction },
|
||||
}
|
||||
|
||||
struct RegisteredBuffer {
|
||||
@@ -245,7 +204,6 @@ impl Zeta {
|
||||
syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
|
||||
events: VecDeque::new(),
|
||||
registered_buffers: HashMap::new(),
|
||||
current_prediction: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -347,83 +305,7 @@ impl Zeta {
|
||||
events.push_back(event);
|
||||
}
|
||||
|
||||
fn current_prediction_for_buffer(
|
||||
&self,
|
||||
buffer: &Entity<Buffer>,
|
||||
project: &Entity<Project>,
|
||||
cx: &App,
|
||||
) -> Option<BufferEditPrediction<'_>> {
|
||||
let project_state = self.projects.get(&project.entity_id())?;
|
||||
|
||||
let CurrentEditPrediction {
|
||||
requested_by_buffer_id,
|
||||
prediction,
|
||||
} = project_state.current_prediction.as_ref()?;
|
||||
|
||||
if prediction.targets_buffer(buffer.read(cx), cx) {
|
||||
Some(BufferEditPrediction::Local { prediction })
|
||||
} else if *requested_by_buffer_id == buffer.entity_id() {
|
||||
Some(BufferEditPrediction::Jump { prediction })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn accept_current_prediction(&mut self, project: &Entity<Project>) {
|
||||
if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
|
||||
project_state.current_prediction.take();
|
||||
};
|
||||
// TODO report accepted
|
||||
}
|
||||
|
||||
fn discard_current_prediction(&mut self, project: &Entity<Project>) {
|
||||
if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
|
||||
project_state.current_prediction.take();
|
||||
};
|
||||
}
|
||||
|
||||
pub fn refresh_prediction(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
position: language::Anchor,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
let request_task = self.request_prediction(project, buffer, position, cx);
|
||||
let buffer = buffer.clone();
|
||||
let project = project.clone();
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
if let Some(prediction) = request_task.await? {
|
||||
this.update(cx, |this, cx| {
|
||||
let project_state = this
|
||||
.projects
|
||||
.get_mut(&project.entity_id())
|
||||
.context("Project not found")?;
|
||||
|
||||
let new_prediction = CurrentEditPrediction {
|
||||
requested_by_buffer_id: buffer.entity_id(),
|
||||
prediction: prediction,
|
||||
};
|
||||
|
||||
if project_state
|
||||
.current_prediction
|
||||
.as_ref()
|
||||
.is_none_or(|old_prediction| {
|
||||
new_prediction
|
||||
.should_replace_prediction(&old_prediction, buffer.read(cx))
|
||||
})
|
||||
{
|
||||
project_state.current_prediction = Some(new_prediction);
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})??;
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn request_prediction(
|
||||
pub fn request_prediction(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
@@ -457,7 +339,7 @@ impl Zeta {
|
||||
state
|
||||
.events
|
||||
.iter()
|
||||
.filter_map(|event| match event {
|
||||
.map(|event| match event {
|
||||
Event::BufferChange {
|
||||
old_snapshot,
|
||||
new_snapshot,
|
||||
@@ -474,20 +356,15 @@ impl Zeta {
|
||||
}
|
||||
});
|
||||
|
||||
// TODO [zeta2] move to bg?
|
||||
let diff =
|
||||
language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
|
||||
|
||||
if path == old_path && diff.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(predict_edits_v3::Event::BufferChange {
|
||||
old_path,
|
||||
path,
|
||||
diff,
|
||||
//todo: Actually detect if this edit was predicted or not
|
||||
predicted: false,
|
||||
})
|
||||
predict_edits_v3::Event::BufferChange {
|
||||
old_path,
|
||||
path,
|
||||
diff: language::unified_diff(
|
||||
&old_snapshot.text(),
|
||||
&new_snapshot.text(),
|
||||
),
|
||||
//todo: Actually detect if this edit was predicted or not
|
||||
predicted: false,
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -575,63 +452,74 @@ impl Zeta {
|
||||
.ok();
|
||||
}
|
||||
|
||||
anyhow::Ok(Some(response?))
|
||||
let (response, usage) = response?;
|
||||
let edits = edits_from_response(&response.edits, &snapshot);
|
||||
|
||||
anyhow::Ok(Some((response.request_id, edits, usage)))
|
||||
}
|
||||
});
|
||||
|
||||
let buffer = buffer.clone();
|
||||
|
||||
cx.spawn({
|
||||
let project = project.clone();
|
||||
async move |this, cx| {
|
||||
match request_task.await {
|
||||
Ok(Some((response, usage))) => {
|
||||
if let Some(usage) = usage {
|
||||
this.update(cx, |this, cx| {
|
||||
this.user_store.update(cx, |user_store, cx| {
|
||||
user_store.update_edit_prediction_usage(usage, cx);
|
||||
});
|
||||
cx.spawn(async move |this, cx| {
|
||||
match request_task.await {
|
||||
Ok(Some((id, edits, usage))) => {
|
||||
if let Some(usage) = usage {
|
||||
this.update(cx, |this, cx| {
|
||||
this.user_store.update(cx, |user_store, cx| {
|
||||
user_store.update_edit_prediction_usage(usage, cx);
|
||||
});
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
// TODO telemetry: duration, etc
|
||||
let Some((edits, snapshot, edit_preview_task)) =
|
||||
buffer.read_with(cx, |buffer, cx| {
|
||||
let new_snapshot = buffer.snapshot();
|
||||
let edits: Arc<[_]> =
|
||||
interpolate_edits(&snapshot, &new_snapshot, edits)?.into();
|
||||
Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
|
||||
})?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
Ok(Some(EditPrediction {
|
||||
id: id.into(),
|
||||
edits,
|
||||
snapshot,
|
||||
edit_preview: edit_preview_task.await,
|
||||
}))
|
||||
}
|
||||
Ok(None) => Ok(None),
|
||||
Err(err) => {
|
||||
if err.is::<ZedUpdateRequiredError>() {
|
||||
cx.update(|cx| {
|
||||
this.update(cx, |this, _cx| {
|
||||
this.update_required = true;
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
let prediction = EditPrediction::from_response(
|
||||
response, &snapshot, &buffer, &project, cx,
|
||||
)
|
||||
.await;
|
||||
|
||||
// TODO telemetry: duration, etc
|
||||
Ok(prediction)
|
||||
let error_message: SharedString = err.to_string().into();
|
||||
show_app_notification(
|
||||
NotificationId::unique::<ZedUpdateRequiredError>(),
|
||||
cx,
|
||||
move |cx| {
|
||||
cx.new(|cx| {
|
||||
ErrorMessagePrompt::new(error_message.clone(), cx)
|
||||
.with_link_button(
|
||||
"Update Zed",
|
||||
"https://zed.dev/releases",
|
||||
)
|
||||
})
|
||||
},
|
||||
);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
Ok(None) => Ok(None),
|
||||
Err(err) => {
|
||||
if err.is::<ZedUpdateRequiredError>() {
|
||||
cx.update(|cx| {
|
||||
this.update(cx, |this, _cx| {
|
||||
this.update_required = true;
|
||||
})
|
||||
.ok();
|
||||
|
||||
let error_message: SharedString = err.to_string().into();
|
||||
show_app_notification(
|
||||
NotificationId::unique::<ZedUpdateRequiredError>(),
|
||||
cx,
|
||||
move |cx| {
|
||||
cx.new(|cx| {
|
||||
ErrorMessagePrompt::new(error_message.clone(), cx)
|
||||
.with_link_button(
|
||||
"Update Zed",
|
||||
"https://zed.dev/releases",
|
||||
)
|
||||
})
|
||||
},
|
||||
);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
Err(err)
|
||||
}
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -966,113 +854,13 @@ mod tests {
|
||||
};
|
||||
use indoc::indoc;
|
||||
use language::{LanguageServerId, OffsetRangeExt as _};
|
||||
use pretty_assertions::{assert_eq, assert_matches};
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{BufferEditPrediction, Zeta};
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_current_state(cx: &mut TestAppContext) {
|
||||
let (zeta, mut req_rx) = init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
json!({
|
||||
"1.txt": "Hello!\nHow\nBye",
|
||||
"2.txt": "Hola!\nComo\nAdios"
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
|
||||
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.register_project(&project, cx);
|
||||
});
|
||||
|
||||
let buffer1 = project
|
||||
.update(cx, |project, cx| {
|
||||
let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
|
||||
project.open_buffer(path, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
let position = snapshot1.anchor_before(language::Point::new(1, 3));
|
||||
|
||||
// Prediction for current file
|
||||
|
||||
let prediction_task = zeta.update(cx, |zeta, cx| {
|
||||
zeta.refresh_prediction(&project, &buffer1, position, cx)
|
||||
});
|
||||
let (_request, respond_tx) = req_rx.next().await.unwrap();
|
||||
respond_tx
|
||||
.send(predict_edits_v3::PredictEditsResponse {
|
||||
request_id: Uuid::new_v4(),
|
||||
edits: vec![predict_edits_v3::Edit {
|
||||
path: Path::new(path!("root/1.txt")).into(),
|
||||
range: 0..snapshot1.len(),
|
||||
content: "Hello!\nHow are you?\nBye".into(),
|
||||
}],
|
||||
debug_info: None,
|
||||
})
|
||||
.unwrap();
|
||||
prediction_task.await.unwrap();
|
||||
|
||||
zeta.read_with(cx, |zeta, cx| {
|
||||
let prediction = zeta
|
||||
.current_prediction_for_buffer(&buffer1, &project, cx)
|
||||
.unwrap();
|
||||
assert_matches!(prediction, BufferEditPrediction::Local { .. });
|
||||
});
|
||||
|
||||
// Prediction for another file
|
||||
|
||||
let prediction_task = zeta.update(cx, |zeta, cx| {
|
||||
zeta.refresh_prediction(&project, &buffer1, position, cx)
|
||||
});
|
||||
let (_request, respond_tx) = req_rx.next().await.unwrap();
|
||||
respond_tx
|
||||
.send(predict_edits_v3::PredictEditsResponse {
|
||||
request_id: Uuid::new_v4(),
|
||||
edits: vec![predict_edits_v3::Edit {
|
||||
path: Path::new(path!("root/2.txt")).into(),
|
||||
range: 0..snapshot1.len(),
|
||||
content: "Hola!\nComo estas?\nAdios".into(),
|
||||
}],
|
||||
debug_info: None,
|
||||
})
|
||||
.unwrap();
|
||||
prediction_task.await.unwrap();
|
||||
|
||||
zeta.read_with(cx, |zeta, cx| {
|
||||
let prediction = zeta
|
||||
.current_prediction_for_buffer(&buffer1, &project, cx)
|
||||
.unwrap();
|
||||
assert_matches!(
|
||||
prediction,
|
||||
BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
|
||||
);
|
||||
});
|
||||
|
||||
let buffer2 = project
|
||||
.update(cx, |project, cx| {
|
||||
let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
|
||||
project.open_buffer(path, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
zeta.read_with(cx, |zeta, cx| {
|
||||
let prediction = zeta
|
||||
.current_prediction_for_buffer(&buffer2, &project, cx)
|
||||
.unwrap();
|
||||
assert_matches!(prediction, BufferEditPrediction::Local { .. });
|
||||
});
|
||||
}
|
||||
use crate::Zeta;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_simple_request(cx: &mut TestAppContext) {
|
||||
@@ -1353,7 +1141,6 @@ mod tests {
|
||||
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
let zeta = Zeta::global(&client, &user_store, cx);
|
||||
|
||||
(zeta, req_rx)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -185,7 +185,7 @@ impl Zeta2Inspector {
|
||||
cx.background_executor().timer(THROTTLE_TIME).await;
|
||||
if let Some(task) = zeta
|
||||
.update(cx, |zeta, cx| {
|
||||
zeta.refresh_prediction(&project, &buffer, position, cx)
|
||||
zeta.request_prediction(&project, &buffer, position, cx)
|
||||
})
|
||||
.ok()
|
||||
{
|
||||
|
||||
@@ -410,6 +410,7 @@ tower = { version = "0.5", default-features = false, features = ["timeout", "uti
|
||||
[target.x86_64-unknown-linux-gnu.dependencies]
|
||||
aes = { version = "0.8", default-features = false, features = ["zeroize"] }
|
||||
ahash = { version = "0.8", default-features = false, features = ["compile-time-rng"] }
|
||||
ashpd = { version = "0.11", default-features = false, features = ["async-std", "wayland"] }
|
||||
bytemuck = { version = "1", default-features = false, features = ["min_const_generics"] }
|
||||
cipher = { version = "0.4", default-features = false, features = ["block-padding", "rand_core", "zeroize"] }
|
||||
codespan-reporting = { version = "0.12" }
|
||||
@@ -448,12 +449,15 @@ tokio-rustls = { version = "0.26", default-features = false, features = ["loggin
|
||||
tokio-socks = { version = "0.5", features = ["futures-io"] }
|
||||
tokio-stream = { version = "0.1", features = ["fs"] }
|
||||
tower = { version = "0.5", default-features = false, features = ["timeout", "util"] }
|
||||
wayland-backend = { version = "0.3", default-features = false, features = ["client_system", "dlopen"] }
|
||||
wayland-sys = { version = "0.31", default-features = false, features = ["client", "dlopen"] }
|
||||
zeroize = { version = "1", features = ["zeroize_derive"] }
|
||||
zvariant = { version = "5", features = ["enumflags2", "gvariant", "url"] }
|
||||
|
||||
[target.x86_64-unknown-linux-gnu.build-dependencies]
|
||||
aes = { version = "0.8", default-features = false, features = ["zeroize"] }
|
||||
ahash = { version = "0.8", default-features = false, features = ["compile-time-rng"] }
|
||||
ashpd = { version = "0.11", default-features = false, features = ["async-std", "wayland"] }
|
||||
bytemuck = { version = "1", default-features = false, features = ["min_const_generics"] }
|
||||
cipher = { version = "0.4", default-features = false, features = ["block-padding", "rand_core", "zeroize"] }
|
||||
codespan-reporting = { version = "0.12" }
|
||||
@@ -490,12 +494,16 @@ tokio-rustls = { version = "0.26", default-features = false, features = ["loggin
|
||||
tokio-socks = { version = "0.5", features = ["futures-io"] }
|
||||
tokio-stream = { version = "0.1", features = ["fs"] }
|
||||
tower = { version = "0.5", default-features = false, features = ["timeout", "util"] }
|
||||
wayland-backend = { version = "0.3", default-features = false, features = ["client_system", "dlopen"] }
|
||||
wayland-sys = { version = "0.31", default-features = false, features = ["client", "dlopen"] }
|
||||
zbus_macros = { version = "5", features = ["gvariant"] }
|
||||
zeroize = { version = "1", features = ["zeroize_derive"] }
|
||||
zvariant = { version = "5", features = ["enumflags2", "gvariant", "url"] }
|
||||
|
||||
[target.aarch64-unknown-linux-gnu.dependencies]
|
||||
aes = { version = "0.8", default-features = false, features = ["zeroize"] }
|
||||
ahash = { version = "0.8", default-features = false, features = ["compile-time-rng"] }
|
||||
ashpd = { version = "0.11", default-features = false, features = ["async-std", "wayland"] }
|
||||
bytemuck = { version = "1", default-features = false, features = ["min_const_generics"] }
|
||||
cipher = { version = "0.4", default-features = false, features = ["block-padding", "rand_core", "zeroize"] }
|
||||
codespan-reporting = { version = "0.12" }
|
||||
@@ -534,12 +542,15 @@ tokio-rustls = { version = "0.26", default-features = false, features = ["loggin
|
||||
tokio-socks = { version = "0.5", features = ["futures-io"] }
|
||||
tokio-stream = { version = "0.1", features = ["fs"] }
|
||||
tower = { version = "0.5", default-features = false, features = ["timeout", "util"] }
|
||||
wayland-backend = { version = "0.3", default-features = false, features = ["client_system", "dlopen"] }
|
||||
wayland-sys = { version = "0.31", default-features = false, features = ["client", "dlopen"] }
|
||||
zeroize = { version = "1", features = ["zeroize_derive"] }
|
||||
zvariant = { version = "5", features = ["enumflags2", "gvariant", "url"] }
|
||||
|
||||
[target.aarch64-unknown-linux-gnu.build-dependencies]
|
||||
aes = { version = "0.8", default-features = false, features = ["zeroize"] }
|
||||
ahash = { version = "0.8", default-features = false, features = ["compile-time-rng"] }
|
||||
ashpd = { version = "0.11", default-features = false, features = ["async-std", "wayland"] }
|
||||
bytemuck = { version = "1", default-features = false, features = ["min_const_generics"] }
|
||||
cipher = { version = "0.4", default-features = false, features = ["block-padding", "rand_core", "zeroize"] }
|
||||
codespan-reporting = { version = "0.12" }
|
||||
@@ -576,6 +587,9 @@ tokio-rustls = { version = "0.26", default-features = false, features = ["loggin
|
||||
tokio-socks = { version = "0.5", features = ["futures-io"] }
|
||||
tokio-stream = { version = "0.1", features = ["fs"] }
|
||||
tower = { version = "0.5", default-features = false, features = ["timeout", "util"] }
|
||||
wayland-backend = { version = "0.3", default-features = false, features = ["client_system", "dlopen"] }
|
||||
wayland-sys = { version = "0.31", default-features = false, features = ["client", "dlopen"] }
|
||||
zbus_macros = { version = "5", features = ["gvariant"] }
|
||||
zeroize = { version = "1", features = ["zeroize_derive"] }
|
||||
zvariant = { version = "5", features = ["enumflags2", "gvariant", "url"] }
|
||||
|
||||
@@ -635,6 +649,7 @@ windows-sys-d4189bed749088b6 = { package = "windows-sys", version = "0.61", feat
|
||||
[target.x86_64-unknown-linux-musl.dependencies]
|
||||
aes = { version = "0.8", default-features = false, features = ["zeroize"] }
|
||||
ahash = { version = "0.8", default-features = false, features = ["compile-time-rng"] }
|
||||
ashpd = { version = "0.11", default-features = false, features = ["async-std", "wayland"] }
|
||||
bytemuck = { version = "1", default-features = false, features = ["min_const_generics"] }
|
||||
cipher = { version = "0.4", default-features = false, features = ["block-padding", "rand_core", "zeroize"] }
|
||||
codespan-reporting = { version = "0.12" }
|
||||
@@ -673,12 +688,15 @@ tokio-rustls = { version = "0.26", default-features = false, features = ["loggin
|
||||
tokio-socks = { version = "0.5", features = ["futures-io"] }
|
||||
tokio-stream = { version = "0.1", features = ["fs"] }
|
||||
tower = { version = "0.5", default-features = false, features = ["timeout", "util"] }
|
||||
wayland-backend = { version = "0.3", default-features = false, features = ["client_system", "dlopen"] }
|
||||
wayland-sys = { version = "0.31", default-features = false, features = ["client", "dlopen"] }
|
||||
zeroize = { version = "1", features = ["zeroize_derive"] }
|
||||
zvariant = { version = "5", features = ["enumflags2", "gvariant", "url"] }
|
||||
|
||||
[target.x86_64-unknown-linux-musl.build-dependencies]
|
||||
aes = { version = "0.8", default-features = false, features = ["zeroize"] }
|
||||
ahash = { version = "0.8", default-features = false, features = ["compile-time-rng"] }
|
||||
ashpd = { version = "0.11", default-features = false, features = ["async-std", "wayland"] }
|
||||
bytemuck = { version = "1", default-features = false, features = ["min_const_generics"] }
|
||||
cipher = { version = "0.4", default-features = false, features = ["block-padding", "rand_core", "zeroize"] }
|
||||
codespan-reporting = { version = "0.12" }
|
||||
@@ -715,6 +733,9 @@ tokio-rustls = { version = "0.26", default-features = false, features = ["loggin
|
||||
tokio-socks = { version = "0.5", features = ["futures-io"] }
|
||||
tokio-stream = { version = "0.1", features = ["fs"] }
|
||||
tower = { version = "0.5", default-features = false, features = ["timeout", "util"] }
|
||||
wayland-backend = { version = "0.3", default-features = false, features = ["client_system", "dlopen"] }
|
||||
wayland-sys = { version = "0.31", default-features = false, features = ["client", "dlopen"] }
|
||||
zbus_macros = { version = "5", features = ["gvariant"] }
|
||||
zeroize = { version = "1", features = ["zeroize_derive"] }
|
||||
zvariant = { version = "5", features = ["enumflags2", "gvariant", "url"] }
|
||||
|
||||
|
||||
Reference in New Issue
Block a user