Compare commits

...

15 Commits

Author SHA1 Message Date
Agus Zubiaga
bf3c5705e7 Checkpoint: Displaying debug info
Co-Authored-By: Bennet <bennet@zed.dev>
2025-09-22 18:47:03 -03:00
Bennet Bo Fenner
989ff500d9 Track edit events 2025-09-22 18:40:16 +02:00
Bennet Bo Fenner
a5ad9a9615 Port should_replace_prediction 2025-09-22 15:13:53 +02:00
Bennet Bo Fenner
ee4f8e7579 Port test_edit_prediction_basic_interpolation 2025-09-22 15:12:51 +02:00
Bennet Bo Fenner
4d9c4e187f Address check account todo 2025-09-22 14:57:13 +02:00
Bennet Bo Fenner
3944fb6ff7 Interpolate in suggest and refresh 2025-09-22 14:56:57 +02:00
Michael Sloan
49d9280344 Misc
Co-authored-by: Agus <agus@zed.dev>
2025-09-20 20:00:53 -06:00
Michael Sloan
1cc74ba885 Add ZED_ZETA2 env var 2025-09-19 16:02:05 -06:00
Michael Sloan
ad8bfbdf56 Send paths and ranges in zeta2 requests + add debug_info 2025-09-19 16:01:55 -06:00
Agus Zubiaga
439ab2575f Request completions
Co-Authored-By: Bennet <bennet@zed.dev>
2025-09-19 12:21:23 -03:00
Michael Sloan
a4024b495d Move cloud request building code to zeta2 + other misc changes 2025-09-19 00:56:14 -06:00
Michael Sloan
b511aa9274 Merge remote-tracking branch 'origin/zeta2-provider' into zeta2-cloud-request 2025-09-18 20:33:40 -06:00
Michael Sloan
cb0c4bec24 Progress preparing new cloud request + using index in excerpt selection
Co-authored-by: Agus <agus@zed.dev>
2025-09-18 17:39:23 -06:00
Agus Zubiaga
d8dd2b2977 Add zeta2 to registry 2025-09-15 12:17:01 -03:00
Agus Zubiaga
e19995431b Create zeta2 crate 2025-09-15 10:46:10 -03:00
24 changed files with 2375 additions and 675 deletions

91
Cargo.lock generated
View File

