Checkpoint

This commit is contained in:
Antonio Scandurra
2024-06-27 10:40:22 +02:00
parent df829e50ea
commit f18e9b073b
3 changed files with 275 additions and 113 deletions

4
Cargo.lock generated
View File

@@ -6675,13 +6675,13 @@ version = "0.1.0"
dependencies = [
"anyhow",
"futures 0.3.28",
"project",
"ignore",
"indicatif",
"reqwest 0.12.5",
"serde",
"serde_json",
"tokenizers",
"tokio",
"walkdir",
]
[[package]]

View File

@@ -12,10 +12,10 @@ path = "src/miner.rs"
[dependencies]
anyhow.workspace = true
futures.workspace = true
project.workspace = true
ignore.workspace = true
indicatif = "0.17.8"
reqwest = { version = "0.12.5", features = ["json", "stream"] }
serde.workspace = true
serde_json.workspace = true
tokenizers = { version = "0.19.1", features = ["http"] }
tokio.workspace = true
walkdir = "2.5.0"

View File

@@ -1,22 +1,17 @@
use anyhow::Result;
use futures::{Stream, StreamExt};
use anyhow::{anyhow, Result};
use futures::StreamExt;
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use serde::Serialize;
use std::{
collections::{BTreeMap, HashMap, VecDeque},
path::{Path, PathBuf},
sync::Arc,
};
use tokenizers::tokenizer::Tokenizer;
use tokenizers::FromPretrainedParameters;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use walkdir::WalkDir;
#[derive(Debug, Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<Message>,
stream: bool,
}
#[derive(Debug, Serialize)]
struct Message {
@@ -24,71 +19,54 @@ struct Message {
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionChunk {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
delta: Delta,
}
#[derive(Debug, Deserialize)]
struct Delta {
content: Option<String>,
}
pub struct GroqClient {
pub struct OllamaClient {
client: Client,
api_key: String,
base_url: String,
}
impl GroqClient {
pub fn new(api_key: String) -> Self {
impl OllamaClient {
pub fn new(base_url: String) -> Self {
Self {
client: Client::new(),
api_key,
base_url,
}
}
pub async fn stream_completion(
async fn stream_completion(
&self,
model: String,
messages: Vec<Message>,
) -> Result<mpsc::Receiver<String>> {
let (tx, rx) = mpsc::channel(100);
let request = ChatCompletionRequest {
model,
messages,
stream: true,
};
let request = serde_json::json!({
"model": model,
"messages": messages,
"stream": true,
});
let response = self
.client
.post("https://api.groq.com/openai/v1/chat/completions")
.header("Authorization", format!("Bearer {}", self.api_key))
.post(format!("{}/api/chat", self.base_url))
.json(&request)
.send()
.await?;
if !response.status().is_success() {
return Err(anyhow!(
"error streaming completion: {:?}",
response.text().await?
));
}
tokio::spawn(async move {
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
if let Ok(chunk) = chunk {
if let Ok(text) = String::from_utf8(chunk.to_vec()) {
for line in text.lines() {
if line.starts_with("data: ") && line != "data: [DONE]" {
if let Ok(mut chunk) =
serde_json::from_str::<ChatCompletionChunk>(&line[6..])
{
if let Some(content) =
chunk.choices.pop().unwrap().delta.content
{
let _ = tx.send(content).await;
}
}
if let Ok(response) = serde_json::from_str::<serde_json::Value>(&text) {
if let Some(content) = response["message"]["content"].as_str() {
let _ = tx.send(content.to_string()).await;
}
}
}
@@ -100,15 +78,8 @@ impl GroqClient {
}
}
// Below we perform a bottom up traversal over each worktree in the project.
// We push an entry for each file and directory into a queue.
// We then read from that queue with N workers in a tokio thread pool.
// For each file, we perform a summarization with a prompt.
// For each directory, we combine the summaries of all its files with a prompt.
// If a file is too big, truncate it at the max tokens of mixtral model, 32k.
// Use the tokenizers crate to estimate token counts
const MAX_TOKENS: usize = 32_000;
const CHUNK_SIZE: usize = 16_000;
const OVERLAP: usize = 2_000;
#[derive(Debug)]
enum Entry {
@@ -116,11 +87,11 @@ enum Entry {
Directory(PathBuf),
}
async fn summarize_project(root: &Path, num_workers: usize) -> Result<String> {
async fn summarize_project(root: &Path, num_workers: usize) -> Result<BTreeMap<PathBuf, String>> {
let tokenizer = Tokenizer::from_pretrained(
"mistralai/Mixtral-8x7B-v0.1",
"Qwen/Qwen2-0.5B",
Some(FromPretrainedParameters {
revision: String::new(),
revision: "main".into(),
user_agent: HashMap::default(),
auth_token: Some(
std::env::var("HUGGINGFACE_API_TOKEN").expect("HUGGINGFACE_API_TOKEN not set"),
@@ -128,24 +99,44 @@ async fn summarize_project(root: &Path, num_workers: usize) -> Result<String> {
}),
)
.unwrap();
let client = Arc::new(GroqClient::new(std::env::var("GROQ_API_KEY")?));
let client = Arc::new(OllamaClient::new("http://localhost:11434".into()));
let queue = Arc::new(Mutex::new(VecDeque::new()));
let multi_progress = Arc::new(MultiProgress::new());
let overall_progress = multi_progress.add(ProgressBar::new_spinner());
overall_progress.set_style(
ProgressStyle::default_spinner()
.template("{spinner:.green} {msg}")
.unwrap(),
);
overall_progress.set_message("Summarizing project...");
// Populate the queue with files and directories
for entry in WalkDir::new(root)
.min_depth(1)
.into_iter()
.filter_map(|e| e.ok())
{
let path = entry.path().to_owned();
if entry.file_type().is_dir() {
queue.lock().await.push_back(Entry::Directory(path));
} else {
queue.lock().await.push_back(Entry::File(path));
let mut walker = ignore::WalkBuilder::new(root)
.hidden(true)
.ignore(true)
.build();
while let Some(entry) = walker.next() {
if let Ok(entry) = entry {
let path = entry.path().to_owned();
if entry.file_type().map_or(false, |ft| ft.is_dir()) {
queue.lock().await.push_back(Entry::Directory(path));
} else {
queue.lock().await.push_back(Entry::File(path));
}
}
}
let summaries = Arc::new(Mutex::new(HashMap::new()));
let total_entries = queue.lock().await.len();
let progress_bar = multi_progress.add(ProgressBar::new(total_entries as u64));
progress_bar.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
.unwrap()
.progress_chars("##-"),
);
let summaries = Arc::new(Mutex::new(BTreeMap::new()));
let workers: Vec<_> = (0..num_workers)
.map(|_| {
@@ -153,41 +144,67 @@ async fn summarize_project(root: &Path, num_workers: usize) -> Result<String> {
let client = Arc::clone(&client);
let summaries = Arc::clone(&summaries);
let tokenizer = tokenizer.clone();
let progress_bar = progress_bar.clone();
tokio::spawn(async move {
while let Some(entry) = queue.lock().await.pop_front() {
loop {
let mut queue_lock = queue.lock().await;
let Some(entry) = queue_lock.pop_front() else {
break;
};
match entry {
Entry::File(path) => {
let content = tokio::fs::read_to_string(&path).await?;
let truncated_content =
truncate_to_max_tokens(&content, &tokenizer, MAX_TOKENS);
let summary = summarize_file(&client, &truncated_content).await?;
drop(queue_lock);
let summary = async {
let content = tokio::fs::read_to_string(&path).await?;
let chunks =
split_into_chunks(&content, &tokenizer, CHUNK_SIZE, OVERLAP);
let chunk_summaries = summarize_chunks(&client, &chunks).await?;
combine_summaries(&client, &chunk_summaries, true).await
};
let summary = summary
.await
.unwrap_or_else(|_| "path could not be summarized".into());
summaries.lock().await.insert(path, summary);
progress_bar.inc(1);
}
Entry::Directory(path) => {
let mut dir_summaries = Vec::new();
let mut all_children_summarized = true;
for entry in path.read_dir()? {
let dir_walker = ignore::WalkBuilder::new(&path)
.hidden(true)
.ignore(true)
.max_depth(Some(1))
.build();
for entry in dir_walker {
if let Ok(entry) = entry {
if let Some(summary) = summaries.lock().await.get(&entry.path())
{
dir_summaries.push(summary.clone());
} else {
all_children_summarized = false;
break;
if entry.path() != path {
if let Some(summary) =
summaries.lock().await.get(entry.path())
{
dir_summaries.push(summary.clone());
} else {
all_children_summarized = false;
break;
}
}
}
}
if all_children_summarized {
drop(queue_lock);
let combined_summary =
combine_summaries(&client, &dir_summaries).await?;
combine_summaries(&client, &dir_summaries, false).await?;
summaries.lock().await.insert(path, combined_summary);
progress_bar.inc(1);
} else {
queue.lock().await.push_back(Entry::Directory(path));
queue_lock.push_back(Entry::Directory(path));
}
}
}
}
Ok::<_, anyhow::Error>(())
})
})
@@ -197,28 +214,73 @@ async fn summarize_project(root: &Path, num_workers: usize) -> Result<String> {
worker.await??;
}
let summaries = summaries.lock().await;
Ok(summaries.get(root).cloned().unwrap_or_default())
progress_bar.finish_with_message("Summarization complete");
overall_progress.finish_with_message("Project summarization finished");
Ok(Arc::try_unwrap(summaries).unwrap().into_inner())
}
fn truncate_to_max_tokens(content: &str, tokenizer: &Tokenizer, max_tokens: usize) -> String {
let encoding = tokenizer.encode(content, false).unwrap();
if encoding.get_ids().len() <= max_tokens {
content.to_string()
} else {
tokenizer
.decode(&encoding.get_ids()[..max_tokens], false)
.unwrap()
fn split_into_chunks(
content: &str,
tokenizer: &Tokenizer,
chunk_size: usize,
overlap: usize,
) -> Vec<String> {
let mut chunks = Vec::new();
let lines: Vec<&str> = content.lines().collect();
let mut current_chunk = String::new();
let mut current_tokens = 0;
for line in lines {
let line_tokens = tokenizer.encode(line, false).unwrap().get_ids().len();
if current_tokens + line_tokens > chunk_size {
chunks.push(current_chunk.clone());
current_chunk.clear();
current_tokens = 0;
}
current_chunk.push_str(line);
current_chunk.push('\n');
current_tokens += line_tokens;
}
if !current_chunk.is_empty() {
chunks.push(current_chunk);
}
// Add overlap
for i in 1..chunks.len() {
let overlap_text = chunks[i - 1]
.lines()
.rev()
.take(overlap)
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect::<Vec<_>>()
.join("\n");
chunks[i] = format!("{}\n{}", overlap_text, chunks[i]);
}
chunks
}
async fn summarize_file(client: &GroqClient, content: &str) -> Result<String> {
async fn summarize_chunks(client: &OllamaClient, chunks: &[String]) -> Result<Vec<String>> {
let mut chunk_summaries = Vec::new();
for chunk in chunks {
let summary = summarize_file(client, chunk).await?;
chunk_summaries.push(summary);
}
Ok(chunk_summaries)
}
async fn summarize_file(client: &OllamaClient, content: &str) -> Result<String> {
let messages = vec![
Message {
role: "system".to_string(),
content:
"You are a code summarization assistant. Provide a brief summary of the given code."
.to_string(),
"You are a code summarization assistant. Provide a brief summary of the given code chunk, focusing on its main functionality and purpose.".to_string(),
},
Message {
role: "user".to_string(),
@@ -227,7 +289,7 @@ async fn summarize_file(client: &GroqClient, content: &str) -> Result<String> {
];
let mut receiver = client
.stream_completion("mixtral-8x7b-32768".to_string(), messages)
.stream_completion("qwen2:0.5b".to_string(), messages)
.await?;
let mut summary = String::new();
@@ -238,12 +300,22 @@ async fn summarize_file(client: &GroqClient, content: &str) -> Result<String> {
Ok(summary)
}
async fn combine_summaries(client: &GroqClient, summaries: &[String]) -> Result<String> {
async fn combine_summaries(
client: &OllamaClient,
summaries: &[String],
is_chunk: bool,
) -> Result<String> {
let combined_content = summaries.join("\n\n");
let prompt = if is_chunk {
"You are a code summarization assistant. Combine the given summaries into a single, coherent summary that captures the overall functionality and structure of the code. Ensure that the final summary is comprehensive and reflects the content as if it was summarized from a single, complete file."
} else {
"You are a code summarization assistant. Combine the given summaries of different files or directories into a single, coherent summary that captures the overall structure and functionality of the project or directory. Focus on the relationships between different components and the high-level architecture."
};
let messages = vec![
Message {
role: "system".to_string(),
content: "You are a code summarization assistant. Combine the given summaries into a single, coherent summary.".to_string(),
content: prompt.to_string(),
},
Message {
role: "user".to_string(),
@@ -252,7 +324,7 @@ async fn combine_summaries(client: &GroqClient, summaries: &[String]) -> Result<
];
let mut receiver = client
.stream_completion("mixtral-8x7b-32768".to_string(), messages)
.stream_completion("qwen2:0.5b".to_string(), messages)
.await?;
let mut combined_summary = String::new();
@@ -280,7 +352,97 @@ async fn main() -> Result<()> {
println!("Summarizing project at: {}", project_path.display());
let summary = summarize_project(project_path, 16).await?;
println!("Project Summary:\n{}", summary);
println!("Project Summary:\n{:?}", summary);
Ok(())
}
// #[derive(Debug, Serialize)]
// struct ChatCompletionRequest {
// model: String,
// messages: Vec<Message>,
// stream: bool,
// }
//
// #[derive(Debug, Deserialize)]
// struct ChatCompletionChunk {
// choices: Vec<Choice>,
// }
// #[derive(Debug, Deserialize)]
// struct Choice {
// delta: Delta,
// }
// #[derive(Debug, Deserialize)]
// struct Delta {
// content: Option<String>,
// }
// pub struct GroqClient {
// client: Client,
// api_key: String,
// }
// impl GroqClient {
// pub fn new(api_key: String) -> Self {
// Self {
// client: Client::new(),
// api_key,
// }
// }
// async fn stream_completion(
// &self,
// model: String,
// messages: Vec<Message>,
// ) -> Result<mpsc::Receiver<String>> {
// let (tx, rx) = mpsc::channel(100);
// let request = ChatCompletionRequest {
// model,
// messages,
// stream: true,
// };
// let response = self
// .client
// .post("https://api.groq.com/openai/v1/chat/completions")
// .header("Authorization", format!("Bearer {}", self.api_key))
// .json(&request)
// .send()
// .await?;
// if !response.status().is_success() {
// return Err(anyhow!(
// "error streaming completion: {:?}",
// response.text().await?
// ));
// }
// tokio::spawn(async move {
// let mut stream = response.bytes_stream();
// while let Some(chunk) = stream.next().await {
// if let Ok(chunk) = chunk {
// if let Ok(text) = String::from_utf8(chunk.to_vec()) {
// for line in text.lines() {
// if line.starts_with("data: ") && line != "data: [DONE]" {
// if let Ok(mut chunk) =
// serde_json::from_str::<ChatCompletionChunk>(&line[6..])
// {
// if let Some(content) =
// chunk.choices.pop().and_then(|choice| choice.delta.content)
// {
// let _ = tx.send(content).await;
// }
// }
// }
// }
// }
// }
// }
// });
// Ok(rx)
// }
// }