Compare commits
13 Commits
diff-perf-
...
zeta2-cont
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4a4ee4fed7 | ||
|
|
ea4bf46a36 | ||
|
|
05545abab6 | ||
|
|
a85608566d | ||
|
|
69af5261ea | ||
|
|
b9e2f61a38 | ||
|
|
38bbb497dd | ||
|
|
0cc7b4a93c | ||
|
|
cc32bfdfdf | ||
|
|
50de8ddc28 | ||
|
|
f770011d7f | ||
|
|
f2a6b57909 | ||
|
|
96b67ac70e |
6
Cargo.lock
generated
6
Cargo.lock
generated
@@ -5140,17 +5140,23 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"arrayvec",
|
||||
"clap",
|
||||
"collections",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"indoc",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
"log",
|
||||
"ordered-float 2.10.1",
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"slotmap",
|
||||
"strum 0.27.1",
|
||||
"text",
|
||||
"tree-sitter",
|
||||
"util",
|
||||
|
||||
@@ -15,17 +15,24 @@ path = "src/edit_prediction_context.rs"
|
||||
anyhow.workspace = true
|
||||
arrayvec.workspace = true
|
||||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
log.workspace = true
|
||||
ordered-float.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
serde.workspace = true
|
||||
slotmap.workspace = true
|
||||
strum.workspace = true
|
||||
text.workspace = true
|
||||
tree-sitter.workspace = true
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
clap.workspace = true
|
||||
futures.workspace = true
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
indoc.workspace = true
|
||||
|
||||
193
crates/edit_prediction_context/src/declaration.rs
Normal file
193
crates/edit_prediction_context/src/declaration.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
use language::LanguageId;
|
||||
use project::ProjectEntryId;
|
||||
use std::borrow::Cow;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
use text::{Bias, BufferId, Rope};
|
||||
|
||||
use crate::outline::OutlineDeclaration;
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
|
||||
pub struct Identifier {
|
||||
pub name: Arc<str>,
|
||||
pub language_id: LanguageId,
|
||||
}
|
||||
|
||||
slotmap::new_key_type! {
|
||||
pub struct DeclarationId;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Declaration {
|
||||
File {
|
||||
project_entry_id: ProjectEntryId,
|
||||
declaration: FileDeclaration,
|
||||
},
|
||||
Buffer {
|
||||
project_entry_id: ProjectEntryId,
|
||||
buffer_id: BufferId,
|
||||
rope: Rope,
|
||||
declaration: BufferDeclaration,
|
||||
},
|
||||
}
|
||||
|
||||
const ITEM_TEXT_TRUNCATION_LENGTH: usize = 1024;
|
||||
|
||||
impl Declaration {
|
||||
pub fn identifier(&self) -> &Identifier {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => &declaration.identifier,
|
||||
Declaration::Buffer { declaration, .. } => &declaration.identifier,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn project_entry_id(&self) -> Option<ProjectEntryId> {
|
||||
match self {
|
||||
Declaration::File {
|
||||
project_entry_id, ..
|
||||
} => Some(*project_entry_id),
|
||||
Declaration::Buffer {
|
||||
project_entry_id, ..
|
||||
} => Some(*project_entry_id),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn item_text(&self) -> (Cow<'_, str>, bool) {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => (
|
||||
declaration.text.as_ref().into(),
|
||||
declaration.text_is_truncated,
|
||||
),
|
||||
Declaration::Buffer {
|
||||
rope, declaration, ..
|
||||
} => (
|
||||
rope.chunks_in_range(declaration.item_range.clone())
|
||||
.collect::<Cow<str>>(),
|
||||
declaration.item_range_is_truncated,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn signature_text(&self) -> (Cow<'_, str>, bool) {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => (
|
||||
declaration.text[declaration.signature_range_in_text.clone()].into(),
|
||||
declaration.signature_is_truncated,
|
||||
),
|
||||
Declaration::Buffer {
|
||||
rope, declaration, ..
|
||||
} => (
|
||||
rope.chunks_in_range(declaration.signature_range.clone())
|
||||
.collect::<Cow<str>>(),
|
||||
declaration.signature_range_is_truncated,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn expand_range_to_line_boundaries_and_truncate(
|
||||
range: &Range<usize>,
|
||||
limit: usize,
|
||||
rope: &Rope,
|
||||
) -> (Range<usize>, bool) {
|
||||
let mut point_range = rope.offset_to_point(range.start)..rope.offset_to_point(range.end);
|
||||
point_range.start.column = 0;
|
||||
point_range.end.row += 1;
|
||||
point_range.end.column = 0;
|
||||
|
||||
let mut item_range =
|
||||
rope.point_to_offset(point_range.start)..rope.point_to_offset(point_range.end);
|
||||
let is_truncated = item_range.len() > limit;
|
||||
if is_truncated {
|
||||
item_range.end = item_range.start + limit;
|
||||
}
|
||||
item_range.end = rope.clip_offset(item_range.end, Bias::Left);
|
||||
(item_range, is_truncated)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileDeclaration {
|
||||
pub parent: Option<DeclarationId>,
|
||||
pub identifier: Identifier,
|
||||
/// offset range of the declaration in the file, expanded to line boundaries and truncated
|
||||
pub item_range_in_file: Range<usize>,
|
||||
/// text of `item_range_in_file`
|
||||
pub text: Arc<str>,
|
||||
/// whether `text` was truncated
|
||||
pub text_is_truncated: bool,
|
||||
/// offset range of the signature within `text`
|
||||
pub signature_range_in_text: Range<usize>,
|
||||
/// whether `signature` was truncated
|
||||
pub signature_is_truncated: bool,
|
||||
}
|
||||
|
||||
impl FileDeclaration {
|
||||
pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> FileDeclaration {
|
||||
let (item_range_in_file, text_is_truncated) = expand_range_to_line_boundaries_and_truncate(
|
||||
&declaration.item_range,
|
||||
ITEM_TEXT_TRUNCATION_LENGTH,
|
||||
rope,
|
||||
);
|
||||
|
||||
// TODO: consider logging if unexpected
|
||||
let signature_start = declaration
|
||||
.signature_range
|
||||
.start
|
||||
.saturating_sub(item_range_in_file.start);
|
||||
let mut signature_end = declaration
|
||||
.signature_range
|
||||
.end
|
||||
.saturating_sub(item_range_in_file.start);
|
||||
let signature_is_truncated = signature_end > item_range_in_file.len();
|
||||
if signature_is_truncated {
|
||||
signature_end = item_range_in_file.len();
|
||||
}
|
||||
|
||||
FileDeclaration {
|
||||
parent: None,
|
||||
identifier: declaration.identifier,
|
||||
signature_range_in_text: signature_start..signature_end,
|
||||
signature_is_truncated,
|
||||
text: rope
|
||||
.chunks_in_range(item_range_in_file.clone())
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
text_is_truncated,
|
||||
item_range_in_file,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BufferDeclaration {
|
||||
pub parent: Option<DeclarationId>,
|
||||
pub identifier: Identifier,
|
||||
pub item_range: Range<usize>,
|
||||
pub item_range_is_truncated: bool,
|
||||
pub signature_range: Range<usize>,
|
||||
pub signature_range_is_truncated: bool,
|
||||
}
|
||||
|
||||
impl BufferDeclaration {
|
||||
pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> Self {
|
||||
let (item_range, item_range_is_truncated) = expand_range_to_line_boundaries_and_truncate(
|
||||
&declaration.item_range,
|
||||
ITEM_TEXT_TRUNCATION_LENGTH,
|
||||
rope,
|
||||
);
|
||||
let (signature_range, signature_range_is_truncated) =
|
||||
expand_range_to_line_boundaries_and_truncate(
|
||||
&declaration.signature_range,
|
||||
ITEM_TEXT_TRUNCATION_LENGTH,
|
||||
rope,
|
||||
);
|
||||
Self {
|
||||
parent: None,
|
||||
identifier: declaration.identifier,
|
||||
item_range,
|
||||
item_range_is_truncated,
|
||||
signature_range,
|
||||
signature_range_is_truncated,
|
||||
}
|
||||
}
|
||||
}
|
||||
324
crates/edit_prediction_context/src/declaration_scoring.rs
Normal file
324
crates/edit_prediction_context/src/declaration_scoring.rs
Normal file
@@ -0,0 +1,324 @@
|
||||
use itertools::Itertools as _;
|
||||
use language::BufferSnapshot;
|
||||
use ordered_float::OrderedFloat;
|
||||
use serde::Serialize;
|
||||
use std::{collections::HashMap, ops::Range};
|
||||
use strum::EnumIter;
|
||||
use text::{OffsetRangeExt, Point, ToPoint};
|
||||
|
||||
use crate::{
|
||||
Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
|
||||
reference::{Reference, ReferenceRegion},
|
||||
syntax_index::SyntaxIndexState,
|
||||
text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
|
||||
};
|
||||
|
||||
const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
|
||||
|
||||
// TODO:
|
||||
//
|
||||
// * Consider adding declaration_file_count
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ScoredSnippet {
|
||||
pub identifier: Identifier,
|
||||
pub declaration: Declaration,
|
||||
pub score_components: ScoreInputs,
|
||||
pub scores: Scores,
|
||||
}
|
||||
|
||||
// TODO: Consider having "Concise" style corresponding to `concise_text`
|
||||
#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
pub enum SnippetStyle {
|
||||
Signature,
|
||||
Declaration,
|
||||
}
|
||||
|
||||
impl ScoredSnippet {
|
||||
/// Returns the score for this snippet with the specified style.
|
||||
pub fn score(&self, style: SnippetStyle) -> f32 {
|
||||
match style {
|
||||
SnippetStyle::Signature => self.scores.signature,
|
||||
SnippetStyle::Declaration => self.scores.declaration,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn size(&self, style: SnippetStyle) -> usize {
|
||||
// TODO: how to handle truncation?
|
||||
match &self.declaration {
|
||||
Declaration::File { declaration, .. } => match style {
|
||||
SnippetStyle::Signature => declaration.signature_range_in_text.len(),
|
||||
SnippetStyle::Declaration => declaration.text.len(),
|
||||
},
|
||||
Declaration::Buffer { declaration, .. } => match style {
|
||||
SnippetStyle::Signature => declaration.signature_range.len(),
|
||||
SnippetStyle::Declaration => declaration.item_range.len(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn score_density(&self, style: SnippetStyle) -> f32 {
|
||||
self.score(style) / (self.size(style)) as f32
|
||||
}
|
||||
}
|
||||
|
||||
pub fn scored_snippets(
|
||||
index: &SyntaxIndexState,
|
||||
excerpt: &EditPredictionExcerpt,
|
||||
excerpt_text: &EditPredictionExcerptText,
|
||||
identifier_to_references: HashMap<Identifier, Vec<Reference>>,
|
||||
cursor_offset: usize,
|
||||
current_buffer: &BufferSnapshot,
|
||||
) -> Vec<ScoredSnippet> {
|
||||
let containing_range_identifier_occurrences =
|
||||
IdentifierOccurrences::within_string(&excerpt_text.body);
|
||||
let cursor_point = cursor_offset.to_point(¤t_buffer);
|
||||
|
||||
let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
|
||||
let end_point = Point::new(cursor_point.row + 1, 0);
|
||||
let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
|
||||
¤t_buffer
|
||||
.text_for_range(start_point..end_point)
|
||||
.collect::<String>(),
|
||||
);
|
||||
|
||||
let mut snippets = identifier_to_references
|
||||
.into_iter()
|
||||
.flat_map(|(identifier, references)| {
|
||||
let declarations =
|
||||
index.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier);
|
||||
let declaration_count = declarations.len();
|
||||
|
||||
declarations
|
||||
.iter()
|
||||
.filter_map(|declaration| match declaration {
|
||||
Declaration::Buffer {
|
||||
buffer_id,
|
||||
declaration: buffer_declaration,
|
||||
..
|
||||
} => {
|
||||
let is_same_file = buffer_id == ¤t_buffer.remote_id();
|
||||
|
||||
if is_same_file {
|
||||
range_intersection(
|
||||
&buffer_declaration.item_range.to_offset(¤t_buffer),
|
||||
&excerpt.range,
|
||||
)
|
||||
.is_none()
|
||||
.then(|| {
|
||||
let declaration_line = buffer_declaration
|
||||
.item_range
|
||||
.start
|
||||
.to_point(current_buffer)
|
||||
.row;
|
||||
(
|
||||
true,
|
||||
(cursor_point.row as i32 - declaration_line as i32).abs()
|
||||
as u32,
|
||||
declaration,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
Some((false, 0, declaration))
|
||||
}
|
||||
}
|
||||
Declaration::File { .. } => {
|
||||
// We can assume that a file declaration is in a different file,
|
||||
// because the current one must be open
|
||||
Some((false, 0, declaration))
|
||||
}
|
||||
})
|
||||
.sorted_by_key(|&(_, distance, _)| distance)
|
||||
.enumerate()
|
||||
.map(
|
||||
|(
|
||||
declaration_line_distance_rank,
|
||||
(is_same_file, declaration_line_distance, declaration),
|
||||
)| {
|
||||
let same_file_declaration_count = index.file_declaration_count(declaration);
|
||||
|
||||
score_snippet(
|
||||
&identifier,
|
||||
&references,
|
||||
declaration.clone(),
|
||||
is_same_file,
|
||||
declaration_line_distance,
|
||||
declaration_line_distance_rank,
|
||||
same_file_declaration_count,
|
||||
declaration_count,
|
||||
&containing_range_identifier_occurrences,
|
||||
&adjacent_identifier_occurrences,
|
||||
cursor_point,
|
||||
current_buffer,
|
||||
)
|
||||
},
|
||||
)
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.flatten()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
snippets.sort_unstable_by_key(|snippet| {
|
||||
OrderedFloat(
|
||||
snippet
|
||||
.score_density(SnippetStyle::Declaration)
|
||||
.max(snippet.score_density(SnippetStyle::Signature)),
|
||||
)
|
||||
});
|
||||
|
||||
snippets
|
||||
}
|
||||
|
||||
fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
|
||||
let start = a.start.clone().max(b.start.clone());
|
||||
let end = a.end.clone().min(b.end.clone());
|
||||
if start < end {
|
||||
Some(Range { start, end })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn score_snippet(
|
||||
identifier: &Identifier,
|
||||
references: &[Reference],
|
||||
declaration: Declaration,
|
||||
is_same_file: bool,
|
||||
declaration_line_distance: u32,
|
||||
declaration_line_distance_rank: usize,
|
||||
same_file_declaration_count: usize,
|
||||
declaration_count: usize,
|
||||
containing_range_identifier_occurrences: &IdentifierOccurrences,
|
||||
adjacent_identifier_occurrences: &IdentifierOccurrences,
|
||||
cursor: Point,
|
||||
current_buffer: &BufferSnapshot,
|
||||
) -> Option<ScoredSnippet> {
|
||||
let is_referenced_nearby = references
|
||||
.iter()
|
||||
.any(|r| r.region == ReferenceRegion::Nearby);
|
||||
let is_referenced_in_breadcrumb = references
|
||||
.iter()
|
||||
.any(|r| r.region == ReferenceRegion::Breadcrumb);
|
||||
let reference_count = references.len();
|
||||
let reference_line_distance = references
|
||||
.iter()
|
||||
.map(|r| {
|
||||
let reference_line = r.range.start.to_point(current_buffer).row as i32;
|
||||
(cursor.row as i32 - reference_line).abs() as u32
|
||||
})
|
||||
.min()
|
||||
.unwrap();
|
||||
|
||||
let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0);
|
||||
let item_signature_occurrences =
|
||||
IdentifierOccurrences::within_string(&declaration.signature_text().0);
|
||||
let containing_range_vs_item_jaccard = jaccard_similarity(
|
||||
containing_range_identifier_occurrences,
|
||||
&item_source_occurrences,
|
||||
);
|
||||
let containing_range_vs_signature_jaccard = jaccard_similarity(
|
||||
containing_range_identifier_occurrences,
|
||||
&item_signature_occurrences,
|
||||
);
|
||||
let adjacent_vs_item_jaccard =
|
||||
jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
|
||||
let adjacent_vs_signature_jaccard =
|
||||
jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
|
||||
|
||||
let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
|
||||
containing_range_identifier_occurrences,
|
||||
&item_source_occurrences,
|
||||
);
|
||||
let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
|
||||
containing_range_identifier_occurrences,
|
||||
&item_signature_occurrences,
|
||||
);
|
||||
let adjacent_vs_item_weighted_overlap =
|
||||
weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
|
||||
let adjacent_vs_signature_weighted_overlap =
|
||||
weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
|
||||
|
||||
let score_components = ScoreInputs {
|
||||
is_same_file,
|
||||
is_referenced_nearby,
|
||||
is_referenced_in_breadcrumb,
|
||||
reference_line_distance,
|
||||
declaration_line_distance,
|
||||
declaration_line_distance_rank,
|
||||
reference_count,
|
||||
same_file_declaration_count,
|
||||
declaration_count,
|
||||
containing_range_vs_item_jaccard,
|
||||
containing_range_vs_signature_jaccard,
|
||||
adjacent_vs_item_jaccard,
|
||||
adjacent_vs_signature_jaccard,
|
||||
containing_range_vs_item_weighted_overlap,
|
||||
containing_range_vs_signature_weighted_overlap,
|
||||
adjacent_vs_item_weighted_overlap,
|
||||
adjacent_vs_signature_weighted_overlap,
|
||||
};
|
||||
|
||||
Some(ScoredSnippet {
|
||||
identifier: identifier.clone(),
|
||||
declaration: declaration,
|
||||
scores: score_components.score(),
|
||||
score_components,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct ScoreInputs {
|
||||
pub is_same_file: bool,
|
||||
pub is_referenced_nearby: bool,
|
||||
pub is_referenced_in_breadcrumb: bool,
|
||||
pub reference_count: usize,
|
||||
pub same_file_declaration_count: usize,
|
||||
pub declaration_count: usize,
|
||||
pub reference_line_distance: u32,
|
||||
pub declaration_line_distance: u32,
|
||||
pub declaration_line_distance_rank: usize,
|
||||
pub containing_range_vs_item_jaccard: f32,
|
||||
pub containing_range_vs_signature_jaccard: f32,
|
||||
pub adjacent_vs_item_jaccard: f32,
|
||||
pub adjacent_vs_signature_jaccard: f32,
|
||||
pub containing_range_vs_item_weighted_overlap: f32,
|
||||
pub containing_range_vs_signature_weighted_overlap: f32,
|
||||
pub adjacent_vs_item_weighted_overlap: f32,
|
||||
pub adjacent_vs_signature_weighted_overlap: f32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct Scores {
|
||||
pub signature: f32,
|
||||
pub declaration: f32,
|
||||
}
|
||||
|
||||
impl ScoreInputs {
|
||||
fn score(&self) -> Scores {
|
||||
// Score related to how likely this is the correct declaration, range 0 to 1
|
||||
let accuracy_score = if self.is_same_file {
|
||||
// TODO: use declaration_line_distance_rank
|
||||
1.0 / self.same_file_declaration_count as f32
|
||||
} else {
|
||||
1.0 / self.declaration_count as f32
|
||||
};
|
||||
|
||||
// Score related to the distance between the reference and cursor, range 0 to 1
|
||||
let distance_score = if self.is_referenced_nearby {
|
||||
1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0)
|
||||
} else {
|
||||
// same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
|
||||
0.5
|
||||
};
|
||||
|
||||
// For now instead of linear combination, the scores are just multiplied together.
|
||||
let combined_score = 10.0 * accuracy_score * distance_score;
|
||||
|
||||
Scores {
|
||||
signature: combined_score * self.containing_range_vs_signature_weighted_overlap,
|
||||
// declaration score gets boosted both by being multipled by 2 and by there being more
|
||||
// weighted overlap.
|
||||
declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,220 @@
|
||||
mod declaration;
|
||||
mod declaration_scoring;
|
||||
mod excerpt;
|
||||
mod outline;
|
||||
mod reference;
|
||||
mod tree_sitter_index;
|
||||
mod syntax_index;
|
||||
mod text_similarity;
|
||||
|
||||
pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
|
||||
pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
|
||||
use gpui::{App, AppContext as _, Entity, Task};
|
||||
use language::BufferSnapshot;
|
||||
pub use reference::references_in_excerpt;
|
||||
pub use tree_sitter_index::{BufferDeclaration, Declaration, FileDeclaration, TreeSitterIndex};
|
||||
pub use syntax_index::SyntaxIndex;
|
||||
use text::{Point, ToOffset as _};
|
||||
|
||||
use crate::declaration_scoring::{ScoredSnippet, scored_snippets};
|
||||
|
||||
pub struct EditPredictionContext {
|
||||
pub excerpt: EditPredictionExcerpt,
|
||||
pub excerpt_text: EditPredictionExcerptText,
|
||||
pub snippets: Vec<ScoredSnippet>,
|
||||
}
|
||||
|
||||
impl EditPredictionContext {
|
||||
pub fn gather(
|
||||
cursor_point: Point,
|
||||
buffer: BufferSnapshot,
|
||||
excerpt_options: EditPredictionExcerptOptions,
|
||||
syntax_index: Entity<SyntaxIndex>,
|
||||
cx: &mut App,
|
||||
) -> Task<Self> {
|
||||
let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
|
||||
cx.background_spawn(async move {
|
||||
let index_state = index_state.lock().await;
|
||||
|
||||
let excerpt =
|
||||
EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options)
|
||||
.unwrap();
|
||||
let excerpt_text = excerpt.text(&buffer);
|
||||
let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
|
||||
let cursor_offset = cursor_point.to_offset(&buffer);
|
||||
|
||||
let snippets = scored_snippets(
|
||||
&index_state,
|
||||
&excerpt,
|
||||
&excerpt_text,
|
||||
references,
|
||||
cursor_offset,
|
||||
&buffer,
|
||||
);
|
||||
|
||||
Self {
|
||||
excerpt,
|
||||
excerpt_text,
|
||||
snippets,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
use gpui::{Entity, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
|
||||
use crate::{EditPredictionExcerptOptions, SyntaxIndex};
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_call_site(cx: &mut TestAppContext) {
|
||||
let (project, index, _rust_lang_id) = init_test(cx).await;
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = project.find_project_path("c.rs", cx).unwrap();
|
||||
project.open_buffer(project_path, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
// first process_data call site
|
||||
let cursor_point = language::Point::new(8, 21);
|
||||
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
|
||||
|
||||
let context = cx
|
||||
.update(|cx| {
|
||||
EditPredictionContext::gather(
|
||||
cursor_point,
|
||||
buffer_snapshot,
|
||||
EditPredictionExcerptOptions {
|
||||
max_bytes: 40,
|
||||
min_bytes: 10,
|
||||
target_before_cursor_over_total_bytes: 0.5,
|
||||
include_parent_signatures: false,
|
||||
},
|
||||
index,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(context.snippets.len(), 1);
|
||||
assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data");
|
||||
drop(buffer);
|
||||
}
|
||||
|
||||
async fn init_test(
|
||||
cx: &mut TestAppContext,
|
||||
) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
Project::init_settings(cx);
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"a.rs": indoc! {r#"
|
||||
fn main() {
|
||||
let x = 1;
|
||||
let y = 2;
|
||||
let z = add(x, y);
|
||||
println!("Result: {}", z);
|
||||
}
|
||||
|
||||
fn add(a: i32, b: i32) -> i32 {
|
||||
a + b
|
||||
}
|
||||
"#},
|
||||
"b.rs": indoc! {"
|
||||
pub struct Config {
|
||||
pub name: String,
|
||||
pub value: i32,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn new(name: String, value: i32) -> Self {
|
||||
Config { name, value }
|
||||
}
|
||||
}
|
||||
"},
|
||||
"c.rs": indoc! {r#"
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let data: Vec<i32> = args[1..]
|
||||
.iter()
|
||||
.filter_map(|s| s.parse().ok())
|
||||
.collect();
|
||||
let result = process_data(data);
|
||||
println!("{:?}", result);
|
||||
}
|
||||
|
||||
fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
|
||||
let mut counts = HashMap::new();
|
||||
for value in data {
|
||||
*counts.entry(value).or_insert(0) += 1;
|
||||
}
|
||||
counts
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_process_data() {
|
||||
let data = vec![1, 2, 2, 3];
|
||||
let result = process_data(data);
|
||||
assert_eq!(result.get(&2), Some(&2));
|
||||
}
|
||||
}
|
||||
"#}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
|
||||
let lang = rust_lang();
|
||||
let lang_id = lang.id();
|
||||
language_registry.add(Arc::new(lang));
|
||||
|
||||
let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
|
||||
cx.run_until_parked();
|
||||
|
||||
(project, index, lang_id)
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
.with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
|
||||
.unwrap()
|
||||
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ pub struct EditPredictionExcerptOptions {
|
||||
pub include_parent_signatures: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EditPredictionExcerpt {
|
||||
pub range: Range<usize>,
|
||||
pub parent_signature_ranges: Vec<Range<usize>>,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use language::{BufferSnapshot, LanguageId, SyntaxMapMatches};
|
||||
use std::{cmp::Reverse, ops::Range, sync::Arc};
|
||||
use language::{BufferSnapshot, SyntaxMapMatches};
|
||||
use std::{cmp::Reverse, ops::Range};
|
||||
|
||||
use crate::declaration::Identifier;
|
||||
|
||||
// TODO:
|
||||
//
|
||||
@@ -18,12 +20,6 @@ pub struct OutlineDeclaration {
|
||||
pub signature_range: Range<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
|
||||
pub struct Identifier {
|
||||
pub name: Arc<str>,
|
||||
pub language_id: LanguageId,
|
||||
}
|
||||
|
||||
pub fn declarations_in_buffer(buffer: &BufferSnapshot) -> Vec<OutlineDeclaration> {
|
||||
declarations_overlapping_range(0..buffer.len(), buffer)
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@ use std::collections::HashMap;
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::{
|
||||
declaration::Identifier,
|
||||
excerpt::{EditPredictionExcerpt, EditPredictionExcerptText},
|
||||
outline::Identifier,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@@ -1,20 +1,26 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::lock::Mutex;
|
||||
use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity};
|
||||
use language::{Buffer, BufferEvent, BufferSnapshot};
|
||||
use language::{Buffer, BufferEvent};
|
||||
use project::buffer_store::{BufferStore, BufferStoreEvent};
|
||||
use project::worktree_store::{WorktreeStore, WorktreeStoreEvent};
|
||||
use project::{PathChange, Project, ProjectEntryId, ProjectPath};
|
||||
use slotmap::SlotMap;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
use text::Anchor;
|
||||
use text::BufferId;
|
||||
use util::{ResultExt as _, debug_panic, some_or_debug_panic};
|
||||
|
||||
use crate::outline::{Identifier, OutlineDeclaration, declarations_in_buffer};
|
||||
use crate::declaration::{
|
||||
BufferDeclaration, Declaration, DeclarationId, FileDeclaration, Identifier,
|
||||
};
|
||||
use crate::outline::declarations_in_buffer;
|
||||
|
||||
// TODO:
|
||||
//
|
||||
// * Skip for remote projects
|
||||
//
|
||||
// * Consider making SyntaxIndex not an Entity.
|
||||
|
||||
// Potential future improvements:
|
||||
//
|
||||
@@ -34,17 +40,19 @@ use crate::outline::{Identifier, OutlineDeclaration, declarations_in_buffer};
|
||||
// * Concurrent slotmap
|
||||
//
|
||||
// * Use queue for parsing
|
||||
//
|
||||
|
||||
slotmap::new_key_type! {
|
||||
pub struct DeclarationId;
|
||||
pub struct SyntaxIndex {
|
||||
state: Arc<Mutex<SyntaxIndexState>>,
|
||||
project: WeakEntity<Project>,
|
||||
}
|
||||
|
||||
pub struct TreeSitterIndex {
|
||||
#[derive(Default)]
|
||||
pub struct SyntaxIndexState {
|
||||
declarations: SlotMap<DeclarationId, Declaration>,
|
||||
identifiers: HashMap<Identifier, HashSet<DeclarationId>>,
|
||||
files: HashMap<ProjectEntryId, FileState>,
|
||||
buffers: HashMap<WeakEntity<Buffer>, BufferState>,
|
||||
project: WeakEntity<Project>,
|
||||
buffers: HashMap<BufferId, BufferState>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
@@ -59,52 +67,11 @@ struct BufferState {
|
||||
task: Option<Task<()>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Declaration {
|
||||
File {
|
||||
project_entry_id: ProjectEntryId,
|
||||
declaration: FileDeclaration,
|
||||
},
|
||||
Buffer {
|
||||
buffer: WeakEntity<Buffer>,
|
||||
declaration: BufferDeclaration,
|
||||
},
|
||||
}
|
||||
|
||||
impl Declaration {
|
||||
fn identifier(&self) -> &Identifier {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => &declaration.identifier,
|
||||
Declaration::Buffer { declaration, .. } => &declaration.identifier,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileDeclaration {
|
||||
pub parent: Option<DeclarationId>,
|
||||
pub identifier: Identifier,
|
||||
pub item_range: Range<usize>,
|
||||
pub signature_range: Range<usize>,
|
||||
pub signature_text: Arc<str>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BufferDeclaration {
|
||||
pub parent: Option<DeclarationId>,
|
||||
pub identifier: Identifier,
|
||||
pub item_range: Range<Anchor>,
|
||||
pub signature_range: Range<Anchor>,
|
||||
}
|
||||
|
||||
impl TreeSitterIndex {
|
||||
impl SyntaxIndex {
|
||||
pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
|
||||
let mut this = Self {
|
||||
declarations: SlotMap::with_key(),
|
||||
identifiers: HashMap::default(),
|
||||
project: project.downgrade(),
|
||||
files: HashMap::default(),
|
||||
buffers: HashMap::default(),
|
||||
state: Arc::new(Mutex::new(SyntaxIndexState::default())),
|
||||
};
|
||||
|
||||
let worktree_store = project.read(cx).worktree_store();
|
||||
@@ -139,73 +106,6 @@ impl TreeSitterIndex {
|
||||
this
|
||||
}
|
||||
|
||||
pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> {
|
||||
self.declarations.get(id)
|
||||
}
|
||||
|
||||
pub fn declarations_for_identifier<const N: usize>(
|
||||
&self,
|
||||
identifier: Identifier,
|
||||
cx: &App,
|
||||
) -> Vec<Declaration> {
|
||||
// make sure to not have a large stack allocation
|
||||
assert!(N < 32);
|
||||
|
||||
let Some(declaration_ids) = self.identifiers.get(&identifier) else {
|
||||
return vec![];
|
||||
};
|
||||
|
||||
let mut result = Vec::with_capacity(N);
|
||||
let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new();
|
||||
let mut file_declarations = Vec::new();
|
||||
|
||||
for declaration_id in declaration_ids {
|
||||
let declaration = self.declarations.get(*declaration_id);
|
||||
let Some(declaration) = some_or_debug_panic(declaration) else {
|
||||
continue;
|
||||
};
|
||||
match declaration {
|
||||
Declaration::Buffer { buffer, .. } => {
|
||||
if let Ok(Some(entry_id)) = buffer.read_with(cx, |buffer, cx| {
|
||||
project::File::from_dyn(buffer.file()).and_then(|f| f.project_entry_id(cx))
|
||||
}) {
|
||||
included_buffer_entry_ids.push(entry_id);
|
||||
result.push(declaration.clone());
|
||||
if result.len() == N {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
Declaration::File {
|
||||
project_entry_id, ..
|
||||
} => {
|
||||
if !included_buffer_entry_ids.contains(project_entry_id) {
|
||||
file_declarations.push(declaration.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for declaration in file_declarations {
|
||||
match declaration {
|
||||
Declaration::File {
|
||||
project_entry_id, ..
|
||||
} => {
|
||||
if !included_buffer_entry_ids.contains(&project_entry_id) {
|
||||
result.push(declaration);
|
||||
|
||||
if result.len() == N {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
Declaration::Buffer { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn handle_worktree_store_event(
|
||||
&mut self,
|
||||
_worktree_store: Entity<WorktreeStore>,
|
||||
@@ -215,21 +115,33 @@ impl TreeSitterIndex {
|
||||
use WorktreeStoreEvent::*;
|
||||
match event {
|
||||
WorktreeUpdatedEntries(worktree_id, updated_entries_set) => {
|
||||
for (path, entry_id, path_change) in updated_entries_set.iter() {
|
||||
if let PathChange::Removed = path_change {
|
||||
self.files.remove(entry_id);
|
||||
} else {
|
||||
let project_path = ProjectPath {
|
||||
worktree_id: *worktree_id,
|
||||
path: path.clone(),
|
||||
};
|
||||
self.update_file(*entry_id, project_path, cx);
|
||||
let state = Arc::downgrade(&self.state);
|
||||
let worktree_id = *worktree_id;
|
||||
let updated_entries_set = updated_entries_set.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
let Some(state) = state.upgrade() else { return };
|
||||
for (path, entry_id, path_change) in updated_entries_set.iter() {
|
||||
if let PathChange::Removed = path_change {
|
||||
state.lock().await.files.remove(entry_id);
|
||||
} else {
|
||||
let project_path = ProjectPath {
|
||||
worktree_id,
|
||||
path: path.clone(),
|
||||
};
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_file(*entry_id, project_path, cx);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
WorktreeDeletedEntry(_worktree_id, project_entry_id) => {
|
||||
// TODO: Is this needed?
|
||||
self.files.remove(project_entry_id);
|
||||
let project_entry_id = *project_entry_id;
|
||||
self.with_state(cx, move |state| {
|
||||
state.files.remove(&project_entry_id);
|
||||
})
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
@@ -251,15 +163,42 @@ impl TreeSitterIndex {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn state(&self) -> &Arc<Mutex<SyntaxIndexState>> {
|
||||
&self.state
|
||||
}
|
||||
|
||||
fn with_state(&self, cx: &mut App, f: impl FnOnce(&mut SyntaxIndexState) + Send + 'static) {
|
||||
if let Some(mut state) = self.state.try_lock() {
|
||||
f(&mut state);
|
||||
return;
|
||||
}
|
||||
let state = Arc::downgrade(&self.state);
|
||||
cx.background_spawn(async move {
|
||||
let Some(state) = state.upgrade() else {
|
||||
return None;
|
||||
};
|
||||
let mut state = state.lock().await;
|
||||
Some(f(&mut state))
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.buffers
|
||||
.insert(buffer.downgrade(), BufferState::default());
|
||||
let weak_buf = buffer.downgrade();
|
||||
cx.observe_release(buffer, move |this, _buffer, _cx| {
|
||||
this.buffers.remove(&weak_buf);
|
||||
let buffer_id = buffer.read(cx).remote_id();
|
||||
cx.observe_release(buffer, move |this, _buffer, cx| {
|
||||
this.with_state(cx, move |state| {
|
||||
if let Some(buffer_state) = state.buffers.remove(&buffer_id) {
|
||||
SyntaxIndexState::remove_buffer_declarations(
|
||||
&buffer_state.declarations,
|
||||
&mut state.declarations,
|
||||
&mut state.identifiers,
|
||||
);
|
||||
}
|
||||
})
|
||||
})
|
||||
.detach();
|
||||
cx.subscribe(buffer, Self::handle_buffer_event).detach();
|
||||
|
||||
self.update_buffer(buffer.clone(), cx);
|
||||
}
|
||||
|
||||
@@ -275,10 +214,19 @@ impl TreeSitterIndex {
|
||||
}
|
||||
}
|
||||
|
||||
fn update_buffer(&mut self, buffer: Entity<Buffer>, cx: &Context<Self>) {
|
||||
let mut parse_status = buffer.read(cx).parse_status();
|
||||
fn update_buffer(&mut self, buffer_entity: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
let buffer = buffer_entity.read(cx);
|
||||
|
||||
let Some(project_entry_id) =
|
||||
project::File::from_dyn(buffer.file()).and_then(|f| f.project_entry_id(cx))
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let buffer_id = buffer.remote_id();
|
||||
|
||||
let mut parse_status = buffer.parse_status();
|
||||
let snapshot_task = cx.spawn({
|
||||
let weak_buffer = buffer.downgrade();
|
||||
let weak_buffer = buffer_entity.downgrade();
|
||||
async move |_, cx| {
|
||||
while *parse_status.borrow() != language::ParseStatus::Idle {
|
||||
parse_status.changed().await?;
|
||||
@@ -289,75 +237,77 @@ impl TreeSitterIndex {
|
||||
|
||||
let parse_task = cx.background_spawn(async move {
|
||||
let snapshot = snapshot_task.await?;
|
||||
let rope = snapshot.text.as_rope().clone();
|
||||
|
||||
anyhow::Ok(
|
||||
anyhow::Ok((
|
||||
declarations_in_buffer(&snapshot)
|
||||
.into_iter()
|
||||
.map(|item| {
|
||||
(
|
||||
item.parent_index,
|
||||
BufferDeclaration::from_outline(item, &snapshot),
|
||||
BufferDeclaration::from_outline(item, &rope),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
rope,
|
||||
))
|
||||
});
|
||||
|
||||
let task = cx.spawn({
|
||||
let weak_buffer = buffer.downgrade();
|
||||
async move |this, cx| {
|
||||
let Ok(declarations) = parse_task.await else {
|
||||
let Ok((declarations, rope)) = parse_task.await else {
|
||||
return;
|
||||
};
|
||||
|
||||
this.update(cx, |this, _cx| {
|
||||
let buffer_state = this
|
||||
.buffers
|
||||
.entry(weak_buffer.clone())
|
||||
.or_insert_with(Default::default);
|
||||
this.update(cx, move |this, cx| {
|
||||
this.with_state(cx, move |state| {
|
||||
let buffer_state = state
|
||||
.buffers
|
||||
.entry(buffer_id)
|
||||
.or_insert_with(Default::default);
|
||||
|
||||
for old_declaration_id in &buffer_state.declarations {
|
||||
let Some(declaration) = this.declarations.remove(*old_declaration_id)
|
||||
else {
|
||||
debug_panic!("declaration not found");
|
||||
continue;
|
||||
};
|
||||
if let Some(identifier_declarations) =
|
||||
this.identifiers.get_mut(declaration.identifier())
|
||||
{
|
||||
identifier_declarations.remove(old_declaration_id);
|
||||
SyntaxIndexState::remove_buffer_declarations(
|
||||
&buffer_state.declarations,
|
||||
&mut state.declarations,
|
||||
&mut state.identifiers,
|
||||
);
|
||||
|
||||
let mut new_ids = Vec::with_capacity(declarations.len());
|
||||
state.declarations.reserve(declarations.len());
|
||||
for (parent_index, mut declaration) in declarations {
|
||||
declaration.parent = parent_index
|
||||
.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
|
||||
|
||||
let identifier = declaration.identifier.clone();
|
||||
let declaration_id = state.declarations.insert(Declaration::Buffer {
|
||||
rope: rope.clone(),
|
||||
buffer_id,
|
||||
declaration,
|
||||
project_entry_id,
|
||||
});
|
||||
new_ids.push(declaration_id);
|
||||
|
||||
state
|
||||
.identifiers
|
||||
.entry(identifier)
|
||||
.or_default()
|
||||
.insert(declaration_id);
|
||||
}
|
||||
}
|
||||
|
||||
let mut new_ids = Vec::with_capacity(declarations.len());
|
||||
this.declarations.reserve(declarations.len());
|
||||
for (parent_index, mut declaration) in declarations {
|
||||
declaration.parent = parent_index
|
||||
.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
|
||||
|
||||
let identifier = declaration.identifier.clone();
|
||||
let declaration_id = this.declarations.insert(Declaration::Buffer {
|
||||
buffer: weak_buffer.clone(),
|
||||
declaration,
|
||||
});
|
||||
new_ids.push(declaration_id);
|
||||
|
||||
this.identifiers
|
||||
.entry(identifier)
|
||||
.or_default()
|
||||
.insert(declaration_id);
|
||||
}
|
||||
|
||||
buffer_state.declarations = new_ids;
|
||||
buffer_state.declarations = new_ids;
|
||||
});
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
});
|
||||
|
||||
self.buffers
|
||||
.entry(buffer.downgrade())
|
||||
.or_insert_with(Default::default)
|
||||
.task = Some(task);
|
||||
self.with_state(cx, move |state| {
|
||||
state
|
||||
.buffers
|
||||
.entry(buffer_id)
|
||||
.or_insert_with(Default::default)
|
||||
.task = Some(task)
|
||||
});
|
||||
}
|
||||
|
||||
fn update_file(
|
||||
@@ -401,14 +351,10 @@ impl TreeSitterIndex {
|
||||
|
||||
let parse_task = cx.background_spawn(async move {
|
||||
let snapshot = snapshot_task.await?;
|
||||
let rope = snapshot.as_rope();
|
||||
let declarations = declarations_in_buffer(&snapshot)
|
||||
.into_iter()
|
||||
.map(|item| {
|
||||
(
|
||||
item.parent_index,
|
||||
FileDeclaration::from_outline(item, &snapshot),
|
||||
)
|
||||
})
|
||||
.map(|item| (item.parent_index, FileDeclaration::from_outline(item, rope)))
|
||||
.collect::<Vec<_>>();
|
||||
anyhow::Ok(declarations)
|
||||
});
|
||||
@@ -419,84 +365,160 @@ impl TreeSitterIndex {
|
||||
let Ok(declarations) = parse_task.await else {
|
||||
return;
|
||||
};
|
||||
this.update(cx, |this, _cx| {
|
||||
let file_state = this.files.entry(entry_id).or_insert_with(Default::default);
|
||||
this.update(cx, |this, cx| {
|
||||
this.with_state(cx, move |state| {
|
||||
let file_state =
|
||||
state.files.entry(entry_id).or_insert_with(Default::default);
|
||||
|
||||
for old_declaration_id in &file_state.declarations {
|
||||
let Some(declaration) = this.declarations.remove(*old_declaration_id)
|
||||
else {
|
||||
debug_panic!("declaration not found");
|
||||
continue;
|
||||
};
|
||||
if let Some(identifier_declarations) =
|
||||
this.identifiers.get_mut(declaration.identifier())
|
||||
{
|
||||
identifier_declarations.remove(old_declaration_id);
|
||||
for old_declaration_id in &file_state.declarations {
|
||||
let Some(declaration) = state.declarations.remove(*old_declaration_id)
|
||||
else {
|
||||
debug_panic!("declaration not found");
|
||||
continue;
|
||||
};
|
||||
if let Some(identifier_declarations) =
|
||||
state.identifiers.get_mut(declaration.identifier())
|
||||
{
|
||||
identifier_declarations.remove(old_declaration_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut new_ids = Vec::with_capacity(declarations.len());
|
||||
this.declarations.reserve(declarations.len());
|
||||
let mut new_ids = Vec::with_capacity(declarations.len());
|
||||
state.declarations.reserve(declarations.len());
|
||||
|
||||
for (parent_index, mut declaration) in declarations {
|
||||
declaration.parent = parent_index
|
||||
.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
|
||||
for (parent_index, mut declaration) in declarations {
|
||||
declaration.parent = parent_index
|
||||
.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
|
||||
|
||||
let identifier = declaration.identifier.clone();
|
||||
let declaration_id = this.declarations.insert(Declaration::File {
|
||||
project_entry_id: entry_id,
|
||||
declaration,
|
||||
});
|
||||
new_ids.push(declaration_id);
|
||||
let identifier = declaration.identifier.clone();
|
||||
let declaration_id = state.declarations.insert(Declaration::File {
|
||||
project_entry_id: entry_id,
|
||||
declaration,
|
||||
});
|
||||
new_ids.push(declaration_id);
|
||||
|
||||
this.identifiers
|
||||
.entry(identifier)
|
||||
.or_default()
|
||||
.insert(declaration_id);
|
||||
}
|
||||
state
|
||||
.identifiers
|
||||
.entry(identifier)
|
||||
.or_default()
|
||||
.insert(declaration_id);
|
||||
}
|
||||
|
||||
file_state.declarations = new_ids;
|
||||
file_state.declarations = new_ids;
|
||||
});
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
});
|
||||
|
||||
self.files
|
||||
.entry(entry_id)
|
||||
.or_insert_with(Default::default)
|
||||
.task = Some(task);
|
||||
self.with_state(cx, move |state| {
|
||||
state
|
||||
.files
|
||||
.entry(entry_id)
|
||||
.or_insert_with(Default::default)
|
||||
.task = Some(task);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl BufferDeclaration {
|
||||
pub fn from_outline(declaration: OutlineDeclaration, snapshot: &BufferSnapshot) -> Self {
|
||||
// use of anchor_before is a guess that the proper behavior is to expand to include
|
||||
// insertions immediately before the declaration, but not for insertions immediately after
|
||||
Self {
|
||||
parent: None,
|
||||
identifier: declaration.identifier,
|
||||
item_range: snapshot.anchor_before(declaration.item_range.start)
|
||||
..snapshot.anchor_before(declaration.item_range.end),
|
||||
signature_range: snapshot.anchor_before(declaration.signature_range.start)
|
||||
..snapshot.anchor_before(declaration.signature_range.end),
|
||||
impl SyntaxIndexState {
|
||||
pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> {
|
||||
self.declarations.get(id)
|
||||
}
|
||||
|
||||
/// Returns declarations for the identifier. If the limit is exceeded, returns an empty vector.
|
||||
///
|
||||
/// TODO: Consider doing some pre-ranking and instead truncating when N is exceeded.
|
||||
pub fn declarations_for_identifier<const N: usize>(
|
||||
&self,
|
||||
identifier: &Identifier,
|
||||
) -> Vec<Declaration> {
|
||||
// make sure to not have a large stack allocation
|
||||
assert!(N < 32);
|
||||
|
||||
let Some(declaration_ids) = self.identifiers.get(&identifier) else {
|
||||
return vec![];
|
||||
};
|
||||
|
||||
let mut result = Vec::with_capacity(N);
|
||||
let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new();
|
||||
let mut file_declarations = Vec::new();
|
||||
|
||||
for declaration_id in declaration_ids {
|
||||
let declaration = self.declarations.get(*declaration_id);
|
||||
let Some(declaration) = some_or_debug_panic(declaration) else {
|
||||
continue;
|
||||
};
|
||||
match declaration {
|
||||
Declaration::Buffer {
|
||||
project_entry_id, ..
|
||||
} => {
|
||||
included_buffer_entry_ids.push(*project_entry_id);
|
||||
result.push(declaration.clone());
|
||||
if result.len() == N {
|
||||
return Vec::new();
|
||||
}
|
||||
}
|
||||
Declaration::File {
|
||||
project_entry_id, ..
|
||||
} => {
|
||||
if !included_buffer_entry_ids.contains(&project_entry_id) {
|
||||
file_declarations.push(declaration.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for declaration in file_declarations {
|
||||
match declaration {
|
||||
Declaration::File {
|
||||
project_entry_id, ..
|
||||
} => {
|
||||
if !included_buffer_entry_ids.contains(&project_entry_id) {
|
||||
result.push(declaration);
|
||||
|
||||
if result.len() == N {
|
||||
return Vec::new();
|
||||
}
|
||||
}
|
||||
}
|
||||
Declaration::Buffer { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn file_declaration_count(&self, declaration: &Declaration) -> usize {
|
||||
match declaration {
|
||||
Declaration::File {
|
||||
project_entry_id, ..
|
||||
} => self
|
||||
.files
|
||||
.get(project_entry_id)
|
||||
.map(|file_state| file_state.declarations.len())
|
||||
.unwrap_or_default(),
|
||||
Declaration::Buffer { buffer_id, .. } => self
|
||||
.buffers
|
||||
.get(buffer_id)
|
||||
.map(|buffer_state| buffer_state.declarations.len())
|
||||
.unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FileDeclaration {
|
||||
pub fn from_outline(
|
||||
declaration: OutlineDeclaration,
|
||||
snapshot: &BufferSnapshot,
|
||||
) -> FileDeclaration {
|
||||
FileDeclaration {
|
||||
parent: None,
|
||||
identifier: declaration.identifier,
|
||||
item_range: declaration.item_range,
|
||||
signature_text: snapshot
|
||||
.text_for_range(declaration.signature_range.clone())
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
signature_range: declaration.signature_range,
|
||||
fn remove_buffer_declarations(
|
||||
old_declaration_ids: &[DeclarationId],
|
||||
declarations: &mut SlotMap<DeclarationId, Declaration>,
|
||||
identifiers: &mut HashMap<Identifier, HashSet<DeclarationId>>,
|
||||
) {
|
||||
for old_declaration_id in old_declaration_ids {
|
||||
let Some(declaration) = declarations.remove(*old_declaration_id) else {
|
||||
debug_panic!("declaration not found");
|
||||
continue;
|
||||
};
|
||||
if let Some(identifier_declarations) = identifiers.get_mut(declaration.identifier()) {
|
||||
identifier_declarations.remove(old_declaration_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -506,17 +528,16 @@ mod tests {
|
||||
use super::*;
|
||||
use std::{path::Path, sync::Arc};
|
||||
|
||||
use futures::channel::oneshot;
|
||||
use gpui::TestAppContext;
|
||||
use indoc::indoc;
|
||||
use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
|
||||
use project::{FakeFs, Project, ProjectItem};
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use text::OffsetRangeExt as _;
|
||||
use util::path;
|
||||
|
||||
use crate::tree_sitter_index::TreeSitterIndex;
|
||||
use crate::syntax_index::SyntaxIndex;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_unopen_indexed_files(cx: &mut TestAppContext) {
|
||||
@@ -526,17 +547,19 @@ mod tests {
|
||||
language_id: rust_lang_id,
|
||||
};
|
||||
|
||||
index.read_with(cx, |index, cx| {
|
||||
let decls = index.declarations_for_identifier::<8>(main.clone(), cx);
|
||||
let index_state = index.read_with(cx, |index, _cx| index.state().clone());
|
||||
let index_state = index_state.lock().await;
|
||||
cx.update(|cx| {
|
||||
let decls = index_state.declarations_for_identifier::<8>(&main);
|
||||
assert_eq!(decls.len(), 2);
|
||||
|
||||
let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
|
||||
assert_eq!(decl.identifier, main.clone());
|
||||
assert_eq!(decl.item_range, 32..279);
|
||||
assert_eq!(decl.item_range_in_file, 32..280);
|
||||
|
||||
let decl = expect_file_decl("a.rs", &decls[1], &project, cx);
|
||||
assert_eq!(decl.identifier, main);
|
||||
assert_eq!(decl.item_range, 0..97);
|
||||
assert_eq!(decl.item_range_in_file, 0..98);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -548,15 +571,17 @@ mod tests {
|
||||
language_id: rust_lang_id,
|
||||
};
|
||||
|
||||
index.read_with(cx, |index, cx| {
|
||||
let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx);
|
||||
let index_state = index.read_with(cx, |index, _cx| index.state().clone());
|
||||
let index_state = index_state.lock().await;
|
||||
cx.update(|cx| {
|
||||
let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
|
||||
assert_eq!(decls.len(), 1);
|
||||
|
||||
let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
|
||||
assert_eq!(decl.identifier, test_process_data);
|
||||
|
||||
let parent_id = decl.parent.unwrap();
|
||||
let parent = index.declaration(parent_id).unwrap();
|
||||
let parent = index_state.declaration(parent_id).unwrap();
|
||||
let parent_decl = expect_file_decl("c.rs", &parent, &project, cx);
|
||||
assert_eq!(
|
||||
parent_decl.identifier,
|
||||
@@ -587,16 +612,18 @@ mod tests {
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
index.read_with(cx, |index, cx| {
|
||||
let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx);
|
||||
let index_state = index.read_with(cx, |index, _cx| index.state().clone());
|
||||
let index_state = index_state.lock().await;
|
||||
cx.update(|cx| {
|
||||
let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
|
||||
assert_eq!(decls.len(), 1);
|
||||
|
||||
let decl = expect_buffer_decl("c.rs", &decls[0], cx);
|
||||
let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
|
||||
assert_eq!(decl.identifier, test_process_data);
|
||||
|
||||
let parent_id = decl.parent.unwrap();
|
||||
let parent = index.declaration(parent_id).unwrap();
|
||||
let parent_decl = expect_buffer_decl("c.rs", &parent, cx);
|
||||
let parent = index_state.declaration(parent_id).unwrap();
|
||||
let parent_decl = expect_buffer_decl("c.rs", &parent, &project, cx);
|
||||
assert_eq!(
|
||||
parent_decl.identifier,
|
||||
Identifier {
|
||||
@@ -614,16 +641,13 @@ mod tests {
|
||||
async fn test_declarations_limt(cx: &mut TestAppContext) {
|
||||
let (_, index, rust_lang_id) = init_test(cx).await;
|
||||
|
||||
index.read_with(cx, |index, cx| {
|
||||
let decls = index.declarations_for_identifier::<1>(
|
||||
Identifier {
|
||||
name: "main".into(),
|
||||
language_id: rust_lang_id,
|
||||
},
|
||||
cx,
|
||||
);
|
||||
assert_eq!(decls.len(), 1);
|
||||
let index_state = index.read_with(cx, |index, _cx| index.state().clone());
|
||||
let index_state = index_state.lock().await;
|
||||
let decls = index_state.declarations_for_identifier::<1>(&Identifier {
|
||||
name: "main".into(),
|
||||
language_id: rust_lang_id,
|
||||
});
|
||||
assert_eq!(decls.len(), 0);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
@@ -645,31 +669,31 @@ mod tests {
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
index.read_with(cx, |index, cx| {
|
||||
let decls = index.declarations_for_identifier::<8>(main.clone(), cx);
|
||||
assert_eq!(decls.len(), 2);
|
||||
let decl = expect_buffer_decl("c.rs", &decls[0], cx);
|
||||
assert_eq!(decl.identifier, main);
|
||||
assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..279);
|
||||
let index_state_arc = index.read_with(cx, |index, _cx| index.state().clone());
|
||||
{
|
||||
let index_state = index_state_arc.lock().await;
|
||||
|
||||
expect_file_decl("a.rs", &decls[1], &project, cx);
|
||||
});
|
||||
cx.update(|cx| {
|
||||
let decls = index_state.declarations_for_identifier::<8>(&main);
|
||||
assert_eq!(decls.len(), 2);
|
||||
let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
|
||||
assert_eq!(decl.identifier, main);
|
||||
assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..279);
|
||||
|
||||
expect_file_decl("a.rs", &decls[1], &project, cx);
|
||||
});
|
||||
}
|
||||
|
||||
// Drop the buffer and wait for release
|
||||
let (release_tx, release_rx) = oneshot::channel();
|
||||
cx.update(|cx| {
|
||||
cx.observe_release(&buffer, |_, _| {
|
||||
release_tx.send(()).ok();
|
||||
})
|
||||
.detach();
|
||||
cx.update(|_| {
|
||||
drop(buffer);
|
||||
});
|
||||
drop(buffer);
|
||||
cx.run_until_parked();
|
||||
release_rx.await.ok();
|
||||
cx.run_until_parked();
|
||||
|
||||
index.read_with(cx, |index, cx| {
|
||||
let decls = index.declarations_for_identifier::<8>(main, cx);
|
||||
let index_state = index_state_arc.lock().await;
|
||||
|
||||
cx.update(|cx| {
|
||||
let decls = index_state.declarations_for_identifier::<8>(&main);
|
||||
assert_eq!(decls.len(), 2);
|
||||
expect_file_decl("c.rs", &decls[0], &project, cx);
|
||||
expect_file_decl("a.rs", &decls[1], &project, cx);
|
||||
@@ -679,24 +703,20 @@ mod tests {
|
||||
fn expect_buffer_decl<'a>(
|
||||
path: &str,
|
||||
declaration: &'a Declaration,
|
||||
project: &Entity<Project>,
|
||||
cx: &App,
|
||||
) -> &'a BufferDeclaration {
|
||||
if let Declaration::Buffer {
|
||||
declaration,
|
||||
buffer,
|
||||
project_entry_id,
|
||||
..
|
||||
} = declaration
|
||||
{
|
||||
assert_eq!(
|
||||
buffer
|
||||
.upgrade()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.project_path(cx)
|
||||
.unwrap()
|
||||
.path
|
||||
.as_ref(),
|
||||
Path::new(path),
|
||||
);
|
||||
let project_path = project
|
||||
.read(cx)
|
||||
.path_for_entry(*project_entry_id, cx)
|
||||
.unwrap();
|
||||
assert_eq!(project_path.path.as_ref(), Path::new(path),);
|
||||
declaration
|
||||
} else {
|
||||
panic!("Expected a buffer declaration, found {:?}", declaration);
|
||||
@@ -731,7 +751,7 @@ mod tests {
|
||||
|
||||
async fn init_test(
|
||||
cx: &mut TestAppContext,
|
||||
) -> (Entity<Project>, Entity<TreeSitterIndex>, LanguageId) {
|
||||
) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
@@ -809,7 +829,7 @@ mod tests {
|
||||
let lang_id = lang.id();
|
||||
language_registry.add(Arc::new(lang));
|
||||
|
||||
let index = cx.new(|cx| TreeSitterIndex::new(&project, cx));
|
||||
let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
|
||||
cx.run_until_parked();
|
||||
|
||||
(project, index, lang_id)
|
||||
241
crates/edit_prediction_context/src/text_similarity.rs
Normal file
241
crates/edit_prediction_context/src/text_similarity.rs
Normal file
@@ -0,0 +1,241 @@
|
||||
use regex::Regex;
|
||||
use std::{collections::HashMap, sync::LazyLock};
|
||||
|
||||
use crate::reference::Reference;
|
||||
|
||||
// TODO: Consider implementing sliding window similarity matching like
|
||||
// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts
|
||||
//
|
||||
// That implementation could actually be more efficient - no need to track words in the window that
|
||||
// are not in the query.
|
||||
|
||||
static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct IdentifierOccurrences {
|
||||
identifier_to_count: HashMap<String, usize>,
|
||||
total_count: usize,
|
||||
}
|
||||
|
||||
impl IdentifierOccurrences {
|
||||
pub fn within_string(code: &str) -> Self {
|
||||
Self::from_iterator(IDENTIFIER_REGEX.find_iter(code).map(|mat| mat.as_str()))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn within_references(references: &[Reference]) -> Self {
|
||||
Self::from_iterator(
|
||||
references
|
||||
.iter()
|
||||
.map(|reference| reference.identifier.name.as_ref()),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_iterator<'a>(identifier_iterator: impl Iterator<Item = &'a str>) -> Self {
|
||||
let mut identifier_to_count = HashMap::new();
|
||||
let mut total_count = 0;
|
||||
for identifier in identifier_iterator {
|
||||
// TODO: Score matches that match case higher?
|
||||
//
|
||||
// TODO: Also include unsplit identifier?
|
||||
for identifier_part in split_identifier(identifier) {
|
||||
identifier_to_count
|
||||
.entry(identifier_part.to_lowercase())
|
||||
.and_modify(|count| *count += 1)
|
||||
.or_insert(1);
|
||||
total_count += 1;
|
||||
}
|
||||
}
|
||||
IdentifierOccurrences {
|
||||
identifier_to_count,
|
||||
total_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Splits camelcase / snakecase / kebabcase / pascalcase
|
||||
//
|
||||
// TODO: Make this more efficient / elegant.
|
||||
fn split_identifier<'a>(identifier: &'a str) -> Vec<&'a str> {
|
||||
let mut parts = Vec::new();
|
||||
let mut start = 0;
|
||||
let chars: Vec<char> = identifier.chars().collect();
|
||||
|
||||
if chars.is_empty() {
|
||||
return parts;
|
||||
}
|
||||
|
||||
let mut i = 0;
|
||||
while i < chars.len() {
|
||||
let ch = chars[i];
|
||||
|
||||
// Handle explicit delimiters (underscore and hyphen)
|
||||
if ch == '_' || ch == '-' {
|
||||
if i > start {
|
||||
parts.push(&identifier[start..i]);
|
||||
}
|
||||
start = i + 1;
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle camelCase and PascalCase transitions
|
||||
if i > 0 && i < chars.len() {
|
||||
let prev_char = chars[i - 1];
|
||||
|
||||
// Transition from lowercase/digit to uppercase
|
||||
if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() {
|
||||
parts.push(&identifier[start..i]);
|
||||
start = i;
|
||||
}
|
||||
// Handle sequences like "XMLParser" -> ["XML", "Parser"]
|
||||
else if i + 1 < chars.len()
|
||||
&& ch.is_uppercase()
|
||||
&& chars[i + 1].is_lowercase()
|
||||
&& prev_char.is_uppercase()
|
||||
{
|
||||
parts.push(&identifier[start..i]);
|
||||
start = i;
|
||||
}
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
// Add the last part if there's any remaining
|
||||
if start < identifier.len() {
|
||||
parts.push(&identifier[start..]);
|
||||
}
|
||||
|
||||
// Filter out empty strings
|
||||
parts.into_iter().filter(|s| !s.is_empty()).collect()
|
||||
}
|
||||
|
||||
pub fn jaccard_similarity<'a>(
|
||||
mut set_a: &'a IdentifierOccurrences,
|
||||
mut set_b: &'a IdentifierOccurrences,
|
||||
) -> f32 {
|
||||
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
|
||||
std::mem::swap(&mut set_a, &mut set_b);
|
||||
}
|
||||
let intersection = set_a
|
||||
.identifier_to_count
|
||||
.keys()
|
||||
.filter(|key| set_b.identifier_to_count.contains_key(*key))
|
||||
.count();
|
||||
let union = set_a.identifier_to_count.len() + set_b.identifier_to_count.len() - intersection;
|
||||
intersection as f32 / union as f32
|
||||
}
|
||||
|
||||
// TODO
|
||||
#[allow(dead_code)]
|
||||
pub fn overlap_coefficient<'a>(
|
||||
mut set_a: &'a IdentifierOccurrences,
|
||||
mut set_b: &'a IdentifierOccurrences,
|
||||
) -> f32 {
|
||||
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
|
||||
std::mem::swap(&mut set_a, &mut set_b);
|
||||
}
|
||||
let intersection = set_a
|
||||
.identifier_to_count
|
||||
.keys()
|
||||
.filter(|key| set_b.identifier_to_count.contains_key(*key))
|
||||
.count();
|
||||
intersection as f32 / set_a.identifier_to_count.len() as f32
|
||||
}
|
||||
|
||||
// TODO
|
||||
#[allow(dead_code)]
|
||||
pub fn weighted_jaccard_similarity<'a>(
|
||||
mut set_a: &'a IdentifierOccurrences,
|
||||
mut set_b: &'a IdentifierOccurrences,
|
||||
) -> f32 {
|
||||
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
|
||||
std::mem::swap(&mut set_a, &mut set_b);
|
||||
}
|
||||
|
||||
let mut numerator = 0;
|
||||
let mut denominator_a = 0;
|
||||
let mut used_count_b = 0;
|
||||
for (symbol, count_a) in set_a.identifier_to_count.iter() {
|
||||
let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
|
||||
numerator += count_a.min(count_b);
|
||||
denominator_a += count_a.max(count_b);
|
||||
used_count_b += count_b;
|
||||
}
|
||||
|
||||
let denominator = denominator_a + (set_b.total_count - used_count_b);
|
||||
if denominator == 0 {
|
||||
0.0
|
||||
} else {
|
||||
numerator as f32 / denominator as f32
|
||||
}
|
||||
}
|
||||
|
||||
pub fn weighted_overlap_coefficient<'a>(
|
||||
mut set_a: &'a IdentifierOccurrences,
|
||||
mut set_b: &'a IdentifierOccurrences,
|
||||
) -> f32 {
|
||||
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
|
||||
std::mem::swap(&mut set_a, &mut set_b);
|
||||
}
|
||||
|
||||
let mut numerator = 0;
|
||||
for (symbol, count_a) in set_a.identifier_to_count.iter() {
|
||||
let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
|
||||
numerator += count_a.min(count_b);
|
||||
}
|
||||
|
||||
let denominator = set_a.total_count.min(set_b.total_count);
|
||||
if denominator == 0 {
|
||||
0.0
|
||||
} else {
|
||||
numerator as f32 / denominator as f32
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_split_identifier() {
|
||||
assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]);
|
||||
assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]);
|
||||
assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]);
|
||||
assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]);
|
||||
assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_similarity_functions() {
|
||||
// 10 identifier parts, 8 unique
|
||||
// Repeats: 2 "outline", 2 "items"
|
||||
let set_a = IdentifierOccurrences::within_string(
|
||||
"let mut outline_items = query_outline_items(&language, &tree, &source);",
|
||||
);
|
||||
// 14 identifier parts, 11 unique
|
||||
// Repeats: 2 "outline", 2 "language", 2 "tree"
|
||||
let set_b = IdentifierOccurrences::within_string(
|
||||
"pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
|
||||
);
|
||||
|
||||
// 6 overlaps: "outline", "items", "query", "language", "tree", "source"
|
||||
// 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
|
||||
assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0));
|
||||
|
||||
// Numerator is one more than before due to both having 2 "outline".
|
||||
// Denominator is the same except for 3 more due to the non-overlapping duplicates
|
||||
assert_eq!(
|
||||
weighted_jaccard_similarity(&set_a, &set_b),
|
||||
7.0 / (7.0 + 7.0 + 3.0)
|
||||
);
|
||||
|
||||
// Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
|
||||
assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0);
|
||||
|
||||
// Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
|
||||
// the smaller set, 10.
|
||||
assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0);
|
||||
}
|
||||
}
|
||||
35
crates/edit_prediction_context/src/wip_requests.rs
Normal file
35
crates/edit_prediction_context/src/wip_requests.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
// To discuss: What to send to the new endpoint? Thinking it'd make sense to put `prompt.rs` from
|
||||
// `zeta_context.rs` in cloud.
|
||||
//
|
||||
// * Run excerpt selection at several different sizes, send the largest size with offsets within for
|
||||
// the smaller sizes.
|
||||
//
|
||||
// * Longer event history.
|
||||
//
|
||||
// * Many more snippets than could fit in model context - allows ranking experimentation.
|
||||
|
||||
pub struct Zeta2Request {
|
||||
pub event_history: Vec<Event>,
|
||||
pub excerpt: String,
|
||||
pub excerpt_subsets: Vec<Zeta2ExcerptSubset>,
|
||||
/// Within `excerpt`
|
||||
pub cursor_position: usize,
|
||||
pub signatures: Vec<String>,
|
||||
pub retrieved_declarations: Vec<ReferencedDeclaration>,
|
||||
}
|
||||
|
||||
pub struct Zeta2ExcerptSubset {
|
||||
/// Within `excerpt` text.
|
||||
pub excerpt_range: Range<usize>,
|
||||
/// Within `signatures`.
|
||||
pub parent_signatures: Vec<usize>,
|
||||
}
|
||||
|
||||
pub struct ReferencedDeclaration {
|
||||
pub text: Arc<str>,
|
||||
/// Range within `text`
|
||||
pub signature_range: Range<usize>,
|
||||
/// Indices within `signatures`.
|
||||
pub parent_signatures: Vec<usize>,
|
||||
// A bunch of score metrics
|
||||
}
|
||||
Reference in New Issue
Block a user