From f18e9b073b51593df841bc95377eef4e69d883af Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 27 Jun 2024 10:40:22 +0200 Subject: [PATCH] Checkpoint --- Cargo.lock | 4 +- crates/miner/Cargo.toml | 4 +- crates/miner/src/miner.rs | 380 +++++++++++++++++++++++++++----------- 3 files changed, 275 insertions(+), 113 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fd3329f3c7..1b0af31a72 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6675,13 +6675,13 @@ version = "0.1.0" dependencies = [ "anyhow", "futures 0.3.28", - "project", + "ignore", + "indicatif", "reqwest 0.12.5", "serde", "serde_json", "tokenizers", "tokio", - "walkdir", ] [[package]] diff --git a/crates/miner/Cargo.toml b/crates/miner/Cargo.toml index 901675fa1c..9bc6ebe68f 100644 --- a/crates/miner/Cargo.toml +++ b/crates/miner/Cargo.toml @@ -12,10 +12,10 @@ path = "src/miner.rs" [dependencies] anyhow.workspace = true futures.workspace = true -project.workspace = true +ignore.workspace = true +indicatif = "0.17.8" reqwest = { version = "0.12.5", features = ["json", "stream"] } serde.workspace = true serde_json.workspace = true tokenizers = { version = "0.19.1", features = ["http"] } tokio.workspace = true -walkdir = "2.5.0" diff --git a/crates/miner/src/miner.rs b/crates/miner/src/miner.rs index 2e00a65029..5eba0d8180 100644 --- a/crates/miner/src/miner.rs +++ b/crates/miner/src/miner.rs @@ -1,22 +1,17 @@ -use anyhow::Result; -use futures::{Stream, StreamExt}; +use anyhow::{anyhow, Result}; +use futures::StreamExt; +use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use reqwest::Client; -use serde::{Deserialize, Serialize}; -use std::collections::{HashMap, VecDeque}; -use std::path::{Path, PathBuf}; -use std::sync::Arc; +use serde::Serialize; +use std::{ + collections::{BTreeMap, HashMap, VecDeque}, + path::{Path, PathBuf}, + sync::Arc, +}; use tokenizers::tokenizer::Tokenizer; use tokenizers::FromPretrainedParameters; use tokio::sync::mpsc; use tokio::sync::Mutex; -use walkdir::WalkDir; - -#[derive(Debug, Serialize)] -struct ChatCompletionRequest { - model: String, - messages: Vec, - stream: bool, -} #[derive(Debug, Serialize)] struct Message { @@ -24,71 +19,54 @@ struct Message { content: String, } -#[derive(Debug, Deserialize)] -struct ChatCompletionChunk { - choices: Vec, -} - -#[derive(Debug, Deserialize)] -struct Choice { - delta: Delta, -} - -#[derive(Debug, Deserialize)] -struct Delta { - content: Option, -} - -pub struct GroqClient { +pub struct OllamaClient { client: Client, - api_key: String, + base_url: String, } -impl GroqClient { - pub fn new(api_key: String) -> Self { +impl OllamaClient { + pub fn new(base_url: String) -> Self { Self { client: Client::new(), - api_key, + base_url, } } - pub async fn stream_completion( + async fn stream_completion( &self, model: String, messages: Vec, ) -> Result> { let (tx, rx) = mpsc::channel(100); - let request = ChatCompletionRequest { - model, - messages, - stream: true, - }; + let request = serde_json::json!({ + "model": model, + "messages": messages, + "stream": true, + }); let response = self .client - .post("https://api.groq.com/openai/v1/chat/completions") - .header("Authorization", format!("Bearer {}", self.api_key)) + .post(format!("{}/api/chat", self.base_url)) .json(&request) .send() .await?; + if !response.status().is_success() { + return Err(anyhow!( + "error streaming completion: {:?}", + response.text().await? + )); + } + tokio::spawn(async move { let mut stream = response.bytes_stream(); while let Some(chunk) = stream.next().await { if let Ok(chunk) = chunk { if let Ok(text) = String::from_utf8(chunk.to_vec()) { - for line in text.lines() { - if line.starts_with("data: ") && line != "data: [DONE]" { - if let Ok(mut chunk) = - serde_json::from_str::(&line[6..]) - { - if let Some(content) = - chunk.choices.pop().unwrap().delta.content - { - let _ = tx.send(content).await; - } - } + if let Ok(response) = serde_json::from_str::(&text) { + if let Some(content) = response["message"]["content"].as_str() { + let _ = tx.send(content.to_string()).await; } } } @@ -100,15 +78,8 @@ impl GroqClient { } } -// Below we perform a bottom up traversal over each worktree in the project. -// We push an entry for each file and directory into a queue. -// We then read from that queue with N workers in a tokio thread pool. -// For each file, we perform a summarization with a prompt. -// For each directory, we combine the summaries of all its files with a prompt. -// If a file is too big, truncate it at the max tokens of mixtral model, 32k. -// Use the tokenizers crate to estimate token counts - -const MAX_TOKENS: usize = 32_000; +const CHUNK_SIZE: usize = 16_000; +const OVERLAP: usize = 2_000; #[derive(Debug)] enum Entry { @@ -116,11 +87,11 @@ enum Entry { Directory(PathBuf), } -async fn summarize_project(root: &Path, num_workers: usize) -> Result { +async fn summarize_project(root: &Path, num_workers: usize) -> Result> { let tokenizer = Tokenizer::from_pretrained( - "mistralai/Mixtral-8x7B-v0.1", + "Qwen/Qwen2-0.5B", Some(FromPretrainedParameters { - revision: String::new(), + revision: "main".into(), user_agent: HashMap::default(), auth_token: Some( std::env::var("HUGGINGFACE_API_TOKEN").expect("HUGGINGFACE_API_TOKEN not set"), @@ -128,24 +99,44 @@ async fn summarize_project(root: &Path, num_workers: usize) -> Result { }), ) .unwrap(); - let client = Arc::new(GroqClient::new(std::env::var("GROQ_API_KEY")?)); + let client = Arc::new(OllamaClient::new("http://localhost:11434".into())); let queue = Arc::new(Mutex::new(VecDeque::new())); + let multi_progress = Arc::new(MultiProgress::new()); + let overall_progress = multi_progress.add(ProgressBar::new_spinner()); + overall_progress.set_style( + ProgressStyle::default_spinner() + .template("{spinner:.green} {msg}") + .unwrap(), + ); + overall_progress.set_message("Summarizing project..."); + // Populate the queue with files and directories - for entry in WalkDir::new(root) - .min_depth(1) - .into_iter() - .filter_map(|e| e.ok()) - { - let path = entry.path().to_owned(); - if entry.file_type().is_dir() { - queue.lock().await.push_back(Entry::Directory(path)); - } else { - queue.lock().await.push_back(Entry::File(path)); + let mut walker = ignore::WalkBuilder::new(root) + .hidden(true) + .ignore(true) + .build(); + while let Some(entry) = walker.next() { + if let Ok(entry) = entry { + let path = entry.path().to_owned(); + if entry.file_type().map_or(false, |ft| ft.is_dir()) { + queue.lock().await.push_back(Entry::Directory(path)); + } else { + queue.lock().await.push_back(Entry::File(path)); + } } } - let summaries = Arc::new(Mutex::new(HashMap::new())); + let total_entries = queue.lock().await.len(); + let progress_bar = multi_progress.add(ProgressBar::new(total_entries as u64)); + progress_bar.set_style( + ProgressStyle::default_bar() + .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}") + .unwrap() + .progress_chars("##-"), + ); + + let summaries = Arc::new(Mutex::new(BTreeMap::new())); let workers: Vec<_> = (0..num_workers) .map(|_| { @@ -153,41 +144,67 @@ async fn summarize_project(root: &Path, num_workers: usize) -> Result { let client = Arc::clone(&client); let summaries = Arc::clone(&summaries); let tokenizer = tokenizer.clone(); + let progress_bar = progress_bar.clone(); tokio::spawn(async move { - while let Some(entry) = queue.lock().await.pop_front() { + loop { + let mut queue_lock = queue.lock().await; + let Some(entry) = queue_lock.pop_front() else { + break; + }; + match entry { Entry::File(path) => { - let content = tokio::fs::read_to_string(&path).await?; - let truncated_content = - truncate_to_max_tokens(&content, &tokenizer, MAX_TOKENS); - let summary = summarize_file(&client, &truncated_content).await?; + drop(queue_lock); + let summary = async { + let content = tokio::fs::read_to_string(&path).await?; + let chunks = + split_into_chunks(&content, &tokenizer, CHUNK_SIZE, OVERLAP); + let chunk_summaries = summarize_chunks(&client, &chunks).await?; + combine_summaries(&client, &chunk_summaries, true).await + }; + + let summary = summary + .await + .unwrap_or_else(|_| "path could not be summarized".into()); summaries.lock().await.insert(path, summary); + progress_bar.inc(1); } Entry::Directory(path) => { let mut dir_summaries = Vec::new(); let mut all_children_summarized = true; - for entry in path.read_dir()? { + let dir_walker = ignore::WalkBuilder::new(&path) + .hidden(true) + .ignore(true) + .max_depth(Some(1)) + .build(); + for entry in dir_walker { if let Ok(entry) = entry { - if let Some(summary) = summaries.lock().await.get(&entry.path()) - { - dir_summaries.push(summary.clone()); - } else { - all_children_summarized = false; - break; + if entry.path() != path { + if let Some(summary) = + summaries.lock().await.get(entry.path()) + { + dir_summaries.push(summary.clone()); + } else { + all_children_summarized = false; + break; + } } } } if all_children_summarized { + drop(queue_lock); let combined_summary = - combine_summaries(&client, &dir_summaries).await?; + combine_summaries(&client, &dir_summaries, false).await?; summaries.lock().await.insert(path, combined_summary); + progress_bar.inc(1); } else { - queue.lock().await.push_back(Entry::Directory(path)); + queue_lock.push_back(Entry::Directory(path)); } } } } + Ok::<_, anyhow::Error>(()) }) }) @@ -197,28 +214,73 @@ async fn summarize_project(root: &Path, num_workers: usize) -> Result { worker.await??; } - let summaries = summaries.lock().await; - Ok(summaries.get(root).cloned().unwrap_or_default()) + progress_bar.finish_with_message("Summarization complete"); + overall_progress.finish_with_message("Project summarization finished"); + + Ok(Arc::try_unwrap(summaries).unwrap().into_inner()) } -fn truncate_to_max_tokens(content: &str, tokenizer: &Tokenizer, max_tokens: usize) -> String { - let encoding = tokenizer.encode(content, false).unwrap(); - if encoding.get_ids().len() <= max_tokens { - content.to_string() - } else { - tokenizer - .decode(&encoding.get_ids()[..max_tokens], false) - .unwrap() +fn split_into_chunks( + content: &str, + tokenizer: &Tokenizer, + chunk_size: usize, + overlap: usize, +) -> Vec { + let mut chunks = Vec::new(); + let lines: Vec<&str> = content.lines().collect(); + let mut current_chunk = String::new(); + let mut current_tokens = 0; + + for line in lines { + let line_tokens = tokenizer.encode(line, false).unwrap().get_ids().len(); + if current_tokens + line_tokens > chunk_size { + chunks.push(current_chunk.clone()); + current_chunk.clear(); + current_tokens = 0; + } + current_chunk.push_str(line); + current_chunk.push('\n'); + current_tokens += line_tokens; } + + if !current_chunk.is_empty() { + chunks.push(current_chunk); + } + + // Add overlap + for i in 1..chunks.len() { + let overlap_text = chunks[i - 1] + .lines() + .rev() + .take(overlap) + .collect::>() + .into_iter() + .rev() + .collect::>() + .join("\n"); + chunks[i] = format!("{}\n{}", overlap_text, chunks[i]); + } + + chunks } -async fn summarize_file(client: &GroqClient, content: &str) -> Result { +async fn summarize_chunks(client: &OllamaClient, chunks: &[String]) -> Result> { + let mut chunk_summaries = Vec::new(); + + for chunk in chunks { + let summary = summarize_file(client, chunk).await?; + chunk_summaries.push(summary); + } + + Ok(chunk_summaries) +} + +async fn summarize_file(client: &OllamaClient, content: &str) -> Result { let messages = vec![ Message { role: "system".to_string(), content: - "You are a code summarization assistant. Provide a brief summary of the given code." - .to_string(), + "You are a code summarization assistant. Provide a brief summary of the given code chunk, focusing on its main functionality and purpose.".to_string(), }, Message { role: "user".to_string(), @@ -227,7 +289,7 @@ async fn summarize_file(client: &GroqClient, content: &str) -> Result { ]; let mut receiver = client - .stream_completion("mixtral-8x7b-32768".to_string(), messages) + .stream_completion("qwen2:0.5b".to_string(), messages) .await?; let mut summary = String::new(); @@ -238,12 +300,22 @@ async fn summarize_file(client: &GroqClient, content: &str) -> Result { Ok(summary) } -async fn combine_summaries(client: &GroqClient, summaries: &[String]) -> Result { +async fn combine_summaries( + client: &OllamaClient, + summaries: &[String], + is_chunk: bool, +) -> Result { let combined_content = summaries.join("\n\n"); + let prompt = if is_chunk { + "You are a code summarization assistant. Combine the given summaries into a single, coherent summary that captures the overall functionality and structure of the code. Ensure that the final summary is comprehensive and reflects the content as if it was summarized from a single, complete file." + } else { + "You are a code summarization assistant. Combine the given summaries of different files or directories into a single, coherent summary that captures the overall structure and functionality of the project or directory. Focus on the relationships between different components and the high-level architecture." + }; + let messages = vec![ Message { role: "system".to_string(), - content: "You are a code summarization assistant. Combine the given summaries into a single, coherent summary.".to_string(), + content: prompt.to_string(), }, Message { role: "user".to_string(), @@ -252,7 +324,7 @@ async fn combine_summaries(client: &GroqClient, summaries: &[String]) -> Result< ]; let mut receiver = client - .stream_completion("mixtral-8x7b-32768".to_string(), messages) + .stream_completion("qwen2:0.5b".to_string(), messages) .await?; let mut combined_summary = String::new(); @@ -280,7 +352,97 @@ async fn main() -> Result<()> { println!("Summarizing project at: {}", project_path.display()); let summary = summarize_project(project_path, 16).await?; - println!("Project Summary:\n{}", summary); + println!("Project Summary:\n{:?}", summary); Ok(()) } + +// #[derive(Debug, Serialize)] +// struct ChatCompletionRequest { +// model: String, +// messages: Vec, +// stream: bool, +// } +// +// #[derive(Debug, Deserialize)] +// struct ChatCompletionChunk { +// choices: Vec, +// } + +// #[derive(Debug, Deserialize)] +// struct Choice { +// delta: Delta, +// } + +// #[derive(Debug, Deserialize)] +// struct Delta { +// content: Option, +// } + +// pub struct GroqClient { +// client: Client, +// api_key: String, +// } + +// impl GroqClient { +// pub fn new(api_key: String) -> Self { +// Self { +// client: Client::new(), +// api_key, +// } +// } + +// async fn stream_completion( +// &self, +// model: String, +// messages: Vec, +// ) -> Result> { +// let (tx, rx) = mpsc::channel(100); + +// let request = ChatCompletionRequest { +// model, +// messages, +// stream: true, +// }; + +// let response = self +// .client +// .post("https://api.groq.com/openai/v1/chat/completions") +// .header("Authorization", format!("Bearer {}", self.api_key)) +// .json(&request) +// .send() +// .await?; + +// if !response.status().is_success() { +// return Err(anyhow!( +// "error streaming completion: {:?}", +// response.text().await? +// )); +// } + +// tokio::spawn(async move { +// let mut stream = response.bytes_stream(); +// while let Some(chunk) = stream.next().await { +// if let Ok(chunk) = chunk { +// if let Ok(text) = String::from_utf8(chunk.to_vec()) { +// for line in text.lines() { +// if line.starts_with("data: ") && line != "data: [DONE]" { +// if let Ok(mut chunk) = +// serde_json::from_str::(&line[6..]) +// { +// if let Some(content) = +// chunk.choices.pop().and_then(|choice| choice.delta.content) +// { +// let _ = tx.send(content).await; +// } +// } +// } +// } +// } +// } +// } +// }); + +// Ok(rx) +// } +// }