@@ -3213,6 +3213,7 @@ name = "cloud_llm_client"
version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
"pretty_assertions",
"serde",
"serde_json",
@@ -5173,7 +5174,9 @@ version = "0.1.0"
dependencies = [
"anyhow",
"arrayvec",
"chrono",
"clap",
"cloud_llm_client",
"collections",
"futures 0.3.31",
"gpui",
@@ -5197,33 +5200,6 @@ dependencies = [
"zlog",
]
[[package]]
name = "edit_prediction_tools"
version = "0.1.0"
dependencies = [
"clap",
"collections",
"edit_prediction_context",
"editor",
"futures 0.3.31",
"gpui",
"indoc",
"language",
"log",
"pretty_assertions",
"project",
"serde",
"serde_json",
"settings",
"text",
"ui",
"ui_input",
"util",
"workspace",
"workspace-hack",
"zlog",
]
[[package]]
name = "editor"
version = "0.1.0"
@@ -21244,7 +21220,6 @@ dependencies = [
"debugger_ui",
"diagnostics",
"edit_prediction_button",
"edit_prediction_tools",
"editor",
"env_logger 0.11.8",
"extension",
@@ -21355,6 +21330,8 @@ dependencies = [
"zed_actions",
"zed_env_vars",
"zeta",
"zeta2",
"zeta2_tools",
"zlog",
"zlog_settings",
]
@@ -21632,6 +21609,64 @@ dependencies = [
"zlog",
]
[[package]]
name = "zeta2"
version = "0.1.0"
dependencies = [
"anyhow",
"arrayvec",
"chrono",
"client",
"cloud_llm_client",
"edit_prediction",
"edit_prediction_context",
"futures 0.3.31",
"gpui",
"language",
"language_model",
"log",
"project",
"release_channel",
"serde_json",
"thiserror 2.0.12",
"util",
"uuid",
"workspace",
"workspace-hack",
"worktree",
]
[[package]]
name = "zeta2_tools"
version = "0.1.0"
dependencies = [
"chrono",
"clap",
"client",
"collections",
"edit_prediction_context",
"editor",
"futures 0.3.31",
"gpui",
"indoc",
"language",
"log",
"markdown",
"pretty_assertions",
"project",
"serde",
"serde_json",
"settings",
"text",
"ui",
"ui_input",
"util",
"workspace",
"workspace-hack",
"zeta2",
"zlog",
]
[[package]]
name = "zeta_cli"
version = "0.1.0"

View File

@@ -58,7 +58,7 @@ members = [
"crates/edit_prediction",
"crates/edit_prediction_button",
"crates/edit_prediction_context",
"crates/edit_prediction_tools",
"crates/zeta2_tools",
"crates/editor",
"crates/eval",
"crates/explorer_command_injector",
@@ -199,6 +199,7 @@ members = [
"crates/zed_actions",
"crates/zed_env_vars",
"crates/zeta",
"crates/zeta2",
"crates/zeta_cli",
"crates/zlog",
"crates/zlog_settings",
@@ -315,7 +316,7 @@ image_viewer = { path = "crates/image_viewer" }
edit_prediction = { path = "crates/edit_prediction" }
edit_prediction_button = { path = "crates/edit_prediction_button" }
edit_prediction_context = { path = "crates/edit_prediction_context" }
edit_prediction_tools = { path = "crates/edit_prediction_tools" }
zeta2_tools = { path = "crates/zeta2_tools" }
inspector_ui = { path = "crates/inspector_ui" }
install_cli = { path = "crates/install_cli" }
jj = { path = "crates/jj" }
@@ -431,6 +432,7 @@ zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
zed_env_vars = { path = "crates/zed_env_vars" }
zeta = { path = "crates/zeta" }
zeta2 = { path = "crates/zeta2" }
zlog = { path = "crates/zlog" }
zlog_settings = { path = "crates/zlog_settings" }

View File

@@ -13,6 +13,7 @@ path = "src/cloud_llm_client.rs"
[dependencies]
anyhow.workspace = true
chrono.workspace = true
serde = { workspace = true, features = ["derive", "rc"] }
serde_json.workspace = true
strum = { workspace = true, features = ["derive"] }

View File

@@ -1,3 +1,5 @@
pub mod predict_edits_v3;
use std::str::FromStr;
use std::sync::Arc;

View File

@@ -0,0 +1,172 @@
use chrono::Duration;
use serde::{Deserialize, Serialize};
use std::{ops::Range, path::PathBuf};
use uuid::Uuid;
use crate::PredictEditsGitInfo;
// TODO: snippet ordering within file / relative to excerpt
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictEditsRequest {
pub excerpt: String,
pub excerpt_path: PathBuf,
/// Within file
pub excerpt_range: Range<usize>,
/// Within `excerpt`
pub cursor_offset: usize,
/// Within `signatures`
pub excerpt_parent: Option<usize>,
pub signatures: Vec<Signature>,
pub referenced_declarations: Vec<ReferencedDeclaration>,
pub events: Vec<Event>,
#[serde(default)]
pub can_collect_data: bool,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub diagnostic_groups: Vec<DiagnosticGroup>,
/// Info about the git repository state, only present when can_collect_data is true.
#[serde(skip_serializing_if = "Option::is_none", default)]
pub git_info: Option<PredictEditsGitInfo>,
#[serde(default)]
pub debug_info: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "event")]
pub enum Event {
BufferChange {
path: Option<PathBuf>,
old_path: Option<PathBuf>,
diff: String,
predicted: bool,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Signature {
pub text: String,
pub text_is_truncated: bool,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub parent_index: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReferencedDeclaration {
pub path: PathBuf,
pub text: String,
pub text_is_truncated: bool,
/// Range of `text` within file, potentially truncated according to `text_is_truncated`
pub range: Range<usize>,
/// Range within `text`
pub signature_range: Range<usize>,
/// Index within `signatures`.
#[serde(skip_serializing_if = "Option::is_none", default)]
pub parent_index: Option<usize>,
pub score_components: ScoreComponents,
pub signature_score: f32,
pub declaration_score: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoreComponents {
pub is_same_file: bool,
pub is_referenced_nearby: bool,
pub is_referenced_in_breadcrumb: bool,
pub reference_count: usize,
pub same_file_declaration_count: usize,
pub declaration_count: usize,
pub reference_line_distance: u32,
pub declaration_line_distance: u32,
pub declaration_line_distance_rank: usize,
pub containing_range_vs_item_jaccard: f32,
pub containing_range_vs_signature_jaccard: f32,
pub adjacent_vs_item_jaccard: f32,
pub adjacent_vs_signature_jaccard: f32,
pub containing_range_vs_item_weighted_overlap: f32,
pub containing_range_vs_signature_weighted_overlap: f32,
pub adjacent_vs_item_weighted_overlap: f32,
pub adjacent_vs_signature_weighted_overlap: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiagnosticGroup {
pub language_server: String,
pub diagnostic_group: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictEditsResponse {
pub request_id: Uuid,
pub edits: Vec<Edit>,
pub debug_info: Option<DebugInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DebugInfo {
pub prompt: String,
pub prompt_planning_time: Duration,
pub model_response: String,
pub inference_time: Duration,
pub parsing_time: Duration,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Edit {
pub path: PathBuf,
pub range: Range<usize>,
pub content: String,
}
/*
#[derive(Debug, Clone)]
pub struct SerializedJson<T> {
raw: Box<RawValue>,
_phantom: PhantomData<T>,
}
impl<T> SerializedJson<T>
where
T: Serialize + for<'de> Deserialize<'de>,
{
pub fn new(value: &T) -> Result<Self, serde_json::Error> {
Ok(SerializedJson {
raw: serde_json::value::to_raw_value(value)?,
_phantom: PhantomData,
})
}
pub fn deserialize(&self) -> Result<T, serde_json::Error> {
serde_json::from_str(self.raw.get())
}
pub fn as_raw(&self) -> &RawValue {
&self.raw
}
pub fn into_raw(self) -> Box<RawValue> {
self.raw
}
}
impl<T> Serialize for SerializedJson<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.raw.serialize(serializer)
}
}
impl<'de, T> Deserialize<'de> for SerializedJson<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let raw = Box::<RawValue>::deserialize(deserializer)?;
Ok(SerializedJson {
raw,
_phantom: PhantomData,
})
}
}
*/

View File

@@ -14,6 +14,8 @@ path = "src/edit_prediction_context.rs"
[dependencies]
anyhow.workspace = true
arrayvec.workspace = true
chrono.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true

View File

@@ -41,6 +41,20 @@ impl Declaration {
}
}
pub fn parent(&self) -> Option<DeclarationId> {
match self {
Declaration::File { declaration, .. } => declaration.parent,
Declaration::Buffer { declaration, .. } => declaration.parent,
}
}
pub fn as_buffer(&self) -> Option<&BufferDeclaration> {
match self {
Declaration::File { .. } => None,
Declaration::Buffer { declaration, .. } => Some(declaration),
}
}
pub fn project_entry_id(&self) -> ProjectEntryId {
match self {
Declaration::File {
@@ -52,6 +66,13 @@ impl Declaration {
}
}
pub fn item_range(&self) -> Range<usize> {
match self {
Declaration::File { declaration, .. } => declaration.item_range_in_file.clone(),
Declaration::Buffer { declaration, .. } => declaration.item_range.clone(),
}
}
pub fn item_text(&self) -> (Cow<'_, str>, bool) {
match self {
Declaration::File { declaration, .. } => (
@@ -83,6 +104,16 @@ impl Declaration {
),
}
}
pub fn signature_range_in_item_text(&self) -> Range<usize> {
match self {
Declaration::File { declaration, .. } => declaration.signature_range_in_text.clone(),
Declaration::Buffer { declaration, .. } => {
declaration.signature_range.start - declaration.item_range.start
..declaration.signature_range.end - declaration.item_range.start
}
}
}
}
fn expand_range_to_line_boundaries_and_truncate(

View File

@@ -1,10 +1,11 @@
use cloud_llm_client::predict_edits_v3::ScoreComponents;
use itertools::Itertools as _;
use language::BufferSnapshot;
use ordered_float::OrderedFloat;
use serde::Serialize;
use std::{collections::HashMap, ops::Range};
use strum::EnumIter;
use text::{OffsetRangeExt, Point, ToPoint};
use text::{Point, ToPoint};
use crate::{
Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
@@ -15,19 +16,14 @@ use crate::{
const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
// TODO:
//
// * Consider adding declaration_file_count
#[derive(Clone, Debug)]
pub struct ScoredSnippet {
pub identifier: Identifier,
pub declaration: Declaration,
pub score_components: ScoreInputs,
pub score_components: ScoreComponents,
pub scores: Scores,
}
// TODO: Consider having "Concise" style corresponding to `concise_text`
#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub enum SnippetStyle {
Signature,
@@ -90,8 +86,8 @@ pub fn scored_snippets(
let declaration_count = declarations.len();
declarations
.iter()
.filter_map(|declaration| match declaration {
.into_iter()
.filter_map(|(declaration_id, declaration)| match declaration {
Declaration::Buffer {
buffer_id,
declaration: buffer_declaration,
@@ -100,24 +96,29 @@ pub fn scored_snippets(
let is_same_file = buffer_id == &current_buffer.remote_id();
if is_same_file {
range_intersection(
&buffer_declaration.item_range.to_offset(&current_buffer),
&excerpt.range,
)
.is_none()
.then(|| {
let overlaps_excerpt =
range_intersection(&buffer_declaration.item_range, &excerpt.range)
.is_some();
if overlaps_excerpt
|| excerpt
.parent_declarations
.iter()
.any(|(excerpt_parent, _)| excerpt_parent == &declaration_id)
{
None
} else {
let declaration_line = buffer_declaration
.item_range
.start
.to_point(current_buffer)
.row;
(
Some((
true,
(cursor_point.row as i32 - declaration_line as i32)
.unsigned_abs(),
declaration,
)
})
))
}
} else {
Some((false, u32::MAX, declaration))
}
@@ -238,7 +239,8 @@ fn score_snippet(
let adjacent_vs_signature_weighted_overlap =
weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
let score_components = ScoreInputs {
// TODO: Consider adding declaration_file_count
let score_components = ScoreComponents {
is_same_file,
is_referenced_nearby,
is_referenced_in_breadcrumb,
@@ -261,51 +263,30 @@ fn score_snippet(
Some(ScoredSnippet {
identifier: identifier.clone(),
declaration: declaration,
scores: score_components.score(),
scores: Scores::score(&score_components),
score_components,
})
}
#[derive(Clone, Debug, Serialize)]
pub struct ScoreInputs {
pub is_same_file: bool,
pub is_referenced_nearby: bool,
pub is_referenced_in_breadcrumb: bool,
pub reference_count: usize,
pub same_file_declaration_count: usize,
pub declaration_count: usize,
pub reference_line_distance: u32,
pub declaration_line_distance: u32,
pub declaration_line_distance_rank: usize,
pub containing_range_vs_item_jaccard: f32,
pub containing_range_vs_signature_jaccard: f32,
pub adjacent_vs_item_jaccard: f32,
pub adjacent_vs_signature_jaccard: f32,
pub containing_range_vs_item_weighted_overlap: f32,
pub containing_range_vs_signature_weighted_overlap: f32,
pub adjacent_vs_item_weighted_overlap: f32,
pub adjacent_vs_signature_weighted_overlap: f32,
}
#[derive(Clone, Debug, Serialize)]
pub struct Scores {
pub signature: f32,
pub declaration: f32,
}
impl ScoreInputs {
fn score(&self) -> Scores {
impl Scores {
fn score(components: &ScoreComponents) -> Scores {
// Score related to how likely this is the correct declaration, range 0 to 1
let accuracy_score = if self.is_same_file {
let accuracy_score = if components.is_same_file {
// TODO: use declaration_line_distance_rank
1.0 / self.same_file_declaration_count as f32
1.0 / components.same_file_declaration_count as f32
} else {
1.0 / self.declaration_count as f32
1.0 / components.declaration_count as f32
};
// Score related to the distance between the reference and cursor, range 0 to 1
let distance_score = if self.is_referenced_nearby {
1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0)
let distance_score = if components.is_referenced_nearby {
1.0 / (1.0 + components.reference_line_distance as f32 / 10.0).powf(2.0)
} else {
// same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
0.5
@@ -315,10 +296,12 @@ impl ScoreInputs {
let combined_score = 10.0 * accuracy_score * distance_score;
Scores {
signature: combined_score * self.containing_range_vs_signature_weighted_overlap,
signature: combined_score * components.containing_range_vs_signature_weighted_overlap,
// declaration score gets boosted both by being multiplied by 2 and by there being more
// weighted overlap.
declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap,
declaration: 2.0
* combined_score
* components.containing_range_vs_item_weighted_overlap,
}
}
}

View File

@@ -6,62 +6,82 @@ mod reference;
mod syntax_index;
mod text_similarity;
use std::time::Instant;
pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
pub use declaration_scoring::SnippetStyle;
pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
use gpui::{App, AppContext as _, Entity, Task};
use language::BufferSnapshot;
pub use reference::references_in_excerpt;
pub use syntax_index::SyntaxIndex;
use text::{Point, ToOffset as _};
use crate::declaration_scoring::{ScoredSnippet, scored_snippets};
pub use declaration::*;
pub use declaration_scoring::*;
pub use excerpt::*;
pub use reference::*;
pub use syntax_index::*;
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct EditPredictionContext {
pub excerpt: EditPredictionExcerpt,
pub excerpt_text: EditPredictionExcerptText,
pub cursor_offset_in_excerpt: usize,
pub snippets: Vec<ScoredSnippet>,
pub retrieval_duration: std::time::Duration,
}
impl EditPredictionContext {
pub fn gather(
pub fn gather_context_in_background(
cursor_point: Point,
buffer: BufferSnapshot,
excerpt_options: EditPredictionExcerptOptions,
syntax_index: Entity<SyntaxIndex>,
syntax_index: Option<Entity<SyntaxIndex>>,
cx: &mut App,
) -> Task<Option<Self>> {
let start = Instant::now();
let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
cx.background_spawn(async move {
let index_state = index_state.lock().await;
if let Some(syntax_index) = syntax_index {
let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
cx.background_spawn(async move {
let index_state = index_state.lock().await;
Self::gather_context(cursor_point, &buffer, &excerpt_options, Some(&index_state))
})
} else {
cx.background_spawn(async move {
Self::gather_context(cursor_point, &buffer, &excerpt_options, None)
})
}
}
let excerpt =
EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options)?;
let excerpt_text = excerpt.text(&buffer);
let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
let cursor_offset = cursor_point.to_offset(&buffer);
pub fn gather_context(
cursor_point: Point,
buffer: &BufferSnapshot,
excerpt_options: &EditPredictionExcerptOptions,
index_state: Option<&SyntaxIndexState>,
) -> Option<Self> {
let excerpt = EditPredictionExcerpt::select_from_buffer(
cursor_point,
buffer,
excerpt_options,
index_state,
)?;
let excerpt_text = excerpt.text(buffer);
let cursor_offset_in_file = cursor_point.to_offset(buffer);
// todo! fix this to not need saturating_sub
let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
let snippets = scored_snippets(
let snippets = if let Some(index_state) = index_state {
let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
scored_snippets(
&index_state,
&excerpt,
&excerpt_text,
references,
cursor_offset,
&buffer,
);
cursor_offset_in_file,
buffer,
)
} else {
vec![]
};
Some(Self {
excerpt,
excerpt_text,
snippets,
retrieval_duration: start.elapsed(),
})
Some(Self {
excerpt,
excerpt_text,
cursor_offset_in_excerpt,
snippets,
})
}
}
@@ -101,24 +121,28 @@ mod tests {
let context = cx
.update(|cx| {
EditPredictionContext::gather(
EditPredictionContext::gather_context_in_background(
cursor_point,
buffer_snapshot,
EditPredictionExcerptOptions {
max_bytes: 40,
max_bytes: 60,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.5,
include_parent_signatures: false,
},
index,
Some(index),
cx,
)
})
.await
.unwrap();
assert_eq!(context.snippets.len(), 1);
assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data");
let mut snippet_identifiers = context
.snippets
.iter()
.map(|snippet| snippet.identifier.name.as_ref())
.collect::<Vec<_>>();
snippet_identifiers.sort();
assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
drop(buffer);
}

View File

@@ -1,9 +1,11 @@
use language::BufferSnapshot;
use language::{BufferSnapshot, LanguageId};
use std::ops::Range;
use text::{OffsetRangeExt as _, Point, ToOffset as _, ToPoint as _};
use text::{Point, ToOffset as _, ToPoint as _};
use tree_sitter::{Node, TreeCursor};
use util::RangeExt;
use crate::{BufferDeclaration, declaration::DeclarationId, syntax_index::SyntaxIndexState};
// TODO:
//
// - Test parent signatures
@@ -27,14 +29,13 @@ pub struct EditPredictionExcerptOptions {
pub min_bytes: usize,
/// Target ratio of bytes before the cursor divided by total bytes in the window.
pub target_before_cursor_over_total_bytes: f32,
/// Whether to include parent signatures
pub include_parent_signatures: bool,
}
// TODO: consider merging these
#[derive(Debug, Clone)]
pub struct EditPredictionExcerpt {
pub range: Range<usize>,
pub parent_signature_ranges: Vec<Range<usize>>,
pub parent_declarations: Vec<(DeclarationId, Range<usize>)>,
pub size: usize,
}
@@ -42,6 +43,7 @@ pub struct EditPredictionExcerpt {
pub struct EditPredictionExcerptText {
pub body: String,
pub parent_signatures: Vec<String>,
pub language_id: Option<LanguageId>,
}
impl EditPredictionExcerpt {
@@ -50,20 +52,23 @@ impl EditPredictionExcerpt {
.text_for_range(self.range.clone())
.collect::<String>();
let parent_signatures = self
.parent_signature_ranges
.parent_declarations
.iter()
.map(|range| buffer.text_for_range(range.clone()).collect::<String>())
.map(|(_, range)| buffer.text_for_range(range.clone()).collect::<String>())
.collect();
let language_id = buffer.language().map(|l| l.id());
EditPredictionExcerptText {
body,
parent_signatures,
language_id,
}
}
/// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
/// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the
/// cursor. When `include_parent_signatures` is true, the excerpt also includes the signatures
/// of parent outline items.
/// cursor.
///
/// When `index` is provided, the excerpt will include the signatures of parent outline items.
///
/// First tries to use AST node boundaries to select the excerpt, and falls back on line-based
/// expansion.
@@ -73,6 +78,7 @@ impl EditPredictionExcerpt {
query_point: Point,
buffer: &BufferSnapshot,
options: &EditPredictionExcerptOptions,
syntax_index: Option<&SyntaxIndexState>,
) -> Option<Self> {
if buffer.len() <= options.max_bytes {
log::debug!(
@@ -90,17 +96,9 @@ impl EditPredictionExcerpt {
return None;
}
// TODO: Don't compute text / annotation_range / skip converting to and from anchors.
let outline_items = if options.include_parent_signatures {
buffer
.outline_items_containing(query_range.clone(), false, None)
.into_iter()
.flat_map(|item| {
Some(ExcerptOutlineItem {
item_range: item.range.to_offset(&buffer),
signature_range: item.signature_range?.to_offset(&buffer),
})
})
let parent_declarations = if let Some(syntax_index) = syntax_index {
syntax_index
.buffer_declarations_containing_range(buffer.remote_id(), query_range.clone())
.collect()
} else {
Vec::new()
@@ -109,7 +107,7 @@ impl EditPredictionExcerpt {
let excerpt_selector = ExcerptSelector {
query_offset,
query_range,
outline_items: &outline_items,
parent_declarations: &parent_declarations,
buffer,
options,
};
@@ -132,15 +130,15 @@ impl EditPredictionExcerpt {
excerpt_selector.select_lines()
}
fn new(range: Range<usize>, parent_signature_ranges: Vec<Range<usize>>) -> Self {
fn new(range: Range<usize>, parent_declarations: Vec<(DeclarationId, Range<usize>)>) -> Self {
let size = range.len()
+ parent_signature_ranges
+ parent_declarations
.iter()
.map(|r| r.len())
.map(|(_, range)| range.len())
.sum::<usize>();
Self {
range,
parent_signature_ranges,
parent_declarations,
size,
}
}
@@ -150,20 +148,14 @@ impl EditPredictionExcerpt {
// this is an issue because parent_signature_ranges may be incorrect
log::error!("bug: with_expanded_range called with disjoint range");
}
let mut parent_signature_ranges = Vec::with_capacity(self.parent_signature_ranges.len());
let mut size = new_range.len();
for range in &self.parent_signature_ranges {
if range.contains_inclusive(&new_range) {
let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len());
for (declaration_id, range) in &self.parent_declarations {
if !range.contains_inclusive(&new_range) {
break;
}
parent_signature_ranges.push(range.clone());
size += range.len();
}
Self {
range: new_range,
parent_signature_ranges,
size,
parent_declarations.push((*declaration_id, range.clone()));
}
Self::new(new_range, parent_declarations)
}
fn parent_signatures_size(&self) -> usize {
@@ -174,16 +166,11 @@ impl EditPredictionExcerpt {
struct ExcerptSelector<'a> {
query_offset: usize,
query_range: Range<usize>,
outline_items: &'a [ExcerptOutlineItem],
parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)],
buffer: &'a BufferSnapshot,
options: &'a EditPredictionExcerptOptions,
}
struct ExcerptOutlineItem {
item_range: Range<usize>,
signature_range: Range<usize>,
}
impl<'a> ExcerptSelector<'a> {
/// Finds the largest node that is smaller than the window size and contains `query_range`.
fn select_tree_sitter_nodes(&self) -> Option<EditPredictionExcerpt> {
@@ -396,13 +383,13 @@ impl<'a> ExcerptSelector<'a> {
}
fn make_excerpt(&self, range: Range<usize>) -> EditPredictionExcerpt {
let parent_signature_ranges = self
.outline_items
let parent_declarations = self
.parent_declarations
.iter()
.filter(|item| item.item_range.contains_inclusive(&range))
.map(|item| item.signature_range.clone())
.filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range))
.map(|(id, declaration)| (*id, declaration.signature_range.clone()))
.collect();
EditPredictionExcerpt::new(range, parent_signature_ranges)
EditPredictionExcerpt::new(range, parent_declarations)
}
/// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
@@ -493,8 +480,9 @@ mod tests {
let buffer = create_buffer(&text, cx);
let cursor_point = cursor.to_point(&buffer);
let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options)
.expect("Should select an excerpt");
let excerpt =
EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None)
.expect("Should select an excerpt");
pretty_assertions::assert_eq!(
generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
generate_marked_text(&text, &[expected_excerpt], false)
@@ -517,7 +505,6 @@ fn main() {
max_bytes: 20,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.5,
include_parent_signatures: false,
};
check_example(options, text, cx);
@@ -541,7 +528,6 @@ fn bar() {}"#;
max_bytes: 65,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.5,
include_parent_signatures: false,
};
check_example(options, text, cx);
@@ -561,7 +547,6 @@ fn main() {
max_bytes: 50,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.5,
include_parent_signatures: false,
};
check_example(options, text, cx);
@@ -583,7 +568,6 @@ fn main() {
max_bytes: 60,
min_bytes: 45,
target_before_cursor_over_total_bytes: 0.5,
include_parent_signatures: false,
};
check_example(options, text, cx);
@@ -608,7 +592,6 @@ fn main() {
max_bytes: 120,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.6,
include_parent_signatures: false,
};
check_example(options, text, cx);

View File

@@ -33,8 +33,8 @@ pub fn references_in_excerpt(
snapshot,
);
for (range, text) in excerpt
.parent_signature_ranges
for ((_, range), text) in excerpt
.parent_declarations
.iter()
.zip(excerpt_text.parent_signatures.iter())
{

View File

@@ -1,5 +1,3 @@
use std::sync::Arc;
use collections::{HashMap, HashSet};
use futures::lock::Mutex;
use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity};
@@ -8,20 +6,17 @@ use project::buffer_store::{BufferStore, BufferStoreEvent};
use project::worktree_store::{WorktreeStore, WorktreeStoreEvent};
use project::{PathChange, Project, ProjectEntryId, ProjectPath};
use slotmap::SlotMap;
use std::iter;
use std::ops::Range;
use std::sync::Arc;
use text::BufferId;
use util::{debug_panic, some_or_debug_panic};
use util::{RangeExt as _, debug_panic, some_or_debug_panic};
use crate::declaration::{
BufferDeclaration, Declaration, DeclarationId, FileDeclaration, Identifier,
};
use crate::outline::declarations_in_buffer;
// TODO:
//
// * Skip for remote projects
//
// * Consider making SyntaxIndex not an Entity.
// Potential future improvements:
//
// * Send multiple selected excerpt ranges. Challenge is that excerpt ranges influence which
@@ -40,7 +35,6 @@ use crate::outline::declarations_in_buffer;
// * Concurrent slotmap
//
// * Use queue for parsing
//
pub struct SyntaxIndex {
state: Arc<Mutex<SyntaxIndexState>>,
@@ -432,7 +426,7 @@ impl SyntaxIndexState {
pub fn declarations_for_identifier<const N: usize>(
&self,
identifier: &Identifier,
) -> Vec<Declaration> {
) -> Vec<(DeclarationId, &Declaration)> {
// make sure to not have a large stack allocation
assert!(N < 32);
@@ -454,7 +448,7 @@ impl SyntaxIndexState {
project_entry_id, ..
} => {
included_buffer_entry_ids.push(*project_entry_id);
result.push(declaration.clone());
result.push((*declaration_id, declaration));
if result.len() == N {
return Vec::new();
}
@@ -463,19 +457,19 @@ impl SyntaxIndexState {
project_entry_id, ..
} => {
if !included_buffer_entry_ids.contains(&project_entry_id) {
file_declarations.push(declaration.clone());
file_declarations.push((*declaration_id, declaration));
}
}
}
}
for declaration in file_declarations {
for (declaration_id, declaration) in file_declarations {
match declaration {
Declaration::File {
project_entry_id, ..
} => {
if !included_buffer_entry_ids.contains(&project_entry_id) {
result.push(declaration);
result.push((declaration_id, declaration));
if result.len() == N {
return Vec::new();
@@ -489,6 +483,35 @@ impl SyntaxIndexState {
result
}
pub fn buffer_declarations_containing_range(
&self,
buffer_id: BufferId,
range: Range<usize>,
) -> impl Iterator<Item = (DeclarationId, &BufferDeclaration)> {
let Some(buffer_state) = self.buffers.get(&buffer_id) else {
return itertools::Either::Left(iter::empty());
};
let iter = buffer_state
.declarations
.iter()
.filter_map(move |declaration_id| {
let Some(declaration) = self
.declarations
.get(*declaration_id)
.and_then(|d| d.as_buffer())
else {
log::error!("bug: missing buffer outline declaration");
return None;
};
if declaration.item_range.contains_inclusive(&range) {
return Some((*declaration_id, declaration));
}
return None;
});
itertools::Either::Right(iter)
}
pub fn file_declaration_count(&self, declaration: &Declaration) -> usize {
match declaration {
Declaration::File {
@@ -553,11 +576,11 @@ mod tests {
let decls = index_state.declarations_for_identifier::<8>(&main);
assert_eq!(decls.len(), 2);
let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx);
assert_eq!(decl.identifier, main.clone());
assert_eq!(decl.item_range_in_file, 32..280);
let decl = expect_file_decl("a.rs", &decls[1], &project, cx);
let decl = expect_file_decl("a.rs", &decls[1].1, &project, cx);
assert_eq!(decl.identifier, main);
assert_eq!(decl.item_range_in_file, 0..98);
});
@@ -577,7 +600,7 @@ mod tests {
let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
assert_eq!(decls.len(), 1);
let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx);
assert_eq!(decl.identifier, test_process_data);
let parent_id = decl.parent.unwrap();
@@ -618,7 +641,7 @@ mod tests {
let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
assert_eq!(decls.len(), 1);
let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx);
assert_eq!(decl.identifier, test_process_data);
let parent_id = decl.parent.unwrap();
@@ -676,11 +699,11 @@ mod tests {
cx.update(|cx| {
let decls = index_state.declarations_for_identifier::<8>(&main);
assert_eq!(decls.len(), 2);
let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx);
assert_eq!(decl.identifier, main);
assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..280);
expect_file_decl("a.rs", &decls[1], &project, cx);
expect_file_decl("a.rs", &decls[1].1, &project, cx);
});
}
@@ -695,8 +718,8 @@ mod tests {
cx.update(|cx| {
let decls = index_state.declarations_for_identifier::<8>(&main);
assert_eq!(decls.len(), 2);
expect_file_decl("c.rs", &decls[0], &project, cx);
expect_file_decl("a.rs", &decls[1], &project, cx);
expect_file_decl("c.rs", &decls[0].1, &project, cx);
expect_file_decl("a.rs", &decls[1].1, &project, cx);
});
}

View File

@@ -9,8 +9,12 @@ use crate::reference::Reference;
// That implementation could actually be more efficient - no need to track words in the window that
// are not in the query.
// TODO: Consider a flat sorted Vec<(String, usize)> representation. Intersection can just walk the
// two in parallel.
static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
// TODO: use &str or Cow<str> keys?
#[derive(Debug)]
pub struct IdentifierOccurrences {
identifier_to_count: HashMap<String, usize>,

View File

@@ -1,457 +0,0 @@
use std::{
collections::hash_map::Entry,
ffi::OsStr,
path::{Path, PathBuf},
str::FromStr,
sync::Arc,
time::Duration,
};
use collections::HashMap;
use editor::{Editor, EditorEvent, EditorMode, ExcerptRange, MultiBuffer};
use gpui::{
Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, actions,
prelude::*,
};
use language::{Buffer, DiskState};
use project::{Project, WorktreeId};
use text::ToPoint;
use ui::prelude::*;
use ui_input::SingleLineInput;
use workspace::{Item, SplitDirection, Workspace};
use edit_prediction_context::{
EditPredictionContext, EditPredictionExcerptOptions, SnippetStyle, SyntaxIndex,
};
actions!(
dev,
[
/// Opens the language server protocol logs viewer.
OpenEditPredictionContext
]
);
pub fn init(cx: &mut App) {
cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
workspace.register_action(
move |workspace, _: &OpenEditPredictionContext, window, cx| {
let workspace_entity = cx.entity();
let project = workspace.project();
let active_editor = workspace.active_item_as::<Editor>(cx);
workspace.split_item(
SplitDirection::Right,
Box::new(cx.new(|cx| {
EditPredictionTools::new(
&workspace_entity,
&project,
active_editor,
window,
cx,
)
})),
window,
cx,
);
},
);
})
.detach();
}
pub struct EditPredictionTools {
focus_handle: FocusHandle,
project: Entity<Project>,
last_context: Option<ContextState>,
max_bytes_input: Entity<SingleLineInput>,
min_bytes_input: Entity<SingleLineInput>,
cursor_context_ratio_input: Entity<SingleLineInput>,
// TODO move to project or provider?
syntax_index: Entity<SyntaxIndex>,
last_editor: WeakEntity<Editor>,
_active_editor_subscription: Option<Subscription>,
_edit_prediction_context_task: Task<()>,
}
struct ContextState {
context_editor: Entity<Editor>,
retrieval_duration: Duration,
}
impl EditPredictionTools {
pub fn new(
workspace: &Entity<Workspace>,
project: &Entity<Project>,
active_editor: Option<Entity<Editor>>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
cx.subscribe_in(workspace, window, |this, workspace, event, window, cx| {
if let workspace::Event::ActiveItemChanged = event {
if let Some(editor) = workspace.read(cx).active_item_as::<Editor>(cx) {
this._active_editor_subscription = Some(cx.subscribe_in(
&editor,
window,
|this, editor, event, window, cx| {
if let EditorEvent::SelectionsChanged { .. } = event {
this.update_context(editor, window, cx);
}
},
));
this.update_context(&editor, window, cx);
} else {
this._active_editor_subscription = None;
}
}
})
.detach();
let syntax_index = cx.new(|cx| SyntaxIndex::new(project, cx));
let number_input = |label: &'static str,
value: &'static str,
window: &mut Window,
cx: &mut Context<Self>|
-> Entity<SingleLineInput> {
let input = cx.new(|cx| {
let input = SingleLineInput::new(window, cx, "")
.label(label)
.label_min_width(px(64.));
input.set_text(value, window, cx);
input
});
cx.subscribe_in(
&input.read(cx).editor().clone(),
window,
|this, _, event, window, cx| {
if let EditorEvent::BufferEdited = event
&& let Some(editor) = this.last_editor.upgrade()
{
this.update_context(&editor, window, cx);
}
},
)
.detach();
input
};
let mut this = Self {
focus_handle: cx.focus_handle(),
project: project.clone(),
last_context: None,
max_bytes_input: number_input("Max Bytes", "512", window, cx),
min_bytes_input: number_input("Min Bytes", "128", window, cx),
cursor_context_ratio_input: number_input("Cursor Context Ratio", "0.5", window, cx),
syntax_index,
last_editor: WeakEntity::new_invalid(),
_active_editor_subscription: None,
_edit_prediction_context_task: Task::ready(()),
};
if let Some(editor) = active_editor {
this.update_context(&editor, window, cx);
}
this
}
fn update_context(
&mut self,
editor: &Entity<Editor>,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.last_editor = editor.downgrade();
let editor = editor.read(cx);
let buffer = editor.buffer().clone();
let cursor_position = editor.selections.newest_anchor().start;
let Some(buffer) = buffer.read(cx).buffer_for_anchor(cursor_position, cx) else {
self.last_context.take();
return;
};
let current_buffer_snapshot = buffer.read(cx).snapshot();
let cursor_position = cursor_position
.text_anchor
.to_point(&current_buffer_snapshot);
let language = current_buffer_snapshot.language().cloned();
let Some(worktree_id) = self
.project
.read(cx)
.worktrees(cx)
.next()
.map(|worktree| worktree.read(cx).id())
else {
log::error!("Open a worktree to use edit prediction debug view");
self.last_context.take();
return;
};
self._edit_prediction_context_task = cx.spawn_in(window, {
let language_registry = self.project.read(cx).languages().clone();
async move |this, cx| {
cx.background_executor()
.timer(Duration::from_millis(50))
.await;
let Ok(task) = this.update(cx, |this, cx| {
fn number_input_value<T: FromStr + Default>(
input: &Entity<SingleLineInput>,
cx: &App,
) -> T {
input
.read(cx)
.editor()
.read(cx)
.text(cx)
.parse::<T>()
.unwrap_or_default()
}
let options = EditPredictionExcerptOptions {
max_bytes: number_input_value(&this.max_bytes_input, cx),
min_bytes: number_input_value(&this.min_bytes_input, cx),
target_before_cursor_over_total_bytes: number_input_value(
&this.cursor_context_ratio_input,
cx,
),
// TODO Display and add to options
include_parent_signatures: false,
};
EditPredictionContext::gather(
cursor_position,
current_buffer_snapshot,
options,
this.syntax_index.clone(),
cx,
)
}) else {
this.update(cx, |this, _cx| {
this.last_context.take();
})
.ok();
return;
};
let Some(context) = task.await else {
// TODO: Display message
this.update(cx, |this, _cx| {
this.last_context.take();
})
.ok();
return;
};
let mut languages = HashMap::default();
for snippet in context.snippets.iter() {
let lang_id = snippet.declaration.identifier().language_id;
if let Entry::Vacant(entry) = languages.entry(lang_id) {
// Most snippets are gonna be the same language,
// so we think it's fine to do this sequentially for now
entry.insert(language_registry.language_for_id(lang_id).await.ok());
}
}
this.update_in(cx, |this, window, cx| {
let context_editor = cx.new(|cx| {
let multibuffer = cx.new(|cx| {
let mut multibuffer = MultiBuffer::new(language::Capability::ReadOnly);
let excerpt_file = Arc::new(ExcerptMetadataFile {
title: PathBuf::from("Cursor Excerpt").into(),
worktree_id,
});
let excerpt_buffer = cx.new(|cx| {
let mut buffer = Buffer::local(context.excerpt_text.body, cx);
buffer.set_language(language, cx);
buffer.file_updated(excerpt_file, cx);
buffer
});
multibuffer.push_excerpts(
excerpt_buffer,
[ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
cx,
);
for snippet in context.snippets {
let path = this
.project
.read(cx)
.path_for_entry(snippet.declaration.project_entry_id(), cx);
let snippet_file = Arc::new(ExcerptMetadataFile {
title: PathBuf::from(format!(
"{} (Score density: {})",
path.map(|p| p.path.to_string_lossy().to_string())
.unwrap_or_else(|| "".to_string()),
snippet.score_density(SnippetStyle::Declaration)
))
.into(),
worktree_id,
});
let excerpt_buffer = cx.new(|cx| {
let mut buffer =
Buffer::local(snippet.declaration.item_text().0, cx);
buffer.file_updated(snippet_file, cx);
if let Some(language) =
languages.get(&snippet.declaration.identifier().language_id)
{
buffer.set_language(language.clone(), cx);
}
buffer
});
multibuffer.push_excerpts(
excerpt_buffer,
[ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
cx,
);
}
multibuffer
});
Editor::new(EditorMode::full(), multibuffer, None, window, cx)
});
this.last_context = Some(ContextState {
context_editor,
retrieval_duration: context.retrieval_duration,
});
cx.notify();
})
.ok();
}
});
}
}
impl Focusable for EditPredictionTools {
fn focus_handle(&self, _cx: &App) -> FocusHandle {
self.focus_handle.clone()
}
}
impl Item for EditPredictionTools {
type Event = ();
fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
"Edit Prediction Context Debug View".into()
}
fn tab_icon(&self, _window: &Window, _cx: &App) -> Option<Icon> {
Some(Icon::new(IconName::ZedPredict))
}
}
impl EventEmitter<()> for EditPredictionTools {}
impl Render for EditPredictionTools {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
.size_full()
.bg(cx.theme().colors().editor_background)
.child(
h_flex()
.items_start()
.w_full()
.child(
v_flex()
.flex_1()
.p_4()
.gap_2()
.child(Headline::new("Excerpt Options").size(HeadlineSize::Small))
.child(
h_flex()
.gap_2()
.child(self.max_bytes_input.clone())
.child(self.min_bytes_input.clone())
.child(self.cursor_context_ratio_input.clone()),
),
)
.child(ui::Divider::vertical())
.when_some(self.last_context.as_ref(), |this, last_context| {
this.child(
v_flex()
.p_4()
.gap_2()
.min_w(px(160.))
.child(Headline::new("Stats").size(HeadlineSize::Small))
.child(
h_flex()
.gap_1()
.child(
Label::new("Time to retrieve")
.color(Color::Muted)
.size(LabelSize::Small),
)
.child(
Label::new(
if last_context.retrieval_duration.as_micros()
> 1000
{
format!(
"{} ms",
last_context.retrieval_duration.as_millis()
)
} else {
format!(
"{} µs",
last_context.retrieval_duration.as_micros()
)
},
)
.size(LabelSize::Small),
),
),
)
}),
)
.children(self.last_context.as_ref().map(|c| c.context_editor.clone()))
}
}
// Using same approach as commit view
struct ExcerptMetadataFile {
title: Arc<Path>,
worktree_id: WorktreeId,
}
impl language::File for ExcerptMetadataFile {
fn as_local(&self) -> Option<&dyn language::LocalFile> {
None
}
fn disk_state(&self) -> DiskState {
DiskState::New
}
fn path(&self) -> &Arc<Path> {
&self.title
}
fn full_path(&self, _: &App) -> PathBuf {
self.title.as_ref().into()
}
fn file_name<'a>(&'a self, _: &'a App) -> &'a OsStr {
self.title.file_name().unwrap()
}
fn worktree_id(&self, _: &App) -> WorktreeId {
self.worktree_id
}
fn to_proto(&self, _: &App) -> language::proto::File {
unimplemented!()
}
fn is_private(&self) -> bool {
false
}
}

View File

@@ -54,6 +54,17 @@ pub enum EditPredictionProvider {
Zed,
}
impl EditPredictionProvider {
pub fn is_zed(&self) -> bool {
match self {
EditPredictionProvider::Zed => true,
EditPredictionProvider::None
| EditPredictionProvider::Copilot
| EditPredictionProvider::Supermaven => false,
}
}
}
/// The contents of the edit prediction settings.
#[skip_serializing_none]
#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq)]

View File

@@ -52,7 +52,7 @@ debugger_tools.workspace = true
debugger_ui.workspace = true
diagnostics.workspace = true
editor.workspace = true
edit_prediction_tools.workspace = true
zeta2_tools.workspace = true
env_logger.workspace = true
extension.workspace = true
extension_host.workspace = true
@@ -163,6 +163,7 @@ workspace.workspace = true
zed_actions.workspace = true
zed_env_vars.workspace = true
zeta.workspace = true
zeta2.workspace = true
zlog.workspace = true
zlog_settings.workspace = true

View File

@@ -549,7 +549,7 @@ pub fn main() {
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
agent_settings::init(cx);
acp_tools::init(cx);
edit_prediction_tools::init(cx);
zeta2_tools::init(cx);
web_search::init(cx);
web_search_providers::init(app_state.client.clone(), cx);
snippet_provider::init(cx);

View File

@@ -203,21 +203,43 @@ fn assign_edit_prediction_provider(
}
}
let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);
if let Some(buffer) = &singleton_buffer
&& buffer.read(cx).file().is_some()
&& let Some(project) = editor.project()
{
zeta.update(cx, |zeta, cx| {
zeta.register_buffer(buffer, project, cx);
if std::env::var("ZED_ZETA2").is_ok() {
let zeta = zeta2::Zeta::global(client, &user_store, cx);
let provider = cx.new(|cx| {
zeta2::ZetaEditPredictionProvider::new(
editor.project(),
&client,
&user_store,
cx,
)
});
if let Some(buffer) = &singleton_buffer
&& buffer.read(cx).file().is_some()
&& let Some(project) = editor.project()
{
zeta.update(cx, |zeta, cx| {
zeta.register_buffer(buffer, project, cx);
});
}
editor.set_edit_prediction_provider(Some(provider), window, cx);
} else {
let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);
if let Some(buffer) = &singleton_buffer
&& buffer.read(cx).file().is_some()
&& let Some(project) = editor.project()
{
zeta.update(cx, |zeta, cx| {
zeta.register_buffer(buffer, project, cx);
});
}
let provider =
cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, singleton_buffer));
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
let provider =
cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, singleton_buffer));
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
}
}

38
crates/zeta2/Cargo.toml Normal file
View File

@@ -0,0 +1,38 @@
[package]
name = "zeta2"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/zeta2.rs"
[dependencies]
anyhow.workspace = true
arrayvec.workspace = true
chrono.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
edit_prediction.workspace = true
edit_prediction_context.workspace = true
futures.workspace = true
gpui.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
project.workspace = true
release_channel.workspace = true
serde_json.workspace = true
thiserror.workspace = true
util.workspace = true
uuid.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
worktree.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }

1213
crates/zeta2/src/zeta2.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
[package]
name = "edit_prediction_tools"
name = "zeta2_tools"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
@@ -9,15 +9,19 @@ license = "GPL-3.0-or-later"
workspace = true
[lib]
path = "src/edit_prediction_tools.rs"
path = "src/zeta2_tools.rs"
[dependencies]
edit_prediction_context.workspace = true
chrono.workspace = true
client.workspace = true
collections.workspace = true
edit_prediction_context.workspace = true
editor.workspace = true
futures.workspace = true
gpui.workspace = true
language.workspace = true
log.workspace = true
markdown.workspace = true
project.workspace = true
serde.workspace = true
text.workspace = true
@@ -25,17 +29,17 @@ ui.workspace = true
ui_input.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
zeta2.workspace = true
[dev-dependencies]
clap.workspace = true
futures.workspace = true
gpui = { workspace = true, features = ["test-support"] }
indoc.workspace = true
language = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true
project = {workspace= true, features = ["test-support"]}
project = { workspace = true, features = ["test-support"] }
serde_json.workspace = true
settings = {workspace= true, features = ["test-support"]}
settings = { workspace = true, features = ["test-support"] }
text = { workspace = true, features = ["test-support"] }
util = { workspace = true, features = ["test-support"] }
zlog.workspace = true

View File

@@ -0,0 +1 @@
../../LICENSE-GPL

View File

@@ -0,0 +1,605 @@
use std::{
collections::hash_map::Entry,
ffi::OsStr,
path::{Path, PathBuf},
sync::Arc,
};
use chrono::TimeDelta;
use client::{Client, UserStore};
use collections::HashMap;
use editor::{Editor, EditorMode, ExcerptRange, MultiBuffer};
use futures::StreamExt as _;
use gpui::{
BorderStyle, EdgesRefinement, Entity, EventEmitter, FocusHandle, Focusable, Length,
StyleRefinement, Subscription, Task, TextStyleRefinement, UnderlineStyle, actions, prelude::*,
};
use language::{Buffer, DiskState};
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
use project::{Project, WorktreeId};
use ui::prelude::*;
use ui_input::SingleLineInput;
use workspace::{Item, SplitDirection, Workspace};
use zeta2::Zeta;
use edit_prediction_context::SnippetStyle;
actions!(
dev,
[
/// Opens the language server protocol logs viewer.
OpenZeta2Inspector
]
);
pub fn init(cx: &mut App) {
cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
workspace.register_action(move |workspace, _: &OpenZeta2Inspector, window, cx| {
let project = workspace.project();
workspace.split_item(
SplitDirection::Right,
Box::new(cx.new(|cx| {
EditPredictionTools::new(
&project,
workspace.client(),
workspace.user_store(),
window,
cx,
)
})),
window,
cx,
);
});
})
.detach();
}
pub struct EditPredictionTools {
focus_handle: FocusHandle,
project: Entity<Project>,
last_prediction: Option<Result<LastPredictionState, SharedString>>,
max_bytes_input: Entity<SingleLineInput>,
min_bytes_input: Entity<SingleLineInput>,
cursor_context_ratio_input: Entity<SingleLineInput>,
active_view: ActiveView,
_active_editor_subscription: Option<Subscription>,
_update_state_task: Task<()>,
_receive_task: Task<()>,
}
#[derive(PartialEq)]
enum ActiveView {
Context,
Inference,
}
struct LastPredictionState {
context_editor: Entity<Editor>,
retrieval_time: TimeDelta,
prompt_planning_time: TimeDelta,
inference_time: TimeDelta,
parsing_time: TimeDelta,
prompt_md: Entity<Markdown>,
model_response_md: Entity<Markdown>,
}
impl EditPredictionTools {
pub fn new(
project: &Entity<Project>,
client: &Arc<Client>,
user_store: &Entity<UserStore>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let number_input = |label: &'static str,
value: &'static str,
window: &mut Window,
cx: &mut Context<Self>|
-> Entity<SingleLineInput> {
let input = cx.new(|cx| {
let input = SingleLineInput::new(window, cx, "")
.label(label)
.label_min_width(px(64.));
input.set_text(value, window, cx);
input
});
// todo!
// cx.subscribe_in(
// &input.read(cx).editor().clone(),
// window,
// |this, _, event, window, cx| {
// if let EditorEvent::BufferEdited = event
// && let Some(editor) = this.last_editor.upgrade()
// {
// this.update_context(&editor, window, cx);
// }
// },
// )
// .detach();
input
};
let zeta = Zeta::global(client, user_store, cx);
let mut request_rx = zeta.update(cx, |zeta, _cx| zeta.debug_info());
let receive_task = cx.spawn_in(window, async move |this, cx| {
while let Some(prediction_result) = request_rx.next().await {
this.update_in(cx, |this, window, cx| match prediction_result {
Ok(prediction) => {
this.update_last_prediction(prediction, window, cx);
}
Err(err) => {
this.last_prediction = Some(Err(err.into()));
cx.notify();
}
})
.ok();
}
});
Self {
focus_handle: cx.focus_handle(),
project: project.clone(),
last_prediction: None,
active_view: ActiveView::Context,
max_bytes_input: number_input("Max Bytes", "512", window, cx),
min_bytes_input: number_input("Min Bytes", "128", window, cx),
cursor_context_ratio_input: number_input("Cursor Context Ratio", "0.5", window, cx),
_active_editor_subscription: None,
_update_state_task: Task::ready(()),
_receive_task: receive_task,
}
}
fn update_last_prediction(
&mut self,
prediction: zeta2::PredictionDebugInfo,
window: &mut Window,
cx: &mut Context<Self>,
) {
let Some(worktree_id) = self
.project
.read(cx)
.worktrees(cx)
.next()
.map(|worktree| worktree.read(cx).id())
else {
log::error!("Open a worktree to use edit prediction debug view");
self.last_prediction.take();
return;
};
self._update_state_task = cx.spawn_in(window, {
let language_registry = self.project.read(cx).languages().clone();
async move |this, cx| {
// fn number_input_value<T: FromStr + Default>(
// input: &Entity<SingleLineInput>,
// cx: &App,
// ) -> T {
// input
// .read(cx)
// .editor()
// .read(cx)
// .text(cx)
// .parse::<T>()
// .unwrap_or_default()
// }
// let options = EditPredictionExcerptOptions {
// max_bytes: number_input_value(&this.max_bytes_input, cx),
// min_bytes: number_input_value(&this.min_bytes_input, cx),
// target_before_cursor_over_total_bytes: number_input_value(
// &this.cursor_context_ratio_input,
// cx,
// ),
// };
let mut languages = HashMap::default();
for lang_id in prediction
.context
.snippets
.iter()
.map(|snippet| snippet.declaration.identifier().language_id)
.chain(prediction.context.excerpt_text.language_id)
{
if let Entry::Vacant(entry) = languages.entry(lang_id) {
// Most snippets are gonna be the same language,
// so we think it's fine to do this sequentially for now
entry.insert(language_registry.language_for_id(lang_id).await.ok());
}
}
this.update_in(cx, |this, window, cx| {
let context_editor = cx.new(|cx| {
let multibuffer = cx.new(|cx| {
let mut multibuffer = MultiBuffer::new(language::Capability::ReadOnly);
let excerpt_file = Arc::new(ExcerptMetadataFile {
title: PathBuf::from("Cursor Excerpt").into(),
worktree_id,
});
let excerpt_buffer = cx.new(|cx| {
let mut buffer =
Buffer::local(prediction.context.excerpt_text.body, cx);
if let Some(language) = prediction
.context
.excerpt_text
.language_id
.as_ref()
.and_then(|id| languages.get(id))
{
buffer.set_language(language.clone(), cx);
}
buffer.file_updated(excerpt_file, cx);
buffer
});
multibuffer.push_excerpts(
excerpt_buffer,
[ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
cx,
);
for snippet in &prediction.context.snippets {
let path = this
.project
.read(cx)
.path_for_entry(snippet.declaration.project_entry_id(), cx);
let snippet_file = Arc::new(ExcerptMetadataFile {
title: PathBuf::from(format!(
"{} (Score density: {})",
path.map(|p| p.path.to_string_lossy().to_string())
.unwrap_or_else(|| "".to_string()),
snippet.score_density(SnippetStyle::Declaration)
))
.into(),
worktree_id,
});
let excerpt_buffer = cx.new(|cx| {
let mut buffer =
Buffer::local(snippet.declaration.item_text().0, cx);
buffer.file_updated(snippet_file, cx);
if let Some(language) =
languages.get(&snippet.declaration.identifier().language_id)
{
buffer.set_language(language.clone(), cx);
}
buffer
});
multibuffer.push_excerpts(
excerpt_buffer,
[ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
cx,
);
}
multibuffer
});
Editor::new(EditorMode::full(), multibuffer, None, window, cx)
});
this.last_prediction = Some(Ok(LastPredictionState {
context_editor,
prompt_md: cx.new(|cx| {
Markdown::new(prediction.request.prompt.into(), None, None, cx)
}),
model_response_md: cx.new(|cx| {
Markdown::new(prediction.request.model_response.into(), None, None, cx)
}),
retrieval_time: prediction.retrieval_time,
prompt_planning_time: prediction.request.prompt_planning_time,
inference_time: prediction.request.inference_time,
parsing_time: prediction.request.parsing_time,
}));
cx.notify();
})
.ok();
}
});
}
fn render_duration(name: &'static str, time: chrono::TimeDelta) -> Div {
h_flex()
.gap_1()
.child(Label::new(name).color(Color::Muted).size(LabelSize::Small))
.child(
Label::new(if time.num_microseconds().unwrap_or(0) > 1000 {
format!("{} ms", time.num_milliseconds())
} else {
format!("{} µs", time.num_microseconds().unwrap_or(0))
})
.size(LabelSize::Small),
)
}
}
impl Focusable for EditPredictionTools {
fn focus_handle(&self, _cx: &App) -> FocusHandle {
self.focus_handle.clone()
}
}
impl Item for EditPredictionTools {
type Event = ();
fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
"Zeta2 Inspector".into()
}
}
impl EventEmitter<()> for EditPredictionTools {}
impl Render for EditPredictionTools {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
.size_full()
.bg(cx.theme().colors().editor_background)
.child(
h_flex()
.items_start()
.w_full()
.child(
v_flex()
.flex_1()
.p_4()
.gap_2()
.child(Headline::new("Excerpt Options").size(HeadlineSize::Small))
.child(
h_flex()
.gap_2()
.child(self.max_bytes_input.clone())
.child(self.min_bytes_input.clone())
.child(self.cursor_context_ratio_input.clone()),
)
.child(div().flex_1())
.when(
self.last_prediction.as_ref().is_some_and(|r| r.is_ok()),
|this| {
this.child(
ui::ToggleButtonGroup::single_row(
"prediction",
[
ui::ToggleButtonSimple::new(
"Context",
cx.listener(|this, _, _, cx| {
this.active_view = ActiveView::Context;
cx.notify();
}),
)
.selected(self.active_view == ActiveView::Context),
ui::ToggleButtonSimple::new(
"Inference",
cx.listener(|this, _, _, cx| {
this.active_view = ActiveView::Inference;
cx.notify();
}),
)
.selected(self.active_view == ActiveView::Context),
],
)
.style(ui::ToggleButtonGroupStyle::Outlined),
)
},
),
)
.child(ui::vertical_divider())
.when_some(
self.last_prediction.as_ref().and_then(|r| r.as_ref().ok()),
|this, last_prediction| {
this.child(
v_flex()
.p_4()
.gap_2()
.min_w(px(160.))
.child(Headline::new("Stats").size(HeadlineSize::Small))
.child(Self::render_duration(
"Context retrieval",
last_prediction.retrieval_time,
))
.child(Self::render_duration(
"Prompt planning",
last_prediction.prompt_planning_time,
))
.child(Self::render_duration(
"Inference",
last_prediction.inference_time,
))
.child(Self::render_duration(
"Parsing",
last_prediction.parsing_time,
)),
)
},
),
)
.children(self.last_prediction.as_ref().map(|result| {
match result {
Ok(state) => match &self.active_view {
ActiveView::Context => state.context_editor.clone().into_any_element(),
ActiveView::Inference => h_flex()
.items_start()
.w_full()
.gap_2()
.bg(cx.theme().colors().editor_background)
// todo! fix layout
.child(
v_flex()
.flex_1()
.p_4()
.gap_2()
.child(
ui::Headline::new("Prompt").size(ui::HeadlineSize::Small),
)
.child(MarkdownElement::new(
state.prompt_md.clone(),
markdown_style(window, cx),
)),
)
.child(ui::vertical_divider())
.child(
v_flex()
.flex_1()
.p_4()
.gap_2()
.child(
ui::Headline::new("Model Response")
.size(ui::HeadlineSize::Small),
)
.child(MarkdownElement::new(
state.model_response_md.clone(),
markdown_style(window, cx),
)),
)
.into_any(),
},
Err(err) => v_flex()
.p_4()
.gap_2()
.child(Label::new(err.clone()).buffer_font(cx))
.into_any(),
}
}))
}
}
// Mostly copied from agent-ui
fn markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
let colors = cx.theme().colors();
let buffer_font_size = TextSize::Small.rems(cx);
let mut text_style = window.text_style();
let line_height = buffer_font_size * 1.75;
let font_size = TextSize::Small.rems(cx);
let text_color = colors.text;
text_style.refine(&TextStyleRefinement {
font_size: Some(font_size.into()),
line_height: Some(line_height.into()),
color: Some(text_color),
..Default::default()
});
MarkdownStyle {
base_text_style: text_style.clone(),
syntax: cx.theme().syntax().clone(),
selection_background_color: colors.element_selection_background,
code_block_overflow_x_scroll: true,
table_overflow_x_scroll: true,
heading_level_styles: Some(HeadingLevelStyles {
h1: Some(TextStyleRefinement {
font_size: Some(rems(1.15).into()),
..Default::default()
}),
h2: Some(TextStyleRefinement {
font_size: Some(rems(1.1).into()),
..Default::default()
}),
h3: Some(TextStyleRefinement {
font_size: Some(rems(1.05).into()),
..Default::default()
}),
h4: Some(TextStyleRefinement {
font_size: Some(rems(1.).into()),
..Default::default()
}),
h5: Some(TextStyleRefinement {
font_size: Some(rems(0.95).into()),
..Default::default()
}),
h6: Some(TextStyleRefinement {
font_size: Some(rems(0.875).into()),
..Default::default()
}),
}),
code_block: StyleRefinement {
padding: EdgesRefinement {
top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
},
margin: EdgesRefinement {
top: Some(Length::Definite(Pixels(8.).into())),
left: Some(Length::Definite(Pixels(0.).into())),
right: Some(Length::Definite(Pixels(0.).into())),
bottom: Some(Length::Definite(Pixels(12.).into())),
},
border_style: Some(BorderStyle::Solid),
border_widths: EdgesRefinement {
top: Some(AbsoluteLength::Pixels(Pixels(1.))),
left: Some(AbsoluteLength::Pixels(Pixels(1.))),
right: Some(AbsoluteLength::Pixels(Pixels(1.))),
bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
},
border_color: Some(colors.border_variant),
background: Some(colors.editor_background.into()),
text: Some(TextStyleRefinement {
font_size: Some(buffer_font_size.into()),
..Default::default()
}),
..Default::default()
},
inline_code: TextStyleRefinement {
font_size: Some(buffer_font_size.into()),
background_color: Some(colors.editor_foreground.opacity(0.08)),
..Default::default()
},
link: TextStyleRefinement {
background_color: Some(colors.editor_foreground.opacity(0.025)),
underline: Some(UnderlineStyle {
color: Some(colors.text_accent.opacity(0.5)),
thickness: px(1.),
..Default::default()
}),
..Default::default()
},
..Default::default()
}
}
// Using same approach as commit view
struct ExcerptMetadataFile {
title: Arc<Path>,
worktree_id: WorktreeId,
}
impl language::File for ExcerptMetadataFile {
fn as_local(&self) -> Option<&dyn language::LocalFile> {
None
}
fn disk_state(&self) -> DiskState {
DiskState::New
}
fn path(&self) -> &Arc<Path> {
&self.title
}
fn full_path(&self, _: &App) -> PathBuf {
self.title.as_ref().into()
}
fn file_name<'a>(&'a self, _: &'a App) -> &'a OsStr {
self.title.file_name().unwrap()
}
fn worktree_id(&self, _: &App) -> WorktreeId {
self.worktree_id
}
fn to_proto(&self, _: &App) -> language::proto::File {
unimplemented!()
}
fn is_private(&self) -> bool {
false
}
}