Checkpoint
This commit is contained in:
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -6675,13 +6675,13 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.28",
|
||||
"project",
|
||||
"ignore",
|
||||
"indicatif",
|
||||
"reqwest 0.12.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -12,10 +12,10 @@ path = "src/miner.rs"
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
project.workspace = true
|
||||
ignore.workspace = true
|
||||
indicatif = "0.17.8"
|
||||
reqwest = { version = "0.12.5", features = ["json", "stream"] }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
tokenizers = { version = "0.19.1", features = ["http"] }
|
||||
tokio.workspace = true
|
||||
walkdir = "2.5.0"
|
||||
|
||||
@@ -1,22 +1,17 @@
|
||||
use anyhow::Result;
|
||||
use futures::{Stream, StreamExt};
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::StreamExt;
|
||||
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use serde::Serialize;
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap, VecDeque},
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokenizers::FromPretrainedParameters;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::Mutex;
|
||||
use walkdir::WalkDir;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatCompletionRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Message {
|
||||
@@ -24,71 +19,54 @@ struct Message {
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatCompletionChunk {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
delta: Delta,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Delta {
|
||||
content: Option<String>,
|
||||
}
|
||||
|
||||
pub struct GroqClient {
|
||||
pub struct OllamaClient {
|
||||
client: Client,
|
||||
api_key: String,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl GroqClient {
|
||||
pub fn new(api_key: String) -> Self {
|
||||
impl OllamaClient {
|
||||
pub fn new(base_url: String) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
api_key,
|
||||
base_url,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
async fn stream_completion(
|
||||
&self,
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<mpsc::Receiver<String>> {
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model,
|
||||
messages,
|
||||
stream: true,
|
||||
};
|
||||
let request = serde_json::json!({
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": true,
|
||||
});
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("https://api.groq.com/openai/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.post(format!("{}/api/chat", self.base_url))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(anyhow!(
|
||||
"error streaming completion: {:?}",
|
||||
response.text().await?
|
||||
));
|
||||
}
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut stream = response.bytes_stream();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
if let Ok(chunk) = chunk {
|
||||
if let Ok(text) = String::from_utf8(chunk.to_vec()) {
|
||||
for line in text.lines() {
|
||||
if line.starts_with("data: ") && line != "data: [DONE]" {
|
||||
if let Ok(mut chunk) =
|
||||
serde_json::from_str::<ChatCompletionChunk>(&line[6..])
|
||||
{
|
||||
if let Some(content) =
|
||||
chunk.choices.pop().unwrap().delta.content
|
||||
{
|
||||
let _ = tx.send(content).await;
|
||||
}
|
||||
}
|
||||
if let Ok(response) = serde_json::from_str::<serde_json::Value>(&text) {
|
||||
if let Some(content) = response["message"]["content"].as_str() {
|
||||
let _ = tx.send(content.to_string()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -100,15 +78,8 @@ impl GroqClient {
|
||||
}
|
||||
}
|
||||
|
||||
// Below we perform a bottom up traversal over each worktree in the project.
|
||||
// We push an entry for each file and directory into a queue.
|
||||
// We then read from that queue with N workers in a tokio thread pool.
|
||||
// For each file, we perform a summarization with a prompt.
|
||||
// For each directory, we combine the summaries of all its files with a prompt.
|
||||
// If a file is too big, truncate it at the max tokens of mixtral model, 32k.
|
||||
// Use the tokenizers crate to estimate token counts
|
||||
|
||||
const MAX_TOKENS: usize = 32_000;
|
||||
const CHUNK_SIZE: usize = 16_000;
|
||||
const OVERLAP: usize = 2_000;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Entry {
|
||||
@@ -116,11 +87,11 @@ enum Entry {
|
||||
Directory(PathBuf),
|
||||
}
|
||||
|
||||
async fn summarize_project(root: &Path, num_workers: usize) -> Result<String> {
|
||||
async fn summarize_project(root: &Path, num_workers: usize) -> Result<BTreeMap<PathBuf, String>> {
|
||||
let tokenizer = Tokenizer::from_pretrained(
|
||||
"mistralai/Mixtral-8x7B-v0.1",
|
||||
"Qwen/Qwen2-0.5B",
|
||||
Some(FromPretrainedParameters {
|
||||
revision: String::new(),
|
||||
revision: "main".into(),
|
||||
user_agent: HashMap::default(),
|
||||
auth_token: Some(
|
||||
std::env::var("HUGGINGFACE_API_TOKEN").expect("HUGGINGFACE_API_TOKEN not set"),
|
||||
@@ -128,24 +99,44 @@ async fn summarize_project(root: &Path, num_workers: usize) -> Result<String> {
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
let client = Arc::new(GroqClient::new(std::env::var("GROQ_API_KEY")?));
|
||||
let client = Arc::new(OllamaClient::new("http://localhost:11434".into()));
|
||||
let queue = Arc::new(Mutex::new(VecDeque::new()));
|
||||
|
||||
let multi_progress = Arc::new(MultiProgress::new());
|
||||
let overall_progress = multi_progress.add(ProgressBar::new_spinner());
|
||||
overall_progress.set_style(
|
||||
ProgressStyle::default_spinner()
|
||||
.template("{spinner:.green} {msg}")
|
||||
.unwrap(),
|
||||
);
|
||||
overall_progress.set_message("Summarizing project...");
|
||||
|
||||
// Populate the queue with files and directories
|
||||
for entry in WalkDir::new(root)
|
||||
.min_depth(1)
|
||||
.into_iter()
|
||||
.filter_map(|e| e.ok())
|
||||
{
|
||||
let path = entry.path().to_owned();
|
||||
if entry.file_type().is_dir() {
|
||||
queue.lock().await.push_back(Entry::Directory(path));
|
||||
} else {
|
||||
queue.lock().await.push_back(Entry::File(path));
|
||||
let mut walker = ignore::WalkBuilder::new(root)
|
||||
.hidden(true)
|
||||
.ignore(true)
|
||||
.build();
|
||||
while let Some(entry) = walker.next() {
|
||||
if let Ok(entry) = entry {
|
||||
let path = entry.path().to_owned();
|
||||
if entry.file_type().map_or(false, |ft| ft.is_dir()) {
|
||||
queue.lock().await.push_back(Entry::Directory(path));
|
||||
} else {
|
||||
queue.lock().await.push_back(Entry::File(path));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let summaries = Arc::new(Mutex::new(HashMap::new()));
|
||||
let total_entries = queue.lock().await.len();
|
||||
let progress_bar = multi_progress.add(ProgressBar::new(total_entries as u64));
|
||||
progress_bar.set_style(
|
||||
ProgressStyle::default_bar()
|
||||
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
|
||||
.unwrap()
|
||||
.progress_chars("##-"),
|
||||
);
|
||||
|
||||
let summaries = Arc::new(Mutex::new(BTreeMap::new()));
|
||||
|
||||
let workers: Vec<_> = (0..num_workers)
|
||||
.map(|_| {
|
||||
@@ -153,41 +144,67 @@ async fn summarize_project(root: &Path, num_workers: usize) -> Result<String> {
|
||||
let client = Arc::clone(&client);
|
||||
let summaries = Arc::clone(&summaries);
|
||||
let tokenizer = tokenizer.clone();
|
||||
let progress_bar = progress_bar.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(entry) = queue.lock().await.pop_front() {
|
||||
loop {
|
||||
let mut queue_lock = queue.lock().await;
|
||||
let Some(entry) = queue_lock.pop_front() else {
|
||||
break;
|
||||
};
|
||||
|
||||
match entry {
|
||||
Entry::File(path) => {
|
||||
let content = tokio::fs::read_to_string(&path).await?;
|
||||
let truncated_content =
|
||||
truncate_to_max_tokens(&content, &tokenizer, MAX_TOKENS);
|
||||
let summary = summarize_file(&client, &truncated_content).await?;
|
||||
drop(queue_lock);
|
||||
let summary = async {
|
||||
let content = tokio::fs::read_to_string(&path).await?;
|
||||
let chunks =
|
||||
split_into_chunks(&content, &tokenizer, CHUNK_SIZE, OVERLAP);
|
||||
let chunk_summaries = summarize_chunks(&client, &chunks).await?;
|
||||
combine_summaries(&client, &chunk_summaries, true).await
|
||||
};
|
||||
|
||||
let summary = summary
|
||||
.await
|
||||
.unwrap_or_else(|_| "path could not be summarized".into());
|
||||
summaries.lock().await.insert(path, summary);
|
||||
progress_bar.inc(1);
|
||||
}
|
||||
Entry::Directory(path) => {
|
||||
let mut dir_summaries = Vec::new();
|
||||
let mut all_children_summarized = true;
|
||||
for entry in path.read_dir()? {
|
||||
let dir_walker = ignore::WalkBuilder::new(&path)
|
||||
.hidden(true)
|
||||
.ignore(true)
|
||||
.max_depth(Some(1))
|
||||
.build();
|
||||
for entry in dir_walker {
|
||||
if let Ok(entry) = entry {
|
||||
if let Some(summary) = summaries.lock().await.get(&entry.path())
|
||||
{
|
||||
dir_summaries.push(summary.clone());
|
||||
} else {
|
||||
all_children_summarized = false;
|
||||
break;
|
||||
if entry.path() != path {
|
||||
if let Some(summary) =
|
||||
summaries.lock().await.get(entry.path())
|
||||
{
|
||||
dir_summaries.push(summary.clone());
|
||||
} else {
|
||||
all_children_summarized = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if all_children_summarized {
|
||||
drop(queue_lock);
|
||||
let combined_summary =
|
||||
combine_summaries(&client, &dir_summaries).await?;
|
||||
combine_summaries(&client, &dir_summaries, false).await?;
|
||||
summaries.lock().await.insert(path, combined_summary);
|
||||
progress_bar.inc(1);
|
||||
} else {
|
||||
queue.lock().await.push_back(Entry::Directory(path));
|
||||
queue_lock.push_back(Entry::Directory(path));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok::<_, anyhow::Error>(())
|
||||
})
|
||||
})
|
||||
@@ -197,28 +214,73 @@ async fn summarize_project(root: &Path, num_workers: usize) -> Result<String> {
|
||||
worker.await??;
|
||||
}
|
||||
|
||||
let summaries = summaries.lock().await;
|
||||
Ok(summaries.get(root).cloned().unwrap_or_default())
|
||||
progress_bar.finish_with_message("Summarization complete");
|
||||
overall_progress.finish_with_message("Project summarization finished");
|
||||
|
||||
Ok(Arc::try_unwrap(summaries).unwrap().into_inner())
|
||||
}
|
||||
|
||||
fn truncate_to_max_tokens(content: &str, tokenizer: &Tokenizer, max_tokens: usize) -> String {
|
||||
let encoding = tokenizer.encode(content, false).unwrap();
|
||||
if encoding.get_ids().len() <= max_tokens {
|
||||
content.to_string()
|
||||
} else {
|
||||
tokenizer
|
||||
.decode(&encoding.get_ids()[..max_tokens], false)
|
||||
.unwrap()
|
||||
fn split_into_chunks(
|
||||
content: &str,
|
||||
tokenizer: &Tokenizer,
|
||||
chunk_size: usize,
|
||||
overlap: usize,
|
||||
) -> Vec<String> {
|
||||
let mut chunks = Vec::new();
|
||||
let lines: Vec<&str> = content.lines().collect();
|
||||
let mut current_chunk = String::new();
|
||||
let mut current_tokens = 0;
|
||||
|
||||
for line in lines {
|
||||
let line_tokens = tokenizer.encode(line, false).unwrap().get_ids().len();
|
||||
if current_tokens + line_tokens > chunk_size {
|
||||
chunks.push(current_chunk.clone());
|
||||
current_chunk.clear();
|
||||
current_tokens = 0;
|
||||
}
|
||||
current_chunk.push_str(line);
|
||||
current_chunk.push('\n');
|
||||
current_tokens += line_tokens;
|
||||
}
|
||||
|
||||
if !current_chunk.is_empty() {
|
||||
chunks.push(current_chunk);
|
||||
}
|
||||
|
||||
// Add overlap
|
||||
for i in 1..chunks.len() {
|
||||
let overlap_text = chunks[i - 1]
|
||||
.lines()
|
||||
.rev()
|
||||
.take(overlap)
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.rev()
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
chunks[i] = format!("{}\n{}", overlap_text, chunks[i]);
|
||||
}
|
||||
|
||||
chunks
|
||||
}
|
||||
|
||||
async fn summarize_file(client: &GroqClient, content: &str) -> Result<String> {
|
||||
async fn summarize_chunks(client: &OllamaClient, chunks: &[String]) -> Result<Vec<String>> {
|
||||
let mut chunk_summaries = Vec::new();
|
||||
|
||||
for chunk in chunks {
|
||||
let summary = summarize_file(client, chunk).await?;
|
||||
chunk_summaries.push(summary);
|
||||
}
|
||||
|
||||
Ok(chunk_summaries)
|
||||
}
|
||||
|
||||
async fn summarize_file(client: &OllamaClient, content: &str) -> Result<String> {
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "system".to_string(),
|
||||
content:
|
||||
"You are a code summarization assistant. Provide a brief summary of the given code."
|
||||
.to_string(),
|
||||
"You are a code summarization assistant. Provide a brief summary of the given code chunk, focusing on its main functionality and purpose.".to_string(),
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
@@ -227,7 +289,7 @@ async fn summarize_file(client: &GroqClient, content: &str) -> Result<String> {
|
||||
];
|
||||
|
||||
let mut receiver = client
|
||||
.stream_completion("mixtral-8x7b-32768".to_string(), messages)
|
||||
.stream_completion("qwen2:0.5b".to_string(), messages)
|
||||
.await?;
|
||||
|
||||
let mut summary = String::new();
|
||||
@@ -238,12 +300,22 @@ async fn summarize_file(client: &GroqClient, content: &str) -> Result<String> {
|
||||
Ok(summary)
|
||||
}
|
||||
|
||||
async fn combine_summaries(client: &GroqClient, summaries: &[String]) -> Result<String> {
|
||||
async fn combine_summaries(
|
||||
client: &OllamaClient,
|
||||
summaries: &[String],
|
||||
is_chunk: bool,
|
||||
) -> Result<String> {
|
||||
let combined_content = summaries.join("\n\n");
|
||||
let prompt = if is_chunk {
|
||||
"You are a code summarization assistant. Combine the given summaries into a single, coherent summary that captures the overall functionality and structure of the code. Ensure that the final summary is comprehensive and reflects the content as if it was summarized from a single, complete file."
|
||||
} else {
|
||||
"You are a code summarization assistant. Combine the given summaries of different files or directories into a single, coherent summary that captures the overall structure and functionality of the project or directory. Focus on the relationships between different components and the high-level architecture."
|
||||
};
|
||||
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "system".to_string(),
|
||||
content: "You are a code summarization assistant. Combine the given summaries into a single, coherent summary.".to_string(),
|
||||
content: prompt.to_string(),
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
@@ -252,7 +324,7 @@ async fn combine_summaries(client: &GroqClient, summaries: &[String]) -> Result<
|
||||
];
|
||||
|
||||
let mut receiver = client
|
||||
.stream_completion("mixtral-8x7b-32768".to_string(), messages)
|
||||
.stream_completion("qwen2:0.5b".to_string(), messages)
|
||||
.await?;
|
||||
|
||||
let mut combined_summary = String::new();
|
||||
@@ -280,7 +352,97 @@ async fn main() -> Result<()> {
|
||||
|
||||
println!("Summarizing project at: {}", project_path.display());
|
||||
let summary = summarize_project(project_path, 16).await?;
|
||||
println!("Project Summary:\n{}", summary);
|
||||
println!("Project Summary:\n{:?}", summary);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// #[derive(Debug, Serialize)]
|
||||
// struct ChatCompletionRequest {
|
||||
// model: String,
|
||||
// messages: Vec<Message>,
|
||||
// stream: bool,
|
||||
// }
|
||||
//
|
||||
// #[derive(Debug, Deserialize)]
|
||||
// struct ChatCompletionChunk {
|
||||
// choices: Vec<Choice>,
|
||||
// }
|
||||
|
||||
// #[derive(Debug, Deserialize)]
|
||||
// struct Choice {
|
||||
// delta: Delta,
|
||||
// }
|
||||
|
||||
// #[derive(Debug, Deserialize)]
|
||||
// struct Delta {
|
||||
// content: Option<String>,
|
||||
// }
|
||||
|
||||
// pub struct GroqClient {
|
||||
// client: Client,
|
||||
// api_key: String,
|
||||
// }
|
||||
|
||||
// impl GroqClient {
|
||||
// pub fn new(api_key: String) -> Self {
|
||||
// Self {
|
||||
// client: Client::new(),
|
||||
// api_key,
|
||||
// }
|
||||
// }
|
||||
|
||||
// async fn stream_completion(
|
||||
// &self,
|
||||
// model: String,
|
||||
// messages: Vec<Message>,
|
||||
// ) -> Result<mpsc::Receiver<String>> {
|
||||
// let (tx, rx) = mpsc::channel(100);
|
||||
|
||||
// let request = ChatCompletionRequest {
|
||||
// model,
|
||||
// messages,
|
||||
// stream: true,
|
||||
// };
|
||||
|
||||
// let response = self
|
||||
// .client
|
||||
// .post("https://api.groq.com/openai/v1/chat/completions")
|
||||
// .header("Authorization", format!("Bearer {}", self.api_key))
|
||||
// .json(&request)
|
||||
// .send()
|
||||
// .await?;
|
||||
|
||||
// if !response.status().is_success() {
|
||||
// return Err(anyhow!(
|
||||
// "error streaming completion: {:?}",
|
||||
// response.text().await?
|
||||
// ));
|
||||
// }
|
||||
|
||||
// tokio::spawn(async move {
|
||||
// let mut stream = response.bytes_stream();
|
||||
// while let Some(chunk) = stream.next().await {
|
||||
// if let Ok(chunk) = chunk {
|
||||
// if let Ok(text) = String::from_utf8(chunk.to_vec()) {
|
||||
// for line in text.lines() {
|
||||
// if line.starts_with("data: ") && line != "data: [DONE]" {
|
||||
// if let Ok(mut chunk) =
|
||||
// serde_json::from_str::<ChatCompletionChunk>(&line[6..])
|
||||
// {
|
||||
// if let Some(content) =
|
||||
// chunk.choices.pop().and_then(|choice| choice.delta.content)
|
||||
// {
|
||||
// let _ = tx.send(content).await;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// });
|
||||
|
||||
// Ok(rx)
|
||||
// }
|
||||
// }
|
||||
|
||||
Reference in New Issue
Block a user