Compare commits

...

12 Commits
main ... miner

Author SHA1 Message Date
Antonio Scandurra
34b728539f WIP: start on accumulating exports/imports 2024-06-27 18:57:03 +02:00
Antonio Scandurra
9ef2d85fa8 WIP 2024-06-27 15:32:08 +02:00
Antonio Scandurra
4d5a70ccbf Checkpoint 2024-06-27 13:45:49 +02:00
Antonio Scandurra
d4992ecab4 Checkpoint 2024-06-27 13:34:11 +02:00
Antonio Scandurra
49be47d322 Checkpoint 2024-06-27 11:39:23 +02:00
Antonio Scandurra
f18e9b073b Checkpoint 2024-06-27 10:40:22 +02:00
Antonio Scandurra
df829e50ea WIP 2024-06-26 21:08:52 +02:00
Antonio Scandurra
330bb4c1ce WIP 2024-06-26 18:02:01 +02:00
Antonio Scandurra
65d47587c8 Checkpoint 2024-06-26 17:54:02 +02:00
Nathan Sobo
aceb5581b3 Try aider 2024-06-24 14:30:14 -06:00
Nathan Sobo (aider)
24c8bad8de Added a new crate 'miner' to the Cargo.toml file. 2024-06-24 14:26:47 -06:00
Nathan Sobo (aider)
8f6ea25a95 Added the 'semantic_mining' crate to the Cargo.toml file. 2024-06-24 14:25:10 -06:00
7 changed files with 1436 additions and 49 deletions

1
.gitignore vendored
View File

@@ -28,3 +28,4 @@ DerivedData/
.vscode
.wrangler
.flatpak-builder
.aider*

