Files
zed/crates/eval/src/example.rs
Max Brunsfeld 03f9cf4414 Represent relative paths using a dedicated, separator-agnostic type (#38744)
Closes https://github.com/zed-industries/zed/issues/38690
Closes #37353

### Background

On Windows, paths are normally separated by `\`, unlike mac and linux
where they are separated by `/`. When editing code in a project that
uses a different path style than your local system (e.g. remoting from
Windows to Linux, using WSL, and collaboration between windows and unix
users), the correct separator for a path may differ from the "native"
separator.

Previously, to work around this, Zed converted paths' separators in
numerous places. This was applied to both absolute and relative paths,
leading to incorrect conversions in some cases.

### Solution

Many code paths in Zed use paths that are *relative* to either a
worktree root or a git repository. This PR introduces a dedicated type
for these paths called `RelPath`, which stores the path in the same way
regardless of host platform, and offers `Path`-like manipulation APIs.
RelPath supports *displaying* the path using either separator, so that
we can display paths in a style that is determined at runtime based on
the current project.

The representation of absolute paths is left untouched, for now.
Absolute paths are different from relative paths because (except in
contexts where we know that the path refers to the local filesystem)
they should generally be treated as opaque strings. Currently we use a
mix of types for these paths (std::path::Path, String, SanitizedPath).

Release Notes:

- N/A

---------

Co-authored-by: Cole Miller <cole@zed.dev>
Co-authored-by: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com>
Co-authored-by: Peter Tripp <petertripp@gmail.com>
Co-authored-by: Smit Barmase <heysmitbarmase@gmail.com>
Co-authored-by: Lukas Wirth <me@lukaswirth.dev>
2025-09-24 18:57:33 -04:00

559 lines
17 KiB
Rust

use std::{
error::Error,
fmt::{self, Debug},
sync::{Arc, Mutex},
time::Duration,
};
use crate::{
ToolMetrics,
assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
};
use agent::{ContextLoadResult, Thread, ThreadEvent};
use agent_settings::AgentProfileId;
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use buffer_diff::DiffHunkStatus;
use cloud_llm_client::CompletionIntent;
use collections::HashMap;
use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
use gpui::{App, AppContext, AsyncApp, Entity};
use language_model::{LanguageModel, Role, StopReason};
use util::rel_path::RelPath;
pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
#[async_trait(?Send)]
pub trait Example {
fn meta(&self) -> ExampleMetadata;
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
fn diff_assertions(&self) -> Vec<JudgeAssertion> {
Vec::new()
}
fn thread_assertions(&self) -> Vec<JudgeAssertion> {
Vec::new()
}
}
#[derive(Clone, Debug)]
pub struct JudgeAssertion {
pub id: String,
pub description: String,
}
#[derive(Clone, Debug)]
pub struct ExampleMetadata {
pub name: String,
pub url: String,
pub revision: String,
pub language_server: Option<LanguageServer>,
pub max_assertions: Option<usize>,
pub profile_id: AgentProfileId,
pub existing_thread_json: Option<String>,
pub max_turns: Option<u32>,
}
#[derive(Clone, Debug)]
pub struct LanguageServer {
pub file_extension: String,
pub allow_preexisting_diagnostics: bool,
}
impl ExampleMetadata {
pub fn repo_name(&self) -> String {
self.url
.split('/')
.next_back()
.unwrap_or("")
.trim_end_matches(".git")
.into()
}
}
pub struct FailedAssertion(pub String);
impl fmt::Debug for FailedAssertion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Assertion failure: {}", self.0)
}
}
impl fmt::Display for FailedAssertion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl Error for FailedAssertion {}
pub struct ExampleContext {
meta: ExampleMetadata,
log_prefix: String,
agent_thread: Entity<agent::Thread>,
app: AsyncApp,
model: Arc<dyn LanguageModel>,
pub assertions: AssertionsReport,
pub tool_metrics: Arc<Mutex<ToolMetrics>>,
}
impl ExampleContext {
pub fn new(
meta: ExampleMetadata,
log_prefix: String,
agent_thread: Entity<Thread>,
model: Arc<dyn LanguageModel>,
app: AsyncApp,
) -> Self {
let assertions = AssertionsReport::new(meta.max_assertions);
Self {
meta,
log_prefix,
agent_thread,
assertions,
model,
app,
tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
}
}
pub fn push_user_message(&mut self, text: impl ToString) {
self.app
.update_entity(&self.agent_thread, |thread, cx| {
thread.insert_user_message(
text.to_string(),
ContextLoadResult::default(),
None,
Vec::new(),
cx,
);
})
.unwrap();
}
pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
let message = message.to_string();
self.log_assertion(
if expected {
Ok(())
} else {
Err(anyhow::Error::from(FailedAssertion(message.clone())))
},
message,
)
}
pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
let message = message.to_string();
self.log_assertion(
match option {
Some(value) => Ok(value),
None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
},
message,
)
}
#[allow(dead_code)]
pub fn assert_eq<T: PartialEq + Debug>(
&mut self,
left: T,
right: T,
message: impl ToString,
) -> Result<()> {
let message = message.to_string();
self.log_assertion(
if left == right {
Ok(())
} else {
println!(
"{}{}",
self.log_prefix,
pretty_assertions::Comparison::new(&left, &right)
);
Err(anyhow::Error::from(FailedAssertion(message.clone())))
},
message,
)
}
fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
if let Some(max) = self.meta.max_assertions {
anyhow::ensure!(
self.assertions.run_count() <= max,
"More assertions were run than the stated max_assertions of {max}"
);
}
self.assertions.ran.push(RanAssertion {
id: message.clone(),
result: Ok(RanAssertionResult {
analysis: None,
passed: result.is_ok(),
}),
});
if result.is_ok() {
println!("{}{}", self.log_prefix, message);
} else {
println!("{}{}", self.log_prefix, message);
}
result
}
pub async fn run_to_end(&mut self) -> Result<Response> {
self.run_turns(u32::MAX).await
}
pub async fn run_turn(&mut self) -> Result<Response> {
self.run_turns(1).await
}
pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
let (mut tx, mut rx) = mpsc::channel(1);
let tool_metrics = self.tool_metrics.clone();
let log_prefix = self.log_prefix.clone();
let _subscription = self.app.subscribe(
&self.agent_thread,
move |thread, event: &ThreadEvent, cx| match event {
ThreadEvent::ShowError(thread_error) => {
tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
}
ThreadEvent::Stopped(reason) => match reason {
Ok(StopReason::EndTurn) => {
tx.close_channel();
}
Ok(StopReason::ToolUse) => {
if thread.read(cx).remaining_turns() == 0 {
tx.close_channel();
}
}
Ok(StopReason::MaxTokens) => {
tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
}
Ok(StopReason::Refusal) => {
tx.try_send(Err(anyhow!("Model refused to generate content")))
.ok();
}
Err(err) => {
tx.try_send(Err(anyhow!(err.clone()))).ok();
}
},
ThreadEvent::NewRequest
| ThreadEvent::StreamedAssistantText(_, _)
| ThreadEvent::StreamedAssistantThinking(_, _)
| ThreadEvent::UsePendingTools { .. }
| ThreadEvent::CompletionCanceled => {}
ThreadEvent::ToolUseLimitReached => {}
ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
..
} => {
thread.update(cx, |thread, _cx| {
if let Some(tool_use) = pending_tool_use {
let mut tool_metrics = tool_metrics.lock().unwrap();
if let Some(tool_result) = thread.tool_result(tool_use_id) {
let message = if tool_result.is_error {
format!("✖︎ {}", tool_use.name)
} else {
format!("✔︎ {}", tool_use.name)
};
println!("{log_prefix}{message}");
tool_metrics
.insert(tool_result.tool_name.clone(), !tool_result.is_error);
} else {
let message =
format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
println!("{log_prefix}{message}");
tool_metrics.insert(tool_use.name.clone(), true);
}
}
});
}
ThreadEvent::InvalidToolInput { .. } => {
println!("{log_prefix} invalid tool input");
}
ThreadEvent::MissingToolUse {
tool_use_id: _,
ui_text,
} => {
println!("{log_prefix} {ui_text}");
}
ThreadEvent::ToolConfirmationNeeded => {
panic!(
"{}Bug: Tool confirmation should not be required in eval",
log_prefix
);
}
ThreadEvent::StreamedCompletion
| ThreadEvent::MessageAdded(_)
| ThreadEvent::MessageEdited(_)
| ThreadEvent::MessageDeleted(_)
| ThreadEvent::SummaryChanged
| ThreadEvent::SummaryGenerated
| ThreadEvent::ProfileChanged
| ThreadEvent::ReceivedTextChunk
| ThreadEvent::StreamedToolUse { .. }
| ThreadEvent::CheckpointChanged
| ThreadEvent::CancelEditing => {
tx.try_send(Ok(())).ok();
if std::env::var("ZED_EVAL_DEBUG").is_ok() {
println!("{}Event: {:#?}", log_prefix, event);
}
}
},
);
let model = self.model.clone();
let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
thread.set_remaining_turns(iterations);
thread.send_to_model(model, CompletionIntent::UserPrompt, None, cx);
thread.messages().len()
})?;
loop {
select_biased! {
result = rx.next() => {
if let Some(result) = result {
result?;
} else {
break;
}
}
_ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
anyhow::bail!("Agentic loop stalled - waited {THREAD_EVENT_TIMEOUT:?} without any events");
}
}
}
let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
let mut messages = Vec::new();
for message in thread.messages().skip(message_count_before) {
messages.push(Message {
_role: message.role,
text: message.to_message_content(),
tool_use: thread
.tool_uses_for_message(message.id, cx)
.into_iter()
.map(|tool_use| ToolUse {
name: tool_use.name.to_string(),
value: tool_use.input,
})
.collect(),
});
}
messages
})?;
let response = Response::new(messages);
Ok(response)
}
pub fn edits(&self) -> HashMap<Arc<RelPath>, FileEdits> {
self.agent_thread
.read_with(&self.app, |thread, cx| {
let action_log = thread.action_log().read(cx);
HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
|(buffer, diff)| {
let snapshot = buffer.read(cx).snapshot();
let file = snapshot.file().unwrap();
let diff = diff.read(cx);
let base_text = diff.base_text().text();
let hunks = diff
.hunks(&snapshot, cx)
.map(|hunk| FileEditHunk {
base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
text: snapshot
.text_for_range(hunk.range.clone())
.collect::<String>(),
status: hunk.status(),
})
.collect();
(file.path().clone(), FileEdits { hunks })
},
))
})
.unwrap()
}
pub fn agent_thread(&self) -> Entity<Thread> {
self.agent_thread.clone()
}
}
impl AppContext for ExampleContext {
type Result<T> = anyhow::Result<T>;
fn new<T: 'static>(
&mut self,
build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
) -> Self::Result<Entity<T>> {
self.app.new(build_entity)
}
fn reserve_entity<T: 'static>(&mut self) -> Self::Result<gpui::Reservation<T>> {
self.app.reserve_entity()
}
fn insert_entity<T: 'static>(
&mut self,
reservation: gpui::Reservation<T>,
build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
) -> Self::Result<Entity<T>> {
self.app.insert_entity(reservation, build_entity)
}
fn update_entity<T, R>(
&mut self,
handle: &Entity<T>,
update: impl FnOnce(&mut T, &mut gpui::Context<T>) -> R,
) -> Self::Result<R>
where
T: 'static,
{
self.app.update_entity(handle, update)
}
fn as_mut<'a, T>(&'a mut self, handle: &Entity<T>) -> Self::Result<gpui::GpuiBorrow<'a, T>>
where
T: 'static,
{
self.app.as_mut(handle)
}
fn read_entity<T, R>(
&self,
handle: &Entity<T>,
read: impl FnOnce(&T, &App) -> R,
) -> Self::Result<R>
where
T: 'static,
{
self.app.read_entity(handle, read)
}
fn update_window<T, F>(&mut self, window: gpui::AnyWindowHandle, f: F) -> Result<T>
where
F: FnOnce(gpui::AnyView, &mut gpui::Window, &mut App) -> T,
{
self.app.update_window(window, f)
}
fn read_window<T, R>(
&self,
window: &gpui::WindowHandle<T>,
read: impl FnOnce(Entity<T>, &App) -> R,
) -> Result<R>
where
T: 'static,
{
self.app.read_window(window, read)
}
fn background_spawn<R>(
&self,
future: impl std::future::Future<Output = R> + Send + 'static,
) -> gpui::Task<R>
where
R: Send + 'static,
{
self.app.background_spawn(future)
}
fn read_global<G, R>(&self, callback: impl FnOnce(&G, &App) -> R) -> Self::Result<R>
where
G: gpui::Global,
{
self.app.read_global(callback)
}
}
#[derive(Debug)]
pub struct Response {
messages: Vec<Message>,
}
impl Response {
pub fn new(messages: Vec<Message>) -> Self {
Self { messages }
}
pub fn expect_tool(
&self,
tool_name: &'static str,
cx: &mut ExampleContext,
) -> Result<&ToolUse> {
let result = self.find_tool_call(tool_name);
cx.assert_some(result, format!("called `{}`", tool_name))
}
pub fn find_tool_call(&self, tool_name: &str) -> Option<&ToolUse> {
self.messages.iter().rev().find_map(|msg| {
msg.tool_use
.iter()
.find(|tool_use| tool_use.name == tool_name)
})
}
#[allow(dead_code)]
pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
self.messages.iter().flat_map(|msg| &msg.tool_use)
}
pub fn texts(&self) -> impl Iterator<Item = String> {
self.messages.iter().map(|message| message.text.clone())
}
}
#[derive(Debug)]
pub struct Message {
_role: Role,
text: String,
tool_use: Vec<ToolUse>,
}
#[derive(Debug)]
pub struct ToolUse {
pub name: String,
value: serde_json::Value,
}
impl ToolUse {
pub fn parse_input<Input>(&self) -> Result<Input>
where
Input: for<'de> serde::Deserialize<'de>,
{
serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
}
}
#[derive(Debug, Eq, PartialEq)]
pub struct FileEdits {
pub hunks: Vec<FileEditHunk>,
}
#[derive(Debug, Eq, PartialEq)]
pub struct FileEditHunk {
pub base_text: String,
pub text: String,
pub status: DiffHunkStatus,
}
impl FileEdits {
pub fn has_added_line(&self, line: &str) -> bool {
self.hunks.iter().any(|hunk| {
hunk.status == DiffHunkStatus::added_none()
&& hunk.base_text.is_empty()
&& hunk.text.contains(line)
})
}
}