Compare commits

...

13 Commits

Author SHA1 Message Date
Michael Sloan
4a4ee4fed7 Remove cli example 2025-09-17 18:13:51 -06:00
Michael Sloan
ea4bf46a36 Return 0 results when declaration count limit exceeded
Co-authored-by: Agus <agus@zed.dev>
2025-09-17 18:10:32 -06:00
Michael Sloan
05545abab6 Checkpoint
Co-authored-by: Agus <agus@zed.dev>
2025-09-17 18:10:32 -06:00
Michael Sloan
a85608566d Checkpoint
Co-authored-by: Agus Zubiaga <hi@aguz.me>
2025-09-17 17:07:30 -06:00
Michael Sloan
69af5261ea Renames + fixes
Co-authored-by: Agus <agus@zed.dev>
2025-09-17 15:02:57 -06:00
Michael Sloan
b9e2f61a38 Expand declaration ranges to line boundaries and truncate, store text for file declarations
Co-authored-by: Agus <agus@zed.dev>
2025-09-17 15:00:29 -06:00
Michael Sloan
38bbb497dd Rename definition->declaration
Co-authored-by: Agus Zubiaga <hi@aguz.me>
2025-09-17 12:57:49 -06:00
Agus Zubiaga
0cc7b4a93c Simple call site snippet test 2025-09-17 12:47:36 -03:00
Agus
cc32bfdfdf Checkpoint: Get score_snippets to compile
Co-Authored-By: Finn <finn@zed.dev>
2025-09-17 11:48:22 -03:00
Michael Sloan
50de8ddc28 Progress on porting scored_declaration.rs 2025-09-17 02:11:58 -06:00
Michael Sloan
f770011d7f Add WIP zeta2 request types 2025-09-17 01:56:48 -06:00
Michael Sloan
f2a6b57909 Copy in experimental cli / declaration scoring code
Co-authored-by: Oleksiy <oleksiy@zed.dev>
2025-09-17 01:55:06 -06:00
Michael Sloan
96b67ac70e Add text similarity metrics 2025-09-17 01:38:02 -06:00
11 changed files with 1354 additions and 320 deletions

6
Cargo.lock generated
View File

