Compare commits

...

3 Commits

Author SHA1 Message Date
Nathan Sobo
878312a812 Fix warnings
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-04-23 15:05:54 -06:00
Nathan Sobo
6dc61cefb0 Send serialized EvaluatedExample structs to Snowflake
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-04-23 14:53:31 -06:00
Nathan Sobo
88008c940b Register assertion groups in programmatic tests
This will let us track how many assertions were even run out of the
total we may expect.

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-04-23 14:44:06 -06:00
7 changed files with 205 additions and 157 deletions

View File

@@ -1,16 +1,19 @@
use serde::{Deserialize, Serialize};
use std::collections::BTreeSet;
use std::fmt::Write;
use std::fmt::{self};
use crate::example::AssertionGroupId;
#[derive(Default, Debug, Serialize, Deserialize, Clone)]
pub struct AssertionsReport {
pub ran: Vec<RanAssertion>,
pub max: Option<usize>,
pub ran: Vec<Assertion>,
pub groups: BTreeSet<AssertionGroupId>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RanAssertion {
pub id: String,
pub struct Assertion {
pub group_id: AssertionGroupId,
pub result: Result<RanAssertionResult, String>,
}
@@ -21,19 +24,12 @@ pub struct RanAssertionResult {
}
impl AssertionsReport {
pub fn new(max: Option<usize>) -> Self {
AssertionsReport {
ran: Vec::new(),
max,
}
}
pub fn is_empty(&self) -> bool {
self.ran.is_empty()
}
pub fn total_count(&self) -> usize {
self.run_count().max(self.max.unwrap_or(0))
self.run_count()
}
pub fn run_count(&self) -> usize {
@@ -91,7 +87,7 @@ pub fn display_error_row(f: &mut String, round: usize, error: String) -> fmt::Re
)
}
pub fn display_table_row(f: &mut String, round: usize, assertion: &RanAssertion) -> fmt::Result {
pub fn display_table_row(f: &mut String, round: usize, assertion: &Assertion) -> fmt::Result {
let result = match &assertion.result {
Ok(result) if result.passed => "\x1b[32m✔ Passed\x1b[0m",
Ok(_) => "\x1b[31m✗ Failed\x1b[0m",
@@ -102,7 +98,7 @@ pub fn display_table_row(f: &mut String, round: usize, assertion: &RanAssertion)
f,
"│ {:^ROUND_WIDTH$} │ {:<ASSERTIONS_WIDTH$} │ {:>RESULTS_WIDTH$} │",
round,
truncate(&assertion.id, ASSERTIONS_WIDTH),
truncate(&assertion.group_id.to_string(), ASSERTIONS_WIDTH),
result
)
}

View File

@@ -6,7 +6,9 @@ mod instance;
mod tool_metrics;
use assertions::display_error_row;
use example::ExampleMetadata;
use instance::{ExampleInstance, JudgeOutput, RunOutput, run_git};
use serde::Serialize;
pub(crate) use tool_metrics::*;
use ::fs::RealFs;
@@ -619,7 +621,7 @@ pub fn git_branch_for_path(repo_path: &Path) -> String {
}
async fn judge_example(
example: ExampleInstance,
instance: ExampleInstance,
model: Arc<dyn LanguageModel>,
zed_commit_sha: &str,
zed_branch_name: &str,
@@ -628,35 +630,41 @@ async fn judge_example(
enable_telemetry: bool,
cx: &AsyncApp,
) -> JudgeOutput {
let judge_output = example.judge(model.clone(), &run_output, cx).await;
let judge_output = instance.judge(model.clone(), &run_output, cx).await;
let evaluated_example = EvaluatedExample {
example: instance.example.meta(),
run: RunMetadata {
zed_commit_sha: zed_commit_sha.to_string(),
zed_branch_name: zed_branch_name.to_string(),
run_id: run_id.to_string(),
},
run_output: run_output.clone(),
judge_output: judge_output.clone(),
};
if enable_telemetry {
telemetry::event!(
"Agent Example Evaluated",
zed_commit_sha = zed_commit_sha,
zed_branch_name = zed_branch_name,
run_id = run_id,
example_name = example.name.clone(),
example_repetition = example.repetition,
diff_evaluation = judge_output.diff.clone(),
thread_evaluation = judge_output.thread.clone(),
tool_metrics = run_output.tool_metrics,
response_count = run_output.response_count,
token_usage = run_output.token_usage,
model = model.telemetry_id(),
model_provider = model.provider_id().to_string(),
repository_url = example.repo_url(),
repository_revision = example.revision(),
diagnostic_summary_before = run_output.diagnostic_summary_before,
diagnostic_summary_after = run_output.diagnostic_summary_after,
diagnostics_before = run_output.diagnostics_before,
diagnostics_after = run_output.diagnostics_after,
);
telemetry::event!("Agent Example Evaluated", evaluated_example);
}
judge_output
}
#[derive(Serialize)]
struct EvaluatedExample {
example: ExampleMetadata,
run: RunMetadata,
run_output: RunOutput,
judge_output: JudgeOutput,
}
#[derive(Serialize)]
struct RunMetadata {
zed_commit_sha: String,
zed_branch_name: String,
run_id: String,
}
const HEADER_WIDTH: usize = 65;
fn print_h1(header: &str) {

View File

@@ -8,7 +8,7 @@ use std::{
use crate::{
ToolMetrics,
assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
assertions::{Assertion, AssertionsReport, RanAssertionResult},
};
use agent::ThreadEvent;
use anyhow::{Result, anyhow};
@@ -18,6 +18,7 @@ use collections::HashMap;
use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
use gpui::{AppContext, AsyncApp, Entity};
use language_model::{LanguageModel, Role, StopReason};
use serde::{Deserialize, Serialize};
pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
@@ -35,20 +36,19 @@ pub trait Example {
#[derive(Clone, Debug)]
pub struct JudgeAssertion {
pub id: String,
pub group_id: String,
pub description: String,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Serialize)]
pub struct ExampleMetadata {
pub name: String,
pub url: String,
pub revision: String,
pub language_server: Option<LanguageServer>,
pub max_assertions: Option<usize>,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Serialize)]
pub struct LanguageServer {
pub file_extension: String,
pub allow_preexisting_diagnostics: bool,
@@ -82,7 +82,6 @@ impl fmt::Display for FailedAssertion {
impl Error for FailedAssertion {}
pub struct ExampleContext {
meta: ExampleMetadata,
log_prefix: String,
agent_thread: Entity<agent::Thread>,
app: AsyncApp,
@@ -93,16 +92,14 @@ pub struct ExampleContext {
impl ExampleContext {
pub fn new(
meta: ExampleMetadata,
log_prefix: String,
agent_thread: Entity<agent::Thread>,
model: Arc<dyn LanguageModel>,
app: AsyncApp,
) -> Self {
let assertions = AssertionsReport::new(meta.max_assertions);
let assertions = AssertionsReport::default();
Self {
meta,
log_prefix,
agent_thread,
assertions,
@@ -120,60 +117,20 @@ impl ExampleContext {
.unwrap();
}
pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
let message = message.to_string();
self.log_assertion(
if expected {
Ok(())
} else {
Err(anyhow::Error::from(FailedAssertion(message.clone())))
},
message,
)
pub fn assertion(&mut self, key: impl Into<String>) -> AssertionGroupId {
let group_id = AssertionGroupId(key.into());
self.assertions.groups.insert(group_id.clone());
group_id
}
pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
let message = message.to_string();
self.log_assertion(
match option {
Some(value) => Ok(value),
None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
},
message,
)
}
#[allow(dead_code)]
pub fn assert_eq<T: PartialEq + Debug>(
fn log_assertion<T>(
&mut self,
left: T,
right: T,
message: impl ToString,
) -> Result<()> {
let message = message.to_string();
self.log_assertion(
if left == right {
Ok(())
} else {
println!("{}{:#?} != {:#?}", self.log_prefix, left, right);
Err(anyhow::Error::from(FailedAssertion(message.clone())))
},
message,
)
}
fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
if let Some(max) = self.meta.max_assertions {
if self.assertions.run_count() > max {
return Err(anyhow!(
"More assertions were run than the stated max_assertions of {}",
max
));
}
}
self.assertions.ran.push(RanAssertion {
id: message.clone(),
group_id: AssertionGroupId,
result: Result<T>,
message: String,
) -> Result<T> {
self.assertions.ran.push(Assertion {
group_id,
result: Ok(RanAssertionResult {
analysis: None,
passed: result.is_ok(),
@@ -355,6 +312,73 @@ impl ExampleContext {
}
}
#[derive(Debug, Serialize, Deserialize, Clone, Ord, PartialOrd, PartialEq, Eq)]
pub struct AssertionGroupId(pub String);
impl AssertionGroupId {
pub fn assert(
&self,
expected: bool,
message: impl ToString,
cx: &mut ExampleContext,
) -> Result<()> {
let message = message.to_string();
cx.log_assertion(
self.clone(),
if expected {
Ok(())
} else {
Err(anyhow::Error::from(FailedAssertion(message.clone())))
},
message,
)
}
pub fn assert_some<T>(
&self,
option: Option<T>,
message: impl ToString,
cx: &mut ExampleContext,
) -> Result<T> {
let message = message.to_string();
cx.log_assertion(
self.clone(),
match option {
Some(value) => Ok(value),
None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
},
message,
)
}
#[allow(dead_code)]
pub fn assert_eq<T: PartialEq + Debug>(
&self,
left: T,
right: T,
message: impl ToString,
cx: &mut ExampleContext,
) -> Result<()> {
let message = message.to_string();
cx.log_assertion(
self.clone(),
if left == right {
Ok(())
} else {
println!("{}{:#?} != {:#?}", cx.log_prefix, left, right);
Err(anyhow::Error::from(FailedAssertion(message.clone())))
},
message,
)
}
}
impl fmt::Display for AssertionGroupId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}
#[derive(Debug)]
pub struct Response {
messages: Vec<Message>,
@@ -367,6 +391,7 @@ impl Response {
pub fn expect_tool(
&self,
group_id: AssertionGroupId,
tool_name: &'static str,
cx: &mut ExampleContext,
) -> Result<&ToolUse> {
@@ -375,7 +400,7 @@ impl Response {
.iter()
.find(|tool_use| tool_use.name == tool_name)
});
cx.assert_some(result, format!("called `{}`", tool_name))
group_id.assert_some(result, format!("called `{}`", tool_name), cx)
}
pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {

View File

@@ -19,11 +19,15 @@ impl Example for AddArgToTraitMethod {
file_extension: "rs".to_string(),
allow_preexisting_diagnostics: false,
}),
max_assertions: None,
}
}
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
let read_before_edit = cx.assertion("read_before_edit");
let added_any_param = cx.assertion("added_any_param");
let added_unused_param = cx.assertion("added_unused_param");
let added_used_param_to_batch_tool = cx.assertion("added_used_param_to_batch_tool");
const FILENAME: &str = "assistant_tool.rs";
cx.push_user_message(format!(
r#"
@@ -52,14 +56,16 @@ impl Example for AddArgToTraitMethod {
}
"edit_file" => {
if let Ok(input) = tool_use.parse_input::<EditFileToolInput>() {
cx.assert(
read_files.contains(input.path.to_str().unwrap()),
format!(
"Read before edit: {}",
&input.path.file_stem().unwrap().to_str().unwrap()
),
)
.ok();
read_before_edit
.assert(
read_files.contains(input.path.to_str().unwrap()),
format!(
"Read before edit: {}",
&input.path.file_stem().unwrap().to_str().unwrap()
),
cx,
)
.ok();
}
}
_ => {}
@@ -106,10 +112,16 @@ impl Example for AddArgToTraitMethod {
edits.has_added_line(" window: Option<gpui::AnyWindowHandle>,\n")
});
cx.assert(ignored || uningored, format!("Argument: {}", tool_name))
added_any_param
.assert(
ignored || uningored,
format!("Argument: {}", tool_name),
cx,
)
.ok();
cx.assert(ignored, format!("`_` prefix: {}", tool_name))
added_unused_param
.assert(ignored, format!("`_` prefix: {}", tool_name), cx)
.ok();
}
@@ -117,13 +129,15 @@ impl Example for AddArgToTraitMethod {
let batch_tool_edits = edits.get(Path::new("crates/assistant_tools/src/batch_tool.rs"));
cx.assert(
batch_tool_edits.map_or(false, |edits| {
edits.has_added_line(" window: Option<gpui::AnyWindowHandle>,\n")
}),
"Argument: batch_tool",
)
.ok();
added_used_param_to_batch_tool
.assert(
batch_tool_edits.map_or(false, |edits| {
edits.has_added_line(" window: Option<gpui::AnyWindowHandle>,\n")
}),
"Argument: batch_tool",
cx,
)
.ok();
Ok(())
}
@@ -131,13 +145,13 @@ impl Example for AddArgToTraitMethod {
fn diff_assertions(&self) -> Vec<JudgeAssertion> {
vec![
JudgeAssertion {
id: "batch tool passes window to each".to_string(),
group_id: "batch tool passes window to each".to_string(),
description:
"batch_tool is modified to pass a clone of the window to each tool it calls."
.to_string(),
},
JudgeAssertion {
id: "tool tests updated".to_string(),
group_id: "tool tests updated".to_string(),
description:
"tool tests are updated to pass the new `window` argument (`None` is ok)."
.to_string(),

View File

@@ -15,11 +15,14 @@ impl Example for FileSearchExample {
url: "https://github.com/zed-industries/zed.git".to_string(),
revision: "03ecb88fe30794873f191ddb728f597935b3101c".to_string(),
language_server: None,
max_assertions: Some(4),
}
}
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
let ends_with_filename = cx.assertion("ends_with_filename");
let correct_glob = cx.assertion("correct_glob");
let used_path_search = cx.assertion("used_path_search");
const FILENAME: &str = "find_replace_file_tool.rs";
cx.push_user_message(format!(
r#"
@@ -32,13 +35,14 @@ impl Example for FileSearchExample {
));
let response = cx.run_turn().await?;
let tool_use = response.expect_tool("path_search", cx)?;
let tool_use = response.expect_tool(used_path_search, "path_search", cx)?;
let input = tool_use.parse_input::<PathSearchToolInput>()?;
let glob = input.glob;
cx.assert(
ends_with_filename.assert(
glob.ends_with(FILENAME),
format!("glob ends with `{FILENAME}`"),
cx,
)?;
let without_filename = glob.replace(FILENAME, "");
@@ -46,7 +50,7 @@ impl Example for FileSearchExample {
.unwrap()
.is_match(&without_filename);
cx.assert(matches, "glob starts with either `**` or `zed`")?;
correct_glob.assert(matches, "glob starts with either `**` or `zed`", cx)?;
Ok(())
}

View File

@@ -37,14 +37,14 @@ struct DeclarativeExample {
impl DeclarativeExample {
pub fn load(example_path: &Path) -> Result<Self> {
let name = Self::name_from_path(example_path);
let base: ExampleToml = toml::from_str(&fs::read_to_string(&example_path)?)?;
let toml: ExampleToml = toml::from_str(&fs::read_to_string(&example_path)?)?;
let language_server = if base.require_lsp {
let language_server = if toml.require_lsp {
Some(crate::example::LanguageServer {
file_extension: base
file_extension: toml
.language_extension
.expect("Language extension is required when require_lsp = true"),
allow_preexisting_diagnostics: base.allow_preexisting_diagnostics,
allow_preexisting_diagnostics: toml.allow_preexisting_diagnostics,
})
} else {
None
@@ -52,24 +52,29 @@ impl DeclarativeExample {
let metadata = ExampleMetadata {
name,
url: base.url,
revision: base.revision,
url: toml.url,
revision: toml.revision,
language_server,
max_assertions: None,
};
Ok(DeclarativeExample {
metadata,
prompt: base.prompt,
thread_assertions: base
prompt: toml.prompt,
thread_assertions: toml
.thread_assertions
.into_iter()
.map(|(id, description)| JudgeAssertion { id, description })
.map(|(id, description)| JudgeAssertion {
group_id: id,
description,
})
.collect(),
diff_assertions: base
diff_assertions: toml
.diff_assertions
.into_iter()
.map(|(id, description)| JudgeAssertion { id, description })
.map(|(id, description)| JudgeAssertion {
group_id: id,
description,
})
.collect(),
})
}

View File

@@ -29,15 +29,15 @@ use util::ResultExt as _;
use util::command::new_smol_command;
use util::markdown::MarkdownString;
use crate::assertions::{AssertionsReport, RanAssertion, RanAssertionResult};
use crate::example::{Example, ExampleContext, FailedAssertion, JudgeAssertion};
use crate::assertions::{Assertion, AssertionsReport, RanAssertionResult};
use crate::example::{AssertionGroupId, Example, ExampleContext, FailedAssertion, JudgeAssertion};
use crate::{AgentAppState, ToolMetrics};
pub const ZED_REPO_URL: &str = "https://github.com/zed-industries/zed.git";
#[derive(Clone)]
pub struct ExampleInstance {
pub thread: Rc<dyn Example>,
pub example: Rc<dyn Example>,
pub name: String,
pub run_directory: PathBuf,
pub log_prefix: String,
@@ -100,7 +100,7 @@ impl ExampleInstance {
Self {
name,
thread,
example: thread,
log_prefix: String::new(),
run_directory,
repetition,
@@ -110,11 +110,7 @@ impl ExampleInstance {
}
pub fn repo_url(&self) -> String {
self.thread.meta().url
}
pub fn revision(&self) -> String {
self.thread.meta().revision
self.example.meta().url
}
pub fn worktree_name(&self) -> String {
@@ -132,7 +128,7 @@ impl ExampleInstance {
/// Set up the example by checking out the specified Git revision
pub async fn fetch(&mut self) -> Result<()> {
let meta = self.thread.meta();
let meta = self.example.meta();
let revision_exists = run_git(
&self.repo_path,
@@ -155,7 +151,7 @@ impl ExampleInstance {
/// Set up the example by checking out the specified Git revision
pub async fn setup(&mut self) -> Result<()> {
let worktree_path = self.worktree_path();
let meta = self.thread.meta();
let meta = self.example.meta();
if worktree_path.is_dir() {
println!("{}Resetting existing worktree", self.log_prefix);
@@ -194,7 +190,7 @@ impl ExampleInstance {
pub fn worktree_path(&self) -> PathBuf {
self.worktrees_dir
.join(self.worktree_name())
.join(self.thread.meta().repo_name())
.join(self.example.meta().repo_name())
}
pub fn run(
@@ -220,7 +216,7 @@ impl ExampleInstance {
let tools = cx.new(|_| ToolWorkingSet::default());
let thread_store =
ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
let meta = self.thread.meta();
let meta = self.example.meta();
let this = self.clone();
cx.spawn(async move |cx| {
@@ -353,8 +349,8 @@ impl ExampleInstance {
});
})?;
let mut example_cx = ExampleContext::new(meta.clone(), this.log_prefix.clone(), thread.clone(), model.clone(), cx.clone());
let result = this.thread.conversation(&mut example_cx).await;
let mut example_cx = ExampleContext::new(this.log_prefix.clone(), thread.clone(), model.clone(), cx.clone());
let result = this.example.conversation(&mut example_cx).await;
if let Err(err) = result {
if !err.is::<FailedAssertion>() {
@@ -428,7 +424,7 @@ impl ExampleInstance {
let worktree_path = self.worktree_path();
run_git(&worktree_path, &["add", "."]).await?;
let mut diff_args = vec!["diff", "--staged"];
if self.thread.meta().url == ZED_REPO_URL {
if self.example.meta().url == ZED_REPO_URL {
diff_args.push(":(exclude).rules");
}
run_git(&worktree_path, &diff_args).await
@@ -469,7 +465,7 @@ impl ExampleInstance {
run_output: &RunOutput,
cx: &AsyncApp,
) -> (String, AssertionsReport) {
let diff_assertions = self.thread.diff_assertions();
let diff_assertions = self.example.diff_assertions();
if diff_assertions.is_empty() {
return (
@@ -516,7 +512,7 @@ impl ExampleInstance {
run_output: &RunOutput,
cx: &AsyncApp,
) -> (String, AssertionsReport) {
let thread_assertions = self.thread.thread_assertions();
let thread_assertions = self.example.thread_assertions();
if thread_assertions.is_empty() {
return (
@@ -591,15 +587,15 @@ impl ExampleInstance {
};
if result.is_ok() {
println!("{}{}", log_prefix, assertion.id);
println!("{}{}", log_prefix, assertion.group_id);
} else {
println!("{}{}", log_prefix, assertion.id);
println!("{}{}", log_prefix, assertion.group_id);
}
(
response,
RanAssertion {
id: assertion.id,
Assertion {
group_id: AssertionGroupId(assertion.group_id),
result,
},
)
@@ -610,7 +606,7 @@ impl ExampleInstance {
let mut report = AssertionsReport::default();
for (response, assertion) in future::join_all(assertions).await {
writeln!(&mut responses, "# {}", assertion.id).unwrap();
writeln!(&mut responses, "# {}", assertion.group_id).unwrap();
writeln!(&mut responses, "{}\n\n", response).unwrap();
report.ran.push(assertion);
}