Compare commits

...

1 Commits

Author SHA1 Message Date
Antonio Scandurra
3ca8c7f5a7 WIP: Start on reworking eval types and nomenclature 2025-04-22 18:26:11 +02:00
2 changed files with 518 additions and 488 deletions

View File

@@ -4,6 +4,7 @@ mod tool_metrics;
pub(crate) use example::*;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
pub(crate) use tool_metrics::*;
use ::fs::RealFs;
@@ -16,11 +17,13 @@ use futures::future;
use gpui::http_client::{Uri, read_proxy_from_env};
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
use gpui_tokio::Tokio;
use language::LanguageRegistry;
use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
use language::{Diagnostic, LanguageRegistry};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest, TokenUsage,
};
use node_runtime::{NodeBinaryOptions, NodeRuntime};
use project::Project;
use project::project_settings::ProjectSettings;
use project::{DiagnosticSummary, Project};
use prompt_store::PromptBuilder;
use release_channel::AppVersion;
use reqwest_client::ReqwestClient;
@@ -271,25 +274,7 @@ fn main() {
let Some(mut example) = examples.lock().pop_front() else {
break;
};
let result = async {
example.setup().await?;
let run_output = cx
.update(|cx| example.run(model.clone(), app_state.clone(), cx))?
.await?;
let judge_output = judge_example(
example.clone(),
model.clone(),
&zed_commit_sha,
&zed_branch_name,
&run_id,
&run_output,
enable_telemetry,
cx,
)
.await;
anyhow::Ok((run_output, judge_output))
}
.await;
let result = example.evaluate();
results
.lock()
.entry(example.name.clone())
@@ -414,19 +399,21 @@ fn list_all_examples(examples_dir: &Path) -> Result<Vec<PathBuf>> {
Ok(result_paths)
}
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
pub struct AgentAppState {
/// GPUI application state for the eval binary
pub struct EvalAppState {
pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>,
pub user_store: Entity<UserStore>,
pub fs: Arc<dyn fs::Fs>,
pub node_runtime: NodeRuntime,
// Additional fields not present in `workspace::AppState`.
pub commit_sha: String,
pub branch_name: String,
pub run_id: String,
pub enable_telemetry: bool,
pub prompt_builder: Arc<PromptBuilder>,
}
pub fn init(cx: &mut App) -> Arc<AgentAppState> {
pub fn init(cx: &mut App) -> Arc<EvalAppState> {
release_channel::init(SemanticVersion::default(), cx);
gpui_tokio::init(cx);
@@ -521,7 +508,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
})
.unwrap();
Arc::new(AgentAppState {
Arc::new(EvalAppState {
languages,
client,
user_store,
@@ -569,55 +556,69 @@ pub fn git_branch_for_path(repo_path: &Path) -> String {
}
}
async fn judge_example(
example: Example,
model: Arc<dyn LanguageModel>,
zed_commit_sha: &str,
zed_branch_name: &str,
run_id: &str,
run_output: &RunOutput,
enable_telemetry: bool,
cx: &AsyncApp,
) -> Result<JudgeOutput> {
let judge_output = example.judge(model.clone(), &run_output, cx).await;
let diff_evaluation;
let thread_evaluation;
if let Ok(output) = judge_output.as_ref() {
diff_evaluation = Some(output.diff.clone());
thread_evaluation = Some(output.thread.clone());
} else {
diff_evaluation = None;
thread_evaluation = None;
}
if enable_telemetry {
telemetry::event!(
"Agent Example Evaluated",
zed_commit_sha = zed_commit_sha,
zed_branch_name = zed_branch_name,
run_id = run_id,
example_name = example.name.clone(),
example_repetition = example.repetition,
diff_evaluation = diff_evaluation,
thread_evaluation = thread_evaluation,
tool_metrics = run_output.tool_metrics,
response_count = run_output.response_count,
token_usage = run_output.token_usage,
model = model.telemetry_id(),
model_provider = model.provider_id().to_string(),
repository_url = example.base.url.clone(),
repository_revision = example.base.revision.clone(),
diagnostic_summary_before = run_output.diagnostic_summary_before,
diagnostic_summary_after = run_output.diagnostic_summary_after,
diagnostics_before = run_output.diagnostics_before,
diagnostics_after = run_output.diagnostics_after,
);
}
judge_output
#[derive(Debug, Default, Serialize, Clone)]
pub struct Sample {
pub repository_diff: String,
pub ran_diagnostics_check: bool,
pub diagnostic_summary_before: DiagnosticSummary,
pub diagnostic_summary_after: DiagnosticSummary,
pub diagnostics_before: Option<String>,
pub diagnostics_after: Option<String>,
pub response_count: usize,
pub token_usage: TokenUsage,
pub tool_metrics: ToolMetrics,
pub last_request: LanguageModelRequest,
pub error: Option<SamplingError>,
}
#[derive(Debug, Serialize, Clone)]
struct SamplingError {
message: String,
full_stack: String,
}
#[derive(Clone, Serialize)]
struct Evaluation {
metadata: EvaluationMetadata,
sample: Sample,
diff_evaluation: DiffEvaluation,
thread_evaluation: ThreadEvaluation,
error: Option<String>,
}
#[derive(Clone, Serialize)]
struct EvaluationMetadata {
zed_commit_sha: String,
zed_branch_name: String,
run_id: String,
example_name: String,
example_repetition: usize,
model: String,
model_provider: String,
repository_url: String,
repository_revision: String,
}
#[derive(Clone, Default, Serialize, Deserialize)]
struct DiffEvaluation {
assertions: Vec<EvaluatedAssertion>,
}
#[derive(Clone, Default, Serialize, Deserialize)]
struct ThreadEvaluation {
assertions: Vec<EvaluatedAssertion>,
}
#[derive(Clone, Serialize, Deserialize)]
struct EvaluatedAssertion {
assertion: Assertion,
passed: bool,
analysis: String,
}
#[derive(Clone, Serialize, Deserialize)]
struct Assertion(String);
fn print_header(header: &str) {
println!("\n========================================");
println!("{:^40}", header);

View File

@@ -1,4 +1,6 @@
use crate::{AgentAppState, ToolMetrics};
use crate::{
DiffEvaluation, EvalAppState, Evaluation, Sample, SamplingError, ThreadEvaluation, ToolMetrics,
};
use agent::{ThreadEvent, ThreadStore};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::ToolWorkingSet;
@@ -81,45 +83,6 @@ pub struct Example {
worktrees_dir: PathBuf,
}
#[derive(Debug, Serialize, Clone)]
pub struct RunOutput {
pub repository_diff: String,
pub ran_diagnostics_check: bool,
pub diagnostic_summary_before: DiagnosticSummary,
pub diagnostic_summary_after: DiagnosticSummary,
pub diagnostics_before: Option<String>,
pub diagnostics_after: Option<String>,
pub response_count: usize,
pub token_usage: TokenUsage,
pub tool_metrics: ToolMetrics,
pub last_request: LanguageModelRequest,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JudgeDiffInput {
pub repository_diff: String,
pub criteria: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JudgeThreadInput {
pub messages: String,
pub criteria: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JudgeResponse {
pub analysis: String,
pub passing_criteria: u32,
pub total_criteria: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JudgeOutput {
pub thread: JudgeResponse,
pub diff: JudgeResponse,
}
impl Example {
/// Load an example from a directory containing base.toml, prompt.md, and criteria.md
pub fn load_from_directory(
@@ -186,6 +149,81 @@ impl Example {
);
}
pub async fn evaluate(
&self,
model: Arc<dyn LanguageModel>,
app_state: Arc<EvalAppState>,
cx: &AsyncApp,
) -> Result<Evaluation> {
let evaluation = async {
self.setup().await?;
let sample = cx
.update(|cx| self.sample(model.clone(), app_state.clone(), cx))?
.await?;
self.judge(model, &sample, cx).await
}
.await;
// let judge_output = example.judge(model.clone(), &run_output, cx).await;
// if enable_telemetry {
// telemetry::event!(
// "Agent Example Evaluated",
// zed_commit_sha = zed_commit_sha,
// zed_branch_name = zed_branch_name,
// run_id = run_id,
// example_name = example.name.clone(),
// example_repetition = example.repetition,
// diff_evaluation = diff_evaluation,
// thread_evaluation = thread_evaluation,
// tool_metrics = run_output.tool_metrics,
// response_count = run_output.response_count,
// token_usage = run_output.token_usage,
// model = model.telemetry_id(),
// model_provider = model.provider_id().to_string(),
// repository_url = example.base.url.clone(),
// repository_revision = example.base.revision.clone(),
// diagnostic_summary_before = run_output.diagnostic_summary_before,
// diagnostic_summary_after = run_output.diagnostic_summary_after,
// diagnostics_before = run_output.diagnostics_before,
// diagnostics_after = run_output.diagnostics_after,
// error = run_output.error,
// );
Ok(evaluation)
}
pub async fn judge(
&self,
model: Arc<dyn LanguageModel>,
sample: Sample,
cx: &AsyncApp,
) -> Result<Evaluation> {
let mut output_file = File::create(self.run_directory_path().join("judge.md"))
.expect("failed to create judge.md");
println!("{}Running judge", self.log_prefix);
let thread_evaluation = self.judge_thread(model.clone(), &sample, cx).await;
let diff_evaluation = self.judge_diff(model.clone(), &sample, cx).await;
let (diff_result, thread_result) = futures::join!(diff_task, thread_task);
writeln!(
&mut output_file,
"# Judgment\n\n## Thread\n\n{thread_evaluation}\n\n## Diff\n\n{diff_evaluation}",
)
.log_err();
Ok(Evaluation {
metadata: todo!(),
sample,
diff_evaluation,
thread_evaluation,
error,
})
}
/// Set up the example by checking out the specified Git revision
pub async fn fetch(&mut self) -> Result<()> {
let revision_exists = run_git(
@@ -247,12 +285,12 @@ impl Example {
Ok(())
}
pub fn run(
pub fn sample(
&self,
model: Arc<dyn LanguageModel>,
app_state: Arc<AgentAppState>,
app_state: Arc<EvalAppState>,
cx: &mut App,
) -> Task<Result<RunOutput>> {
) -> Task<Sample> {
let project = Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
@@ -272,401 +310,392 @@ impl Example {
ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
let this = self.clone();
let task = cx.spawn(async move |cx| anyhow::Ok(()));
// cx.spawn(async move |cx| {
// let worktree = worktree.await?;
// // Wait for worktree scan to finish before choosing a file to open.
// worktree
// .update(cx, |worktree, _cx| {
// worktree.as_local().unwrap().scan_complete()
// })?
// .await;
// let lsp = if this.base.require_lsp {
// let language_extension = this.base.language_extension.as_deref().context(
// "language_extension field is required in base.toml when `require_lsp == true`",
// )?;
// // Open a file that matches the language to cause LSP to start.
// let language_file = worktree.read_with(cx, |worktree, _cx| {
// worktree
// .files(false, 0)
// .find_map(|e| {
// if e.path.clone().extension().and_then(|ext| ext.to_str())
// == Some(language_extension)
// {
// Some(ProjectPath {
// worktree_id: worktree.id(),
// path: e.path.clone(),
// })
// } else {
// None
// }
// })
// .context("Failed to find a file for example language")
// })??;
// let open_language_file_buffer_task = project.update(cx, |project, cx| {
// project.open_buffer(language_file.clone(), cx)
// })?;
// let language_file_buffer = open_language_file_buffer_task.await?;
// let lsp_open_handle = project.update(cx, |project, cx| {
// project.register_buffer_with_language_servers(&language_file_buffer, cx)
// })?;
// wait_for_lang_server(&project, &language_file_buffer, this.log_prefix.clone(), cx).await?;
// Some((lsp_open_handle, language_file_buffer))
// } else {
// None
// };
// let diagnostic_summary_before = project.read_with(cx, |project, cx| {
// project.diagnostic_summary(false, cx)
// })?;
// let diagnostics_before = query_lsp_diagnostics(project.clone(), cx).await?;
// if diagnostics_before.is_some() && !this.base.allow_preexisting_diagnostics {
// return Err(anyhow!("Example has pre-existing diagnostics. If you want to run this example regardless, set `allow_preexisting_diagnostics` to `true` in `base.toml`"));
// }
// if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
// return Err(anyhow!("Setup only mode"));
// }
// let example_output_dir = this.run_directory_path();
// let last_diff_file_path = example_output_dir.join("last.diff");
// // Write an empty "last.diff" so that it can be opened in Zed for convenient view of the
// // history using undo/redo.
// std::fs::write(&last_diff_file_path, "")?;
// let thread_store = thread_store.await?;
// let thread =
// thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
// let last_request = Rc::new(RefCell::new(None));
// thread.update(cx, |thread, _cx| {
// let mut request_count = 0;
// let last_request = Rc::clone(&last_request);
// let previous_diff = Rc::new(RefCell::new("".to_string()));
// let example_output_dir = example_output_dir.clone();
// let last_diff_file_path = last_diff_file_path.clone();
// let this = this.clone();
// thread.set_request_callback(move |request, response_events| {
// *last_request.borrow_mut() = Some(request.clone());
// request_count += 1;
// let messages_file_path = example_output_dir.join(format!("{request_count}.messages.md"));
// let diff_file_path = example_output_dir.join(format!("{request_count}.diff"));
// let last_messages_file_path = example_output_dir.join("last.messages.md");
// let request_markdown = RequestMarkdown::new(request);
// let response_events_markdown = response_events_to_markdown(response_events);
// let messages = format!("{}\n\n{}", request_markdown.messages, response_events_markdown);
// fs::write(&messages_file_path, messages.clone()).expect("failed to write messages file");
// fs::write(&last_messages_file_path, messages).expect("failed to write last messages file");
// let diff_result = smol::block_on(this.repository_diff());
// match diff_result {
// Ok(diff) => {
// if diff != previous_diff.borrow().clone() {
// fs::write(&diff_file_path, &diff).expect("failed to write diff file");
// fs::write(&last_diff_file_path, &diff).expect("failed to write last diff file");
// *previous_diff.borrow_mut() = diff;
// }
// }
// Err(err) => {
// let error_message = format!("{err:?}");
// fs::write(&diff_file_path, &error_message).expect("failed to write diff error to file");
// fs::write(&last_diff_file_path, &error_message).expect("failed to write last diff file");
// }
// }
// if request_count == 1 {
// let tools_file_path = example_output_dir.join("tools.md");
// fs::write(tools_file_path, request_markdown.tools).expect("failed to write tools file");
// }
// });
// })?;
// let tool_metrics = Arc::new(Mutex::new(ToolMetrics::default()));
// let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded();
// let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| {
// thread_event_tx.unbounded_send(event.clone()).log_err();
// });
// let event_handler_task = cx.spawn({
// let log_prefix = this.log_prefix.clone();
// let tool_metrics = tool_metrics.clone();
// let thread = thread.downgrade();
// async move |cx| {
// loop {
// let event = select_biased! {
// event = thread_event_rx.next() => event,
// _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
// return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
// }
// };
// let Some(event) = event else {
// return Err(anyhow!("ThreadEvent channel ended early"));
// };
// match event {
// ThreadEvent::Stopped(reason) => match reason {
// Ok(StopReason::EndTurn) => {
// return Ok(());
// }
// Ok(StopReason::MaxTokens) => {
// return Err(anyhow!("Exceeded maximum tokens"));
// }
// Ok(StopReason::ToolUse) => {
// if std::env::var("ZED_EVAL_DEBUG").is_ok() {
// println!("{}StopReason: Tool use", log_prefix);
// }
// }
// Err(error) => {
// return Err(anyhow!(error.clone()));
// }
// },
// ThreadEvent::ShowError(thread_error) => {
// break Err(anyhow!(thread_error.clone()));
// }
// ThreadEvent::StreamedAssistantText(_, _)| ThreadEvent::StreamedAssistantThinking(_, _) | ThreadEvent::UsePendingTools { .. } => {
// }
// ThreadEvent::ToolFinished {
// tool_use_id,
// pending_tool_use,
// ..
// } => {
// thread.update(cx, |thread, _cx| {
// if let Some(tool_use) = pending_tool_use {
// let mut tool_metrics = tool_metrics.lock().unwrap();
// if let Some(tool_result) = thread.tool_result(&tool_use_id) {
// let message = if tool_result.is_error {
// format!("TOOL FAILED: {}", tool_use.name)
// } else {
// format!("TOOL FINISHED: {}", tool_use.name)
// };
// println!("{log_prefix}{message}");
// tool_metrics.insert(tool_result.tool_name.clone(), !tool_result.is_error);
// } else {
// let message = format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
// println!("{log_prefix}{message}");
// tool_metrics.insert(tool_use.name.clone(), true);
// }
// }
// })?;
// }
// ThreadEvent::ToolConfirmationNeeded => {
// panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix);
// },
// ThreadEvent::StreamedToolUse { .. } |
// ThreadEvent::StreamedCompletion |
// ThreadEvent::MessageAdded(_) |
// ThreadEvent::MessageEdited(_) |
// ThreadEvent::MessageDeleted(_) |
// ThreadEvent::SummaryChanged |
// ThreadEvent::SummaryGenerated |
// ThreadEvent::CheckpointChanged |
// ThreadEvent::ReceivedTextChunk |
// ThreadEvent::UsageUpdated(_) => {
// if std::env::var("ZED_EVAL_DEBUG").is_ok() {
// println!("{}Event: {:#?}", log_prefix, event);
// }
// }
// }
// }
// }
// });
// thread.update(cx, |thread, cx| {
// let context = vec![];
// thread.insert_user_message(this.prompt.clone(), context, None, cx);
// thread.send_to_model(model, cx);
// })?;
// event_handler_task.await?;
// println!("{}Stopped", this.log_prefix);
// if let Some((_, language_file_buffer)) = lsp.as_ref() {
// wait_for_lang_server(&project, &language_file_buffer, this.log_prefix.clone(), cx).await?;
// }
// println!("{}Getting repository diff", this.log_prefix);
// let repository_diff = this.repository_diff().await?;
// std::fs::write(last_diff_file_path, &repository_diff)?;
// println!("{}Getting diagnostics", this.log_prefix);
// let diagnostic_summary_after = project.read_with(cx, |project, cx| {
// project.diagnostic_summary(false, cx)
// })?;
// let diagnostics_after = cx
// .update(move |cx| {
// cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
// })?
// .await?;
// println!("{}Got diagnostics", this.log_prefix);
// let Some(last_request) = last_request.borrow_mut().take() else {
// return Err(anyhow!("No requests ran."));
// };
// drop(subscription);
// drop(lsp);
// if let Some(diagnostics_before) = &diagnostics_before {
// fs::write(example_output_dir.join("diagnostics_before.txt"), diagnostics_before)?;
// }
// if let Some(diagnostics_after) = &diagnostics_after {
// fs::write(example_output_dir.join("diagnostics_after.txt"), diagnostics_after)?;
// }
// thread.update(cx, |thread, _cx| {
// let response_count = thread
// .messages()
// .filter(|message| message.role == language_model::Role::Assistant)
// .count();
// Sample {
// repository_diff,
// ran_diagnostics_check: this.base.require_lsp,
// diagnostic_summary_before,
// diagnostic_summary_after,
// diagnostics_before,
// diagnostics_after,
// response_count,
// token_usage: thread.cumulative_token_usage(),
// tool_metrics: tool_metrics.lock().unwrap().clone(),
// last_request,
// }
// })
// })
cx.spawn(async move |cx| {
let worktree = worktree.await?;
// Wait for worktree scan to finish before choosing a file to open.
worktree
.update(cx, |worktree, _cx| {
worktree.as_local().unwrap().scan_complete()
})?
.await;
let lsp = if this.base.require_lsp {
let language_extension = this.base.language_extension.as_deref().context(
"language_extension field is required in base.toml when `require_lsp == true`",
)?;
// Open a file that matches the language to cause LSP to start.
let language_file = worktree.read_with(cx, |worktree, _cx| {
worktree
.files(false, 0)
.find_map(|e| {
if e.path.clone().extension().and_then(|ext| ext.to_str())
== Some(language_extension)
{
Some(ProjectPath {
worktree_id: worktree.id(),
path: e.path.clone(),
})
} else {
None
}
})
.context("Failed to find a file for example language")
})??;
let open_language_file_buffer_task = project.update(cx, |project, cx| {
project.open_buffer(language_file.clone(), cx)
})?;
let language_file_buffer = open_language_file_buffer_task.await?;
let lsp_open_handle = project.update(cx, |project, cx| {
project.register_buffer_with_language_servers(&language_file_buffer, cx)
})?;
wait_for_lang_server(&project, &language_file_buffer, this.log_prefix.clone(), cx).await?;
Some((lsp_open_handle, language_file_buffer))
} else {
None
};
let diagnostic_summary_before = project.read_with(cx, |project, cx| {
project.diagnostic_summary(false, cx)
})?;
let diagnostics_before = query_lsp_diagnostics(project.clone(), cx).await?;
if diagnostics_before.is_some() && !this.base.allow_preexisting_diagnostics {
return Err(anyhow!("Example has pre-existing diagnostics. If you want to run this example regardless, set `allow_preexisting_diagnostics` to `true` in `base.toml`"));
let result = task.await;
Sample {
repository_diff: todo!(),
ran_diagnostics_check: todo!(),
diagnostic_summary_before: todo!(),
diagnostic_summary_after: todo!(),
diagnostics_before: todo!(),
diagnostics_after: todo!(),
response_count: todo!(),
token_usage: todo!(),
tool_metrics: todo!(),
last_request: todo!(),
error: result.err().map(|error| SamplingError {
message: error.to_string(),
full_stack: error.backtrace().to_string(),
}),
}
if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
return Err(anyhow!("Setup only mode"));
}
let example_output_dir = this.run_directory_path();
let last_diff_file_path = example_output_dir.join("last.diff");
// Write an empty "last.diff" so that it can be opened in Zed for convenient view of the
// history using undo/redo.
std::fs::write(&last_diff_file_path, "")?;
let thread_store = thread_store.await?;
let thread =
thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
let last_request = Rc::new(RefCell::new(None));
thread.update(cx, |thread, _cx| {
let mut request_count = 0;
let last_request = Rc::clone(&last_request);
let previous_diff = Rc::new(RefCell::new("".to_string()));
let example_output_dir = example_output_dir.clone();
let last_diff_file_path = last_diff_file_path.clone();
let this = this.clone();
thread.set_request_callback(move |request, response_events| {
*last_request.borrow_mut() = Some(request.clone());
request_count += 1;
let messages_file_path = example_output_dir.join(format!("{request_count}.messages.md"));
let diff_file_path = example_output_dir.join(format!("{request_count}.diff"));
let last_messages_file_path = example_output_dir.join("last.messages.md");
let request_markdown = RequestMarkdown::new(request);
let response_events_markdown = response_events_to_markdown(response_events);
let messages = format!("{}\n\n{}", request_markdown.messages, response_events_markdown);
fs::write(&messages_file_path, messages.clone()).expect("failed to write messages file");
fs::write(&last_messages_file_path, messages).expect("failed to write last messages file");
let diff_result = smol::block_on(this.repository_diff());
match diff_result {
Ok(diff) => {
if diff != previous_diff.borrow().clone() {
fs::write(&diff_file_path, &diff).expect("failed to write diff file");
fs::write(&last_diff_file_path, &diff).expect("failed to write last diff file");
*previous_diff.borrow_mut() = diff;
}
}
Err(err) => {
let error_message = format!("{err:?}");
fs::write(&diff_file_path, &error_message).expect("failed to write diff error to file");
fs::write(&last_diff_file_path, &error_message).expect("failed to write last diff file");
}
}
if request_count == 1 {
let tools_file_path = example_output_dir.join("tools.md");
fs::write(tools_file_path, request_markdown.tools).expect("failed to write tools file");
}
});
})?;
let tool_metrics = Arc::new(Mutex::new(ToolMetrics::default()));
let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded();
let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| {
thread_event_tx.unbounded_send(event.clone()).log_err();
});
let event_handler_task = cx.spawn({
let log_prefix = this.log_prefix.clone();
let tool_metrics = tool_metrics.clone();
let thread = thread.downgrade();
async move |cx| {
loop {
let event = select_biased! {
event = thread_event_rx.next() => event,
_ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
}
};
let Some(event) = event else {
return Err(anyhow!("ThreadEvent channel ended early"));
};
match event {
ThreadEvent::Stopped(reason) => match reason {
Ok(StopReason::EndTurn) => {
return Ok(());
}
Ok(StopReason::MaxTokens) => {
return Err(anyhow!("Exceeded maximum tokens"));
}
Ok(StopReason::ToolUse) => {
if std::env::var("ZED_EVAL_DEBUG").is_ok() {
println!("{}StopReason: Tool use", log_prefix);
}
}
Err(error) => {
return Err(anyhow!(error.clone()));
}
},
ThreadEvent::ShowError(thread_error) => {
break Err(anyhow!(thread_error.clone()));
}
ThreadEvent::StreamedAssistantText(_, _)| ThreadEvent::StreamedAssistantThinking(_, _) | ThreadEvent::UsePendingTools { .. } => {
}
ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
..
} => {
thread.update(cx, |thread, _cx| {
if let Some(tool_use) = pending_tool_use {
let mut tool_metrics = tool_metrics.lock().unwrap();
if let Some(tool_result) = thread.tool_result(&tool_use_id) {
let message = if tool_result.is_error {
format!("TOOL FAILED: {}", tool_use.name)
} else {
format!("TOOL FINISHED: {}", tool_use.name)
};
println!("{log_prefix}{message}");
tool_metrics.insert(tool_result.tool_name.clone(), !tool_result.is_error);
} else {
let message = format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
println!("{log_prefix}{message}");
tool_metrics.insert(tool_use.name.clone(), true);
}
}
})?;
}
ThreadEvent::ToolConfirmationNeeded => {
panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix);
},
ThreadEvent::StreamedToolUse { .. } |
ThreadEvent::StreamedCompletion |
ThreadEvent::MessageAdded(_) |
ThreadEvent::MessageEdited(_) |
ThreadEvent::MessageDeleted(_) |
ThreadEvent::SummaryChanged |
ThreadEvent::SummaryGenerated |
ThreadEvent::CheckpointChanged |
ThreadEvent::ReceivedTextChunk |
ThreadEvent::UsageUpdated(_) => {
if std::env::var("ZED_EVAL_DEBUG").is_ok() {
println!("{}Event: {:#?}", log_prefix, event);
}
}
}
}
}
});
thread.update(cx, |thread, cx| {
let context = vec![];
thread.insert_user_message(this.prompt.clone(), context, None, cx);
thread.send_to_model(model, cx);
})?;
event_handler_task.await?;
println!("{}Stopped", this.log_prefix);
if let Some((_, language_file_buffer)) = lsp.as_ref() {
wait_for_lang_server(&project, &language_file_buffer, this.log_prefix.clone(), cx).await?;
}
println!("{}Getting repository diff", this.log_prefix);
let repository_diff = this.repository_diff().await?;
std::fs::write(last_diff_file_path, &repository_diff)?;
println!("{}Getting diagnostics", this.log_prefix);
let diagnostic_summary_after = project.read_with(cx, |project, cx| {
project.diagnostic_summary(false, cx)
})?;
let diagnostics_after = cx
.update(move |cx| {
cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
})?
.await?;
println!("{}Got diagnostics", this.log_prefix);
let Some(last_request) = last_request.borrow_mut().take() else {
return Err(anyhow!("No requests ran."));
};
drop(subscription);
drop(lsp);
if let Some(diagnostics_before) = &diagnostics_before {
fs::write(example_output_dir.join("diagnostics_before.txt"), diagnostics_before)?;
}
if let Some(diagnostics_after) = &diagnostics_after {
fs::write(example_output_dir.join("diagnostics_after.txt"), diagnostics_after)?;
}
thread.update(cx, |thread, _cx| {
let response_count = thread
.messages()
.filter(|message| message.role == language_model::Role::Assistant)
.count();
RunOutput {
repository_diff,
ran_diagnostics_check: this.base.require_lsp,
diagnostic_summary_before,
diagnostic_summary_after,
diagnostics_before,
diagnostics_after,
response_count,
token_usage: thread.cumulative_token_usage(),
tool_metrics: tool_metrics.lock().unwrap().clone(),
last_request,
}
})
})
}
async fn judge_diff(
&self,
model: Arc<dyn LanguageModel>,
run_output: &RunOutput,
run_output: &Sample,
cx: &AsyncApp,
) -> Result<(String, JudgeResponse)> {
let judge_diff_prompt = include_str!("judge_diff_prompt.hbs");
let judge_diff_prompt_name = "judge_diff_prompt";
let mut hbs = Handlebars::new();
hbs.register_template_string(judge_diff_prompt_name, judge_diff_prompt)?;
) -> Result<DiffEvaluation> {
todo!()
// let judge_diff_prompt = include_str!("judge_diff_prompt.hbs");
// let judge_diff_prompt_name = "judge_diff_prompt";
// let mut hbs = Handlebars::new();
// hbs.register_template_string(judge_diff_prompt_name, judge_diff_prompt)?;
let diff_prompt = hbs.render(
judge_diff_prompt_name,
&JudgeDiffInput {
repository_diff: run_output.repository_diff.clone(),
criteria: self.diff_criteria.clone(),
},
)?;
// let diff_prompt = hbs.render(
// judge_diff_prompt_name,
// &JudgeDiffInput {
// repository_diff: run_output.repository_diff.clone(),
// criteria: self.diff_criteria.clone(),
// },
// )?;
let request = LanguageModelRequest {
thread_id: None,
prompt_id: None,
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::Text(diff_prompt)],
cache: false,
}],
temperature: None,
tools: Vec::new(),
stop: Vec::new(),
};
// let request = LanguageModelRequest {
// thread_id: None,
// prompt_id: None,
// messages: vec![LanguageModelRequestMessage {
// role: Role::User,
// content: vec![MessageContent::Text(diff_prompt)],
// cache: false,
// }],
// temperature: None,
// tools: Vec::new(),
// stop: Vec::new(),
// };
let diff_response = send_language_model_request(model, request, cx).await?;
let diff_output = JudgeResponse::parse(&diff_response)?;
// let diff_response = send_language_model_request(model, request, cx).await?;
// let diff_output = JudgeResponse::parse(&diff_response)?;
println!(
"{}Judge - Diff score: {}%",
self.log_prefix,
diff_output.score()
);
// println!(
// "{}Judge - Diff score: {}%",
// self.log_prefix,
// diff_output.score()
// );
Ok((diff_response, diff_output))
// Ok((diff_response, diff_output))
}
async fn judge_thread(
&self,
model: Arc<dyn LanguageModel>,
run_output: &RunOutput,
run_output: &Sample,
cx: &AsyncApp,
) -> Result<(String, JudgeResponse)> {
let judge_thread_prompt = include_str!("judge_thread_prompt.hbs");
let judge_thread_prompt_name = "judge_thread_prompt";
let mut hbs = Handlebars::new();
hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt)?;
) -> Result<ThreadEvaluation> {
todo!()
// let judge_thread_prompt = include_str!("judge_thread_prompt.hbs");
// let judge_thread_prompt_name = "judge_thread_prompt";
// let mut hbs = Handlebars::new();
// hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt)?;
let request_markdown = RequestMarkdown::new(&run_output.last_request);
let thread_prompt = hbs.render(
judge_thread_prompt_name,
&JudgeThreadInput {
messages: request_markdown.messages,
criteria: self.thread_criteria.clone(),
},
)?;
// let request_markdown = RequestMarkdown::new(&run_output.last_request);
// let thread_prompt = hbs.render(
// judge_thread_prompt_name,
// &JudgeThreadInput {
// messages: request_markdown.messages,
// criteria: self.thread_criteria.clone(),
// },
// )?;
let request = LanguageModelRequest {
thread_id: None,
prompt_id: None,
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::Text(thread_prompt)],
cache: false,
}],
temperature: None,
tools: Vec::new(),
stop: Vec::new(),
};
// let request = LanguageModelRequest {
// thread_id: None,
// prompt_id: None,
// messages: vec![LanguageModelRequestMessage {
// role: Role::User,
// content: vec![MessageContent::Text(thread_prompt)],
// cache: false,
// }],
// temperature: None,
// tools: Vec::new(),
// stop: Vec::new(),
// };
let thread_response = send_language_model_request(model, request, cx).await?;
let thread_output = JudgeResponse::parse(&thread_response)?;
// let thread_response = send_language_model_request(model, request, cx).await?;
// let thread_output = JudgeResponse::parse(&thread_response)?;
println!(
"{}Judge - Thread score: {}%",
self.log_prefix,
thread_output.score()
);
// println!(
// "{}Judge - Thread score: {}%",
// self.log_prefix,
// thread_output.score()
// );
Ok((thread_response, thread_output))
}
pub async fn judge(
&self,
model: Arc<dyn LanguageModel>,
run_output: &RunOutput,
cx: &AsyncApp,
) -> Result<JudgeOutput> {
let mut output_file = File::create(self.run_directory_path().join("judge.md"))
.expect("failed to create judge.md");
println!("{}Running judge", self.log_prefix);
let diff_task = self.judge_diff(model.clone(), &run_output, cx);
let thread_task = self.judge_thread(model.clone(), &run_output, cx);
let (diff_result, thread_result) = futures::join!(diff_task, thread_task);
let (diff_response, diff_output) = diff_result?;
let (thread_response, thread_output) = thread_result?;
writeln!(
&mut output_file,
"# Judgment\n\n## Thread\n\n{thread_response}\n\n## Diff\n\n{diff_response}",
)
.log_err();
Ok(JudgeOutput {
thread: thread_output,
diff: diff_output,
})
// Ok((thread_response, thread_output))
}
async fn repository_diff(&self) -> Result<String> {