Compare commits

...

28 Commits

Author SHA1 Message Date
Jason Mancuso
c45755c088 Merge remote-tracking branch 'origin/main' into tfidf-indexer 2024-10-09 12:39:16 -04:00
Jason Mancuso
b8cb6a1059 Make search settings configurable in code search eval 2024-10-09 12:38:25 -04:00
Jason Mancuso
fbd8b2b587 Quick attempt to normalize bm25 score by query length 2024-10-09 11:56:16 -04:00
Jason Mancuso
48ac888be3 Add stopword removal to tokenizer 2024-10-09 11:54:51 -04:00
Jason Mancuso
2dc70d64cd Add temporary dbg to illustrate current hybrid score calibration issue 2024-10-09 10:46:58 -04:00
Jason Mancuso
ab4b2bd204 Fix bm25 calculation to avoid NaNs 2024-10-08 12:22:04 -04:00
Jason Mancuso
2e1ee2bcc8 Fix bug in tokenizer leading to empty results 2024-10-08 12:21:40 -04:00
Jason Mancuso
7c8d982caf Add some error contexts 2024-10-08 11:05:42 -04:00
Jason Mancuso
966dbd30f6 Alphabetize Cargo.toml 2024-10-08 11:05:05 -04:00
Jason Mancuso
db1dc47ddb Merge remote-tracking branch 'origin/main' into tfidf-indexer 2024-10-08 11:02:31 -04:00
Jason Mancuso
a1cb4ec947 Filter out search results with score=0 2024-10-08 11:01:03 -04:00
Jason Mancuso
671872c47b Update term frequency accounting only during db transactions 2024-10-08 11:00:22 -04:00
Jason Mancuso
4f4497d0e3 Version semantic db lmdb 2024-10-08 10:53:13 -04:00
Jason Mancuso
5606768679 Simplify WorktreeTermStats; rearrange EmbeddingIndex term frequency accounting 2024-10-07 18:48:33 -04:00
Jason Mancuso
6cc04f71f5 Add args to /search and /project to configure search style 2024-10-03 20:11:58 -04:00
Jason Mancuso
9cfa2933dd Audit and fix lifecycle of WorktreeTermStats in EmbeddingIndex 2024-09-30 16:10:51 -04:00
Jason Mancuso
ca8f9c7476 Fix hybrid-retrieval scoring in ProjectIndex::search 2024-09-27 19:59:28 -04:00
Jason Mancuso
ce70cd00b6 return some debug logs that accidentally got deleted 2024-09-27 19:58:55 -04:00
Jason Mancuso
a74f1766f0 Note to self 2024-09-27 19:33:37 -04:00
Jason Mancuso
a85d773fe2 Clean up some dead code 2024-09-27 19:22:43 -04:00
Jason Mancuso
83d96cf369 Refactor tf-idf utilities again and integrate elsewhere 2024-09-27 18:42:34 -04:00
Jason Mancuso
acf7ad3d83 Initialize and store a CorpusTermFrequency per-worktree in memory 2024-09-27 16:30:50 -04:00
Jason Mancuso
4f9f2e52f6 Revert change to batching logic in EmbeddingIndex; don't store tfidf metadata in DB 2024-09-27 12:37:54 -04:00
Jason Mancuso
ef2f236355 Add Bm25Calculator, use it in ProjectIndex::search 2024-09-26 13:16:59 -04:00
Jason Mancuso
5510bb6715 Add tfidf metadata key const 2024-09-26 09:27:57 -04:00
Jason Mancuso
c1cfd6dc4b Some cleanup 2024-09-25 19:57:44 -04:00
Jason Mancuso
36ab101011 Integrate tfidf counts into EmbeddingIndex and replace with homebaked tokenizer 2024-09-25 18:19:50 -04:00
Jason Mancuso
91e2e815a6 Add TfIdfIndexer WIP 2024-09-24 15:21:16 -04:00
13 changed files with 575 additions and 93 deletions

21
Cargo.lock generated
View File

