Compare commits
3 Commits
ex
...
assertion-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
878312a812 | ||
|
|
6dc61cefb0 | ||
|
|
88008c940b |
@@ -1,16 +1,19 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::BTreeSet;
|
||||
use std::fmt::Write;
|
||||
use std::fmt::{self};
|
||||
|
||||
use crate::example::AssertionGroupId;
|
||||
|
||||
#[derive(Default, Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct AssertionsReport {
|
||||
pub ran: Vec<RanAssertion>,
|
||||
pub max: Option<usize>,
|
||||
pub ran: Vec<Assertion>,
|
||||
pub groups: BTreeSet<AssertionGroupId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct RanAssertion {
|
||||
pub id: String,
|
||||
pub struct Assertion {
|
||||
pub group_id: AssertionGroupId,
|
||||
pub result: Result<RanAssertionResult, String>,
|
||||
}
|
||||
|
||||
@@ -21,19 +24,12 @@ pub struct RanAssertionResult {
|
||||
}
|
||||
|
||||
impl AssertionsReport {
|
||||
pub fn new(max: Option<usize>) -> Self {
|
||||
AssertionsReport {
|
||||
ran: Vec::new(),
|
||||
max,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.ran.is_empty()
|
||||
}
|
||||
|
||||
pub fn total_count(&self) -> usize {
|
||||
self.run_count().max(self.max.unwrap_or(0))
|
||||
self.run_count()
|
||||
}
|
||||
|
||||
pub fn run_count(&self) -> usize {
|
||||
@@ -91,7 +87,7 @@ pub fn display_error_row(f: &mut String, round: usize, error: String) -> fmt::Re
|
||||
)
|
||||
}
|
||||
|
||||
pub fn display_table_row(f: &mut String, round: usize, assertion: &RanAssertion) -> fmt::Result {
|
||||
pub fn display_table_row(f: &mut String, round: usize, assertion: &Assertion) -> fmt::Result {
|
||||
let result = match &assertion.result {
|
||||
Ok(result) if result.passed => "\x1b[32m✔︎ Passed\x1b[0m",
|
||||
Ok(_) => "\x1b[31m✗ Failed\x1b[0m",
|
||||
@@ -102,7 +98,7 @@ pub fn display_table_row(f: &mut String, round: usize, assertion: &RanAssertion)
|
||||
f,
|
||||
"│ {:^ROUND_WIDTH$} │ {:<ASSERTIONS_WIDTH$} │ {:>RESULTS_WIDTH$} │",
|
||||
round,
|
||||
truncate(&assertion.id, ASSERTIONS_WIDTH),
|
||||
truncate(&assertion.group_id.to_string(), ASSERTIONS_WIDTH),
|
||||
result
|
||||
)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,9 @@ mod instance;
|
||||
mod tool_metrics;
|
||||
|
||||
use assertions::display_error_row;
|
||||
use example::ExampleMetadata;
|
||||
use instance::{ExampleInstance, JudgeOutput, RunOutput, run_git};
|
||||
use serde::Serialize;
|
||||
pub(crate) use tool_metrics::*;
|
||||
|
||||
use ::fs::RealFs;
|
||||
@@ -619,7 +621,7 @@ pub fn git_branch_for_path(repo_path: &Path) -> String {
|
||||
}
|
||||
|
||||
async fn judge_example(
|
||||
example: ExampleInstance,
|
||||
instance: ExampleInstance,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
zed_commit_sha: &str,
|
||||
zed_branch_name: &str,
|
||||
@@ -628,35 +630,41 @@ async fn judge_example(
|
||||
enable_telemetry: bool,
|
||||
cx: &AsyncApp,
|
||||
) -> JudgeOutput {
|
||||
let judge_output = example.judge(model.clone(), &run_output, cx).await;
|
||||
let judge_output = instance.judge(model.clone(), &run_output, cx).await;
|
||||
|
||||
let evaluated_example = EvaluatedExample {
|
||||
example: instance.example.meta(),
|
||||
run: RunMetadata {
|
||||
zed_commit_sha: zed_commit_sha.to_string(),
|
||||
zed_branch_name: zed_branch_name.to_string(),
|
||||
run_id: run_id.to_string(),
|
||||
},
|
||||
run_output: run_output.clone(),
|
||||
judge_output: judge_output.clone(),
|
||||
};
|
||||
|
||||
if enable_telemetry {
|
||||
telemetry::event!(
|
||||
"Agent Example Evaluated",
|
||||
zed_commit_sha = zed_commit_sha,
|
||||
zed_branch_name = zed_branch_name,
|
||||
run_id = run_id,
|
||||
example_name = example.name.clone(),
|
||||
example_repetition = example.repetition,
|
||||
diff_evaluation = judge_output.diff.clone(),
|
||||
thread_evaluation = judge_output.thread.clone(),
|
||||
tool_metrics = run_output.tool_metrics,
|
||||
response_count = run_output.response_count,
|
||||
token_usage = run_output.token_usage,
|
||||
model = model.telemetry_id(),
|
||||
model_provider = model.provider_id().to_string(),
|
||||
repository_url = example.repo_url(),
|
||||
repository_revision = example.revision(),
|
||||
diagnostic_summary_before = run_output.diagnostic_summary_before,
|
||||
diagnostic_summary_after = run_output.diagnostic_summary_after,
|
||||
diagnostics_before = run_output.diagnostics_before,
|
||||
diagnostics_after = run_output.diagnostics_after,
|
||||
);
|
||||
telemetry::event!("Agent Example Evaluated", evaluated_example);
|
||||
}
|
||||
|
||||
judge_output
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct EvaluatedExample {
|
||||
example: ExampleMetadata,
|
||||
run: RunMetadata,
|
||||
run_output: RunOutput,
|
||||
judge_output: JudgeOutput,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RunMetadata {
|
||||
zed_commit_sha: String,
|
||||
zed_branch_name: String,
|
||||
run_id: String,
|
||||
}
|
||||
|
||||
const HEADER_WIDTH: usize = 65;
|
||||
|
||||
fn print_h1(header: &str) {
|
||||
|
||||
@@ -8,7 +8,7 @@ use std::{
|
||||
|
||||
use crate::{
|
||||
ToolMetrics,
|
||||
assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
|
||||
assertions::{Assertion, AssertionsReport, RanAssertionResult},
|
||||
};
|
||||
use agent::ThreadEvent;
|
||||
use anyhow::{Result, anyhow};
|
||||
@@ -18,6 +18,7 @@ use collections::HashMap;
|
||||
use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
|
||||
use gpui::{AppContext, AsyncApp, Entity};
|
||||
use language_model::{LanguageModel, Role, StopReason};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
|
||||
|
||||
@@ -35,20 +36,19 @@ pub trait Example {
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct JudgeAssertion {
|
||||
pub id: String,
|
||||
pub group_id: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct ExampleMetadata {
|
||||
pub name: String,
|
||||
pub url: String,
|
||||
pub revision: String,
|
||||
pub language_server: Option<LanguageServer>,
|
||||
pub max_assertions: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct LanguageServer {
|
||||
pub file_extension: String,
|
||||
pub allow_preexisting_diagnostics: bool,
|
||||
@@ -82,7 +82,6 @@ impl fmt::Display for FailedAssertion {
|
||||
impl Error for FailedAssertion {}
|
||||
|
||||
pub struct ExampleContext {
|
||||
meta: ExampleMetadata,
|
||||
log_prefix: String,
|
||||
agent_thread: Entity<agent::Thread>,
|
||||
app: AsyncApp,
|
||||
@@ -93,16 +92,14 @@ pub struct ExampleContext {
|
||||
|
||||
impl ExampleContext {
|
||||
pub fn new(
|
||||
meta: ExampleMetadata,
|
||||
log_prefix: String,
|
||||
agent_thread: Entity<agent::Thread>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
app: AsyncApp,
|
||||
) -> Self {
|
||||
let assertions = AssertionsReport::new(meta.max_assertions);
|
||||
let assertions = AssertionsReport::default();
|
||||
|
||||
Self {
|
||||
meta,
|
||||
log_prefix,
|
||||
agent_thread,
|
||||
assertions,
|
||||
@@ -120,60 +117,20 @@ impl ExampleContext {
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
|
||||
let message = message.to_string();
|
||||
self.log_assertion(
|
||||
if expected {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::Error::from(FailedAssertion(message.clone())))
|
||||
},
|
||||
message,
|
||||
)
|
||||
pub fn assertion(&mut self, key: impl Into<String>) -> AssertionGroupId {
|
||||
let group_id = AssertionGroupId(key.into());
|
||||
self.assertions.groups.insert(group_id.clone());
|
||||
group_id
|
||||
}
|
||||
|
||||
pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
|
||||
let message = message.to_string();
|
||||
self.log_assertion(
|
||||
match option {
|
||||
Some(value) => Ok(value),
|
||||
None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
|
||||
},
|
||||
message,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn assert_eq<T: PartialEq + Debug>(
|
||||
fn log_assertion<T>(
|
||||
&mut self,
|
||||
left: T,
|
||||
right: T,
|
||||
message: impl ToString,
|
||||
) -> Result<()> {
|
||||
let message = message.to_string();
|
||||
self.log_assertion(
|
||||
if left == right {
|
||||
Ok(())
|
||||
} else {
|
||||
println!("{}{:#?} != {:#?}", self.log_prefix, left, right);
|
||||
Err(anyhow::Error::from(FailedAssertion(message.clone())))
|
||||
},
|
||||
message,
|
||||
)
|
||||
}
|
||||
|
||||
fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
|
||||
if let Some(max) = self.meta.max_assertions {
|
||||
if self.assertions.run_count() > max {
|
||||
return Err(anyhow!(
|
||||
"More assertions were run than the stated max_assertions of {}",
|
||||
max
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
self.assertions.ran.push(RanAssertion {
|
||||
id: message.clone(),
|
||||
group_id: AssertionGroupId,
|
||||
result: Result<T>,
|
||||
message: String,
|
||||
) -> Result<T> {
|
||||
self.assertions.ran.push(Assertion {
|
||||
group_id,
|
||||
result: Ok(RanAssertionResult {
|
||||
analysis: None,
|
||||
passed: result.is_ok(),
|
||||
@@ -355,6 +312,73 @@ impl ExampleContext {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, Ord, PartialOrd, PartialEq, Eq)]
|
||||
pub struct AssertionGroupId(pub String);
|
||||
|
||||
impl AssertionGroupId {
|
||||
pub fn assert(
|
||||
&self,
|
||||
expected: bool,
|
||||
message: impl ToString,
|
||||
cx: &mut ExampleContext,
|
||||
) -> Result<()> {
|
||||
let message = message.to_string();
|
||||
cx.log_assertion(
|
||||
self.clone(),
|
||||
if expected {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::Error::from(FailedAssertion(message.clone())))
|
||||
},
|
||||
message,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn assert_some<T>(
|
||||
&self,
|
||||
option: Option<T>,
|
||||
message: impl ToString,
|
||||
cx: &mut ExampleContext,
|
||||
) -> Result<T> {
|
||||
let message = message.to_string();
|
||||
cx.log_assertion(
|
||||
self.clone(),
|
||||
match option {
|
||||
Some(value) => Ok(value),
|
||||
None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
|
||||
},
|
||||
message,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn assert_eq<T: PartialEq + Debug>(
|
||||
&self,
|
||||
left: T,
|
||||
right: T,
|
||||
message: impl ToString,
|
||||
cx: &mut ExampleContext,
|
||||
) -> Result<()> {
|
||||
let message = message.to_string();
|
||||
cx.log_assertion(
|
||||
self.clone(),
|
||||
if left == right {
|
||||
Ok(())
|
||||
} else {
|
||||
println!("{}{:#?} != {:#?}", cx.log_prefix, left, right);
|
||||
Err(anyhow::Error::from(FailedAssertion(message.clone())))
|
||||
},
|
||||
message,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for AssertionGroupId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt::Display::fmt(&self.0, f)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Response {
|
||||
messages: Vec<Message>,
|
||||
@@ -367,6 +391,7 @@ impl Response {
|
||||
|
||||
pub fn expect_tool(
|
||||
&self,
|
||||
group_id: AssertionGroupId,
|
||||
tool_name: &'static str,
|
||||
cx: &mut ExampleContext,
|
||||
) -> Result<&ToolUse> {
|
||||
@@ -375,7 +400,7 @@ impl Response {
|
||||
.iter()
|
||||
.find(|tool_use| tool_use.name == tool_name)
|
||||
});
|
||||
cx.assert_some(result, format!("called `{}`", tool_name))
|
||||
group_id.assert_some(result, format!("called `{}`", tool_name), cx)
|
||||
}
|
||||
|
||||
pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
|
||||
|
||||
@@ -19,11 +19,15 @@ impl Example for AddArgToTraitMethod {
|
||||
file_extension: "rs".to_string(),
|
||||
allow_preexisting_diagnostics: false,
|
||||
}),
|
||||
max_assertions: None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
|
||||
let read_before_edit = cx.assertion("read_before_edit");
|
||||
let added_any_param = cx.assertion("added_any_param");
|
||||
let added_unused_param = cx.assertion("added_unused_param");
|
||||
let added_used_param_to_batch_tool = cx.assertion("added_used_param_to_batch_tool");
|
||||
|
||||
const FILENAME: &str = "assistant_tool.rs";
|
||||
cx.push_user_message(format!(
|
||||
r#"
|
||||
@@ -52,14 +56,16 @@ impl Example for AddArgToTraitMethod {
|
||||
}
|
||||
"edit_file" => {
|
||||
if let Ok(input) = tool_use.parse_input::<EditFileToolInput>() {
|
||||
cx.assert(
|
||||
read_files.contains(input.path.to_str().unwrap()),
|
||||
format!(
|
||||
"Read before edit: {}",
|
||||
&input.path.file_stem().unwrap().to_str().unwrap()
|
||||
),
|
||||
)
|
||||
.ok();
|
||||
read_before_edit
|
||||
.assert(
|
||||
read_files.contains(input.path.to_str().unwrap()),
|
||||
format!(
|
||||
"Read before edit: {}",
|
||||
&input.path.file_stem().unwrap().to_str().unwrap()
|
||||
),
|
||||
cx,
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
@@ -106,10 +112,16 @@ impl Example for AddArgToTraitMethod {
|
||||
edits.has_added_line(" window: Option<gpui::AnyWindowHandle>,\n")
|
||||
});
|
||||
|
||||
cx.assert(ignored || uningored, format!("Argument: {}", tool_name))
|
||||
added_any_param
|
||||
.assert(
|
||||
ignored || uningored,
|
||||
format!("Argument: {}", tool_name),
|
||||
cx,
|
||||
)
|
||||
.ok();
|
||||
|
||||
cx.assert(ignored, format!("`_` prefix: {}", tool_name))
|
||||
added_unused_param
|
||||
.assert(ignored, format!("`_` prefix: {}", tool_name), cx)
|
||||
.ok();
|
||||
}
|
||||
|
||||
@@ -117,13 +129,15 @@ impl Example for AddArgToTraitMethod {
|
||||
|
||||
let batch_tool_edits = edits.get(Path::new("crates/assistant_tools/src/batch_tool.rs"));
|
||||
|
||||
cx.assert(
|
||||
batch_tool_edits.map_or(false, |edits| {
|
||||
edits.has_added_line(" window: Option<gpui::AnyWindowHandle>,\n")
|
||||
}),
|
||||
"Argument: batch_tool",
|
||||
)
|
||||
.ok();
|
||||
added_used_param_to_batch_tool
|
||||
.assert(
|
||||
batch_tool_edits.map_or(false, |edits| {
|
||||
edits.has_added_line(" window: Option<gpui::AnyWindowHandle>,\n")
|
||||
}),
|
||||
"Argument: batch_tool",
|
||||
cx,
|
||||
)
|
||||
.ok();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -131,13 +145,13 @@ impl Example for AddArgToTraitMethod {
|
||||
fn diff_assertions(&self) -> Vec<JudgeAssertion> {
|
||||
vec![
|
||||
JudgeAssertion {
|
||||
id: "batch tool passes window to each".to_string(),
|
||||
group_id: "batch tool passes window to each".to_string(),
|
||||
description:
|
||||
"batch_tool is modified to pass a clone of the window to each tool it calls."
|
||||
.to_string(),
|
||||
},
|
||||
JudgeAssertion {
|
||||
id: "tool tests updated".to_string(),
|
||||
group_id: "tool tests updated".to_string(),
|
||||
description:
|
||||
"tool tests are updated to pass the new `window` argument (`None` is ok)."
|
||||
.to_string(),
|
||||
|
||||
@@ -15,11 +15,14 @@ impl Example for FileSearchExample {
|
||||
url: "https://github.com/zed-industries/zed.git".to_string(),
|
||||
revision: "03ecb88fe30794873f191ddb728f597935b3101c".to_string(),
|
||||
language_server: None,
|
||||
max_assertions: Some(4),
|
||||
}
|
||||
}
|
||||
|
||||
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
|
||||
let ends_with_filename = cx.assertion("ends_with_filename");
|
||||
let correct_glob = cx.assertion("correct_glob");
|
||||
let used_path_search = cx.assertion("used_path_search");
|
||||
|
||||
const FILENAME: &str = "find_replace_file_tool.rs";
|
||||
cx.push_user_message(format!(
|
||||
r#"
|
||||
@@ -32,13 +35,14 @@ impl Example for FileSearchExample {
|
||||
));
|
||||
|
||||
let response = cx.run_turn().await?;
|
||||
let tool_use = response.expect_tool("path_search", cx)?;
|
||||
let tool_use = response.expect_tool(used_path_search, "path_search", cx)?;
|
||||
let input = tool_use.parse_input::<PathSearchToolInput>()?;
|
||||
|
||||
let glob = input.glob;
|
||||
cx.assert(
|
||||
ends_with_filename.assert(
|
||||
glob.ends_with(FILENAME),
|
||||
format!("glob ends with `{FILENAME}`"),
|
||||
cx,
|
||||
)?;
|
||||
|
||||
let without_filename = glob.replace(FILENAME, "");
|
||||
@@ -46,7 +50,7 @@ impl Example for FileSearchExample {
|
||||
.unwrap()
|
||||
.is_match(&without_filename);
|
||||
|
||||
cx.assert(matches, "glob starts with either `**` or `zed`")?;
|
||||
correct_glob.assert(matches, "glob starts with either `**` or `zed`", cx)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -37,14 +37,14 @@ struct DeclarativeExample {
|
||||
impl DeclarativeExample {
|
||||
pub fn load(example_path: &Path) -> Result<Self> {
|
||||
let name = Self::name_from_path(example_path);
|
||||
let base: ExampleToml = toml::from_str(&fs::read_to_string(&example_path)?)?;
|
||||
let toml: ExampleToml = toml::from_str(&fs::read_to_string(&example_path)?)?;
|
||||
|
||||
let language_server = if base.require_lsp {
|
||||
let language_server = if toml.require_lsp {
|
||||
Some(crate::example::LanguageServer {
|
||||
file_extension: base
|
||||
file_extension: toml
|
||||
.language_extension
|
||||
.expect("Language extension is required when require_lsp = true"),
|
||||
allow_preexisting_diagnostics: base.allow_preexisting_diagnostics,
|
||||
allow_preexisting_diagnostics: toml.allow_preexisting_diagnostics,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
@@ -52,24 +52,29 @@ impl DeclarativeExample {
|
||||
|
||||
let metadata = ExampleMetadata {
|
||||
name,
|
||||
url: base.url,
|
||||
revision: base.revision,
|
||||
url: toml.url,
|
||||
revision: toml.revision,
|
||||
language_server,
|
||||
max_assertions: None,
|
||||
};
|
||||
|
||||
Ok(DeclarativeExample {
|
||||
metadata,
|
||||
prompt: base.prompt,
|
||||
thread_assertions: base
|
||||
prompt: toml.prompt,
|
||||
thread_assertions: toml
|
||||
.thread_assertions
|
||||
.into_iter()
|
||||
.map(|(id, description)| JudgeAssertion { id, description })
|
||||
.map(|(id, description)| JudgeAssertion {
|
||||
group_id: id,
|
||||
description,
|
||||
})
|
||||
.collect(),
|
||||
diff_assertions: base
|
||||
diff_assertions: toml
|
||||
.diff_assertions
|
||||
.into_iter()
|
||||
.map(|(id, description)| JudgeAssertion { id, description })
|
||||
.map(|(id, description)| JudgeAssertion {
|
||||
group_id: id,
|
||||
description,
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -29,15 +29,15 @@ use util::ResultExt as _;
|
||||
use util::command::new_smol_command;
|
||||
use util::markdown::MarkdownString;
|
||||
|
||||
use crate::assertions::{AssertionsReport, RanAssertion, RanAssertionResult};
|
||||
use crate::example::{Example, ExampleContext, FailedAssertion, JudgeAssertion};
|
||||
use crate::assertions::{Assertion, AssertionsReport, RanAssertionResult};
|
||||
use crate::example::{AssertionGroupId, Example, ExampleContext, FailedAssertion, JudgeAssertion};
|
||||
use crate::{AgentAppState, ToolMetrics};
|
||||
|
||||
pub const ZED_REPO_URL: &str = "https://github.com/zed-industries/zed.git";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ExampleInstance {
|
||||
pub thread: Rc<dyn Example>,
|
||||
pub example: Rc<dyn Example>,
|
||||
pub name: String,
|
||||
pub run_directory: PathBuf,
|
||||
pub log_prefix: String,
|
||||
@@ -100,7 +100,7 @@ impl ExampleInstance {
|
||||
|
||||
Self {
|
||||
name,
|
||||
thread,
|
||||
example: thread,
|
||||
log_prefix: String::new(),
|
||||
run_directory,
|
||||
repetition,
|
||||
@@ -110,11 +110,7 @@ impl ExampleInstance {
|
||||
}
|
||||
|
||||
pub fn repo_url(&self) -> String {
|
||||
self.thread.meta().url
|
||||
}
|
||||
|
||||
pub fn revision(&self) -> String {
|
||||
self.thread.meta().revision
|
||||
self.example.meta().url
|
||||
}
|
||||
|
||||
pub fn worktree_name(&self) -> String {
|
||||
@@ -132,7 +128,7 @@ impl ExampleInstance {
|
||||
|
||||
/// Set up the example by checking out the specified Git revision
|
||||
pub async fn fetch(&mut self) -> Result<()> {
|
||||
let meta = self.thread.meta();
|
||||
let meta = self.example.meta();
|
||||
|
||||
let revision_exists = run_git(
|
||||
&self.repo_path,
|
||||
@@ -155,7 +151,7 @@ impl ExampleInstance {
|
||||
/// Set up the example by checking out the specified Git revision
|
||||
pub async fn setup(&mut self) -> Result<()> {
|
||||
let worktree_path = self.worktree_path();
|
||||
let meta = self.thread.meta();
|
||||
let meta = self.example.meta();
|
||||
if worktree_path.is_dir() {
|
||||
println!("{}Resetting existing worktree", self.log_prefix);
|
||||
|
||||
@@ -194,7 +190,7 @@ impl ExampleInstance {
|
||||
pub fn worktree_path(&self) -> PathBuf {
|
||||
self.worktrees_dir
|
||||
.join(self.worktree_name())
|
||||
.join(self.thread.meta().repo_name())
|
||||
.join(self.example.meta().repo_name())
|
||||
}
|
||||
|
||||
pub fn run(
|
||||
@@ -220,7 +216,7 @@ impl ExampleInstance {
|
||||
let tools = cx.new(|_| ToolWorkingSet::default());
|
||||
let thread_store =
|
||||
ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
|
||||
let meta = self.thread.meta();
|
||||
let meta = self.example.meta();
|
||||
let this = self.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
@@ -353,8 +349,8 @@ impl ExampleInstance {
|
||||
});
|
||||
})?;
|
||||
|
||||
let mut example_cx = ExampleContext::new(meta.clone(), this.log_prefix.clone(), thread.clone(), model.clone(), cx.clone());
|
||||
let result = this.thread.conversation(&mut example_cx).await;
|
||||
let mut example_cx = ExampleContext::new(this.log_prefix.clone(), thread.clone(), model.clone(), cx.clone());
|
||||
let result = this.example.conversation(&mut example_cx).await;
|
||||
|
||||
if let Err(err) = result {
|
||||
if !err.is::<FailedAssertion>() {
|
||||
@@ -428,7 +424,7 @@ impl ExampleInstance {
|
||||
let worktree_path = self.worktree_path();
|
||||
run_git(&worktree_path, &["add", "."]).await?;
|
||||
let mut diff_args = vec!["diff", "--staged"];
|
||||
if self.thread.meta().url == ZED_REPO_URL {
|
||||
if self.example.meta().url == ZED_REPO_URL {
|
||||
diff_args.push(":(exclude).rules");
|
||||
}
|
||||
run_git(&worktree_path, &diff_args).await
|
||||
@@ -469,7 +465,7 @@ impl ExampleInstance {
|
||||
run_output: &RunOutput,
|
||||
cx: &AsyncApp,
|
||||
) -> (String, AssertionsReport) {
|
||||
let diff_assertions = self.thread.diff_assertions();
|
||||
let diff_assertions = self.example.diff_assertions();
|
||||
|
||||
if diff_assertions.is_empty() {
|
||||
return (
|
||||
@@ -516,7 +512,7 @@ impl ExampleInstance {
|
||||
run_output: &RunOutput,
|
||||
cx: &AsyncApp,
|
||||
) -> (String, AssertionsReport) {
|
||||
let thread_assertions = self.thread.thread_assertions();
|
||||
let thread_assertions = self.example.thread_assertions();
|
||||
|
||||
if thread_assertions.is_empty() {
|
||||
return (
|
||||
@@ -591,15 +587,15 @@ impl ExampleInstance {
|
||||
};
|
||||
|
||||
if result.is_ok() {
|
||||
println!("{}✅ {}", log_prefix, assertion.id);
|
||||
println!("{}✅ {}", log_prefix, assertion.group_id);
|
||||
} else {
|
||||
println!("{}❌ {}", log_prefix, assertion.id);
|
||||
println!("{}❌ {}", log_prefix, assertion.group_id);
|
||||
}
|
||||
|
||||
(
|
||||
response,
|
||||
RanAssertion {
|
||||
id: assertion.id,
|
||||
Assertion {
|
||||
group_id: AssertionGroupId(assertion.group_id),
|
||||
result,
|
||||
},
|
||||
)
|
||||
@@ -610,7 +606,7 @@ impl ExampleInstance {
|
||||
let mut report = AssertionsReport::default();
|
||||
|
||||
for (response, assertion) in future::join_all(assertions).await {
|
||||
writeln!(&mut responses, "# {}", assertion.id).unwrap();
|
||||
writeln!(&mut responses, "# {}", assertion.group_id).unwrap();
|
||||
writeln!(&mut responses, "{}\n\n", response).unwrap();
|
||||
report.ran.push(assertion);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user