657
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -58,6 +58,7 @@ members = [
"crates/markdown_preview",
"crates/media",
"crates/menu",
"crates/miner",
"crates/multi_buffer",
"crates/node_runtime",
"crates/notifications",
@@ -240,6 +241,7 @@ task = { path = "crates/task" }
tasks_ui = { path = "crates/tasks_ui" }
search = { path = "crates/search" }
semantic_index = { path = "crates/semantic_index" }
miner = { path = "crates/miner" }
semantic_version = { path = "crates/semantic_version" }
settings = { path = "crates/settings" }
snippet = { path = "crates/snippet" }

24
crates/miner/Cargo.toml Normal file
View File

@@ -0,0 +1,24 @@
[package]
name = "miner"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[[bin]]
name = "miner"
path = "src/miner.rs"
[dependencies]
anyhow.workspace = true
futures.workspace = true
heed.workspace = true
ignore.workspace = true
indicatif = "0.17.8"
reqwest = { version = "0.12.5", features = ["json", "stream"] }
serde.workspace = true
serde_json.workspace = true
tree-sitter.workspace = true
tree-sitter-rust.workspace = true
tokenizers = { version = "0.19.1", features = ["http"] }
tokio.workspace = true

794
crates/miner/src/miner.rs Normal file
View File

@@ -0,0 +1,794 @@
use anyhow::{anyhow, Result};
use futures::StreamExt;
use heed::{
types::{SerdeJson, Str},
Database as HeedDatabase, EnvOpenOptions, RwTxn,
};
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::{
collections::{BTreeMap, HashMap, VecDeque},
path::{Path, PathBuf},
sync::Arc,
time::SystemTime,
};
use tokenizers::tokenizer::Tokenizer;
use tokenizers::FromPretrainedParameters;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
#[derive(Debug, Serialize)]
struct Message {
role: String,
content: String,
}
pub struct OllamaClient {
client: Client,
base_url: String,
}
impl OllamaClient {
pub fn new(base_url: String) -> Self {
Self {
client: Client::new(),
base_url,
}
}
async fn stream_completion(
&self,
model: String,
messages: Vec<Message>,
) -> Result<mpsc::Receiver<String>> {
let (tx, rx) = mpsc::channel(100);
let request = serde_json::json!({
"model": model,
"messages": messages,
"stream": true,
});
let response = self
.client
.post(format!("{}/api/chat", self.base_url))
.json(&request)
.send()
.await?;
if !response.status().is_success() {
return Err(anyhow!(
"error streaming completion: {:?}",
response.text().await?
));
}
tokio::spawn(async move {
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
if let Ok(chunk) = chunk {
if let Ok(text) = String::from_utf8(chunk.to_vec()) {
if let Ok(response) = serde_json::from_str::<serde_json::Value>(&text) {
if let Some(content) = response["message"]["content"].as_str() {
let _ = tx.send(content.to_string()).await;
}
}
}
}
}
});
Ok(rx)
}
}
pub struct HuggingFaceClient {
client: Client,
endpoint: String,
api_key: String,
}
impl HuggingFaceClient {
pub fn new(endpoint: String, api_key: String) -> Self {
Self {
client: Client::new(),
endpoint,
api_key,
}
}
async fn stream_completion(
&self,
model: String,
messages: Vec<Message>,
) -> Result<mpsc::Receiver<String>> {
let (tx, rx) = mpsc::channel(100);
let request = serde_json::json!({
"model": model,
"messages": messages,
"stream": true,
"max_tokens": 2048
});
let response = self
.client
.post(&self.endpoint)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
return Err(anyhow!(
"error streaming completion: {:?}",
response.text().await?
));
}
tokio::spawn(async move {
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
if let Ok(chunk) = chunk {
if let Ok(text) = String::from_utf8(chunk.to_vec()) {
for line in text.lines() {
if line.starts_with("data:") {
let json_str = line.trim_start_matches("data:");
if json_str == "[DONE]" {
break;
}
if let Ok(response) =
serde_json::from_str::<serde_json::Value>(json_str)
{
if let Some(content) =
response["choices"][0]["delta"]["content"].as_str()
{
let _ = tx.send(content.to_string()).await;
}
}
}
}
}
}
}
});
Ok(rx)
}
}
const CHUNK_SIZE: usize = 5000;
const OVERLAP: usize = 2_000;
#[derive(Debug)]
enum Entry {
File(PathBuf),
Directory(PathBuf),
}
#[derive(Debug, Serialize, Deserialize)]
struct CachedSummary {
summary: String,
mtime: SystemTime,
}
#[derive(Clone)]
struct Database {
tx: mpsc::Sender<Box<dyn FnOnce(&HeedDatabase<Str, SerdeJson<CachedSummary>>, RwTxn) + Send>>,
}
impl Database {
async fn new(db_path: &Path, root: &Path) -> Result<Self> {
std::fs::create_dir_all(&db_path)?;
let env = unsafe {
EnvOpenOptions::new()
.map_size(1024 * 1024 * 1024)
.max_dbs(3000)
.open(db_path)?
};
let mut wtxn = env.write_txn()?;
let db_name = format!("summaries_{}", root.to_string_lossy());
let db: HeedDatabase<Str, SerdeJson<CachedSummary>> =
env.create_database(&mut wtxn, Some(&db_name))?;
wtxn.commit()?;
let (tx, mut rx) = mpsc::channel::<
Box<dyn FnOnce(&HeedDatabase<Str, SerdeJson<CachedSummary>>, RwTxn) + Send>,
>(100);
tokio::spawn(async move {
while let Some(f) = rx.recv().await {
let wtxn = env.write_txn().unwrap();
f(&db, wtxn);
}
});
Ok(Self { tx })
}
async fn transact<F, T>(&self, f: F) -> Result<T>
where
F: FnOnce(&HeedDatabase<Str, SerdeJson<CachedSummary>>, RwTxn) -> Result<T>
+ Send
+ 'static,
T: 'static + Send,
{
let (tx, rx) = tokio::sync::oneshot::channel();
self.tx
.send(Box::new(move |db, txn| {
let result = f(db, txn);
let _ = tx.send(result);
}))
.await
.map_err(|_| anyhow!("database closed"))?;
Ok(rx.await.map_err(|_| anyhow!("transaction failed"))??)
}
}
async fn summarize_project(
db_path: &Path,
root: &Path,
num_workers: usize,
) -> Result<BTreeMap<PathBuf, String>> {
let database = Database::new(db_path, root).await?;
let tokenizer = Tokenizer::from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.1",
Some(FromPretrainedParameters {
revision: "main".into(),
user_agent: HashMap::default(),
auth_token: Some(
std::env::var("HUGGINGFACE_API_TOKEN").expect("HUGGINGFACE_API_TOKEN not set"),
),
}),
)
.unwrap();
let client = Arc::new(HuggingFaceClient::new(
"https://c0es55wrh8muqy3g.us-east-1.aws.endpoints.huggingface.cloud/v1/chat/completions"
.into(),
std::env::var("HUGGINGFACE_API_TOKEN").expect("HUGGINGFACE_API_TOKEN not set"),
));
let queue = Arc::new(Mutex::new(VecDeque::new()));
let multi_progress = Arc::new(MultiProgress::new());
let overall_progress = multi_progress.add(ProgressBar::new_spinner());
overall_progress.set_style(
ProgressStyle::default_spinner()
.template("{spinner:.green} {msg}")
.unwrap(),
);
overall_progress.set_message("Summarizing project...");
// Populate the queue with files and directories
let mut walker = ignore::WalkBuilder::new(root)
.hidden(true)
.ignore(true)
.build();
while let Some(entry) = walker.next() {
if let Ok(entry) = entry {
let path = entry.path().to_owned();
if entry.file_type().map_or(false, |ft| ft.is_dir()) {
queue.lock().await.push_back(Entry::Directory(path));
} else {
queue.lock().await.push_back(Entry::File(path));
}
}
}
let total_entries = queue.lock().await.len();
let progress_bar = multi_progress.add(ProgressBar::new(total_entries as u64));
progress_bar.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
.unwrap()
.progress_chars("##-"),
);
let summaries = Arc::new(Mutex::new(BTreeMap::new()));
let paths_loaded_from_cache = Arc::new(Mutex::new(BTreeMap::new()));
let rust_language = tree_sitter_rust::language();
let workers: Vec<_> = (0..num_workers)
.map(|_| {
let queue = Arc::clone(&queue);
let client = Arc::clone(&client);
let summaries = Arc::clone(&summaries);
let tokenizer = tokenizer.clone();
let progress_bar = progress_bar.clone();
let database = database.clone();
let paths_loaded_from_cache = Arc::clone(&paths_loaded_from_cache);
let mut parser = tree_sitter::Parser::new();
parser.set_language(&rust_language).unwrap();
let rust_language = rust_language.clone();
tokio::spawn(async move {
loop {
let mut queue_lock = queue.lock().await;
let Some(entry) = queue_lock.pop_front() else {
break;
};
match entry {
Entry::File(path) => {
drop(queue_lock);
let summary = async {
let mtime = tokio::fs::metadata(&path).await?.modified()?;
let key = path.to_string_lossy().to_string();
let cached = database
.transact({
let key = key.clone();
move |db, txn| Ok(db.get(&txn, &key)?)
})
.await?;
if let Some(cached) = cached {
if cached.mtime == mtime {
paths_loaded_from_cache
.lock()
.await
.insert(path.clone(), true);
return Ok(cached.summary);
}
}
progress_bar.set_message(format!("Summarizing {}", path.display()));
let content = tokio::fs::read_to_string(&path)
.await
.unwrap_or_else(|_| "binary file".into());
let mut summary = String::new();
if path.extension().map_or(false, |ext| ext == "rs") {
let tree = parser.parse(&content, None).unwrap();
let root_node = tree.root_node();
let export_query = tree_sitter::Query::new(
&rust_language,
include_str!("./rust_exports.scm"),
)
.unwrap();
let mut export_cursor = tree_sitter::QueryCursor::new();
let mut exports = Vec::new();
for m in export_cursor.matches(
&export_query,
root_node,
content.as_bytes(),
) {
let mut current_level = 0;
let mut current_export = String::new();
for c in m.captures {
let export = content[c.node.byte_range()].to_string();
let indent = " ".repeat(current_level);
if current_level == 0 {
current_export = format!("{}{}", indent, export);
} else {
current_export
.push_str(&format!("\n{}{}", indent, export));
}
current_level += 1;
}
exports.push(current_export);
}
let import_query = tree_sitter::Query::new(
&rust_language,
include_str!("./rust_imports.scm"),
)
.unwrap();
let mut import_cursor = tree_sitter::QueryCursor::new();
let imports: Vec<_> = import_cursor
.matches(&import_query, root_node, content.as_bytes())
.flat_map(|m| m.captures)
.map(|c| content[c.node.byte_range()].to_string())
.collect();
summary.push_str("Summary: Rust file containing ");
if !exports.is_empty() {
summary.push_str(&format!("{} exports", exports.len()));
if !imports.is_empty() {
summary.push_str(" and ");
}
}
if !imports.is_empty() {
summary.push_str(&format!("{} imports", imports.len()));
}
summary.push('.');
if !exports.is_empty() {
summary.push_str("\nExports:\n");
summary.push_str(&exports.join("\n"));
}
if !imports.is_empty() {
summary.push_str("\nImports: ");
summary.push_str(&imports.join(", "));
}
println!("{}", summary);
} else {
let chunks = split_into_chunks(
&content, &tokenizer, CHUNK_SIZE, OVERLAP,
);
let chunk_summaries =
summarize_chunks(&client, &chunks).await?;
summary =
combine_summaries(&client, &chunk_summaries, true).await?;
}
let cached_summary = CachedSummary {
summary: summary.clone(),
mtime,
};
database
.transact(move |db, mut txn| {
db.put(&mut txn, &key, &cached_summary)?;
txn.commit()?;
Ok(())
})
.await?;
anyhow::Ok(summary)
};
let summary = summary.await.unwrap_or_else(|error| {
format!("path could not be summarized: {error:?}")
});
summaries.lock().await.insert(path, summary);
progress_bar.inc(1);
}
Entry::Directory(path) => {
let mut dir_summaries = Vec::new();
let mut all_children_summarized = true;
let mut all_children_from_cache = true;
let dir_walker = ignore::WalkBuilder::new(&path)
.hidden(true)
.ignore(true)
.max_depth(Some(1))
.build();
for entry in dir_walker {
if let Ok(entry) = entry {
if entry.path() != path {
if let Some(summary) =
summaries.lock().await.get(entry.path())
{
dir_summaries.push(summary.clone());
if !paths_loaded_from_cache
.lock()
.await
.get(entry.path())
.unwrap_or(&false)
{
all_children_from_cache = false;
}
} else {
all_children_summarized = false;
break;
}
}
}
}
if all_children_summarized {
drop(queue_lock);
let combined_summary = async {
let key = path.to_string_lossy().to_string();
let mtime = tokio::fs::metadata(&path).await?.modified()?;
if all_children_from_cache {
if let Some(cached) = database
.transact({
let key = key.clone();
move |db, txn| Ok(db.get(&txn, &key)?)
})
.await?
{
paths_loaded_from_cache
.lock()
.await
.insert(path.clone(), true);
return Ok(cached.summary);
}
}
progress_bar
.set_message(format!("Summarizing {}", path.display()));
let combined_summary =
combine_summaries(&client, &dir_summaries, false).await?;
let cached_summary = CachedSummary {
summary: combined_summary.clone(),
mtime,
};
database
.transact(move |db, mut txn| {
db.put(&mut txn, &key, &cached_summary)?;
txn.commit()?;
Ok(())
})
.await?;
anyhow::Ok(combined_summary)
};
let combined_summary = combined_summary
.await
.unwrap_or_else(|_| "could not combine summaries".into());
summaries.lock().await.insert(path, combined_summary);
progress_bar.inc(1);
} else {
queue_lock.push_back(Entry::Directory(path));
}
}
}
}
Ok::<_, anyhow::Error>(())
})
})
.collect();
for worker in workers {
worker.await??;
}
// Remove deleted entries from the database
database
.transact(|db, mut txn| {
let mut paths_to_delete = Vec::new();
for item in db.iter(&txn)? {
let (path, _) = item?;
let path = PathBuf::from(path);
if !path.exists() {
paths_to_delete.push(path);
}
}
for path in paths_to_delete {
db.delete(&mut txn, &path.to_string_lossy())?;
}
txn.commit()?;
Ok(())
})
.await?;
progress_bar.finish_with_message("Summarization complete");
overall_progress.finish_with_message("Project summarization finished");
Ok(Arc::try_unwrap(summaries).unwrap().into_inner())
}
fn split_into_chunks(
content: &str,
tokenizer: &Tokenizer,
chunk_size: usize,
overlap: usize,
) -> Vec<String> {
let mut chunks = Vec::new();
let lines: Vec<&str> = content.lines().collect();
let mut current_chunk = String::new();
let mut current_tokens = 0;
for line in lines {
let line_tokens = tokenizer.encode(line, false).unwrap().get_ids().len();
if current_tokens + line_tokens > chunk_size {
chunks.push(current_chunk.clone());
current_chunk.clear();
current_tokens = 0;
}
current_chunk.push_str(line);
current_chunk.push('\n');
current_tokens += line_tokens;
}
if !current_chunk.is_empty() {
chunks.push(current_chunk);
}
// Add overlap
for i in 1..chunks.len() {
let overlap_text = chunks[i - 1]
.lines()
.rev()
.take(overlap)
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect::<Vec<_>>()
.join("\n");
chunks[i] = format!("{}\n{}", overlap_text, chunks[i]);
}
chunks
}
async fn summarize_chunks(client: &HuggingFaceClient, chunks: &[String]) -> Result<Vec<String>> {
let mut chunk_summaries = Vec::new();
for chunk in chunks {
let summary = summarize_file(client, chunk).await?;
chunk_summaries.push(summary);
}
Ok(chunk_summaries)
}
async fn summarize_file(client: &HuggingFaceClient, content: &str) -> Result<String> {
let messages = vec![Message {
role: "user".to_string(),
content: format!(
"You are a code summarization assistant. \
Provide a brief summary of the given file, \
focusing on its main functionality and purpose. \
Be terse and start your response directly with \"Summary: \".\n\
File:\n{}",
content
),
}];
let mut receiver = client
.stream_completion("tgi".to_string(), messages)
.await?;
let mut summary = String::new();
while let Some(content) = receiver.recv().await {
summary.push_str(&content);
}
Ok(summary)
}
async fn combine_summaries(
client: &HuggingFaceClient,
summaries: &[String],
is_chunk: bool,
) -> Result<String> {
let combined_content = summaries.join("\n## Summary\n");
let prompt = if is_chunk {
concat!(
"You are a code summarization assistant. ",
"Combine the given summaries into a single, coherent summary ",
"that captures the overall functionality and structure of the code. ",
"Ensure that the final summary is comprehensive and reflects ",
"the content as if it was summarized from a single, complete file. ",
"Be terse and start your response with \"Summary: \""
)
} else {
concat!(
"You are a code summarization assistant. ",
"Combine the given summaries of different files or directories ",
"into a single, coherent summary that captures the overall ",
"structure and functionality of the project or directory. ",
"Focus on the relationships between different components ",
"and the high-level architecture. ",
"Be terse and start your response with \"Summary: \""
)
};
let messages = vec![Message {
role: "user".to_string(),
content: format!("{}\n# Summaries\n{}", prompt, combined_content),
}];
let mut receiver = client
.stream_completion("tgi".to_string(), messages)
.await?;
let mut combined_summary = String::new();
while let Some(content) = receiver.recv().await {
combined_summary.push_str(&content);
}
Ok(combined_summary)
}
#[tokio::main]
async fn main() -> Result<()> {
let args: Vec<String> = std::env::args().collect();
if args.len() < 2 {
eprintln!(
"Usage: {} <project_path> [db_path] [num_workers] [--read=path]",
args[0]
);
std::process::exit(1);
}
let project_path = Path::new(&args[1]);
if !project_path.exists() || !project_path.is_dir() {
eprintln!("Error: The provided project path does not exist or is not a directory.");
std::process::exit(1);
}
let db_path = if args.len() >= 3 && !args[2].starts_with("--") {
PathBuf::from(&args[2])
} else {
std::env::current_dir()?.join("project_summaries")
};
let num_workers = if args.len() >= 4 && !args[3].starts_with("--") {
args[3].parse().unwrap_or(8)
} else {
8
};
println!("Summarizing project at: {}", project_path.display());
println!("Using database at: {}", db_path.display());
println!("Number of workers: {}", num_workers);
let summaries = summarize_project(&db_path, project_path, num_workers).await?;
println!("Finished summarization");
// Check if --read flag is provided
if let Some(read_path) = args.iter().find(|arg| arg.starts_with("--read=")) {
let path = Path::new(&read_path[7..]);
let full_path = project_path.join(path);
for (child_path, summary) in summaries.iter() {
if child_path.parent() == Some(&full_path) {
println!("<path>{}</path>", child_path.to_string_lossy());
println!("<summary>{}</summary>", summary);
println!();
}
}
} else {
dbg!(summaries);
}
Ok(())
}
// #[derive(Debug, Serialize)]
// struct ChatCompletionRequest {
// model: String,
// messages: Vec<Message>,
// stream: bool,
// }
//
// #[derive(Debug, Deserialize)]
// struct ChatCompletionChunk {
// choices: Vec<Choice>,
// }
// #[derive(Debug, Deserialize)]
// struct Choice {
// delta: Delta,
// }
// #[derive(Debug, Deserialize)]
// struct Delta {
// content: Option<String>,
// }
// pub struct GroqClient {
// client: Client,
// api_key: String,
// }
// impl GroqClient {
// pub fn new(api_key: String) -> Self {
// Self {
// client: Client::new(),
// api_key,
// }
// }
// async fn stream_completion(
// &self,
// model: String,
// messages: Vec<Message>,
// ) -> Result<mpsc::Receiver<String>> {
// let (tx, rx) = mpsc::channel(100);
// let request = ChatCompletionRequest {
// model,
// messages,
// stream: true,
// };
// let response = self
// .client
// .post("https://api.groq.com/openai/v1/chat/completions")
// .header("Authorization", format!("Bearer {}", self.api_key))
// .json(&request)
// .send

View File

@@ -0,0 +1,6 @@
(mod_item name: (identifier) @export)
(struct_item name: (type_identifier) @export)
(impl_item type: (type_identifier) @export)
(enum_item name: (type_identifier) @export)
(function_item name: (identifier) @export)
(trait_item name: (type_identifier) @export)

View File

@@ -0,0 +1 @@
(use_declaration) @import