@@ -5140,17 +5140,23 @@ version = "0.1.0"
dependencies = [
"anyhow",
"arrayvec",
"clap",
"collections",
"futures 0.3.31",
"gpui",
"indoc",
"itertools 0.14.0",
"language",
"log",
"ordered-float 2.10.1",
"pretty_assertions",
"project",
"regex",
"serde",
"serde_json",
"settings",
"slotmap",
"strum 0.27.1",
"text",
"tree-sitter",
"util",

View File

@@ -15,17 +15,24 @@ path = "src/edit_prediction_context.rs"
anyhow.workspace = true
arrayvec.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
itertools.workspace = true
language.workspace = true
log.workspace = true
ordered-float.workspace = true
project.workspace = true
regex.workspace = true
serde.workspace = true
slotmap.workspace = true
strum.workspace = true
text.workspace = true
tree-sitter.workspace = true
util.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
clap.workspace = true
futures.workspace = true
gpui = { workspace = true, features = ["test-support"] }
indoc.workspace = true

View File

@@ -0,0 +1,193 @@
use language::LanguageId;
use project::ProjectEntryId;
use std::borrow::Cow;
use std::ops::Range;
use std::sync::Arc;
use text::{Bias, BufferId, Rope};
use crate::outline::OutlineDeclaration;
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct Identifier {
pub name: Arc<str>,
pub language_id: LanguageId,
}
slotmap::new_key_type! {
pub struct DeclarationId;
}
#[derive(Debug, Clone)]
pub enum Declaration {
File {
project_entry_id: ProjectEntryId,
declaration: FileDeclaration,
},
Buffer {
project_entry_id: ProjectEntryId,
buffer_id: BufferId,
rope: Rope,
declaration: BufferDeclaration,
},
}
const ITEM_TEXT_TRUNCATION_LENGTH: usize = 1024;
impl Declaration {
pub fn identifier(&self) -> &Identifier {
match self {
Declaration::File { declaration, .. } => &declaration.identifier,
Declaration::Buffer { declaration, .. } => &declaration.identifier,
}
}
pub fn project_entry_id(&self) -> Option<ProjectEntryId> {
match self {
Declaration::File {
project_entry_id, ..
} => Some(*project_entry_id),
Declaration::Buffer {
project_entry_id, ..
} => Some(*project_entry_id),
}
}
pub fn item_text(&self) -> (Cow<'_, str>, bool) {
match self {
Declaration::File { declaration, .. } => (
declaration.text.as_ref().into(),
declaration.text_is_truncated,
),
Declaration::Buffer {
rope, declaration, ..
} => (
rope.chunks_in_range(declaration.item_range.clone())
.collect::<Cow<str>>(),
declaration.item_range_is_truncated,
),
}
}
pub fn signature_text(&self) -> (Cow<'_, str>, bool) {
match self {
Declaration::File { declaration, .. } => (
declaration.text[declaration.signature_range_in_text.clone()].into(),
declaration.signature_is_truncated,
),
Declaration::Buffer {
rope, declaration, ..
} => (
rope.chunks_in_range(declaration.signature_range.clone())
.collect::<Cow<str>>(),
declaration.signature_range_is_truncated,
),
}
}
}
fn expand_range_to_line_boundaries_and_truncate(
range: &Range<usize>,
limit: usize,
rope: &Rope,
) -> (Range<usize>, bool) {
let mut point_range = rope.offset_to_point(range.start)..rope.offset_to_point(range.end);
point_range.start.column = 0;
point_range.end.row += 1;
point_range.end.column = 0;
let mut item_range =
rope.point_to_offset(point_range.start)..rope.point_to_offset(point_range.end);
let is_truncated = item_range.len() > limit;
if is_truncated {
item_range.end = item_range.start + limit;
}
item_range.end = rope.clip_offset(item_range.end, Bias::Left);
(item_range, is_truncated)
}
#[derive(Debug, Clone)]
pub struct FileDeclaration {
pub parent: Option<DeclarationId>,
pub identifier: Identifier,
/// offset range of the declaration in the file, expanded to line boundaries and truncated
pub item_range_in_file: Range<usize>,
/// text of `item_range_in_file`
pub text: Arc<str>,
/// whether `text` was truncated
pub text_is_truncated: bool,
/// offset range of the signature within `text`
pub signature_range_in_text: Range<usize>,
/// whether `signature` was truncated
pub signature_is_truncated: bool,
}
impl FileDeclaration {
pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> FileDeclaration {
let (item_range_in_file, text_is_truncated) = expand_range_to_line_boundaries_and_truncate(
&declaration.item_range,
ITEM_TEXT_TRUNCATION_LENGTH,
rope,
);
// TODO: consider logging if unexpected
let signature_start = declaration
.signature_range
.start
.saturating_sub(item_range_in_file.start);
let mut signature_end = declaration
.signature_range
.end
.saturating_sub(item_range_in_file.start);
let signature_is_truncated = signature_end > item_range_in_file.len();
if signature_is_truncated {
signature_end = item_range_in_file.len();
}
FileDeclaration {
parent: None,
identifier: declaration.identifier,
signature_range_in_text: signature_start..signature_end,
signature_is_truncated,
text: rope
.chunks_in_range(item_range_in_file.clone())
.collect::<String>()
.into(),
text_is_truncated,
item_range_in_file,
}
}
}
#[derive(Debug, Clone)]
pub struct BufferDeclaration {
pub parent: Option<DeclarationId>,
pub identifier: Identifier,
pub item_range: Range<usize>,
pub item_range_is_truncated: bool,
pub signature_range: Range<usize>,
pub signature_range_is_truncated: bool,
}
impl BufferDeclaration {
pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> Self {
let (item_range, item_range_is_truncated) = expand_range_to_line_boundaries_and_truncate(
&declaration.item_range,
ITEM_TEXT_TRUNCATION_LENGTH,
rope,
);
let (signature_range, signature_range_is_truncated) =
expand_range_to_line_boundaries_and_truncate(
&declaration.signature_range,
ITEM_TEXT_TRUNCATION_LENGTH,
rope,
);
Self {
parent: None,
identifier: declaration.identifier,
item_range,
item_range_is_truncated,
signature_range,
signature_range_is_truncated,
}
}
}

View File

@@ -0,0 +1,324 @@
use itertools::Itertools as _;
use language::BufferSnapshot;
use ordered_float::OrderedFloat;
use serde::Serialize;
use std::{collections::HashMap, ops::Range};
use strum::EnumIter;
use text::{OffsetRangeExt, Point, ToPoint};
use crate::{
Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
reference::{Reference, ReferenceRegion},
syntax_index::SyntaxIndexState,
text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
};
const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
// TODO:
//
// * Consider adding declaration_file_count
#[derive(Clone, Debug)]
pub struct ScoredSnippet {
pub identifier: Identifier,
pub declaration: Declaration,
pub score_components: ScoreInputs,
pub scores: Scores,
}
// TODO: Consider having "Concise" style corresponding to `concise_text`
#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub enum SnippetStyle {
Signature,
Declaration,
}
impl ScoredSnippet {
/// Returns the score for this snippet with the specified style.
pub fn score(&self, style: SnippetStyle) -> f32 {
match style {
SnippetStyle::Signature => self.scores.signature,
SnippetStyle::Declaration => self.scores.declaration,
}
}
pub fn size(&self, style: SnippetStyle) -> usize {
// TODO: how to handle truncation?
match &self.declaration {
Declaration::File { declaration, .. } => match style {
SnippetStyle::Signature => declaration.signature_range_in_text.len(),
SnippetStyle::Declaration => declaration.text.len(),
},
Declaration::Buffer { declaration, .. } => match style {
SnippetStyle::Signature => declaration.signature_range.len(),
SnippetStyle::Declaration => declaration.item_range.len(),
},
}
}
pub fn score_density(&self, style: SnippetStyle) -> f32 {
self.score(style) / (self.size(style)) as f32
}
}
pub fn scored_snippets(
index: &SyntaxIndexState,
excerpt: &EditPredictionExcerpt,
excerpt_text: &EditPredictionExcerptText,
identifier_to_references: HashMap<Identifier, Vec<Reference>>,
cursor_offset: usize,
current_buffer: &BufferSnapshot,
) -> Vec<ScoredSnippet> {
let containing_range_identifier_occurrences =
IdentifierOccurrences::within_string(&excerpt_text.body);
let cursor_point = cursor_offset.to_point(&current_buffer);
let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
let end_point = Point::new(cursor_point.row + 1, 0);
let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
&current_buffer
.text_for_range(start_point..end_point)
.collect::<String>(),
);
let mut snippets = identifier_to_references
.into_iter()
.flat_map(|(identifier, references)| {
let declarations =
index.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier);
let declaration_count = declarations.len();
declarations
.iter()
.filter_map(|declaration| match declaration {
Declaration::Buffer {
buffer_id,
declaration: buffer_declaration,
..
} => {
let is_same_file = buffer_id == &current_buffer.remote_id();
if is_same_file {
range_intersection(
&buffer_declaration.item_range.to_offset(&current_buffer),
&excerpt.range,
)
.is_none()
.then(|| {
let declaration_line = buffer_declaration
.item_range
.start
.to_point(current_buffer)
.row;
(
true,
(cursor_point.row as i32 - declaration_line as i32).abs()
as u32,
declaration,
)
})
} else {
Some((false, 0, declaration))
}
}
Declaration::File { .. } => {
// We can assume that a file declaration is in a different file,
// because the current one must be open
Some((false, 0, declaration))
}
})
.sorted_by_key(|&(_, distance, _)| distance)
.enumerate()
.map(
|(
declaration_line_distance_rank,
(is_same_file, declaration_line_distance, declaration),
)| {
let same_file_declaration_count = index.file_declaration_count(declaration);
score_snippet(
&identifier,
&references,
declaration.clone(),
is_same_file,
declaration_line_distance,
declaration_line_distance_rank,
same_file_declaration_count,
declaration_count,
&containing_range_identifier_occurrences,
&adjacent_identifier_occurrences,
cursor_point,
current_buffer,
)
},
)
.collect::<Vec<_>>()
})
.flatten()
.collect::<Vec<_>>();
snippets.sort_unstable_by_key(|snippet| {
OrderedFloat(
snippet
.score_density(SnippetStyle::Declaration)
.max(snippet.score_density(SnippetStyle::Signature)),
)
});
snippets
}
fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
let start = a.start.clone().max(b.start.clone());
let end = a.end.clone().min(b.end.clone());
if start < end {
Some(Range { start, end })
} else {
None
}
}
fn score_snippet(
identifier: &Identifier,
references: &[Reference],
declaration: Declaration,
is_same_file: bool,
declaration_line_distance: u32,
declaration_line_distance_rank: usize,
same_file_declaration_count: usize,
declaration_count: usize,
containing_range_identifier_occurrences: &IdentifierOccurrences,
adjacent_identifier_occurrences: &IdentifierOccurrences,
cursor: Point,
current_buffer: &BufferSnapshot,
) -> Option<ScoredSnippet> {
let is_referenced_nearby = references
.iter()
.any(|r| r.region == ReferenceRegion::Nearby);
let is_referenced_in_breadcrumb = references
.iter()
.any(|r| r.region == ReferenceRegion::Breadcrumb);
let reference_count = references.len();
let reference_line_distance = references
.iter()
.map(|r| {
let reference_line = r.range.start.to_point(current_buffer).row as i32;
(cursor.row as i32 - reference_line).abs() as u32
})
.min()
.unwrap();
let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0);
let item_signature_occurrences =
IdentifierOccurrences::within_string(&declaration.signature_text().0);
let containing_range_vs_item_jaccard = jaccard_similarity(
containing_range_identifier_occurrences,
&item_source_occurrences,
);
let containing_range_vs_signature_jaccard = jaccard_similarity(
containing_range_identifier_occurrences,
&item_signature_occurrences,
);
let adjacent_vs_item_jaccard =
jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
let adjacent_vs_signature_jaccard =
jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
containing_range_identifier_occurrences,
&item_source_occurrences,
);
let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
containing_range_identifier_occurrences,
&item_signature_occurrences,
);
let adjacent_vs_item_weighted_overlap =
weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
let adjacent_vs_signature_weighted_overlap =
weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
let score_components = ScoreInputs {
is_same_file,
is_referenced_nearby,
is_referenced_in_breadcrumb,
reference_line_distance,
declaration_line_distance,
declaration_line_distance_rank,
reference_count,
same_file_declaration_count,
declaration_count,
containing_range_vs_item_jaccard,
containing_range_vs_signature_jaccard,
adjacent_vs_item_jaccard,
adjacent_vs_signature_jaccard,
containing_range_vs_item_weighted_overlap,
containing_range_vs_signature_weighted_overlap,
adjacent_vs_item_weighted_overlap,
adjacent_vs_signature_weighted_overlap,
};
Some(ScoredSnippet {
identifier: identifier.clone(),
declaration: declaration,
scores: score_components.score(),
score_components,
})
}
#[derive(Clone, Debug, Serialize)]
pub struct ScoreInputs {
pub is_same_file: bool,
pub is_referenced_nearby: bool,
pub is_referenced_in_breadcrumb: bool,
pub reference_count: usize,
pub same_file_declaration_count: usize,
pub declaration_count: usize,
pub reference_line_distance: u32,
pub declaration_line_distance: u32,
pub declaration_line_distance_rank: usize,
pub containing_range_vs_item_jaccard: f32,
pub containing_range_vs_signature_jaccard: f32,
pub adjacent_vs_item_jaccard: f32,
pub adjacent_vs_signature_jaccard: f32,
pub containing_range_vs_item_weighted_overlap: f32,
pub containing_range_vs_signature_weighted_overlap: f32,
pub adjacent_vs_item_weighted_overlap: f32,
pub adjacent_vs_signature_weighted_overlap: f32,
}
#[derive(Clone, Debug, Serialize)]
pub struct Scores {
pub signature: f32,
pub declaration: f32,
}
impl ScoreInputs {
fn score(&self) -> Scores {
// Score related to how likely this is the correct declaration, range 0 to 1
let accuracy_score = if self.is_same_file {
// TODO: use declaration_line_distance_rank
1.0 / self.same_file_declaration_count as f32
} else {
1.0 / self.declaration_count as f32
};
// Score related to the distance between the reference and cursor, range 0 to 1
let distance_score = if self.is_referenced_nearby {
1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0)
} else {
// same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
0.5
};
// For now instead of linear combination, the scores are just multiplied together.
let combined_score = 10.0 * accuracy_score * distance_score;
Scores {
signature: combined_score * self.containing_range_vs_signature_weighted_overlap,
// declaration score gets boosted both by being multipled by 2 and by there being more
// weighted overlap.
declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap,
}
}
}

