Checkpoint: Get score_snippets to compile

Co-Authored-By: Finn <finn@zed.dev>
This commit is contained in:
Agus
2025-09-17 11:47:47 -03:00
committed by Agus Zubiaga
parent 50de8ddc28
commit cc32bfdfdf
2 changed files with 196 additions and 84 deletions

View File

@@ -1,14 +1,18 @@
use collections::HashSet;
use gpui::{App, Entity};
use itertools::Itertools as _;
use language::BufferSnapshot;
use project::ProjectEntryId;
use serde::Serialize;
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use std::{collections::HashMap, ops::Range};
use strum::EnumIter;
use tree_sitter::StreamingIterator;
use text::{OffsetRangeExt, Point, ToPoint};
use crate::{
Declaration, EditPredictionExcerpt, EditPredictionExcerptText, outline::Identifier,
reference::Reference, text_similarity::IdentifierOccurrences,
Declaration, EditPredictionExcerpt, EditPredictionExcerptText, TreeSitterIndex,
outline::Identifier,
reference::{Reference, ReferenceRegion},
text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
};
#[derive(Clone, Debug)]
@@ -46,23 +50,29 @@ impl ScoredSnippet {
}
fn scored_snippets(
index: Entity<TreeSitterIndex>,
excerpt: &EditPredictionExcerpt,
excerpt_text: &EditPredictionExcerptText,
references: Vec<Reference>,
cursor_offset: usize,
current_buffer: &BufferSnapshot,
cx: &App,
) -> Vec<ScoredSnippet> {
let excerpt_occurrences = IdentifierOccurrences::within_string(&excerpt_text.body);
let containing_range_identifier_occurrences =
IdentifierOccurrences::within_string(&excerpt_text.body);
let cursor_point = cursor_offset.to_point(&current_buffer);
/* todo!
if let Some(cursor_within_excerpt) = cursor_offset.checked_sub(excerpt.range.start) {
} else {
};
let start_point = Point::new(cursor.row.saturating_sub(2), 0);
let end_point = Point::new(cursor.row + 1, 0);
// todo! ask michael why we needed this
// if let Some(cursor_within_excerpt) = cursor_offset.checked_sub(excerpt.range.start) {
// } else {
// };
let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
let end_point = Point::new(cursor_point.row + 1, 0);
let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
&source[offset_from_point(source, start_point)..offset_from_point(source, end_point)],
&current_buffer
.text_for_range(start_point..end_point)
.collect::<String>(),
);
*/
let mut identifier_to_references: HashMap<Identifier, Vec<Reference>> = HashMap::new();
for reference in references {
@@ -75,74 +85,102 @@ fn scored_snippets(
identifier_to_references
.into_iter()
.flat_map(|(identifier, references)| {
let Some(definitions) = index
.identifier_to_definitions
.get(&(identifier.clone(), language.name.clone()))
else {
return Vec::new();
};
let definitions = index
.read(cx)
// todo! pick a limit
.declarations_for_identifier::<16>(&identifier, cx);
let definition_count = definitions.len();
let definition_file_count = definitions.keys().len();
let total_file_count = definitions
.iter()
.filter_map(|definition| definition.project_entry_id(cx))
.collect::<HashSet<ProjectEntryId>>()
.len();
definitions
.iter_all()
.flat_map(|(definition_file, file_definitions)| {
let same_file_definition_count = file_definitions.len();
let is_same_file = reference_file == definition_file.as_ref();
file_definitions
.iter()
.filter(|definition| {
!is_same_file
|| !range_intersection(&definition.item_range, &excerpt_range)
.is_some()
})
.filter_map(|definition| {
let definition_line_distance = if is_same_file {
.iter()
.filter_map(|definition| match definition {
Declaration::Buffer {
declaration,
buffer,
} => {
let is_same_file = buffer
.read_with(cx, |buffer, _| buffer.remote_id())
.is_ok_and(|buffer_id| buffer_id == current_buffer.remote_id());
if is_same_file {
range_intersection(
&declaration.item_range.to_offset(&current_buffer),
&excerpt.range,
)
.is_none()
.then(|| {
let definition_line =
point_from_offset(source, definition.item_range.start).row;
(cursor.row as i32 - definition_line as i32).abs() as u32
} else {
0
};
Some((definition_line_distance, definition))
})
.sorted_by_key(|&(distance, _)| distance)
.enumerate()
.map(
|(
definition_line_distance_rank,
(definition_line_distance, definition),
)| {
score_snippet(
&identifier,
&references,
definition_file.clone(),
definition.clone(),
is_same_file,
definition_line_distance,
definition_line_distance_rank,
same_file_definition_count,
definition_count,
definition_file_count,
&containing_range_identifier_occurrences,
&adjacent_identifier_occurrences,
cursor,
declaration.item_range.start.to_point(current_buffer).row;
(
true,
(cursor_point.row as i32 - definition_line as i32).abs() as u32,
definition,
)
},
)
.collect::<Vec<_>>()
})
} else {
Some((false, 0, definition))
}
}
Declaration::File { .. } => {
// We can assume that a file declaration is in a different file,
// because the current onemust be open
Some((false, 0, definition))
}
})
.sorted_by_key(|&(_, distance, _)| distance)
.enumerate()
.map(
|(
definition_line_distance_rank,
(is_same_file, definition_line_distance, definition),
)| {
let same_file_definition_count =
index.read(cx).file_declaration_count(definition);
score_snippet(
&identifier,
&references,
definition.clone(),
is_same_file,
definition_line_distance,
definition_line_distance_rank,
same_file_definition_count,
definition_count,
total_file_count,
&containing_range_identifier_occurrences,
&adjacent_identifier_occurrences,
cursor_point,
current_buffer,
cx,
)
},
)
.collect::<Vec<_>>()
})
.flatten()
.collect::<Vec<_>>()
}
// todo! replace with existing util?
fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
let start = a.start.clone().max(b.start.clone());
let end = a.end.clone().min(b.end.clone());
if start < end {
Some(Range { start, end })
} else {
None
}
}
fn score_snippet(
identifier: &Identifier,
references: &[Reference],
definition_file: Arc<Path>,
definition: OutlineItem,
definition: Declaration,
is_same_file: bool,
definition_line_distance: u32,
definition_line_distance_rank: usize,
@@ -152,28 +190,28 @@ fn score_snippet(
containing_range_identifier_occurrences: &IdentifierOccurrences,
adjacent_identifier_occurrences: &IdentifierOccurrences,
cursor: Point,
current_buffer: &BufferSnapshot,
cx: &App,
) -> Option<ScoredSnippet> {
let is_referenced_nearby = references
.iter()
.any(|r| r.reference_region == ReferenceRegion::Nearby);
.any(|r| r.region == ReferenceRegion::Nearby);
let is_referenced_in_breadcrumb = references
.iter()
.any(|r| r.reference_region == ReferenceRegion::Breadcrumb);
.any(|r| r.region == ReferenceRegion::Breadcrumb);
let reference_count = references.len();
let reference_line_distance = references
.iter()
.map(|r| {
let reference_line = point_from_offset(reference_source, r.range.start).row as i32;
let reference_line = r.range.start.to_point(current_buffer).row as i32;
(cursor.row as i32 - reference_line).abs() as u32
})
.min()
.unwrap();
let definition_source = index.path_to_source.get(&definition_file).unwrap();
let item_source_occurrences =
IdentifierOccurrences::within_string(definition.item(&definition_source));
let item_source_occurrences = IdentifierOccurrences::within_string(&definition.item_text(cx));
let item_signature_occurrences =
IdentifierOccurrences::within_string(definition.signature(&definition_source));
IdentifierOccurrences::within_string(&definition.signature_text(cx));
let containing_range_vs_item_jaccard = jaccard_similarity(
containing_range_identifier_occurrences,
&item_source_occurrences,
@@ -223,7 +261,6 @@ fn score_snippet(
Some(ScoredSnippet {
identifier: identifier.clone(),
declaration_file: definition_file,
declaration: definition,
scores: score_components.score(),
score_components,
@@ -238,6 +275,7 @@ pub struct ScoreInputs {
pub reference_count: usize,
pub same_file_definition_count: usize,
pub definition_count: usize,
// todo! do we need this?
pub definition_file_count: usize,
pub reference_line_distance: u32,
pub definition_line_distance: u32,

View File

@@ -78,6 +78,57 @@ impl Declaration {
Declaration::Buffer { declaration, .. } => &declaration.identifier,
}
}
pub fn project_entry_id(&self, cx: &App) -> Option<ProjectEntryId> {
match self {
Declaration::File {
project_entry_id, ..
} => Some(*project_entry_id),
Declaration::Buffer { buffer, .. } => buffer
.read_with(cx, |buffer, _cx| {
project::File::from_dyn(buffer.file())
.and_then(|file| file.project_entry_id(cx))
})
.ok()
.flatten(),
}
}
// todo! pick best return type
pub fn item_text(&self, cx: &App) -> Arc<str> {
match self {
Declaration::File { declaration, .. } => declaration.declaration_text.clone(),
Declaration::Buffer {
buffer,
declaration,
} => buffer
.read_with(cx, |buffer, _cx| {
buffer
.text_for_range(declaration.item_range.clone())
.collect::<String>()
.into()
})
.unwrap_or_default(),
}
}
// todo! pick best return type
pub fn signature_text(&self, cx: &App) -> Arc<str> {
match self {
Declaration::File { declaration, .. } => declaration.signature_text.clone(),
Declaration::Buffer {
buffer,
declaration,
} => buffer
.read_with(cx, |buffer, _cx| {
buffer
.text_for_range(declaration.signature_range.clone())
.collect::<String>()
.into()
})
.unwrap_or_default(),
}
}
}
#[derive(Debug, Clone)]
@@ -86,7 +137,9 @@ pub struct FileDeclaration {
pub identifier: Identifier,
pub item_range: Range<usize>,
pub signature_range: Range<usize>,
// todo! should we just store a range with the declaration text?
pub signature_text: Arc<str>,
pub declaration_text: Arc<str>,
}
#[derive(Debug, Clone)]
@@ -145,7 +198,7 @@ impl TreeSitterIndex {
pub fn declarations_for_identifier<const N: usize>(
&self,
identifier: Identifier,
identifier: &Identifier,
cx: &App,
) -> Vec<Declaration> {
// make sure to not have a large stack allocation
@@ -206,6 +259,23 @@ impl TreeSitterIndex {
result
}
pub fn file_declaration_count(&self, declaration: &Declaration) -> usize {
match declaration {
Declaration::File {
project_entry_id, ..
} => self
.files
.get(project_entry_id)
.map(|file_state| file_state.declarations.len())
.unwrap_or_default(),
Declaration::Buffer { buffer, .. } => self
.buffers
.get(buffer)
.map(|buffer_state| buffer_state.declarations.len())
.unwrap_or_default(),
}
}
fn handle_worktree_store_event(
&mut self,
_worktree_store: Entity<WorktreeStore>,
@@ -491,12 +561,16 @@ impl FileDeclaration {
FileDeclaration {
parent: None,
identifier: declaration.identifier,
item_range: declaration.item_range,
signature_text: snapshot
.text_for_range(declaration.signature_range.clone())
.collect::<String>()
.into(),
signature_range: declaration.signature_range,
declaration_text: snapshot
.text_for_range(declaration.item_range.clone())
.collect::<String>()
.into(),
item_range: declaration.item_range,
}
}
}
@@ -527,7 +601,7 @@ mod tests {
};
index.read_with(cx, |index, cx| {
let decls = index.declarations_for_identifier::<8>(main.clone(), cx);
let decls = index.declarations_for_identifier::<8>(&main, cx);
assert_eq!(decls.len(), 2);
let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
@@ -549,7 +623,7 @@ mod tests {
};
index.read_with(cx, |index, cx| {
let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx);
let decls = index.declarations_for_identifier::<8>(&test_process_data, cx);
assert_eq!(decls.len(), 1);
let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
@@ -588,7 +662,7 @@ mod tests {
cx.run_until_parked();
index.read_with(cx, |index, cx| {
let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx);
let decls = index.declarations_for_identifier::<8>(&test_process_data, cx);
assert_eq!(decls.len(), 1);
let decl = expect_buffer_decl("c.rs", &decls[0], cx);
@@ -616,7 +690,7 @@ mod tests {
index.read_with(cx, |index, cx| {
let decls = index.declarations_for_identifier::<1>(
Identifier {
&Identifier {
name: "main".into(),
language_id: rust_lang_id,
},
@@ -646,7 +720,7 @@ mod tests {
cx.run_until_parked();
index.read_with(cx, |index, cx| {
let decls = index.declarations_for_identifier::<8>(main.clone(), cx);
let decls = index.declarations_for_identifier::<8>(&main, cx);
assert_eq!(decls.len(), 2);
let decl = expect_buffer_decl("c.rs", &decls[0], cx);
assert_eq!(decl.identifier, main);
@@ -669,7 +743,7 @@ mod tests {
cx.run_until_parked();
index.read_with(cx, |index, cx| {
let decls = index.declarations_for_identifier::<8>(main, cx);
let decls = index.declarations_for_identifier::<8>(&main, cx);
assert_eq!(decls.len(), 2);
expect_file_decl("c.rs", &decls[0], &project, cx);
expect_file_decl("a.rs", &decls[1], &project, cx);