Compare commits
28 Commits
fix-git-ht
...
tfidf-inde
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c45755c088 | ||
|
|
b8cb6a1059 | ||
|
|
fbd8b2b587 | ||
|
|
48ac888be3 | ||
|
|
2dc70d64cd | ||
|
|
ab4b2bd204 | ||
|
|
2e1ee2bcc8 | ||
|
|
7c8d982caf | ||
|
|
966dbd30f6 | ||
|
|
db1dc47ddb | ||
|
|
a1cb4ec947 | ||
|
|
671872c47b | ||
|
|
4f4497d0e3 | ||
|
|
5606768679 | ||
|
|
6cc04f71f5 | ||
|
|
9cfa2933dd | ||
|
|
ca8f9c7476 | ||
|
|
ce70cd00b6 | ||
|
|
a74f1766f0 | ||
|
|
a85d773fe2 | ||
|
|
83d96cf369 | ||
|
|
acf7ad3d83 | ||
|
|
4f9f2e52f6 | ||
|
|
ef2f236355 | ||
|
|
5510bb6715 | ||
|
|
c1cfd6dc4b | ||
|
|
36ab101011 | ||
|
|
91e2e815a6 |
21
Cargo.lock
generated
21
Cargo.lock
generated
@@ -9492,6 +9492,16 @@ dependencies = [
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-stemmers"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e46a2036019fdb888131db7a4c847a1063a7493f971ed94ea82c67eada63ca54"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust_decimal"
|
||||
version = "1.36.0"
|
||||
@@ -9969,11 +9979,13 @@ dependencies = [
|
||||
"open_ai",
|
||||
"parking_lot",
|
||||
"project",
|
||||
"rust-stemmers",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"sha2",
|
||||
"smol",
|
||||
"stop-words",
|
||||
"tempfile",
|
||||
"theme",
|
||||
"tree-sitter",
|
||||
@@ -10850,6 +10862,15 @@ version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
|
||||
[[package]]
|
||||
name = "stop-words"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8500024d809de02ecbf998472b7bed3c4fca380df2be68917f6a473bdb28ddcc"
|
||||
dependencies = [
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "story"
|
||||
version = "0.1.0"
|
||||
|
||||
@@ -399,8 +399,9 @@ rsa = "0.9.6"
|
||||
runtimelib = { version = "0.15", default-features = false, features = [
|
||||
"async-dispatcher-runtime",
|
||||
] }
|
||||
rustc-demangle = "0.1.23"
|
||||
rust-embed = { version = "8.4", features = ["include-exclude"] }
|
||||
rust-stemmers = "1.2"
|
||||
rustc-demangle = "0.1.23"
|
||||
rustls = "0.20.3"
|
||||
rustls-native-certs = "0.8.0"
|
||||
schemars = { version = "0.8", features = ["impl_json_schema"] }
|
||||
@@ -422,6 +423,7 @@ simplelog = "0.12.2"
|
||||
smallvec = { version = "1.6", features = ["union"] }
|
||||
smol = "1.2"
|
||||
sqlformat = "0.2"
|
||||
stop-words = "0.8"
|
||||
strsim = "0.11"
|
||||
strum = { version = "0.25.0", features = ["derive"] }
|
||||
subtle = "2.5.0"
|
||||
|
||||
@@ -37,7 +37,7 @@ use language_model::{
|
||||
pub(crate) use model_selector::*;
|
||||
pub use prompts::PromptBuilder;
|
||||
use prompts::PromptLoadingParams;
|
||||
use semantic_index::{CloudEmbeddingProvider, SemanticDb};
|
||||
use semantic_index::{CloudEmbeddingProvider, SemanticDb, SEMANTIC_INDEX_DB_VERSION};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{update_settings_file, Settings, SettingsStore};
|
||||
use slash_command::{
|
||||
@@ -215,7 +215,8 @@ pub fn init(
|
||||
async move {
|
||||
let embedding_provider = CloudEmbeddingProvider::new(client.clone());
|
||||
let semantic_index = SemanticDb::new(
|
||||
paths::embeddings_dir().join("semantic-index-db.0.mdb"),
|
||||
paths::embeddings_dir()
|
||||
.join(format!("semantic-index-db.{SEMANTIC_INDEX_DB_VERSION}.mdb")),
|
||||
Arc::new(embedding_provider),
|
||||
&mut cx,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use super::{
|
||||
create_label_for_command, search_command::add_search_result_section, SlashCommand,
|
||||
SlashCommandOutput,
|
||||
create_label_for_command,
|
||||
search_command::{add_search_result_section, SearchStyle},
|
||||
SlashCommand, SlashCommandOutput,
|
||||
};
|
||||
use crate::PromptBuilder;
|
||||
use anyhow::{anyhow, Result};
|
||||
@@ -70,7 +71,7 @@ impl SlashCommand for ProjectSlashCommand {
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
_arguments: &[String],
|
||||
arguments: &[String],
|
||||
_context_slash_command_output_sections: &[SlashCommandOutputSection<Anchor>],
|
||||
context_buffer: language::BufferSnapshot,
|
||||
workspace: WeakView<Workspace>,
|
||||
@@ -80,6 +81,33 @@ impl SlashCommand for ProjectSlashCommand {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let current_model = model_registry.active_model();
|
||||
let prompt_builder = self.prompt_builder.clone();
|
||||
let mut style = SearchStyle::Hybrid;
|
||||
let mut arg_iter = arguments.iter();
|
||||
if let Some(arg) = arg_iter.next() {
|
||||
if arg == "--style" {
|
||||
if let Some(style_value) = arg_iter.next() {
|
||||
match style_value.as_str() {
|
||||
"dense" => style = SearchStyle::Dense,
|
||||
"sparse" => style = SearchStyle::Sparse,
|
||||
"hybrid" => style = SearchStyle::Hybrid,
|
||||
_ => {
|
||||
return Task::ready(Err(anyhow::anyhow!(
|
||||
"Invalid style parameter; should be 'dense', 'sparse', or 'hybrid'."
|
||||
)))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Task::ready(Err(anyhow::anyhow!(
|
||||
"Missing value for --style parameter"
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
let search_param = match style {
|
||||
SearchStyle::Dense => 1.0,
|
||||
SearchStyle::Sparse => 0.0,
|
||||
SearchStyle::Hybrid => 0.7,
|
||||
};
|
||||
|
||||
let Some(workspace) = workspace.upgrade() else {
|
||||
return Task::ready(Err(anyhow::anyhow!("workspace was dropped")));
|
||||
@@ -117,7 +145,7 @@ impl SlashCommand for ProjectSlashCommand {
|
||||
|
||||
let results = project_index
|
||||
.read_with(&cx, |project_index, cx| {
|
||||
project_index.search(search_queries.clone(), 25, cx)
|
||||
project_index.search(search_queries.clone(), 25, search_param, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -24,13 +24,19 @@ impl FeatureFlag for SearchSlashCommandFeatureFlag {
|
||||
|
||||
pub(crate) struct SearchSlashCommand;
|
||||
|
||||
pub(crate) enum SearchStyle {
|
||||
Dense,
|
||||
Hybrid,
|
||||
Sparse,
|
||||
}
|
||||
|
||||
impl SlashCommand for SearchSlashCommand {
|
||||
fn name(&self) -> String {
|
||||
"search".into()
|
||||
}
|
||||
|
||||
fn label(&self, cx: &AppContext) -> CodeLabel {
|
||||
create_label_for_command("search", &["--n"], cx)
|
||||
create_label_for_command("search", &["--n", "--style {*hybrid,dense,sparse}"], cx)
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
@@ -72,24 +78,58 @@ impl SlashCommand for SearchSlashCommand {
|
||||
};
|
||||
|
||||
let mut limit = None;
|
||||
let mut query = String::new();
|
||||
for part in arguments {
|
||||
if let Some(parameter) = part.strip_prefix("--") {
|
||||
if let Ok(count) = parameter.parse::<usize>() {
|
||||
limit = Some(count);
|
||||
let mut query = Vec::new();
|
||||
let mut arg_iter = arguments.iter();
|
||||
let mut style = SearchStyle::Hybrid;
|
||||
|
||||
while let Some(arg) = arg_iter.next() {
|
||||
if arg == "--n" {
|
||||
if let Some(count) = arg_iter.next() {
|
||||
if let Ok(parsed_count) = count.parse::<usize>() {
|
||||
limit = Some(parsed_count);
|
||||
continue;
|
||||
} else {
|
||||
return Task::ready(Err(anyhow::anyhow!(
|
||||
"Invalid count for --n parameter; should be a positive integer."
|
||||
)));
|
||||
}
|
||||
} else {
|
||||
return Task::ready(Err(anyhow::anyhow!("Missing count for --n parameter")));
|
||||
}
|
||||
} else if arg == "--style" {
|
||||
if let Some(style_value) = arg_iter.next() {
|
||||
match style_value.as_str() {
|
||||
"dense" => style = SearchStyle::Dense,
|
||||
"sparse" => style = SearchStyle::Sparse,
|
||||
"hybrid" => style = SearchStyle::Hybrid,
|
||||
_ => {
|
||||
return Task::ready(Err(anyhow::anyhow!(
|
||||
"Invalid style parameter; should be 'dense', 'sparse', or 'hybrid'."
|
||||
)))
|
||||
}
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
return Task::ready(Err(anyhow::anyhow!(
|
||||
"Missing value for --style parameter"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
query.push_str(part);
|
||||
query.push(' ');
|
||||
query.push(arg.clone());
|
||||
}
|
||||
query.pop();
|
||||
|
||||
let query = query.join(" ");
|
||||
|
||||
if query.is_empty() {
|
||||
return Task::ready(Err(anyhow::anyhow!("missing search query")));
|
||||
}
|
||||
|
||||
let search_param = match style {
|
||||
SearchStyle::Dense => 1.0,
|
||||
SearchStyle::Sparse => 0.0,
|
||||
SearchStyle::Hybrid => 0.7,
|
||||
};
|
||||
|
||||
let project = workspace.read(cx).project().clone();
|
||||
let fs = project.read(cx).fs().clone();
|
||||
let Some(project_index) =
|
||||
@@ -99,12 +139,13 @@ impl SlashCommand for SearchSlashCommand {
|
||||
};
|
||||
|
||||
cx.spawn(|cx| async move {
|
||||
let results = project_index
|
||||
let mut results = project_index
|
||||
.read_with(&cx, |project_index, cx| {
|
||||
project_index.search(vec![query.clone()], limit.unwrap_or(5), cx)
|
||||
project_index.search(vec![query.clone()], limit.unwrap_or(5), search_param, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
results = dbg!(results);
|
||||
let loaded_results = SemanticDb::load_results(results, &fs, &cx).await?;
|
||||
|
||||
let output = cx
|
||||
|
||||
@@ -36,8 +36,9 @@ use std::{
|
||||
const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
|
||||
const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
|
||||
const EVAL_DB_PATH: &'static str = "target/eval_db";
|
||||
const SEARCH_RESULT_LIMIT: usize = 8;
|
||||
const SKIP_EVAL_PATH: &'static str = ".skip_eval";
|
||||
const DEFAULT_SEARCH_PARAM: f32 = 0.7;
|
||||
const DEFAULT_SEARCH_RESULT_LIMIT: usize = 8;
|
||||
|
||||
#[derive(clap::Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@@ -50,6 +51,10 @@ struct Cli {
|
||||
enum Commands {
|
||||
Fetch {},
|
||||
Run {
|
||||
#[arg(long, default_value_t = DEFAULT_SEARCH_PARAM)]
|
||||
search_param: f32,
|
||||
#[arg(long, default_value_t = DEFAULT_SEARCH_RESULT_LIMIT)]
|
||||
search_result_limit: usize,
|
||||
#[arg(long)]
|
||||
repo: Option<String>,
|
||||
},
|
||||
@@ -115,9 +120,16 @@ fn main() -> Result<()> {
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
Commands::Run { repo } => {
|
||||
Commands::Run {
|
||||
search_param,
|
||||
search_result_limit,
|
||||
repo,
|
||||
} => {
|
||||
cx.spawn(|mut cx| async move {
|
||||
if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
|
||||
if let Err(err) =
|
||||
run_evaluation(search_param, search_result_limit, repo, &executor, &mut cx)
|
||||
.await
|
||||
{
|
||||
eprintln!("Error: {}", err);
|
||||
exit(1);
|
||||
}
|
||||
@@ -249,6 +261,8 @@ struct Counts {
|
||||
}
|
||||
|
||||
async fn run_evaluation(
|
||||
search_param: f32,
|
||||
search_result_limit: usize,
|
||||
only_repo: Option<String>,
|
||||
executor: &BackgroundExecutor,
|
||||
cx: &mut AsyncAppContext,
|
||||
@@ -359,6 +373,8 @@ async fn run_evaluation(
|
||||
let repo = evaluation_project.repo.clone();
|
||||
if let Err(err) = run_eval_project(
|
||||
evaluation_project,
|
||||
search_param,
|
||||
search_result_limit,
|
||||
&user_store,
|
||||
repo_db_path,
|
||||
&repo_dir,
|
||||
@@ -403,6 +419,8 @@ async fn run_evaluation(
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn run_eval_project(
|
||||
evaluation_project: EvaluationProject,
|
||||
search_param: f32,
|
||||
search_result_limit: usize,
|
||||
user_store: &Model<UserStore>,
|
||||
repo_db_path: PathBuf,
|
||||
repo_dir: &Path,
|
||||
@@ -438,7 +456,12 @@ async fn run_eval_project(
|
||||
loop {
|
||||
match cx.update(|cx| {
|
||||
let project_index = project_index.read(cx);
|
||||
project_index.search(vec![query.query.clone()], SEARCH_RESULT_LIMIT, cx)
|
||||
project_index.search(
|
||||
vec![query.query.clone()],
|
||||
search_result_limit,
|
||||
search_param,
|
||||
cx,
|
||||
)
|
||||
}) {
|
||||
Ok(task) => match task.await {
|
||||
Ok(answer) => {
|
||||
|
||||
@@ -37,11 +37,13 @@ log.workspace = true
|
||||
open_ai.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
rust-stemmers.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
sha2.workspace = true
|
||||
smol.workspace = true
|
||||
stop-words.workspace = true
|
||||
theme.workspace = true
|
||||
tree-sitter.workspace = true
|
||||
ui. workspace = true
|
||||
|
||||
@@ -98,7 +98,7 @@ fn main() {
|
||||
.update(|cx| {
|
||||
let project_index = project_index.read(cx);
|
||||
let query = "converting an anchor to a point";
|
||||
project_index.search(vec![query.into()], 4, cx)
|
||||
project_index.search(vec![query.into()], 4, 0.7, cx)
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
|
||||
@@ -2,6 +2,7 @@ use crate::{
|
||||
chunking::{self, Chunk},
|
||||
embedding::{Embedding, EmbeddingProvider, TextToEmbed},
|
||||
indexing::{IndexingEntryHandle, IndexingEntrySet},
|
||||
tfidf::{SimpleTokenizer, TermCounts, WorktreeTermStats},
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use collections::Bound;
|
||||
@@ -17,15 +18,35 @@ use serde::{Deserialize, Serialize};
|
||||
use smol::channel;
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::HashMap,
|
||||
future::Future,
|
||||
iter,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
sync::{Arc, RwLock},
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use util::ResultExt;
|
||||
use worktree::Snapshot;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct EmbeddingIndexSettings {
|
||||
pub scan_entries_bound: usize,
|
||||
pub deleted_entries_bound: usize,
|
||||
pub chunk_files_bound: usize,
|
||||
pub embed_files_bound: usize,
|
||||
}
|
||||
|
||||
impl Default for EmbeddingIndexSettings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
scan_entries_bound: 512,
|
||||
deleted_entries_bound: 128,
|
||||
chunk_files_bound: 2048,
|
||||
embed_files_bound: 512,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EmbeddingIndex {
|
||||
worktree: Model<Worktree>,
|
||||
db_connection: heed::Env,
|
||||
@@ -34,6 +55,9 @@ pub struct EmbeddingIndex {
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
entry_ids_being_indexed: Arc<IndexingEntrySet>,
|
||||
pub worktree_corpus_stats: Arc<RwLock<WorktreeTermStats>>,
|
||||
tokenizer: SimpleTokenizer,
|
||||
settings: EmbeddingIndexSettings,
|
||||
}
|
||||
|
||||
impl EmbeddingIndex {
|
||||
@@ -41,7 +65,7 @@ impl EmbeddingIndex {
|
||||
worktree: Model<Worktree>,
|
||||
fs: Arc<dyn Fs>,
|
||||
db_connection: heed::Env,
|
||||
embedding_db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
|
||||
db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
entry_ids_being_indexed: Arc<IndexingEntrySet>,
|
||||
@@ -50,10 +74,17 @@ impl EmbeddingIndex {
|
||||
worktree,
|
||||
fs,
|
||||
db_connection,
|
||||
db: embedding_db,
|
||||
db,
|
||||
language_registry,
|
||||
embedding_provider,
|
||||
entry_ids_being_indexed,
|
||||
worktree_corpus_stats: Arc::new(RwLock::new(WorktreeTermStats::new(
|
||||
HashMap::new(),
|
||||
0,
|
||||
0,
|
||||
))),
|
||||
tokenizer: SimpleTokenizer::new(),
|
||||
settings: EmbeddingIndexSettings::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,14 +93,20 @@ impl EmbeddingIndex {
|
||||
}
|
||||
|
||||
pub fn index_entries_changed_on_disk(
|
||||
&self,
|
||||
&mut self,
|
||||
cx: &AppContext,
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
let worktree = self.worktree.read(cx).snapshot();
|
||||
let worktree_abs_path = worktree.abs_path().clone();
|
||||
let scan = self.scan_entries(worktree, cx);
|
||||
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
|
||||
let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
|
||||
let embed = Self::embed_files(
|
||||
self.embedding_provider.clone(),
|
||||
chunk.files,
|
||||
self.tokenizer.clone(),
|
||||
self.settings,
|
||||
cx,
|
||||
);
|
||||
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
|
||||
async move {
|
||||
futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
|
||||
@@ -78,7 +115,7 @@ impl EmbeddingIndex {
|
||||
}
|
||||
|
||||
pub fn index_updated_entries(
|
||||
&self,
|
||||
&mut self,
|
||||
updated_entries: UpdatedEntriesSet,
|
||||
cx: &AppContext,
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
@@ -86,7 +123,13 @@ impl EmbeddingIndex {
|
||||
let worktree_abs_path = worktree.abs_path().clone();
|
||||
let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
|
||||
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
|
||||
let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
|
||||
let embed = Self::embed_files(
|
||||
self.embedding_provider.clone(),
|
||||
chunk.files,
|
||||
self.tokenizer.clone(),
|
||||
self.settings,
|
||||
cx,
|
||||
);
|
||||
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
|
||||
async move {
|
||||
futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
|
||||
@@ -94,12 +137,16 @@ impl EmbeddingIndex {
|
||||
}
|
||||
}
|
||||
|
||||
fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
|
||||
let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
|
||||
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
|
||||
fn scan_entries(&mut self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
|
||||
let (updated_entries_tx, updated_entries_rx) =
|
||||
channel::bounded(self.settings.scan_entries_bound);
|
||||
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) =
|
||||
channel::bounded(self.settings.deleted_entries_bound);
|
||||
let db_connection = self.db_connection.clone();
|
||||
let db = self.db;
|
||||
let entries_being_indexed = self.entry_ids_being_indexed.clone();
|
||||
let worktree_corpus_stats = self.worktree_corpus_stats.clone();
|
||||
|
||||
let task = cx.background_executor().spawn(async move {
|
||||
let txn = db_connection
|
||||
.read_txn()
|
||||
@@ -113,8 +160,15 @@ impl EmbeddingIndex {
|
||||
let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
|
||||
for entry in worktree.files(false, 0) {
|
||||
log::trace!("scanning for embedding index: {:?}", &entry.path);
|
||||
|
||||
let entry_db_key = db_key_for_path(&entry.path);
|
||||
if let Some(embedded_file) = db.get(&txn, &entry_db_key)? {
|
||||
// initialize worktree_corpus_stats from embedded files in database
|
||||
update_corpus_stats(worktree_corpus_stats.clone(), |stats| {
|
||||
for chunk in &embedded_file.chunks {
|
||||
stats.add_counts(&chunk.term_frequencies);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let mut saved_mtime = None;
|
||||
while let Some(db_entry) = db_entries.peek() {
|
||||
@@ -152,6 +206,7 @@ impl EmbeddingIndex {
|
||||
}
|
||||
|
||||
if entry.mtime != saved_mtime {
|
||||
// corpus stats will be adjusted later, while these are chunked/embedded
|
||||
let handle = entries_being_indexed.insert(entry.id);
|
||||
updated_entries_tx.send((entry.clone(), handle)).await?;
|
||||
}
|
||||
@@ -180,8 +235,10 @@ impl EmbeddingIndex {
|
||||
updated_entries: UpdatedEntriesSet,
|
||||
cx: &AppContext,
|
||||
) -> ScanEntries {
|
||||
let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
|
||||
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
|
||||
let (updated_entries_tx, updated_entries_rx) =
|
||||
channel::bounded(self.settings.scan_entries_bound);
|
||||
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) =
|
||||
channel::bounded(self.settings.deleted_entries_bound);
|
||||
let entries_being_indexed = self.entry_ids_being_indexed.clone();
|
||||
let task = cx.background_executor().spawn(async move {
|
||||
for (path, entry_id, status) in updated_entries.iter() {
|
||||
@@ -191,6 +248,7 @@ impl EmbeddingIndex {
|
||||
| project::PathChange::AddedOrUpdated => {
|
||||
if let Some(entry) = worktree.entry_for_id(*entry_id) {
|
||||
if entry.is_file() {
|
||||
// corpus stats will be adjusted later, while these are chunked/embedded
|
||||
let handle = entries_being_indexed.insert(entry.id);
|
||||
updated_entries_tx.send((entry.clone(), handle)).await?;
|
||||
}
|
||||
@@ -226,7 +284,8 @@ impl EmbeddingIndex {
|
||||
) -> ChunkFiles {
|
||||
let language_registry = self.language_registry.clone();
|
||||
let fs = self.fs.clone();
|
||||
let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
|
||||
let (chunked_files_tx, chunked_files_rx) =
|
||||
channel::bounded(self.settings.chunk_files_bound);
|
||||
let task = cx.spawn(|cx| async move {
|
||||
cx.background_executor()
|
||||
.scoped(|cx| {
|
||||
@@ -272,10 +331,11 @@ impl EmbeddingIndex {
|
||||
pub fn embed_files(
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
chunked_files: channel::Receiver<ChunkedFile>,
|
||||
tokenizer: SimpleTokenizer,
|
||||
settings: EmbeddingIndexSettings,
|
||||
cx: &AppContext,
|
||||
) -> EmbedFiles {
|
||||
let embedding_provider = embedding_provider.clone();
|
||||
let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
|
||||
let (embedded_files_tx, embedded_files_rx) = channel::bounded(settings.embed_files_bound);
|
||||
let task = cx.background_executor().spawn(async move {
|
||||
let mut chunked_file_batches =
|
||||
chunked_files.chunks_timeout(512, Duration::from_secs(2));
|
||||
@@ -284,7 +344,6 @@ impl EmbeddingIndex {
|
||||
// Flatten out to a vec of chunks that we can subdivide into batch sized pieces
|
||||
// Once those are done, reassemble them back into the files in which they belong
|
||||
// If any embeddings fail for a file, the entire file is discarded
|
||||
|
||||
let chunks: Vec<TextToEmbed> = chunked_files
|
||||
.iter()
|
||||
.flat_map(|file| {
|
||||
@@ -312,11 +371,10 @@ impl EmbeddingIndex {
|
||||
|
||||
embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
|
||||
}
|
||||
|
||||
let mut embeddings = embeddings.into_iter();
|
||||
for chunked_file in chunked_files {
|
||||
let mut embedded_file = EmbeddedFile {
|
||||
path: chunked_file.path,
|
||||
path: chunked_file.path.clone(),
|
||||
mtime: chunked_file.mtime,
|
||||
chunks: Vec::new(),
|
||||
};
|
||||
@@ -326,9 +384,11 @@ impl EmbeddingIndex {
|
||||
chunked_file.chunks.into_iter().zip(embeddings.by_ref())
|
||||
{
|
||||
if let Some(embedding) = embedding {
|
||||
let chunk_text = &chunked_file.text[chunk.range.clone()];
|
||||
let chunk_counts = TermCounts::from_text(chunk_text, &tokenizer);
|
||||
embedded_file
|
||||
.chunks
|
||||
.push(EmbeddedChunk { chunk, embedding });
|
||||
.push(EmbeddedChunk { chunk, embedding, term_frequencies: chunk_counts });
|
||||
} else {
|
||||
embedded_all_chunks = false;
|
||||
}
|
||||
@@ -358,6 +418,7 @@ impl EmbeddingIndex {
|
||||
) -> Task<Result<()>> {
|
||||
let db_connection = self.db_connection.clone();
|
||||
let db = self.db;
|
||||
let worktree_corpus_stats = self.worktree_corpus_stats.clone();
|
||||
|
||||
cx.background_executor().spawn(async move {
|
||||
loop {
|
||||
@@ -370,6 +431,15 @@ impl EmbeddingIndex {
|
||||
let end = deletion_range.1.as_ref().map(|end| end.as_str());
|
||||
log::debug!("deleting embeddings in range {:?}", &(start, end));
|
||||
db.delete_range(&mut txn, &(start, end))?;
|
||||
for (_, embedded_file) in db.range(&txn, &(start, end))?.flatten() {
|
||||
if let None = update_corpus_stats(worktree_corpus_stats.clone(), |stats| {
|
||||
for chunk in &embedded_file.chunks {
|
||||
stats.remove_counts(&chunk.term_frequencies);
|
||||
}
|
||||
}) {
|
||||
log::error!("Failed to acquire write lock for worktree_corpus_stats; corpus stats will be outdated");
|
||||
}
|
||||
}
|
||||
txn.commit()?;
|
||||
}
|
||||
},
|
||||
@@ -380,6 +450,13 @@ impl EmbeddingIndex {
|
||||
let key = db_key_for_path(&file.path);
|
||||
db.put(&mut txn, &key, &file)?;
|
||||
txn.commit()?;
|
||||
if let None = update_corpus_stats(worktree_corpus_stats.clone(), |stats| {
|
||||
for chunk in &file.chunks {
|
||||
stats.add_counts(&chunk.term_frequencies);
|
||||
}
|
||||
}) {
|
||||
log::error!("Failed to acquire write lock for worktree_corpus_stats; corpus stats will be outdated");
|
||||
}
|
||||
}
|
||||
},
|
||||
complete => break,
|
||||
@@ -399,7 +476,13 @@ impl EmbeddingIndex {
|
||||
.context("failed to create read transaction")?;
|
||||
let result = db
|
||||
.iter(&tx)?
|
||||
.map(|entry| Ok(entry?.1.path.clone()))
|
||||
.filter_map(|entry| {
|
||||
if let Ok((_, file)) = entry {
|
||||
Some(Ok(file.path.clone()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<Arc<Path>>>>();
|
||||
drop(tx);
|
||||
result
|
||||
@@ -417,11 +500,10 @@ impl EmbeddingIndex {
|
||||
let tx = connection
|
||||
.read_txn()
|
||||
.context("failed to create read transaction")?;
|
||||
Ok(db
|
||||
.get(&tx, &db_key_for_path(&path))?
|
||||
.ok_or_else(|| anyhow!("no such path"))?
|
||||
.chunks
|
||||
.clone())
|
||||
match db.get(&tx, &db_key_for_path(&path))? {
|
||||
Some(file) => Ok(file.chunks.clone()),
|
||||
None => Err(anyhow!("no such path")),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -450,7 +532,7 @@ pub struct EmbedFiles {
|
||||
pub task: Task<Result<()>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct EmbeddedFile {
|
||||
pub path: Arc<Path>,
|
||||
pub mtime: Option<SystemTime>,
|
||||
@@ -461,8 +543,18 @@ pub struct EmbeddedFile {
|
||||
pub struct EmbeddedChunk {
|
||||
pub chunk: Chunk,
|
||||
pub embedding: Embedding,
|
||||
pub term_frequencies: TermCounts,
|
||||
}
|
||||
|
||||
fn db_key_for_path(path: &Arc<Path>) -> String {
|
||||
path.to_string_lossy().replace('/', "\0")
|
||||
}
|
||||
|
||||
fn update_corpus_stats<F, R>(stats: Arc<RwLock<WorktreeTermStats>>, f: F) -> Option<R>
|
||||
where
|
||||
F: FnOnce(&mut WorktreeTermStats) -> R,
|
||||
{
|
||||
stats.write().map(|mut guard| f(&mut guard)).map_err(|_| {
|
||||
log::error!("Failed to acquire write lock for worktree_corpus_stats; corpus stats will be outdated");
|
||||
}).ok()
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use crate::{
|
||||
embedding::{EmbeddingProvider, TextToEmbed},
|
||||
summary_index::FileSummary,
|
||||
tfidf::{Bm25Parameters, Bm25Scorer, SimpleTokenizer},
|
||||
worktree_index::{WorktreeIndex, WorktreeIndexHandle},
|
||||
};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use collections::HashMap;
|
||||
use fs::Fs;
|
||||
use futures::{stream::StreamExt, FutureExt};
|
||||
use gpui::{
|
||||
@@ -17,6 +17,7 @@ use serde::{Deserialize, Serialize};
|
||||
use smol::channel;
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::HashMap,
|
||||
future::Future,
|
||||
num::NonZeroUsize,
|
||||
ops::{Range, RangeInclusive},
|
||||
@@ -232,8 +233,16 @@ impl ProjectIndex {
|
||||
&self,
|
||||
queries: Vec<String>,
|
||||
limit: usize,
|
||||
mixing_param: f32,
|
||||
cx: &AppContext,
|
||||
) -> Task<Result<Vec<SearchResult>>> {
|
||||
if !(0.0..=1.0).contains(&mixing_param) {
|
||||
return cx.spawn(|_| async move {
|
||||
Err(anyhow!(
|
||||
"Semantic search mixing param must be between 0 and 1"
|
||||
))
|
||||
});
|
||||
}
|
||||
let (chunks_tx, chunks_rx) = channel::bounded(1024);
|
||||
let mut worktree_scan_tasks = Vec::new();
|
||||
for worktree_index in self.worktree_indices.values() {
|
||||
@@ -241,9 +250,11 @@ impl ProjectIndex {
|
||||
let chunks_tx = chunks_tx.clone();
|
||||
worktree_scan_tasks.push(cx.spawn(|cx| async move {
|
||||
let index = match worktree_index {
|
||||
WorktreeIndexHandle::Loading { index } => {
|
||||
index.clone().await.map_err(|error| anyhow!(error))?
|
||||
}
|
||||
WorktreeIndexHandle::Loading { index } => index
|
||||
.clone()
|
||||
.await
|
||||
.map_err(|error| anyhow!(error))
|
||||
.context("loading worktree index failure")?,
|
||||
WorktreeIndexHandle::Loaded { index } => index.clone(),
|
||||
};
|
||||
|
||||
@@ -252,22 +263,32 @@ impl ProjectIndex {
|
||||
let worktree_id = index.worktree().read(cx).id();
|
||||
let db_connection = index.db_connection().clone();
|
||||
let db = *index.embedding_index().db();
|
||||
let worktree_corpus_stats =
|
||||
index.embedding_index().worktree_corpus_stats.clone();
|
||||
cx.background_executor().spawn(async move {
|
||||
let txn = db_connection
|
||||
.read_txn()
|
||||
.context("failed to create read transaction")?;
|
||||
let db_entries = db.iter(&txn).context("failed to iterate database")?;
|
||||
for db_entry in db_entries {
|
||||
let (_key, db_embedded_file) = db_entry?;
|
||||
for chunk in db_embedded_file.chunks {
|
||||
let (_key, db_embedded_file) =
|
||||
db_entry.context("failed to read embedded file")?;
|
||||
for chunk in &db_embedded_file.chunks {
|
||||
chunks_tx
|
||||
.send((worktree_id, db_embedded_file.path.clone(), chunk))
|
||||
.await?;
|
||||
.send((
|
||||
worktree_id,
|
||||
worktree_corpus_stats.clone(),
|
||||
db_embedded_file.path.clone(),
|
||||
chunk.clone(),
|
||||
))
|
||||
.await
|
||||
.context("failed to send chunks")?;
|
||||
}
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
})?
|
||||
})
|
||||
.context("read_with WorktreeIndex failed")?
|
||||
.await
|
||||
}));
|
||||
}
|
||||
@@ -275,21 +296,43 @@ impl ProjectIndex {
|
||||
|
||||
let project = self.project.clone();
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let bm25_params = Bm25Parameters::default();
|
||||
cx.spawn(|cx| async move {
|
||||
log::info!("Searching for {queries:?}");
|
||||
// BM-25: Tokenize query
|
||||
#[cfg(debug_assertions)]
|
||||
let bm25_query_start = std::time::Instant::now();
|
||||
let tokenizer = SimpleTokenizer::new();
|
||||
let terms_by_query: Vec<HashMap<Arc<str>, u32>> = queries
|
||||
.iter()
|
||||
.map(|query| {
|
||||
tokenizer.tokenize_and_stem(query).into_iter().fold(
|
||||
HashMap::new(),
|
||||
|mut acc, term| {
|
||||
*acc.entry(term).or_insert(0) += 1;
|
||||
acc
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
#[cfg(debug_assertions)]
|
||||
let bm25_query_end = std::time::Instant::now();
|
||||
|
||||
// Similarity search: Embed query
|
||||
#[cfg(debug_assertions)]
|
||||
let embedding_query_start = std::time::Instant::now();
|
||||
log::info!("Searching for {queries:?}");
|
||||
let queries: Vec<TextToEmbed> = queries
|
||||
.iter()
|
||||
.map(|s| TextToEmbed::new(s.as_str()))
|
||||
.collect();
|
||||
|
||||
let query_embeddings = embedding_provider.embed(&queries[..]).await?;
|
||||
if query_embeddings.len() != queries.len() {
|
||||
return Err(anyhow!(
|
||||
"The number of query embeddings does not match the number of queries"
|
||||
));
|
||||
}
|
||||
#[cfg(debug_assertions)]
|
||||
let embedding_query_end = std::time::Instant::now();
|
||||
|
||||
let mut results_by_worker = Vec::new();
|
||||
for _ in 0..cx.background_executor().num_cpus() {
|
||||
@@ -302,28 +345,58 @@ impl ProjectIndex {
|
||||
.scoped(|cx| {
|
||||
for results in results_by_worker.iter_mut() {
|
||||
cx.spawn(async {
|
||||
while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
|
||||
let (score, query_index) =
|
||||
chunk.embedding.similarity(&query_embeddings);
|
||||
while let Ok((worktree_id, worktree_corpus_stats, path, chunk)) =
|
||||
chunks_rx.recv().await
|
||||
{
|
||||
// iterate over every (query_embedding, query_term) and compute its hybrid retrieval score for this chunk
|
||||
// RetScore(chunk, query_embedding, query_term, m) = m * Sim(query_embedding, chunk) + (1 - m) * Bm25(query_term, chunk)]
|
||||
let hybrid_scores: Vec<f32> = query_embeddings
|
||||
.iter()
|
||||
.zip(terms_by_query.iter())
|
||||
.map(|(query_embedding, query_term)| {
|
||||
let (embedding_score, _) =
|
||||
query_embedding.similarity(&[chunk.embedding.clone()]);
|
||||
let bm25_score = {
|
||||
let corpus_stats =
|
||||
worktree_corpus_stats.read().unwrap();
|
||||
let score = corpus_stats.calculate_bm25_score(
|
||||
query_term,
|
||||
&chunk.term_frequencies.0,
|
||||
bm25_params.k1,
|
||||
bm25_params.b,
|
||||
);
|
||||
// quick hack to bound the score for long queries
|
||||
score / query_term.values().sum::<u32>() as f32
|
||||
};
|
||||
mixing_param * embedding_score
|
||||
+ (1. - mixing_param) * bm25_score
|
||||
})
|
||||
.collect();
|
||||
|
||||
let ix = match results.binary_search_by(|probe| {
|
||||
score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
Ok(ix) | Err(ix) => ix,
|
||||
};
|
||||
if ix < limit {
|
||||
results.insert(
|
||||
ix,
|
||||
WorktreeSearchResult {
|
||||
worktree_id,
|
||||
path: path.clone(),
|
||||
range: chunk.chunk.range.clone(),
|
||||
query_index,
|
||||
score,
|
||||
},
|
||||
);
|
||||
if results.len() > limit {
|
||||
results.pop();
|
||||
for (query_index, score) in hybrid_scores.into_iter().enumerate() {
|
||||
if score != 0.0 {
|
||||
let ix = match results.binary_search_by(|probe| {
|
||||
score
|
||||
.partial_cmp(&probe.score)
|
||||
.unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
Ok(ix) | Err(ix) => ix,
|
||||
};
|
||||
if ix < limit {
|
||||
results.insert(
|
||||
ix,
|
||||
WorktreeSearchResult {
|
||||
worktree_id,
|
||||
path: path.clone(),
|
||||
range: chunk.chunk.range.clone(),
|
||||
query_index,
|
||||
score,
|
||||
},
|
||||
);
|
||||
if results.len() > limit {
|
||||
results.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -362,7 +435,9 @@ impl ProjectIndex {
|
||||
search_results.len(),
|
||||
search_elapsed
|
||||
);
|
||||
let embedding_query_elapsed = embedding_query_start.elapsed();
|
||||
let bm25_query_elapsed = bm25_query_start - bm25_query_end;
|
||||
log::debug!("tokenizing query took {:?}", bm25_query_elapsed);
|
||||
let embedding_query_elapsed = embedding_query_start - embedding_query_end;
|
||||
log::debug!("embedding query took {:?}", embedding_query_elapsed);
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ mod project_index;
|
||||
mod project_index_debug_view;
|
||||
mod summary_backlog;
|
||||
mod summary_index;
|
||||
mod tfidf;
|
||||
mod worktree_index;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
@@ -28,6 +29,8 @@ pub use project_index::{LoadedSearchResult, ProjectIndex, SearchResult, Status};
|
||||
pub use project_index_debug_view::ProjectIndexDebugView;
|
||||
pub use summary_index::FileSummary;
|
||||
|
||||
pub const SEMANTIC_INDEX_DB_VERSION: usize = 1;
|
||||
|
||||
pub struct SemanticDb {
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
db_connection: Option<heed::Env>,
|
||||
@@ -268,7 +271,7 @@ mod tests {
|
||||
use super::*;
|
||||
use anyhow::anyhow;
|
||||
use chunking::Chunk;
|
||||
use embedding_index::{ChunkedFile, EmbeddingIndex};
|
||||
use embedding_index::{ChunkedFile, EmbeddingIndex, EmbeddingIndexSettings};
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
use fs::FakeFs;
|
||||
use futures::{future::BoxFuture, FutureExt};
|
||||
@@ -280,6 +283,7 @@ mod tests {
|
||||
use settings::SettingsStore;
|
||||
use smol::{channel, stream::StreamExt};
|
||||
use std::{future, path::Path, sync::Arc};
|
||||
use tfidf::SimpleTokenizer;
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
env_logger::try_init().ok();
|
||||
@@ -398,7 +402,7 @@ mod tests {
|
||||
.update(|cx| {
|
||||
let project_index = project_index.read(cx);
|
||||
let query = "garbage in, garbage out";
|
||||
project_index.search(vec![query.into()], 4, cx)
|
||||
project_index.search(vec![query.into()], 4, 0.7, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -487,8 +491,18 @@ mod tests {
|
||||
.unwrap();
|
||||
chunked_files_tx.close();
|
||||
|
||||
let embed_files_task =
|
||||
cx.update(|cx| EmbeddingIndex::embed_files(provider.clone(), chunked_files_rx, cx));
|
||||
let tokenizer = SimpleTokenizer::new();
|
||||
let embedding_settings = EmbeddingIndexSettings::default();
|
||||
|
||||
let embed_files_task = cx.update(|cx| {
|
||||
EmbeddingIndex::embed_files(
|
||||
provider.clone(),
|
||||
chunked_files_rx,
|
||||
tokenizer,
|
||||
embedding_settings,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
embed_files_task.task.await.unwrap();
|
||||
|
||||
let mut embedded_files_rx = embed_files_task.files;
|
||||
|
||||
184
crates/semantic_index/src/tfidf.rs
Normal file
184
crates/semantic_index/src/tfidf.rs
Normal file
@@ -0,0 +1,184 @@
|
||||
use rust_stemmers::{Algorithm, Stemmer};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use stop_words;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SimpleTokenizer {
|
||||
stemmer: Arc<Stemmer>,
|
||||
stopwords: Vec<String>,
|
||||
}
|
||||
|
||||
impl SimpleTokenizer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
// TODO: handle non-English
|
||||
stemmer: Arc::new(Stemmer::create(Algorithm::English)),
|
||||
stopwords: stop_words::get(stop_words::LANGUAGE::English),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tokenize_and_stem(&self, text: &str) -> Vec<Arc<str>> {
|
||||
// Split on whitespace and punctuation
|
||||
text.split(|c: char| c.is_whitespace() || c.is_ascii_punctuation())
|
||||
.flat_map(|word| {
|
||||
word.chars().fold(vec![String::new()], |mut acc, c| {
|
||||
// Split CamelCaps and camelCase
|
||||
if c.is_uppercase()
|
||||
&& !acc.is_empty()
|
||||
&& acc
|
||||
.last()
|
||||
.and_then(|s| s.chars().last())
|
||||
.map_or(false, |last_char| last_char.is_lowercase())
|
||||
{
|
||||
acc.push(String::new());
|
||||
}
|
||||
acc.last_mut()
|
||||
.unwrap_or(&mut String::new())
|
||||
.push(c.to_lowercase().next().unwrap());
|
||||
acc
|
||||
})
|
||||
})
|
||||
.filter(|s| !s.is_empty() && !self.stopwords.contains(s))
|
||||
.map(|word| {
|
||||
// Stem each word and convert to Arc<str>
|
||||
let stemmed = self.stemmer.stem(&word).to_string();
|
||||
Arc::from(stemmed)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Bm25Parameters {
|
||||
pub k1: f32,
|
||||
pub b: f32,
|
||||
}
|
||||
|
||||
impl Default for Bm25Parameters {
|
||||
fn default() -> Self {
|
||||
Self { k1: 1.2, b: 0.75 }
|
||||
}
|
||||
}
|
||||
pub trait Bm25Scorer {
|
||||
fn total_chunks(&self) -> u32;
|
||||
fn avg_chunk_length(&self) -> f32;
|
||||
fn term_frequency(&self, term: &Arc<str>, chunk_term_counts: &HashMap<Arc<str>, u32>) -> u32;
|
||||
fn document_frequency(&self, term: &Arc<str>) -> u32;
|
||||
|
||||
fn calculate_bm25_score(
|
||||
&self,
|
||||
query_terms: &HashMap<Arc<str>, u32>,
|
||||
chunk_terms: &HashMap<Arc<str>, u32>,
|
||||
k1: f32,
|
||||
b: f32,
|
||||
) -> f32 {
|
||||
// average doc length, current doc length, total docs
|
||||
let avg_dl = self.avg_chunk_length();
|
||||
let dl = chunk_terms.values().sum::<u32>() as f32;
|
||||
let dn = self.total_chunks() as f32;
|
||||
|
||||
query_terms
|
||||
.iter()
|
||||
.map(|(term, &query_tf)| {
|
||||
let tf = self.term_frequency(term, chunk_terms) as f32;
|
||||
let df = self.document_frequency(term) as f32;
|
||||
let idf = ((dn - df + 0.5) / (df + 0.5) + 1.0).ln();
|
||||
let numerator = tf * (k1 + 1.0);
|
||||
let denominator = tf + k1 * (1.0 - b + b * dl / avg_dl);
|
||||
|
||||
idf * (numerator / denominator)
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TermCounts(pub HashMap<Arc<str>, u32>);
|
||||
|
||||
impl TermCounts {
|
||||
pub fn from_text(text: &str, tokenizer: &SimpleTokenizer) -> Self {
|
||||
let tokens = tokenizer.tokenize_and_stem(text);
|
||||
let mut terms = HashMap::new();
|
||||
|
||||
for token in tokens {
|
||||
*terms.entry(token).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
TermCounts(terms)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the term frequency statistics for a single worktree.
|
||||
///
|
||||
/// This struct contains information about chunks, term statistics,
|
||||
/// and the total length of all chunks in the worktree.
|
||||
#[derive(Debug)]
|
||||
pub struct WorktreeTermStats {
|
||||
/// A map of terms to their counts across all chunks in this worktree.
|
||||
term_counts: HashMap<Arc<str>, u32>,
|
||||
/// The total length of all chunks in this worktree.
|
||||
total_length: u32,
|
||||
/// The total number of chunks tracked in this worktree.
|
||||
total_chunks: u32,
|
||||
}
|
||||
|
||||
impl WorktreeTermStats {
|
||||
pub fn new(term_counts: HashMap<Arc<str>, u32>, total_length: u32, total_chunks: u32) -> Self {
|
||||
Self {
|
||||
term_counts,
|
||||
total_length,
|
||||
total_chunks,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_counts(&mut self, chunk_counts: &TermCounts) {
|
||||
let mut chunk_length = 0;
|
||||
for (term, &freq) in &chunk_counts.0 {
|
||||
let counts = self.term_counts.entry(term.clone()).or_insert(0);
|
||||
*counts += freq;
|
||||
chunk_length += freq;
|
||||
}
|
||||
self.total_length += chunk_length;
|
||||
self.total_chunks += 1;
|
||||
}
|
||||
|
||||
pub fn remove_counts(&mut self, chunk_counts: &TermCounts) -> () {
|
||||
debug_assert!(chunk_counts.0.len() <= self.term_counts.len());
|
||||
debug_assert!(chunk_counts
|
||||
.0
|
||||
.keys()
|
||||
.all(|k| self.term_counts.contains_key(k)));
|
||||
|
||||
let mut chunk_length = 0;
|
||||
for (term, &freq) in &chunk_counts.0 {
|
||||
if let Some(stats) = self.term_counts.get_mut(term) {
|
||||
*stats -= freq;
|
||||
chunk_length += 0;
|
||||
}
|
||||
}
|
||||
self.total_length -= chunk_length;
|
||||
self.total_chunks -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
impl Bm25Scorer for WorktreeTermStats {
|
||||
fn total_chunks(&self) -> u32 {
|
||||
self.total_chunks
|
||||
}
|
||||
|
||||
fn avg_chunk_length(&self) -> f32 {
|
||||
if self.total_chunks == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.total_length as f32 / self.total_chunks as f32
|
||||
}
|
||||
}
|
||||
|
||||
fn term_frequency(&self, term: &Arc<str>, chunk_term_counts: &HashMap<Arc<str>, u32>) -> u32 {
|
||||
*chunk_term_counts.get(term).unwrap_or(&0)
|
||||
}
|
||||
|
||||
fn document_frequency(&self, term: &Arc<str>) -> u32 {
|
||||
*self.term_counts.get(term).unwrap_or(&0)
|
||||
}
|
||||
}
|
||||
@@ -46,14 +46,14 @@ impl WorktreeIndex {
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
cx: &mut AppContext,
|
||||
) -> Task<Result<Model<Self>>> {
|
||||
let worktree_for_index = worktree.clone();
|
||||
let worktree_for_embedding = worktree.clone();
|
||||
let worktree_for_summary = worktree.clone();
|
||||
let worktree_abs_path = worktree.read(cx).abs_path();
|
||||
let embedding_fs = Arc::clone(&fs);
|
||||
let summary_fs = fs;
|
||||
cx.spawn(|mut cx| async move {
|
||||
let entries_being_indexed = Arc::new(IndexingEntrySet::new(status_tx));
|
||||
let (embedding_index, summary_index) = cx
|
||||
let (embedding_index, summary_index): (EmbeddingIndex, SummaryIndex) = cx
|
||||
.background_executor()
|
||||
.spawn({
|
||||
let entries_being_indexed = Arc::clone(&entries_being_indexed);
|
||||
@@ -63,9 +63,8 @@ impl WorktreeIndex {
|
||||
let embedding_index = {
|
||||
let db_name = worktree_abs_path.to_string_lossy();
|
||||
let db = db_connection.create_database(&mut txn, Some(&db_name))?;
|
||||
|
||||
EmbeddingIndex::new(
|
||||
worktree_for_index,
|
||||
worktree_for_embedding,
|
||||
embedding_fs,
|
||||
db_connection.clone(),
|
||||
db,
|
||||
@@ -101,7 +100,7 @@ impl WorktreeIndex {
|
||||
)
|
||||
};
|
||||
txn.commit()?;
|
||||
anyhow::Ok((embedding_index, summary_index))
|
||||
Ok::<_, anyhow::Error>((embedding_index, summary_index))
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
|
||||
Reference in New Issue
Block a user