View File

@@ -1,8 +1,220 @@
mod declaration;
mod declaration_scoring;
mod excerpt;
mod outline;
mod reference;
mod tree_sitter_index;
mod syntax_index;
mod text_similarity;
pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
use gpui::{App, AppContext as _, Entity, Task};
use language::BufferSnapshot;
pub use reference::references_in_excerpt;
pub use tree_sitter_index::{BufferDeclaration, Declaration, FileDeclaration, TreeSitterIndex};
pub use syntax_index::SyntaxIndex;
use text::{Point, ToOffset as _};
use crate::declaration_scoring::{ScoredSnippet, scored_snippets};
pub struct EditPredictionContext {
pub excerpt: EditPredictionExcerpt,
pub excerpt_text: EditPredictionExcerptText,
pub snippets: Vec<ScoredSnippet>,
}
impl EditPredictionContext {
pub fn gather(
cursor_point: Point,
buffer: BufferSnapshot,
excerpt_options: EditPredictionExcerptOptions,
syntax_index: Entity<SyntaxIndex>,
cx: &mut App,
) -> Task<Self> {
let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
cx.background_spawn(async move {
let index_state = index_state.lock().await;
let excerpt =
EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options)
.unwrap();
let excerpt_text = excerpt.text(&buffer);
let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
let cursor_offset = cursor_point.to_offset(&buffer);
let snippets = scored_snippets(
&index_state,
&excerpt,
&excerpt_text,
references,
cursor_offset,
&buffer,
);
Self {
excerpt,
excerpt_text,
snippets,
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use gpui::{Entity, TestAppContext};
use indoc::indoc;
use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use util::path;
use crate::{EditPredictionExcerptOptions, SyntaxIndex};
#[gpui::test]
async fn test_call_site(cx: &mut TestAppContext) {
let (project, index, _rust_lang_id) = init_test(cx).await;
let buffer = project
.update(cx, |project, cx| {
let project_path = project.find_project_path("c.rs", cx).unwrap();
project.open_buffer(project_path, cx)
})
.await
.unwrap();
cx.run_until_parked();
// first process_data call site
let cursor_point = language::Point::new(8, 21);
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
let context = cx
.update(|cx| {
EditPredictionContext::gather(
cursor_point,
buffer_snapshot,
EditPredictionExcerptOptions {
max_bytes: 40,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.5,
include_parent_signatures: false,
},
index,
cx,
)
})
.await;
assert_eq!(context.snippets.len(), 1);
assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data");
drop(buffer);
}
async fn init_test(
cx: &mut TestAppContext,
) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root"),
json!({
"a.rs": indoc! {r#"
fn main() {
let x = 1;
let y = 2;
let z = add(x, y);
println!("Result: {}", z);
}
fn add(a: i32, b: i32) -> i32 {
a + b
}
"#},
"b.rs": indoc! {"
pub struct Config {
pub name: String,
pub value: i32,
}
impl Config {
pub fn new(name: String, value: i32) -> Self {
Config { name, value }
}
}
"},
"c.rs": indoc! {r#"
use std::collections::HashMap;
fn main() {
let args: Vec<String> = std::env::args().collect();
let data: Vec<i32> = args[1..]
.iter()
.filter_map(|s| s.parse().ok())
.collect();
let result = process_data(data);
println!("{:?}", result);
}
fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
let mut counts = HashMap::new();
for value in data {
*counts.entry(value).or_insert(0) += 1;
}
counts
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_process_data() {
let data = vec![1, 2, 2, 3];
let result = process_data(data);
assert_eq!(result.get(&2), Some(&2));
}
}
"#}
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
let lang = rust_lang();
let lang_id = lang.id();
language_registry.add(Arc::new(lang));
let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
cx.run_until_parked();
(project, index, lang_id)
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_rust::LANGUAGE.into()),
)
.with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
.unwrap()
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
.unwrap()
}
}

View File

@@ -31,7 +31,7 @@ pub struct EditPredictionExcerptOptions {
pub include_parent_signatures: bool,
}
#[derive(Clone)]
#[derive(Debug, Clone)]
pub struct EditPredictionExcerpt {
pub range: Range<usize>,
pub parent_signature_ranges: Vec<Range<usize>>,

View File

@@ -1,5 +1,7 @@
use language::{BufferSnapshot, LanguageId, SyntaxMapMatches};
use std::{cmp::Reverse, ops::Range, sync::Arc};
use language::{BufferSnapshot, SyntaxMapMatches};
use std::{cmp::Reverse, ops::Range};
use crate::declaration::Identifier;
// TODO:
//
@@ -18,12 +20,6 @@ pub struct OutlineDeclaration {
pub signature_range: Range<usize>,
}
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct Identifier {
pub name: Arc<str>,
pub language_id: LanguageId,
}
pub fn declarations_in_buffer(buffer: &BufferSnapshot) -> Vec<OutlineDeclaration> {
declarations_overlapping_range(0..buffer.len(), buffer)
}

View File

@@ -3,8 +3,8 @@ use std::collections::HashMap;
use std::ops::Range;
use crate::{
declaration::Identifier,
excerpt::{EditPredictionExcerpt, EditPredictionExcerptText},
outline::Identifier,
};
#[derive(Debug)]

View File

@@ -1,20 +1,26 @@
use std::sync::Arc;
use collections::{HashMap, HashSet};
use futures::lock::Mutex;
use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity};
use language::{Buffer, BufferEvent, BufferSnapshot};
use language::{Buffer, BufferEvent};
use project::buffer_store::{BufferStore, BufferStoreEvent};
use project::worktree_store::{WorktreeStore, WorktreeStoreEvent};
use project::{PathChange, Project, ProjectEntryId, ProjectPath};
use slotmap::SlotMap;
use std::ops::Range;
use std::sync::Arc;
use text::Anchor;
use text::BufferId;
use util::{ResultExt as _, debug_panic, some_or_debug_panic};
use crate::outline::{Identifier, OutlineDeclaration, declarations_in_buffer};
use crate::declaration::{
BufferDeclaration, Declaration, DeclarationId, FileDeclaration, Identifier,
};
use crate::outline::declarations_in_buffer;
// TODO:
//
// * Skip for remote projects
//
// * Consider making SyntaxIndex not an Entity.
// Potential future improvements:
//
@@ -34,17 +40,19 @@ use crate::outline::{Identifier, OutlineDeclaration, declarations_in_buffer};
// * Concurrent slotmap
//
// * Use queue for parsing
//
slotmap::new_key_type! {
pub struct DeclarationId;
pub struct SyntaxIndex {
state: Arc<Mutex<SyntaxIndexState>>,
project: WeakEntity<Project>,
}
pub struct TreeSitterIndex {
#[derive(Default)]
pub struct SyntaxIndexState {
declarations: SlotMap<DeclarationId, Declaration>,
identifiers: HashMap<Identifier, HashSet<DeclarationId>>,
files: HashMap<ProjectEntryId, FileState>,
buffers: HashMap<WeakEntity<Buffer>, BufferState>,
project: WeakEntity<Project>,
buffers: HashMap<BufferId, BufferState>,
}
#[derive(Debug, Default)]
@@ -59,52 +67,11 @@ struct BufferState {
task: Option<Task<()>>,
}
#[derive(Debug, Clone)]
pub enum Declaration {
File {
project_entry_id: ProjectEntryId,
declaration: FileDeclaration,
},
Buffer {
buffer: WeakEntity<Buffer>,
declaration: BufferDeclaration,
},
}
impl Declaration {
fn identifier(&self) -> &Identifier {
match self {
Declaration::File { declaration, .. } => &declaration.identifier,
Declaration::Buffer { declaration, .. } => &declaration.identifier,
}
}
}
#[derive(Debug, Clone)]
pub struct FileDeclaration {
pub parent: Option<DeclarationId>,
pub identifier: Identifier,
pub item_range: Range<usize>,
pub signature_range: Range<usize>,
pub signature_text: Arc<str>,
}
#[derive(Debug, Clone)]
pub struct BufferDeclaration {
pub parent: Option<DeclarationId>,
pub identifier: Identifier,
pub item_range: Range<Anchor>,
pub signature_range: Range<Anchor>,
}
impl TreeSitterIndex {
impl SyntaxIndex {
pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
let mut this = Self {
declarations: SlotMap::with_key(),
identifiers: HashMap::default(),
project: project.downgrade(),
files: HashMap::default(),
buffers: HashMap::default(),
state: Arc::new(Mutex::new(SyntaxIndexState::default())),
};
let worktree_store = project.read(cx).worktree_store();
@@ -139,73 +106,6 @@ impl TreeSitterIndex {
this
}
pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> {
self.declarations.get(id)
}
pub fn declarations_for_identifier<const N: usize>(
&self,
identifier: Identifier,
cx: &App,
) -> Vec<Declaration> {
// make sure to not have a large stack allocation
assert!(N < 32);
let Some(declaration_ids) = self.identifiers.get(&identifier) else {
return vec![];
};
let mut result = Vec::with_capacity(N);
let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new();
let mut file_declarations = Vec::new();
for declaration_id in declaration_ids {
let declaration = self.declarations.get(*declaration_id);
let Some(declaration) = some_or_debug_panic(declaration) else {
continue;
};
match declaration {
Declaration::Buffer { buffer, .. } => {
if let Ok(Some(entry_id)) = buffer.read_with(cx, |buffer, cx| {
project::File::from_dyn(buffer.file()).and_then(|f| f.project_entry_id(cx))
}) {
included_buffer_entry_ids.push(entry_id);
result.push(declaration.clone());
if result.len() == N {
return result;
}
}
}
Declaration::File {
project_entry_id, ..
} => {
if !included_buffer_entry_ids.contains(project_entry_id) {
file_declarations.push(declaration.clone());
}
}
}
}
for declaration in file_declarations {
match declaration {
Declaration::File {
project_entry_id, ..
} => {
if !included_buffer_entry_ids.contains(&project_entry_id) {
result.push(declaration);
if result.len() == N {
return result;
}
}
}
Declaration::Buffer { .. } => {}
}
}
result
}
fn handle_worktree_store_event(
&mut self,
_worktree_store: Entity<WorktreeStore>,
@@ -215,21 +115,33 @@ impl TreeSitterIndex {
use WorktreeStoreEvent::*;
match event {
WorktreeUpdatedEntries(worktree_id, updated_entries_set) => {
for (path, entry_id, path_change) in updated_entries_set.iter() {
if let PathChange::Removed = path_change {
self.files.remove(entry_id);
} else {
let project_path = ProjectPath {
worktree_id: *worktree_id,
path: path.clone(),
};
self.update_file(*entry_id, project_path, cx);
let state = Arc::downgrade(&self.state);
let worktree_id = *worktree_id;
let updated_entries_set = updated_entries_set.clone();
cx.spawn(async move |this, cx| {
let Some(state) = state.upgrade() else { return };
for (path, entry_id, path_change) in updated_entries_set.iter() {
if let PathChange::Removed = path_change {
state.lock().await.files.remove(entry_id);
} else {
let project_path = ProjectPath {
worktree_id,
path: path.clone(),
};
this.update(cx, |this, cx| {
this.update_file(*entry_id, project_path, cx);
})
.ok();
}
}
}
})
.detach();
}
WorktreeDeletedEntry(_worktree_id, project_entry_id) => {
// TODO: Is this needed?
self.files.remove(project_entry_id);
let project_entry_id = *project_entry_id;
self.with_state(cx, move |state| {
state.files.remove(&project_entry_id);
})
}
_ => {}
}
@@ -251,15 +163,42 @@ impl TreeSitterIndex {
}
}
pub fn state(&self) -> &Arc<Mutex<SyntaxIndexState>> {
&self.state
}
fn with_state(&self, cx: &mut App, f: impl FnOnce(&mut SyntaxIndexState) + Send + 'static) {
if let Some(mut state) = self.state.try_lock() {
f(&mut state);
return;
}
let state = Arc::downgrade(&self.state);
cx.background_spawn(async move {
let Some(state) = state.upgrade() else {
return None;
};
let mut state = state.lock().await;
Some(f(&mut state))
})
.detach();
}
fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
self.buffers
.insert(buffer.downgrade(), BufferState::default());
let weak_buf = buffer.downgrade();
cx.observe_release(buffer, move |this, _buffer, _cx| {
this.buffers.remove(&weak_buf);
let buffer_id = buffer.read(cx).remote_id();
cx.observe_release(buffer, move |this, _buffer, cx| {
this.with_state(cx, move |state| {
if let Some(buffer_state) = state.buffers.remove(&buffer_id) {
SyntaxIndexState::remove_buffer_declarations(
&buffer_state.declarations,
&mut state.declarations,
&mut state.identifiers,
);
}
})
})
.detach();
cx.subscribe(buffer, Self::handle_buffer_event).detach();
self.update_buffer(buffer.clone(), cx);
}
@@ -275,10 +214,19 @@ impl TreeSitterIndex {
}
}
fn update_buffer(&mut self, buffer: Entity<Buffer>, cx: &Context<Self>) {
let mut parse_status = buffer.read(cx).parse_status();
fn update_buffer(&mut self, buffer_entity: Entity<Buffer>, cx: &mut Context<Self>) {
let buffer = buffer_entity.read(cx);
let Some(project_entry_id) =
project::File::from_dyn(buffer.file()).and_then(|f| f.project_entry_id(cx))
else {
return;
};
let buffer_id = buffer.remote_id();
let mut parse_status = buffer.parse_status();
let snapshot_task = cx.spawn({
let weak_buffer = buffer.downgrade();
let weak_buffer = buffer_entity.downgrade();
async move |_, cx| {
while *parse_status.borrow() != language::ParseStatus::Idle {
parse_status.changed().await?;
@@ -289,75 +237,77 @@ impl TreeSitterIndex {
let parse_task = cx.background_spawn(async move {
let snapshot = snapshot_task.await?;
let rope = snapshot.text.as_rope().clone();
anyhow::Ok(
anyhow::Ok((
declarations_in_buffer(&snapshot)
.into_iter()
.map(|item| {
(
item.parent_index,
BufferDeclaration::from_outline(item, &snapshot),
BufferDeclaration::from_outline(item, &rope),
)
})
.collect::<Vec<_>>(),
)
rope,
))
});
let task = cx.spawn({
let weak_buffer = buffer.downgrade();
async move |this, cx| {
let Ok(declarations) = parse_task.await else {
let Ok((declarations, rope)) = parse_task.await else {
return;
};
this.update(cx, |this, _cx| {
let buffer_state = this
.buffers
.entry(weak_buffer.clone())
.or_insert_with(Default::default);
this.update(cx, move |this, cx| {
this.with_state(cx, move |state| {
let buffer_state = state
.buffers
.entry(buffer_id)
.or_insert_with(Default::default);
for old_declaration_id in &buffer_state.declarations {
let Some(declaration) = this.declarations.remove(*old_declaration_id)
else {
debug_panic!("declaration not found");
continue;
};
if let Some(identifier_declarations) =
this.identifiers.get_mut(declaration.identifier())
{
identifier_declarations.remove(old_declaration_id);
SyntaxIndexState::remove_buffer_declarations(
&buffer_state.declarations,
&mut state.declarations,
&mut state.identifiers,
);
let mut new_ids = Vec::with_capacity(declarations.len());
state.declarations.reserve(declarations.len());
for (parent_index, mut declaration) in declarations {
declaration.parent = parent_index
.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
let identifier = declaration.identifier.clone();
let declaration_id = state.declarations.insert(Declaration::Buffer {
rope: rope.clone(),
buffer_id,
declaration,
project_entry_id,
});
new_ids.push(declaration_id);
state
.identifiers
.entry(identifier)
.or_default()
.insert(declaration_id);
}
}
let mut new_ids = Vec::with_capacity(declarations.len());
this.declarations.reserve(declarations.len());
for (parent_index, mut declaration) in declarations {
declaration.parent = parent_index
.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
let identifier = declaration.identifier.clone();
let declaration_id = this.declarations.insert(Declaration::Buffer {
buffer: weak_buffer.clone(),
declaration,
});
new_ids.push(declaration_id);
this.identifiers
.entry(identifier)
.or_default()
.insert(declaration_id);
}
buffer_state.declarations = new_ids;
buffer_state.declarations = new_ids;
});
})
.ok();
}
});
self.buffers
.entry(buffer.downgrade())
.or_insert_with(Default::default)
.task = Some(task);
self.with_state(cx, move |state| {
state
.buffers
.entry(buffer_id)
.or_insert_with(Default::default)
.task = Some(task)
});
}
fn update_file(
@@ -401,14 +351,10 @@ impl TreeSitterIndex {
let parse_task = cx.background_spawn(async move {
let snapshot = snapshot_task.await?;
let rope = snapshot.as_rope();
let declarations = declarations_in_buffer(&snapshot)
.into_iter()
.map(|item| {
(
item.parent_index,
FileDeclaration::from_outline(item, &snapshot),
)
})
.map(|item| (item.parent_index, FileDeclaration::from_outline(item, rope)))
.collect::<Vec<_>>();
anyhow::Ok(declarations)
});
@@ -419,84 +365,160 @@ impl TreeSitterIndex {
let Ok(declarations) = parse_task.await else {
return;
};
this.update(cx, |this, _cx| {
let file_state = this.files.entry(entry_id).or_insert_with(Default::default);
this.update(cx, |this, cx| {
this.with_state(cx, move |state| {
let file_state =
state.files.entry(entry_id).or_insert_with(Default::default);
for old_declaration_id in &file_state.declarations {
let Some(declaration) = this.declarations.remove(*old_declaration_id)
else {
debug_panic!("declaration not found");
continue;
};
if let Some(identifier_declarations) =
this.identifiers.get_mut(declaration.identifier())
{
identifier_declarations.remove(old_declaration_id);
for old_declaration_id in &file_state.declarations {
let Some(declaration) = state.declarations.remove(*old_declaration_id)
else {
debug_panic!("declaration not found");
continue;
};
if let Some(identifier_declarations) =
state.identifiers.get_mut(declaration.identifier())
{
identifier_declarations.remove(old_declaration_id);
}
}
}
let mut new_ids = Vec::with_capacity(declarations.len());
this.declarations.reserve(declarations.len());
let mut new_ids = Vec::with_capacity(declarations.len());
state.declarations.reserve(declarations.len());
for (parent_index, mut declaration) in declarations {
declaration.parent = parent_index
.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
for (parent_index, mut declaration) in declarations {
declaration.parent = parent_index
.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
let identifier = declaration.identifier.clone();
let declaration_id = this.declarations.insert(Declaration::File {
project_entry_id: entry_id,
declaration,
});
new_ids.push(declaration_id);
let identifier = declaration.identifier.clone();
let declaration_id = state.declarations.insert(Declaration::File {
project_entry_id: entry_id,
declaration,
});
new_ids.push(declaration_id);
this.identifiers
.entry(identifier)
.or_default()
.insert(declaration_id);
}
state
.identifiers
.entry(identifier)
.or_default()
.insert(declaration_id);
}
file_state.declarations = new_ids;
file_state.declarations = new_ids;
});
})
.ok();
}
});
self.files
.entry(entry_id)
.or_insert_with(Default::default)
.task = Some(task);
self.with_state(cx, move |state| {
state
.files
.entry(entry_id)
.or_insert_with(Default::default)
.task = Some(task);
});
}
}
impl BufferDeclaration {
pub fn from_outline(declaration: OutlineDeclaration, snapshot: &BufferSnapshot) -> Self {
// use of anchor_before is a guess that the proper behavior is to expand to include
// insertions immediately before the declaration, but not for insertions immediately after
Self {
parent: None,
identifier: declaration.identifier,
item_range: snapshot.anchor_before(declaration.item_range.start)
..snapshot.anchor_before(declaration.item_range.end),
signature_range: snapshot.anchor_before(declaration.signature_range.start)
..snapshot.anchor_before(declaration.signature_range.end),
impl SyntaxIndexState {
pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> {
self.declarations.get(id)
}
/// Returns declarations for the identifier. If the limit is exceeded, returns an empty vector.
///
/// TODO: Consider doing some pre-ranking and instead truncating when N is exceeded.
pub fn declarations_for_identifier<const N: usize>(
&self,
identifier: &Identifier,
) -> Vec<Declaration> {
// make sure to not have a large stack allocation
assert!(N < 32);
let Some(declaration_ids) = self.identifiers.get(&identifier) else {
return vec![];
};
let mut result = Vec::with_capacity(N);
let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new();
let mut file_declarations = Vec::new();
for declaration_id in declaration_ids {
let declaration = self.declarations.get(*declaration_id);
let Some(declaration) = some_or_debug_panic(declaration) else {
continue;
};
match declaration {
Declaration::Buffer {
project_entry_id, ..
} => {
included_buffer_entry_ids.push(*project_entry_id);
result.push(declaration.clone());
if result.len() == N {
return Vec::new();
}
}
Declaration::File {
project_entry_id, ..
} => {
if !included_buffer_entry_ids.contains(&project_entry_id) {
file_declarations.push(declaration.clone());
}
}
}
}
for declaration in file_declarations {
match declaration {
Declaration::File {
project_entry_id, ..
} => {
if !included_buffer_entry_ids.contains(&project_entry_id) {
result.push(declaration);
if result.len() == N {
return Vec::new();
}
}
}
Declaration::Buffer { .. } => {}
}
}
result
}
pub fn file_declaration_count(&self, declaration: &Declaration) -> usize {
match declaration {
Declaration::File {
project_entry_id, ..
} => self
.files
.get(project_entry_id)
.map(|file_state| file_state.declarations.len())
.unwrap_or_default(),
Declaration::Buffer { buffer_id, .. } => self
.buffers
.get(buffer_id)
.map(|buffer_state| buffer_state.declarations.len())
.unwrap_or_default(),
}
}
}
impl FileDeclaration {
pub fn from_outline(
declaration: OutlineDeclaration,
snapshot: &BufferSnapshot,
) -> FileDeclaration {
FileDeclaration {
parent: None,
identifier: declaration.identifier,
item_range: declaration.item_range,
signature_text: snapshot
.text_for_range(declaration.signature_range.clone())
.collect::<String>()
.into(),
signature_range: declaration.signature_range,
fn remove_buffer_declarations(
old_declaration_ids: &[DeclarationId],
declarations: &mut SlotMap<DeclarationId, Declaration>,
identifiers: &mut HashMap<Identifier, HashSet<DeclarationId>>,
) {
for old_declaration_id in old_declaration_ids {
let Some(declaration) = declarations.remove(*old_declaration_id) else {
debug_panic!("declaration not found");
continue;
};
if let Some(identifier_declarations) = identifiers.get_mut(declaration.identifier()) {
identifier_declarations.remove(old_declaration_id);
}
}
}
}
@@ -506,17 +528,16 @@ mod tests {
use super::*;
use std::{path::Path, sync::Arc};
use futures::channel::oneshot;
use gpui::TestAppContext;
use indoc::indoc;
use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
use project::{FakeFs, Project, ProjectItem};
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use text::OffsetRangeExt as _;
use util::path;
use crate::tree_sitter_index::TreeSitterIndex;
use crate::syntax_index::SyntaxIndex;
#[gpui::test]
async fn test_unopen_indexed_files(cx: &mut TestAppContext) {
@@ -526,17 +547,19 @@ mod tests {
language_id: rust_lang_id,
};
index.read_with(cx, |index, cx| {
let decls = index.declarations_for_identifier::<8>(main.clone(), cx);
let index_state = index.read_with(cx, |index, _cx| index.state().clone());
let index_state = index_state.lock().await;
cx.update(|cx| {
let decls = index_state.declarations_for_identifier::<8>(&main);
assert_eq!(decls.len(), 2);
let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
assert_eq!(decl.identifier, main.clone());
assert_eq!(decl.item_range, 32..279);
assert_eq!(decl.item_range_in_file, 32..280);
let decl = expect_file_decl("a.rs", &decls[1], &project, cx);
assert_eq!(decl.identifier, main);
assert_eq!(decl.item_range, 0..97);
assert_eq!(decl.item_range_in_file, 0..98);
});
}
@@ -548,15 +571,17 @@ mod tests {
language_id: rust_lang_id,
};
index.read_with(cx, |index, cx| {
let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx);
let index_state = index.read_with(cx, |index, _cx| index.state().clone());
let index_state = index_state.lock().await;
cx.update(|cx| {
let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
assert_eq!(decls.len(), 1);
let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
assert_eq!(decl.identifier, test_process_data);
let parent_id = decl.parent.unwrap();
let parent = index.declaration(parent_id).unwrap();
let parent = index_state.declaration(parent_id).unwrap();
let parent_decl = expect_file_decl("c.rs", &parent, &project, cx);
assert_eq!(
parent_decl.identifier,
@@ -587,16 +612,18 @@ mod tests {
cx.run_until_parked();
index.read_with(cx, |index, cx| {
let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx);
let index_state = index.read_with(cx, |index, _cx| index.state().clone());
let index_state = index_state.lock().await;
cx.update(|cx| {
let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
assert_eq!(decls.len(), 1);
let decl = expect_buffer_decl("c.rs", &decls[0], cx);
let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
assert_eq!(decl.identifier, test_process_data);
let parent_id = decl.parent.unwrap();
let parent = index.declaration(parent_id).unwrap();
let parent_decl = expect_buffer_decl("c.rs", &parent, cx);
let parent = index_state.declaration(parent_id).unwrap();
let parent_decl = expect_buffer_decl("c.rs", &parent, &project, cx);
assert_eq!(
parent_decl.identifier,
Identifier {
@@ -614,16 +641,13 @@ mod tests {
async fn test_declarations_limt(cx: &mut TestAppContext) {
let (_, index, rust_lang_id) = init_test(cx).await;
index.read_with(cx, |index, cx| {
let decls = index.declarations_for_identifier::<1>(
Identifier {
name: "main".into(),
language_id: rust_lang_id,
},
cx,
);
assert_eq!(decls.len(), 1);
let index_state = index.read_with(cx, |index, _cx| index.state().clone());
let index_state = index_state.lock().await;
let decls = index_state.declarations_for_identifier::<1>(&Identifier {
name: "main".into(),
language_id: rust_lang_id,
});
assert_eq!(decls.len(), 0);
}
#[gpui::test]
@@ -645,31 +669,31 @@ mod tests {
cx.run_until_parked();
index.read_with(cx, |index, cx| {
let decls = index.declarations_for_identifier::<8>(main.clone(), cx);
assert_eq!(decls.len(), 2);
let decl = expect_buffer_decl("c.rs", &decls[0], cx);
assert_eq!(decl.identifier, main);
assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..279);
let index_state_arc = index.read_with(cx, |index, _cx| index.state().clone());
{
let index_state = index_state_arc.lock().await;
expect_file_decl("a.rs", &decls[1], &project, cx);
});
cx.update(|cx| {
let decls = index_state.declarations_for_identifier::<8>(&main);
assert_eq!(decls.len(), 2);
let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
assert_eq!(decl.identifier, main);
assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..279);
expect_file_decl("a.rs", &decls[1], &project, cx);
});
}
// Drop the buffer and wait for release
let (release_tx, release_rx) = oneshot::channel();
cx.update(|cx| {
cx.observe_release(&buffer, |_, _| {
release_tx.send(()).ok();
})
.detach();
cx.update(|_| {
drop(buffer);
});
drop(buffer);
cx.run_until_parked();
release_rx.await.ok();
cx.run_until_parked();
index.read_with(cx, |index, cx| {
let decls = index.declarations_for_identifier::<8>(main, cx);
let index_state = index_state_arc.lock().await;
cx.update(|cx| {
let decls = index_state.declarations_for_identifier::<8>(&main);
assert_eq!(decls.len(), 2);
expect_file_decl("c.rs", &decls[0], &project, cx);
expect_file_decl("a.rs", &decls[1], &project, cx);
@@ -679,24 +703,20 @@ mod tests {
fn expect_buffer_decl<'a>(
path: &str,
declaration: &'a Declaration,
project: &Entity<Project>,
cx: &App,
) -> &'a BufferDeclaration {
if let Declaration::Buffer {
declaration,
buffer,
project_entry_id,
..
} = declaration
{
assert_eq!(
buffer
.upgrade()
.unwrap()
.read(cx)
.project_path(cx)
.unwrap()
.path
.as_ref(),
Path::new(path),
);
let project_path = project
.read(cx)
.path_for_entry(*project_entry_id, cx)
.unwrap();
assert_eq!(project_path.path.as_ref(), Path::new(path),);
declaration
} else {
panic!("Expected a buffer declaration, found {:?}", declaration);
@@ -731,7 +751,7 @@ mod tests {
async fn init_test(
cx: &mut TestAppContext,
) -> (Entity<Project>, Entity<TreeSitterIndex>, LanguageId) {
) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
@@ -809,7 +829,7 @@ mod tests {
let lang_id = lang.id();
language_registry.add(Arc::new(lang));
let index = cx.new(|cx| TreeSitterIndex::new(&project, cx));
let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
cx.run_until_parked();
(project, index, lang_id)

View File

@@ -0,0 +1,241 @@
use regex::Regex;
use std::{collections::HashMap, sync::LazyLock};
use crate::reference::Reference;
// TODO: Consider implementing sliding window similarity matching like
// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts
//
// That implementation could actually be more efficient - no need to track words in the window that
// are not in the query.
static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
#[derive(Debug)]
pub struct IdentifierOccurrences {
identifier_to_count: HashMap<String, usize>,
total_count: usize,
}
impl IdentifierOccurrences {
pub fn within_string(code: &str) -> Self {
Self::from_iterator(IDENTIFIER_REGEX.find_iter(code).map(|mat| mat.as_str()))
}
#[allow(dead_code)]
pub fn within_references(references: &[Reference]) -> Self {
Self::from_iterator(
references
.iter()
.map(|reference| reference.identifier.name.as_ref()),
)
}
pub fn from_iterator<'a>(identifier_iterator: impl Iterator<Item = &'a str>) -> Self {
let mut identifier_to_count = HashMap::new();
let mut total_count = 0;
for identifier in identifier_iterator {
// TODO: Score matches that match case higher?
//
// TODO: Also include unsplit identifier?
for identifier_part in split_identifier(identifier) {
identifier_to_count
.entry(identifier_part.to_lowercase())
.and_modify(|count| *count += 1)
.or_insert(1);
total_count += 1;
}
}
IdentifierOccurrences {
identifier_to_count,
total_count,
}
}
}
// Splits camelcase / snakecase / kebabcase / pascalcase
//
// TODO: Make this more efficient / elegant.
fn split_identifier<'a>(identifier: &'a str) -> Vec<&'a str> {
let mut parts = Vec::new();
let mut start = 0;
let chars: Vec<char> = identifier.chars().collect();
if chars.is_empty() {
return parts;
}
let mut i = 0;
while i < chars.len() {
let ch = chars[i];
// Handle explicit delimiters (underscore and hyphen)
if ch == '_' || ch == '-' {
if i > start {
parts.push(&identifier[start..i]);
}
start = i + 1;
i += 1;
continue;
}
// Handle camelCase and PascalCase transitions
if i > 0 && i < chars.len() {
let prev_char = chars[i - 1];
// Transition from lowercase/digit to uppercase
if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() {
parts.push(&identifier[start..i]);
start = i;
}
// Handle sequences like "XMLParser" -> ["XML", "Parser"]
else if i + 1 < chars.len()
&& ch.is_uppercase()
&& chars[i + 1].is_lowercase()
&& prev_char.is_uppercase()
{
parts.push(&identifier[start..i]);
start = i;
}
}
i += 1;
}
// Add the last part if there's any remaining
if start < identifier.len() {
parts.push(&identifier[start..]);
}
// Filter out empty strings
parts.into_iter().filter(|s| !s.is_empty()).collect()
}
pub fn jaccard_similarity<'a>(
mut set_a: &'a IdentifierOccurrences,
mut set_b: &'a IdentifierOccurrences,
) -> f32 {
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
std::mem::swap(&mut set_a, &mut set_b);
}
let intersection = set_a
.identifier_to_count
.keys()
.filter(|key| set_b.identifier_to_count.contains_key(*key))
.count();
let union = set_a.identifier_to_count.len() + set_b.identifier_to_count.len() - intersection;
intersection as f32 / union as f32
}
// TODO
#[allow(dead_code)]
pub fn overlap_coefficient<'a>(
mut set_a: &'a IdentifierOccurrences,
mut set_b: &'a IdentifierOccurrences,
) -> f32 {
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
std::mem::swap(&mut set_a, &mut set_b);
}
let intersection = set_a
.identifier_to_count
.keys()
.filter(|key| set_b.identifier_to_count.contains_key(*key))
.count();
intersection as f32 / set_a.identifier_to_count.len() as f32
}
// TODO
#[allow(dead_code)]
pub fn weighted_jaccard_similarity<'a>(
mut set_a: &'a IdentifierOccurrences,
mut set_b: &'a IdentifierOccurrences,
) -> f32 {
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
std::mem::swap(&mut set_a, &mut set_b);
}
let mut numerator = 0;
let mut denominator_a = 0;
let mut used_count_b = 0;
for (symbol, count_a) in set_a.identifier_to_count.iter() {
let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
numerator += count_a.min(count_b);
denominator_a += count_a.max(count_b);
used_count_b += count_b;
}
let denominator = denominator_a + (set_b.total_count - used_count_b);
if denominator == 0 {
0.0
} else {
numerator as f32 / denominator as f32
}
}
pub fn weighted_overlap_coefficient<'a>(
mut set_a: &'a IdentifierOccurrences,
mut set_b: &'a IdentifierOccurrences,
) -> f32 {
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
std::mem::swap(&mut set_a, &mut set_b);
}
let mut numerator = 0;
for (symbol, count_a) in set_a.identifier_to_count.iter() {
let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
numerator += count_a.min(count_b);
}
let denominator = set_a.total_count.min(set_b.total_count);
if denominator == 0 {
0.0
} else {
numerator as f32 / denominator as f32
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_split_identifier() {
assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]);
assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]);
assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]);
assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]);
assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]);
}
#[test]
fn test_similarity_functions() {
// 10 identifier parts, 8 unique
// Repeats: 2 "outline", 2 "items"
let set_a = IdentifierOccurrences::within_string(
"let mut outline_items = query_outline_items(&language, &tree, &source);",
);
// 14 identifier parts, 11 unique
// Repeats: 2 "outline", 2 "language", 2 "tree"
let set_b = IdentifierOccurrences::within_string(
"pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
);
// 6 overlaps: "outline", "items", "query", "language", "tree", "source"
// 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0));
// Numerator is one more than before due to both having 2 "outline".
// Denominator is the same except for 3 more due to the non-overlapping duplicates
assert_eq!(
weighted_jaccard_similarity(&set_a, &set_b),
7.0 / (7.0 + 7.0 + 3.0)
);
// Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0);
// Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
// the smaller set, 10.
assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0);
}
}

View File

@@ -0,0 +1,35 @@
// To discuss: What to send to the new endpoint? Thinking it'd make sense to put `prompt.rs` from
// `zeta_context.rs` in cloud.
//
// * Run excerpt selection at several different sizes, send the largest size with offsets within for
// the smaller sizes.
//
// * Longer event history.
//
// * Many more snippets than could fit in model context - allows ranking experimentation.
pub struct Zeta2Request {
pub event_history: Vec<Event>,
pub excerpt: String,
pub excerpt_subsets: Vec<Zeta2ExcerptSubset>,
/// Within `excerpt`
pub cursor_position: usize,
pub signatures: Vec<String>,
pub retrieved_declarations: Vec<ReferencedDeclaration>,
}
pub struct Zeta2ExcerptSubset {
/// Within `excerpt` text.
pub excerpt_range: Range<usize>,
/// Within `signatures`.
pub parent_signatures: Vec<usize>,
}
pub struct ReferencedDeclaration {
pub text: Arc<str>,
/// Range within `text`
pub signature_range: Range<usize>,
/// Indices within `signatures`.
pub parent_signatures: Vec<usize>,
// A bunch of score metrics
}