@@ -9492,6 +9492,16 @@ dependencies = [
"walkdir",
]
[[package]]
name = "rust-stemmers"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e46a2036019fdb888131db7a4c847a1063a7493f971ed94ea82c67eada63ca54"
dependencies = [
"serde",
"serde_derive",
]
[[package]]
name = "rust_decimal"
version = "1.36.0"
@@ -9969,11 +9979,13 @@ dependencies = [
"open_ai",
"parking_lot",
"project",
"rust-stemmers",
"serde",
"serde_json",
"settings",
"sha2",
"smol",
"stop-words",
"tempfile",
"theme",
"tree-sitter",
@@ -10850,6 +10862,15 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "stop-words"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8500024d809de02ecbf998472b7bed3c4fca380df2be68917f6a473bdb28ddcc"
dependencies = [
"serde_json",
]
[[package]]
name = "story"
version = "0.1.0"

View File

@@ -399,8 +399,9 @@ rsa = "0.9.6"
runtimelib = { version = "0.15", default-features = false, features = [
"async-dispatcher-runtime",
] }
rustc-demangle = "0.1.23"
rust-embed = { version = "8.4", features = ["include-exclude"] }
rust-stemmers = "1.2"
rustc-demangle = "0.1.23"
rustls = "0.20.3"
rustls-native-certs = "0.8.0"
schemars = { version = "0.8", features = ["impl_json_schema"] }
@@ -422,6 +423,7 @@ simplelog = "0.12.2"
smallvec = { version = "1.6", features = ["union"] }
smol = "1.2"
sqlformat = "0.2"
stop-words = "0.8"
strsim = "0.11"
strum = { version = "0.25.0", features = ["derive"] }
subtle = "2.5.0"

View File

@@ -37,7 +37,7 @@ use language_model::{
pub(crate) use model_selector::*;
pub use prompts::PromptBuilder;
use prompts::PromptLoadingParams;
use semantic_index::{CloudEmbeddingProvider, SemanticDb};
use semantic_index::{CloudEmbeddingProvider, SemanticDb, SEMANTIC_INDEX_DB_VERSION};
use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsStore};
use slash_command::{
@@ -215,7 +215,8 @@ pub fn init(
async move {
let embedding_provider = CloudEmbeddingProvider::new(client.clone());
let semantic_index = SemanticDb::new(
paths::embeddings_dir().join("semantic-index-db.0.mdb"),
paths::embeddings_dir()
.join(format!("semantic-index-db.{SEMANTIC_INDEX_DB_VERSION}.mdb")),
Arc::new(embedding_provider),
&mut cx,
)

View File

@@ -1,6 +1,7 @@
use super::{
create_label_for_command, search_command::add_search_result_section, SlashCommand,
SlashCommandOutput,
create_label_for_command,
search_command::{add_search_result_section, SearchStyle},
SlashCommand, SlashCommandOutput,
};
use crate::PromptBuilder;
use anyhow::{anyhow, Result};
@@ -70,7 +71,7 @@ impl SlashCommand for ProjectSlashCommand {
fn run(
self: Arc<Self>,
_arguments: &[String],
arguments: &[String],
_context_slash_command_output_sections: &[SlashCommandOutputSection<Anchor>],
context_buffer: language::BufferSnapshot,
workspace: WeakView<Workspace>,
@@ -80,6 +81,33 @@ impl SlashCommand for ProjectSlashCommand {
let model_registry = LanguageModelRegistry::read_global(cx);
let current_model = model_registry.active_model();
let prompt_builder = self.prompt_builder.clone();
let mut style = SearchStyle::Hybrid;
let mut arg_iter = arguments.iter();
if let Some(arg) = arg_iter.next() {
if arg == "--style" {
if let Some(style_value) = arg_iter.next() {
match style_value.as_str() {
"dense" => style = SearchStyle::Dense,
"sparse" => style = SearchStyle::Sparse,
"hybrid" => style = SearchStyle::Hybrid,
_ => {
return Task::ready(Err(anyhow::anyhow!(
"Invalid style parameter; should be 'dense', 'sparse', or 'hybrid'."
)))
}
}
} else {
return Task::ready(Err(anyhow::anyhow!(
"Missing value for --style parameter"
)));
}
}
}
let search_param = match style {
SearchStyle::Dense => 1.0,
SearchStyle::Sparse => 0.0,
SearchStyle::Hybrid => 0.7,
};
let Some(workspace) = workspace.upgrade() else {
return Task::ready(Err(anyhow::anyhow!("workspace was dropped")));
@@ -117,7 +145,7 @@ impl SlashCommand for ProjectSlashCommand {
let results = project_index
.read_with(&cx, |project_index, cx| {
project_index.search(search_queries.clone(), 25, cx)
project_index.search(search_queries.clone(), 25, search_param, cx)
})?
.await?;

View File

@@ -24,13 +24,19 @@ impl FeatureFlag for SearchSlashCommandFeatureFlag {
pub(crate) struct SearchSlashCommand;
pub(crate) enum SearchStyle {
Dense,
Hybrid,
Sparse,
}
impl SlashCommand for SearchSlashCommand {
fn name(&self) -> String {
"search".into()
}
fn label(&self, cx: &AppContext) -> CodeLabel {
create_label_for_command("search", &["--n"], cx)
create_label_for_command("search", &["--n", "--style {*hybrid,dense,sparse}"], cx)
}
fn description(&self) -> String {
@@ -72,24 +78,58 @@ impl SlashCommand for SearchSlashCommand {
};
let mut limit = None;
let mut query = String::new();
for part in arguments {
if let Some(parameter) = part.strip_prefix("--") {
if let Ok(count) = parameter.parse::<usize>() {
limit = Some(count);
let mut query = Vec::new();
let mut arg_iter = arguments.iter();
let mut style = SearchStyle::Hybrid;
while let Some(arg) = arg_iter.next() {
if arg == "--n" {
if let Some(count) = arg_iter.next() {
if let Ok(parsed_count) = count.parse::<usize>() {
limit = Some(parsed_count);
continue;
} else {
return Task::ready(Err(anyhow::anyhow!(
"Invalid count for --n parameter; should be a positive integer."
)));
}
} else {
return Task::ready(Err(anyhow::anyhow!("Missing count for --n parameter")));
}
} else if arg == "--style" {
if let Some(style_value) = arg_iter.next() {
match style_value.as_str() {
"dense" => style = SearchStyle::Dense,
"sparse" => style = SearchStyle::Sparse,
"hybrid" => style = SearchStyle::Hybrid,
_ => {
return Task::ready(Err(anyhow::anyhow!(
"Invalid style parameter; should be 'dense', 'sparse', or 'hybrid'."
)))
}
}
continue;
} else {
return Task::ready(Err(anyhow::anyhow!(
"Missing value for --style parameter"
)));
}
}
query.push_str(part);
query.push(' ');
query.push(arg.clone());
}
query.pop();
let query = query.join(" ");
if query.is_empty() {
return Task::ready(Err(anyhow::anyhow!("missing search query")));
}
let search_param = match style {
SearchStyle::Dense => 1.0,
SearchStyle::Sparse => 0.0,
SearchStyle::Hybrid => 0.7,
};
let project = workspace.read(cx).project().clone();
let fs = project.read(cx).fs().clone();
let Some(project_index) =
@@ -99,12 +139,13 @@ impl SlashCommand for SearchSlashCommand {
};
cx.spawn(|cx| async move {
let results = project_index
let mut results = project_index
.read_with(&cx, |project_index, cx| {
project_index.search(vec![query.clone()], limit.unwrap_or(5), cx)
project_index.search(vec![query.clone()], limit.unwrap_or(5), search_param, cx)
})?
.await?;
results = dbg!(results);
let loaded_results = SemanticDb::load_results(results, &fs, &cx).await?;
let output = cx

View File

@@ -36,8 +36,9 @@ use std::{
const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
const EVAL_DB_PATH: &'static str = "target/eval_db";
const SEARCH_RESULT_LIMIT: usize = 8;
const SKIP_EVAL_PATH: &'static str = ".skip_eval";
const DEFAULT_SEARCH_PARAM: f32 = 0.7;
const DEFAULT_SEARCH_RESULT_LIMIT: usize = 8;
#[derive(clap::Parser)]
#[command(author, version, about, long_about = None)]
@@ -50,6 +51,10 @@ struct Cli {
enum Commands {
Fetch {},
Run {
#[arg(long, default_value_t = DEFAULT_SEARCH_PARAM)]
search_param: f32,
#[arg(long, default_value_t = DEFAULT_SEARCH_RESULT_LIMIT)]
search_result_limit: usize,
#[arg(long)]
repo: Option<String>,
},
@@ -115,9 +120,16 @@ fn main() -> Result<()> {
})
.detach();
}
Commands::Run { repo } => {
Commands::Run {
search_param,
search_result_limit,
repo,
} => {
cx.spawn(|mut cx| async move {
if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
if let Err(err) =
run_evaluation(search_param, search_result_limit, repo, &executor, &mut cx)
.await
{
eprintln!("Error: {}", err);
exit(1);
}
@@ -249,6 +261,8 @@ struct Counts {
}
async fn run_evaluation(
search_param: f32,
search_result_limit: usize,
only_repo: Option<String>,
executor: &BackgroundExecutor,
cx: &mut AsyncAppContext,
@@ -359,6 +373,8 @@ async fn run_evaluation(
let repo = evaluation_project.repo.clone();
if let Err(err) = run_eval_project(
evaluation_project,
search_param,
search_result_limit,
&user_store,
repo_db_path,
&repo_dir,
@@ -403,6 +419,8 @@ async fn run_evaluation(
#[allow(clippy::too_many_arguments)]
async fn run_eval_project(
evaluation_project: EvaluationProject,
search_param: f32,
search_result_limit: usize,
user_store: &Model<UserStore>,
repo_db_path: PathBuf,
repo_dir: &Path,
@@ -438,7 +456,12 @@ async fn run_eval_project(
loop {
match cx.update(|cx| {
let project_index = project_index.read(cx);
project_index.search(vec![query.query.clone()], SEARCH_RESULT_LIMIT, cx)
project_index.search(
vec![query.query.clone()],
search_result_limit,
search_param,
cx,
)
}) {
Ok(task) => match task.await {
Ok(answer) => {

View File

@@ -37,11 +37,13 @@ log.workspace = true
open_ai.workspace = true
parking_lot.workspace = true
project.workspace = true
rust-stemmers.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
sha2.workspace = true
smol.workspace = true
stop-words.workspace = true
theme.workspace = true
tree-sitter.workspace = true
ui. workspace = true

View File

@@ -98,7 +98,7 @@ fn main() {
.update(|cx| {
let project_index = project_index.read(cx);
let query = "converting an anchor to a point";
project_index.search(vec![query.into()], 4, cx)
project_index.search(vec![query.into()], 4, 0.7, cx)
})
.unwrap()
.await

View File

@@ -2,6 +2,7 @@ use crate::{
chunking::{self, Chunk},
embedding::{Embedding, EmbeddingProvider, TextToEmbed},
indexing::{IndexingEntryHandle, IndexingEntrySet},
tfidf::{SimpleTokenizer, TermCounts, WorktreeTermStats},
};
use anyhow::{anyhow, Context as _, Result};
use collections::Bound;
@@ -17,15 +18,35 @@ use serde::{Deserialize, Serialize};
use smol::channel;
use std::{
cmp::Ordering,
collections::HashMap,
future::Future,
iter,
path::Path,
sync::Arc,
sync::{Arc, RwLock},
time::{Duration, SystemTime},
};
use util::ResultExt;
use worktree::Snapshot;
#[derive(Debug, Clone, Copy)]
pub struct EmbeddingIndexSettings {
pub scan_entries_bound: usize,
pub deleted_entries_bound: usize,
pub chunk_files_bound: usize,
pub embed_files_bound: usize,
}
impl Default for EmbeddingIndexSettings {
fn default() -> Self {
Self {
scan_entries_bound: 512,
deleted_entries_bound: 128,
chunk_files_bound: 2048,
embed_files_bound: 512,
}
}
}
pub struct EmbeddingIndex {
worktree: Model<Worktree>,
db_connection: heed::Env,
@@ -34,6 +55,9 @@ pub struct EmbeddingIndex {
language_registry: Arc<LanguageRegistry>,
embedding_provider: Arc<dyn EmbeddingProvider>,
entry_ids_being_indexed: Arc<IndexingEntrySet>,
pub worktree_corpus_stats: Arc<RwLock<WorktreeTermStats>>,
tokenizer: SimpleTokenizer,
settings: EmbeddingIndexSettings,
}
impl EmbeddingIndex {
@@ -41,7 +65,7 @@ impl EmbeddingIndex {
worktree: Model<Worktree>,
fs: Arc<dyn Fs>,
db_connection: heed::Env,
embedding_db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
language_registry: Arc<LanguageRegistry>,
embedding_provider: Arc<dyn EmbeddingProvider>,
entry_ids_being_indexed: Arc<IndexingEntrySet>,
@@ -50,10 +74,17 @@ impl EmbeddingIndex {
worktree,
fs,
db_connection,
db: embedding_db,
db,
language_registry,
embedding_provider,
entry_ids_being_indexed,
worktree_corpus_stats: Arc::new(RwLock::new(WorktreeTermStats::new(
HashMap::new(),
0,
0,
))),
tokenizer: SimpleTokenizer::new(),
settings: EmbeddingIndexSettings::default(),
}
}
@@ -62,14 +93,20 @@ impl EmbeddingIndex {
}
pub fn index_entries_changed_on_disk(
&self,
&mut self,
cx: &AppContext,
) -> impl Future<Output = Result<()>> {
let worktree = self.worktree.read(cx).snapshot();
let worktree_abs_path = worktree.abs_path().clone();
let scan = self.scan_entries(worktree, cx);
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
let embed = Self::embed_files(
self.embedding_provider.clone(),
chunk.files,
self.tokenizer.clone(),
self.settings,
cx,
);
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
async move {
futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
@@ -78,7 +115,7 @@ impl EmbeddingIndex {
}
pub fn index_updated_entries(
&self,
&mut self,
updated_entries: UpdatedEntriesSet,
cx: &AppContext,
) -> impl Future<Output = Result<()>> {
@@ -86,7 +123,13 @@ impl EmbeddingIndex {
let worktree_abs_path = worktree.abs_path().clone();
let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
let embed = Self::embed_files(
self.embedding_provider.clone(),
chunk.files,
self.tokenizer.clone(),
self.settings,
cx,
);
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
async move {
futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
@@ -94,12 +137,16 @@ impl EmbeddingIndex {
}
}
fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
fn scan_entries(&mut self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
let (updated_entries_tx, updated_entries_rx) =
channel::bounded(self.settings.scan_entries_bound);
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) =
channel::bounded(self.settings.deleted_entries_bound);
let db_connection = self.db_connection.clone();
let db = self.db;
let entries_being_indexed = self.entry_ids_being_indexed.clone();
let worktree_corpus_stats = self.worktree_corpus_stats.clone();
let task = cx.background_executor().spawn(async move {
let txn = db_connection
.read_txn()
@@ -113,8 +160,15 @@ impl EmbeddingIndex {
let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
for entry in worktree.files(false, 0) {
log::trace!("scanning for embedding index: {:?}", &entry.path);
let entry_db_key = db_key_for_path(&entry.path);
if let Some(embedded_file) = db.get(&txn, &entry_db_key)? {
// initialize worktree_corpus_stats from embedded files in database
update_corpus_stats(worktree_corpus_stats.clone(), |stats| {
for chunk in &embedded_file.chunks {
stats.add_counts(&chunk.term_frequencies);
}
});
}
let mut saved_mtime = None;
while let Some(db_entry) = db_entries.peek() {
@@ -152,6 +206,7 @@ impl EmbeddingIndex {
}
if entry.mtime != saved_mtime {
// corpus stats will be adjusted later, while these are chunked/embedded
let handle = entries_being_indexed.insert(entry.id);
updated_entries_tx.send((entry.clone(), handle)).await?;
}
@@ -180,8 +235,10 @@ impl EmbeddingIndex {
updated_entries: UpdatedEntriesSet,
cx: &AppContext,
) -> ScanEntries {
let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
let (updated_entries_tx, updated_entries_rx) =
channel::bounded(self.settings.scan_entries_bound);
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) =
channel::bounded(self.settings.deleted_entries_bound);
let entries_being_indexed = self.entry_ids_being_indexed.clone();
let task = cx.background_executor().spawn(async move {
for (path, entry_id, status) in updated_entries.iter() {
@@ -191,6 +248,7 @@ impl EmbeddingIndex {
| project::PathChange::AddedOrUpdated => {
if let Some(entry) = worktree.entry_for_id(*entry_id) {
if entry.is_file() {
// corpus stats will be adjusted later, while these are chunked/embedded
let handle = entries_being_indexed.insert(entry.id);
updated_entries_tx.send((entry.clone(), handle)).await?;
}
@@ -226,7 +284,8 @@ impl EmbeddingIndex {
) -> ChunkFiles {
let language_registry = self.language_registry.clone();
let fs = self.fs.clone();
let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
let (chunked_files_tx, chunked_files_rx) =
channel::bounded(self.settings.chunk_files_bound);
let task = cx.spawn(|cx| async move {
cx.background_executor()
.scoped(|cx| {
@@ -272,10 +331,11 @@ impl EmbeddingIndex {
pub fn embed_files(
embedding_provider: Arc<dyn EmbeddingProvider>,
chunked_files: channel::Receiver<ChunkedFile>,
tokenizer: SimpleTokenizer,
settings: EmbeddingIndexSettings,
cx: &AppContext,
) -> EmbedFiles {
let embedding_provider = embedding_provider.clone();
let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
let (embedded_files_tx, embedded_files_rx) = channel::bounded(settings.embed_files_bound);
let task = cx.background_executor().spawn(async move {
let mut chunked_file_batches =
chunked_files.chunks_timeout(512, Duration::from_secs(2));
@@ -284,7 +344,6 @@ impl EmbeddingIndex {
// Flatten out to a vec of chunks that we can subdivide into batch sized pieces
// Once those are done, reassemble them back into the files in which they belong
// If any embeddings fail for a file, the entire file is discarded
let chunks: Vec<TextToEmbed> = chunked_files
.iter()
.flat_map(|file| {
@@ -312,11 +371,10 @@ impl EmbeddingIndex {
embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
}
let mut embeddings = embeddings.into_iter();
for chunked_file in chunked_files {
let mut embedded_file = EmbeddedFile {
path: chunked_file.path,
path: chunked_file.path.clone(),
mtime: chunked_file.mtime,
chunks: Vec::new(),
};
@@ -326,9 +384,11 @@ impl EmbeddingIndex {
chunked_file.chunks.into_iter().zip(embeddings.by_ref())
{
if let Some(embedding) = embedding {
let chunk_text = &chunked_file.text[chunk.range.clone()];
let chunk_counts = TermCounts::from_text(chunk_text, &tokenizer);
embedded_file
.chunks
.push(EmbeddedChunk { chunk, embedding });
.push(EmbeddedChunk { chunk, embedding, term_frequencies: chunk_counts });
} else {
embedded_all_chunks = false;
}
@@ -358,6 +418,7 @@ impl EmbeddingIndex {
) -> Task<Result<()>> {
let db_connection = self.db_connection.clone();
let db = self.db;
let worktree_corpus_stats = self.worktree_corpus_stats.clone();
cx.background_executor().spawn(async move {
loop {
@@ -370,6 +431,15 @@ impl EmbeddingIndex {
let end = deletion_range.1.as_ref().map(|end| end.as_str());
log::debug!("deleting embeddings in range {:?}", &(start, end));
db.delete_range(&mut txn, &(start, end))?;
for (_, embedded_file) in db.range(&txn, &(start, end))?.flatten() {
if let None = update_corpus_stats(worktree_corpus_stats.clone(), |stats| {
for chunk in &embedded_file.chunks {
stats.remove_counts(&chunk.term_frequencies);
}
}) {
log::error!("Failed to acquire write lock for worktree_corpus_stats; corpus stats will be outdated");
}
}
txn.commit()?;
}
},
@@ -380,6 +450,13 @@ impl EmbeddingIndex {
let key = db_key_for_path(&file.path);
db.put(&mut txn, &key, &file)?;
txn.commit()?;
if let None = update_corpus_stats(worktree_corpus_stats.clone(), |stats| {
for chunk in &file.chunks {
stats.add_counts(&chunk.term_frequencies);
}
}) {
log::error!("Failed to acquire write lock for worktree_corpus_stats; corpus stats will be outdated");
}
}
},
complete => break,
@@ -399,7 +476,13 @@ impl EmbeddingIndex {
.context("failed to create read transaction")?;
let result = db
.iter(&tx)?
.map(|entry| Ok(entry?.1.path.clone()))
.filter_map(|entry| {
if let Ok((_, file)) = entry {
Some(Ok(file.path.clone()))
} else {
None
}
})
.collect::<Result<Vec<Arc<Path>>>>();
drop(tx);
result
@@ -417,11 +500,10 @@ impl EmbeddingIndex {
let tx = connection
.read_txn()
.context("failed to create read transaction")?;
Ok(db
.get(&tx, &db_key_for_path(&path))?
.ok_or_else(|| anyhow!("no such path"))?
.chunks
.clone())
match db.get(&tx, &db_key_for_path(&path))? {
Some(file) => Ok(file.chunks.clone()),
None => Err(anyhow!("no such path")),
}
})
}
}
@@ -450,7 +532,7 @@ pub struct EmbedFiles {
pub task: Task<Result<()>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct EmbeddedFile {
pub path: Arc<Path>,
pub mtime: Option<SystemTime>,
@@ -461,8 +543,18 @@ pub struct EmbeddedFile {
pub struct EmbeddedChunk {
pub chunk: Chunk,
pub embedding: Embedding,
pub term_frequencies: TermCounts,
}
fn db_key_for_path(path: &Arc<Path>) -> String {
path.to_string_lossy().replace('/', "\0")
}
fn update_corpus_stats<F, R>(stats: Arc<RwLock<WorktreeTermStats>>, f: F) -> Option<R>
where
F: FnOnce(&mut WorktreeTermStats) -> R,
{
stats.write().map(|mut guard| f(&mut guard)).map_err(|_| {
log::error!("Failed to acquire write lock for worktree_corpus_stats; corpus stats will be outdated");
}).ok()
}

View File

@@ -1,10 +1,10 @@
use crate::{
embedding::{EmbeddingProvider, TextToEmbed},
summary_index::FileSummary,
tfidf::{Bm25Parameters, Bm25Scorer, SimpleTokenizer},
worktree_index::{WorktreeIndex, WorktreeIndexHandle},
};
use anyhow::{anyhow, Context, Result};
use collections::HashMap;
use fs::Fs;
use futures::{stream::StreamExt, FutureExt};
use gpui::{
@@ -17,6 +17,7 @@ use serde::{Deserialize, Serialize};
use smol::channel;
use std::{
cmp::Ordering,
collections::HashMap,
future::Future,
num::NonZeroUsize,
ops::{Range, RangeInclusive},
@@ -232,8 +233,16 @@ impl ProjectIndex {
&self,
queries: Vec<String>,
limit: usize,
mixing_param: f32,
cx: &AppContext,
) -> Task<Result<Vec<SearchResult>>> {
if !(0.0..=1.0).contains(&mixing_param) {
return cx.spawn(|_| async move {
Err(anyhow!(
"Semantic search mixing param must be between 0 and 1"
))
});
}
let (chunks_tx, chunks_rx) = channel::bounded(1024);
let mut worktree_scan_tasks = Vec::new();
for worktree_index in self.worktree_indices.values() {
@@ -241,9 +250,11 @@ impl ProjectIndex {
let chunks_tx = chunks_tx.clone();
worktree_scan_tasks.push(cx.spawn(|cx| async move {
let index = match worktree_index {
WorktreeIndexHandle::Loading { index } => {
index.clone().await.map_err(|error| anyhow!(error))?
}
WorktreeIndexHandle::Loading { index } => index
.clone()
.await
.map_err(|error| anyhow!(error))
.context("loading worktree index failure")?,
WorktreeIndexHandle::Loaded { index } => index.clone(),
};
@@ -252,22 +263,32 @@ impl ProjectIndex {
let worktree_id = index.worktree().read(cx).id();
let db_connection = index.db_connection().clone();
let db = *index.embedding_index().db();
let worktree_corpus_stats =
index.embedding_index().worktree_corpus_stats.clone();
cx.background_executor().spawn(async move {
let txn = db_connection
.read_txn()
.context("failed to create read transaction")?;
let db_entries = db.iter(&txn).context("failed to iterate database")?;
for db_entry in db_entries {
let (_key, db_embedded_file) = db_entry?;
for chunk in db_embedded_file.chunks {
let (_key, db_embedded_file) =
db_entry.context("failed to read embedded file")?;
for chunk in &db_embedded_file.chunks {
chunks_tx
.send((worktree_id, db_embedded_file.path.clone(), chunk))
.await?;
.send((
worktree_id,
worktree_corpus_stats.clone(),
db_embedded_file.path.clone(),
chunk.clone(),
))
.await
.context("failed to send chunks")?;
}
}
anyhow::Ok(())
})
})?
})
.context("read_with WorktreeIndex failed")?
.await
}));
}
@@ -275,21 +296,43 @@ impl ProjectIndex {
let project = self.project.clone();
let embedding_provider = self.embedding_provider.clone();
let bm25_params = Bm25Parameters::default();
cx.spawn(|cx| async move {
log::info!("Searching for {queries:?}");
// BM-25: Tokenize query
#[cfg(debug_assertions)]
let bm25_query_start = std::time::Instant::now();
let tokenizer = SimpleTokenizer::new();
let terms_by_query: Vec<HashMap<Arc<str>, u32>> = queries
.iter()
.map(|query| {
tokenizer.tokenize_and_stem(query).into_iter().fold(
HashMap::new(),
|mut acc, term| {
*acc.entry(term).or_insert(0) += 1;
acc
},
)
})
.collect();
#[cfg(debug_assertions)]
let bm25_query_end = std::time::Instant::now();
// Similarity search: Embed query
#[cfg(debug_assertions)]
let embedding_query_start = std::time::Instant::now();
log::info!("Searching for {queries:?}");
let queries: Vec<TextToEmbed> = queries
.iter()
.map(|s| TextToEmbed::new(s.as_str()))
.collect();
let query_embeddings = embedding_provider.embed(&queries[..]).await?;
if query_embeddings.len() != queries.len() {
return Err(anyhow!(
"The number of query embeddings does not match the number of queries"
));
}
#[cfg(debug_assertions)]
let embedding_query_end = std::time::Instant::now();
let mut results_by_worker = Vec::new();
for _ in 0..cx.background_executor().num_cpus() {
@@ -302,28 +345,58 @@ impl ProjectIndex {
.scoped(|cx| {
for results in results_by_worker.iter_mut() {
cx.spawn(async {
while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
let (score, query_index) =
chunk.embedding.similarity(&query_embeddings);
while let Ok((worktree_id, worktree_corpus_stats, path, chunk)) =
chunks_rx.recv().await
{
// iterate over every (query_embedding, query_term) and compute its hybrid retrieval score for this chunk
// RetScore(chunk, query_embedding, query_term, m) = m * Sim(query_embedding, chunk) + (1 - m) * Bm25(query_term, chunk)]
let hybrid_scores: Vec<f32> = query_embeddings
.iter()
.zip(terms_by_query.iter())
.map(|(query_embedding, query_term)| {
let (embedding_score, _) =
query_embedding.similarity(&[chunk.embedding.clone()]);
let bm25_score = {
let corpus_stats =
worktree_corpus_stats.read().unwrap();
let score = corpus_stats.calculate_bm25_score(
query_term,
&chunk.term_frequencies.0,
bm25_params.k1,
bm25_params.b,
);
// quick hack to bound the score for long queries
score / query_term.values().sum::<u32>() as f32
};
mixing_param * embedding_score
+ (1. - mixing_param) * bm25_score
})
.collect();
let ix = match results.binary_search_by(|probe| {
score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
}) {
Ok(ix) | Err(ix) => ix,
};
if ix < limit {
results.insert(
ix,
WorktreeSearchResult {
worktree_id,
path: path.clone(),
range: chunk.chunk.range.clone(),
query_index,
score,
},
);
if results.len() > limit {
results.pop();
for (query_index, score) in hybrid_scores.into_iter().enumerate() {
if score != 0.0 {
let ix = match results.binary_search_by(|probe| {
score
.partial_cmp(&probe.score)
.unwrap_or(Ordering::Equal)
}) {
Ok(ix) | Err(ix) => ix,
};
if ix < limit {
results.insert(
ix,
WorktreeSearchResult {
worktree_id,
path: path.clone(),
range: chunk.chunk.range.clone(),
query_index,
score,
},
);
if results.len() > limit {
results.pop();
}
}
}
}
}
@@ -362,7 +435,9 @@ impl ProjectIndex {
search_results.len(),
search_elapsed
);
let embedding_query_elapsed = embedding_query_start.elapsed();
let bm25_query_elapsed = bm25_query_start - bm25_query_end;
log::debug!("tokenizing query took {:?}", bm25_query_elapsed);
let embedding_query_elapsed = embedding_query_start - embedding_query_end;
log::debug!("embedding query took {:?}", embedding_query_elapsed);
}

View File

@@ -6,6 +6,7 @@ mod project_index;
mod project_index_debug_view;
mod summary_backlog;
mod summary_index;
mod tfidf;
mod worktree_index;
use anyhow::{Context as _, Result};
@@ -28,6 +29,8 @@ pub use project_index::{LoadedSearchResult, ProjectIndex, SearchResult, Status};
pub use project_index_debug_view::ProjectIndexDebugView;
pub use summary_index::FileSummary;
pub const SEMANTIC_INDEX_DB_VERSION: usize = 1;
pub struct SemanticDb {
embedding_provider: Arc<dyn EmbeddingProvider>,
db_connection: Option<heed::Env>,
@@ -268,7 +271,7 @@ mod tests {
use super::*;
use anyhow::anyhow;
use chunking::Chunk;
use embedding_index::{ChunkedFile, EmbeddingIndex};
use embedding_index::{ChunkedFile, EmbeddingIndex, EmbeddingIndexSettings};
use feature_flags::FeatureFlagAppExt;
use fs::FakeFs;
use futures::{future::BoxFuture, FutureExt};
@@ -280,6 +283,7 @@ mod tests {
use settings::SettingsStore;
use smol::{channel, stream::StreamExt};
use std::{future, path::Path, sync::Arc};
use tfidf::SimpleTokenizer;
fn init_test(cx: &mut TestAppContext) {
env_logger::try_init().ok();
@@ -398,7 +402,7 @@ mod tests {
.update(|cx| {
let project_index = project_index.read(cx);
let query = "garbage in, garbage out";
project_index.search(vec![query.into()], 4, cx)
project_index.search(vec![query.into()], 4, 0.7, cx)
})
.await
.unwrap();
@@ -487,8 +491,18 @@ mod tests {
.unwrap();
chunked_files_tx.close();
let embed_files_task =
cx.update(|cx| EmbeddingIndex::embed_files(provider.clone(), chunked_files_rx, cx));
let tokenizer = SimpleTokenizer::new();
let embedding_settings = EmbeddingIndexSettings::default();
let embed_files_task = cx.update(|cx| {
EmbeddingIndex::embed_files(
provider.clone(),
chunked_files_rx,
tokenizer,
embedding_settings,
cx,
)
});
embed_files_task.task.await.unwrap();
let mut embedded_files_rx = embed_files_task.files;

View File

@@ -0,0 +1,184 @@
use rust_stemmers::{Algorithm, Stemmer};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, sync::Arc};
use stop_words;
#[derive(Clone)]
pub struct SimpleTokenizer {
stemmer: Arc<Stemmer>,
stopwords: Vec<String>,
}
impl SimpleTokenizer {
pub fn new() -> Self {
Self {
// TODO: handle non-English
stemmer: Arc::new(Stemmer::create(Algorithm::English)),
stopwords: stop_words::get(stop_words::LANGUAGE::English),
}
}
pub fn tokenize_and_stem(&self, text: &str) -> Vec<Arc<str>> {
// Split on whitespace and punctuation
text.split(|c: char| c.is_whitespace() || c.is_ascii_punctuation())
.flat_map(|word| {
word.chars().fold(vec![String::new()], |mut acc, c| {
// Split CamelCaps and camelCase
if c.is_uppercase()
&& !acc.is_empty()
&& acc
.last()
.and_then(|s| s.chars().last())
.map_or(false, |last_char| last_char.is_lowercase())
{
acc.push(String::new());
}
acc.last_mut()
.unwrap_or(&mut String::new())
.push(c.to_lowercase().next().unwrap());
acc
})
})
.filter(|s| !s.is_empty() && !self.stopwords.contains(s))
.map(|word| {
// Stem each word and convert to Arc<str>
let stemmed = self.stemmer.stem(&word).to_string();
Arc::from(stemmed)
})
.collect()
}
}
pub struct Bm25Parameters {
pub k1: f32,
pub b: f32,
}
impl Default for Bm25Parameters {
fn default() -> Self {
Self { k1: 1.2, b: 0.75 }
}
}
pub trait Bm25Scorer {
fn total_chunks(&self) -> u32;
fn avg_chunk_length(&self) -> f32;
fn term_frequency(&self, term: &Arc<str>, chunk_term_counts: &HashMap<Arc<str>, u32>) -> u32;
fn document_frequency(&self, term: &Arc<str>) -> u32;
fn calculate_bm25_score(
&self,
query_terms: &HashMap<Arc<str>, u32>,
chunk_terms: &HashMap<Arc<str>, u32>,
k1: f32,
b: f32,
) -> f32 {
// average doc length, current doc length, total docs
let avg_dl = self.avg_chunk_length();
let dl = chunk_terms.values().sum::<u32>() as f32;
let dn = self.total_chunks() as f32;
query_terms
.iter()
.map(|(term, &query_tf)| {
let tf = self.term_frequency(term, chunk_terms) as f32;
let df = self.document_frequency(term) as f32;
let idf = ((dn - df + 0.5) / (df + 0.5) + 1.0).ln();
let numerator = tf * (k1 + 1.0);
let denominator = tf + k1 * (1.0 - b + b * dl / avg_dl);
idf * (numerator / denominator)
})
.sum()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TermCounts(pub HashMap<Arc<str>, u32>);
impl TermCounts {
pub fn from_text(text: &str, tokenizer: &SimpleTokenizer) -> Self {
let tokens = tokenizer.tokenize_and_stem(text);
let mut terms = HashMap::new();
for token in tokens {
*terms.entry(token).or_insert(0) += 1;
}
TermCounts(terms)
}
}
/// Represents the term frequency statistics for a single worktree.
///
/// This struct contains information about chunks, term statistics,
/// and the total length of all chunks in the worktree.
#[derive(Debug)]
pub struct WorktreeTermStats {
/// A map of terms to their counts across all chunks in this worktree.
term_counts: HashMap<Arc<str>, u32>,
/// The total length of all chunks in this worktree.
total_length: u32,
/// The total number of chunks tracked in this worktree.
total_chunks: u32,
}
impl WorktreeTermStats {
pub fn new(term_counts: HashMap<Arc<str>, u32>, total_length: u32, total_chunks: u32) -> Self {
Self {
term_counts,
total_length,
total_chunks,
}
}
pub fn add_counts(&mut self, chunk_counts: &TermCounts) {
let mut chunk_length = 0;
for (term, &freq) in &chunk_counts.0 {
let counts = self.term_counts.entry(term.clone()).or_insert(0);
*counts += freq;
chunk_length += freq;
}
self.total_length += chunk_length;
self.total_chunks += 1;
}
pub fn remove_counts(&mut self, chunk_counts: &TermCounts) -> () {
debug_assert!(chunk_counts.0.len() <= self.term_counts.len());
debug_assert!(chunk_counts
.0
.keys()
.all(|k| self.term_counts.contains_key(k)));
let mut chunk_length = 0;
for (term, &freq) in &chunk_counts.0 {
if let Some(stats) = self.term_counts.get_mut(term) {
*stats -= freq;
chunk_length += 0;
}
}
self.total_length -= chunk_length;
self.total_chunks -= 1;
}
}
impl Bm25Scorer for WorktreeTermStats {
fn total_chunks(&self) -> u32 {
self.total_chunks
}
fn avg_chunk_length(&self) -> f32 {
if self.total_chunks == 0 {
0.0
} else {
self.total_length as f32 / self.total_chunks as f32
}
}
fn term_frequency(&self, term: &Arc<str>, chunk_term_counts: &HashMap<Arc<str>, u32>) -> u32 {
*chunk_term_counts.get(term).unwrap_or(&0)
}
fn document_frequency(&self, term: &Arc<str>) -> u32 {
*self.term_counts.get(term).unwrap_or(&0)
}
}

View File

@@ -46,14 +46,14 @@ impl WorktreeIndex {
embedding_provider: Arc<dyn EmbeddingProvider>,
cx: &mut AppContext,
) -> Task<Result<Model<Self>>> {
let worktree_for_index = worktree.clone();
let worktree_for_embedding = worktree.clone();
let worktree_for_summary = worktree.clone();
let worktree_abs_path = worktree.read(cx).abs_path();
let embedding_fs = Arc::clone(&fs);
let summary_fs = fs;
cx.spawn(|mut cx| async move {
let entries_being_indexed = Arc::new(IndexingEntrySet::new(status_tx));
let (embedding_index, summary_index) = cx
let (embedding_index, summary_index): (EmbeddingIndex, SummaryIndex) = cx
.background_executor()
.spawn({
let entries_being_indexed = Arc::clone(&entries_being_indexed);
@@ -63,9 +63,8 @@ impl WorktreeIndex {
let embedding_index = {
let db_name = worktree_abs_path.to_string_lossy();
let db = db_connection.create_database(&mut txn, Some(&db_name))?;
EmbeddingIndex::new(
worktree_for_index,
worktree_for_embedding,
embedding_fs,
db_connection.clone(),
db,
@@ -101,7 +100,7 @@ impl WorktreeIndex {
)
};
txn.commit()?;
anyhow::Ok((embedding_index, summary_index))
Ok::<_, anyhow::Error>((embedding_index, summary_index))
}
})
.await?;