Compare commits

...

13 Commits

Author SHA1 Message Date
Michael Benfield
e8d44ce267 comments 2025-12-03 10:09:22 -08:00
Michael Benfield
08d89254d1 select better and use tools for inline assistant 2025-12-03 09:51:16 -08:00
Mikayla Maki
d7cc949e61 fix license 2025-12-02 17:13:52 -08:00
Mikayla Maki
a3f9dffb04 Clippy 2025-12-02 17:13:52 -08:00
Mikayla Maki
78b32840c3 Feature gate eval 2025-12-02 17:13:52 -08:00
Mikayla Maki
a636b59d72 Clean up PR for merging 2025-12-02 17:13:52 -08:00
Mikayla Maki
983c3a02a8 Clean up dbg 2025-12-02 17:13:52 -08:00
Mikayla Maki
5bc7b11a6a Fix a miscompilation 2025-12-02 17:13:52 -08:00
Michael Benfield
2a5e8c62a7 Add eval utils to workspace and inline assistant test
Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
2025-12-02 17:13:52 -08:00
Michael Benfield
4efeef564c evals uses eval_utils 2025-12-02 17:13:52 -08:00
Michael Benfield
514cb933d1 eval_utils crate 2025-12-02 17:13:52 -08:00
Michael Benfield
06f65b29ab working on thread_store thing 2025-12-02 17:13:52 -08:00
Michael Benfield
cd9242c544 Progress towards an eval for inline assistants.
Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
2025-12-02 17:13:52 -08:00
23 changed files with 1178 additions and 148 deletions

14
Cargo.lock generated
View File

@@ -159,6 +159,7 @@ dependencies = [
"derive_more 0.99.20",
"editor",
"env_logger 0.11.8",
"eval_utils",
"fs",
"futures 0.3.31",
"git",
@@ -328,6 +329,7 @@ dependencies = [
"buffer_diff",
"chrono",
"client",
"clock",
"cloud_llm_client",
"collections",
"command_palette_hooks",
@@ -335,6 +337,7 @@ dependencies = [
"context_server",
"db",
"editor",
"eval_utils",
"extension",
"extension_host",
"feature_flags",
@@ -343,6 +346,7 @@ dependencies = [
"futures 0.3.31",
"fuzzy",
"gpui",
"gpui_tokio",
"html_to_markdown",
"http_client",
"image",
@@ -370,6 +374,7 @@ dependencies = [
"proto",
"rand 0.9.2",
"release_channel",
"reqwest_client",
"rope",
"rules_library",
"schemars",
@@ -5776,6 +5781,15 @@ dependencies = [
"watch",
]
[[package]]
name = "eval_utils"
version = "0.1.0"
dependencies = [
"gpui",
"serde",
"smol",
]
[[package]]
name = "event-listener"
version = "2.5.3"

View File

@@ -59,6 +59,7 @@ members = [
"crates/zeta2_tools",
"crates/editor",
"crates/eval",
"crates/eval_utils",
"crates/explorer_command_injector",
"crates/extension",
"crates/extension_api",
@@ -288,6 +289,7 @@ deepseek = { path = "crates/deepseek" }
derive_refineable = { path = "crates/refineable/derive_refineable" }
diagnostics = { path = "crates/diagnostics" }
editor = { path = "crates/editor" }
eval_utils = { path = "crates/eval_utils" }
extension = { path = "crates/extension" }
extension_host = { path = "crates/extension_host" }
extensions_ui = { path = "crates/extensions_ui" }

View File

@@ -0,0 +1,42 @@
{{#if language_name}}
Here's a file of {{language_name}} that the user is going to ask you to make an edit to.
{{else}}
Here's a file of text that the user is going to ask you to make an edit to.
{{/if}}
The section you'll need to rewrite is marked with <rewrite_this></rewrite_this> tags.
<document>
{{{document_content}}}
</document>
{{#if is_truncated}}
The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.
{{/if}}
{{#if rewrite_section}}
And here's the section to rewrite based on that prompt again for reference:
<rewrite_this>
{{{rewrite_section}}}
</rewrite_this>
{{#if diagnostic_errors}}
Below are the diagnostic errors visible to the user. If the user requests problems to be fixed, use this information, but do not try to fix these errors if the user hasn't asked you to.
{{#each diagnostic_errors}}
<diagnostic_error>
<line_number>{{line_number}}</line_number>
<error_message>{{error_message}}</error_message>
<code_content>{{code_content}}</code_content>
</diagnostic_error>
{{/each}}
{{/if}}
{{/if}}
Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved.
Start at the indentation level in the original file in the rewritten {{content_type}}.
You must use one of the provided tools to make the rewrite or to provide an explanation as to why the user's request cannot be fulfilled.

View File

@@ -83,6 +83,7 @@ ctor.workspace = true
db = { workspace = true, "features" = ["test-support"] }
editor = { workspace = true, "features" = ["test-support"] }
env_logger.workspace = true
eval_utils.workspace = true
fs = { workspace = true, "features" = ["test-support"] }
git = { workspace = true, "features" = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] }

View File

@@ -4,7 +4,6 @@ use crate::{
};
use Role::*;
use client::{Client, UserStore};
use collections::HashMap;
use fs::FakeFs;
use futures::{FutureExt, future::LocalBoxFuture};
use gpui::{AppContext, TestAppContext, Timer};
@@ -20,9 +19,7 @@ use rand::prelude::*;
use reqwest_client::ReqwestClient;
use serde_json::json;
use std::{
cmp::Reverse,
fmt::{self, Display},
io::Write as _,
path::Path,
str::FromStr,
sync::mpsc,
@@ -1316,101 +1313,41 @@ fn eval(
iterations: usize,
expected_pass_ratio: f32,
mismatched_tag_threshold: f32,
mut eval: EvalInput,
eval: EvalInput,
) {
let mut evaluated_count = 0;
let mut failed_count = 0;
report_progress(evaluated_count, failed_count, iterations);
let (tx, rx) = mpsc::channel();
// Cache the last message in the conversation, and run one instance of the eval so that
// all the next ones are cached.
eval.conversation.last_mut().unwrap().cache = true;
run_eval(eval.clone(), tx.clone());
let executor = gpui::background_executor();
let semaphore = Arc::new(smol::lock::Semaphore::new(32));
for _ in 1..iterations {
let eval = eval.clone();
let tx = tx.clone();
let semaphore = semaphore.clone();
executor
.spawn(async move {
let _guard = semaphore.acquire().await;
run_eval(eval, tx)
})
.detach();
}
drop(tx);
let mut failed_evals = HashMap::default();
let mut errored_evals = HashMap::default();
let mut eval_outputs = Vec::new();
let mut cumulative_parser_metrics = EditParserMetrics::default();
while let Ok(output) = rx.recv() {
match output {
Ok(output) => {
cumulative_parser_metrics += output.sample.edit_output.parser_metrics.clone();
eval_outputs.push(output.clone());
if output.assertion.score < 80 {
failed_count += 1;
failed_evals
.entry(output.sample.text_after.clone())
.or_insert(Vec::new())
.push(output);
}
}
Err(error) => {
failed_count += 1;
*errored_evals.entry(format!("{:?}", error)).or_insert(0) += 1;
}
}
evaluated_count += 1;
report_progress(evaluated_count, failed_count, iterations);
}
let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
println!("Actual pass ratio: {}\n", actual_pass_ratio);
if actual_pass_ratio < expected_pass_ratio {
let mut errored_evals = errored_evals.into_iter().collect::<Vec<_>>();
errored_evals.sort_by_key(|(_, count)| Reverse(*count));
for (error, count) in errored_evals {
println!("Eval errored {} times. Error: {}", count, error);
}
let mut failed_evals = failed_evals.into_iter().collect::<Vec<_>>();
failed_evals.sort_by_key(|(_, evals)| Reverse(evals.len()));
for (_buffer_output, failed_evals) in failed_evals {
let eval_output = failed_evals.first().unwrap();
println!("Eval failed {} times", failed_evals.len());
println!("{}", eval_output);
}
panic!(
"Actual pass ratio: {}\nExpected pass ratio: {}",
actual_pass_ratio, expected_pass_ratio
);
}
let mismatched_tag_ratio =
cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32;
if mismatched_tag_ratio > mismatched_tag_threshold {
for eval_output in eval_outputs {
println!("{}", eval_output);
}
panic!("Too many mismatched tags: {:?}", cumulative_parser_metrics);
}
eval_utils::eval(
iterations,
expected_pass_ratio,
mismatched_tag_threshold,
Arc::new(move |tx| run_eval(eval.clone(), tx)),
);
}
fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
fn run_eval(eval: EvalInput, tx: mpsc::Sender<eval_utils::EvalOutput>) {
let dispatcher = gpui::TestDispatcher::new(StdRng::from_os_rng());
let mut cx = TestAppContext::build(dispatcher, None);
let output = cx.executor().block_test(async {
let result = cx.executor().block_test(async {
let test = EditAgentTest::new(&mut cx).await;
test.eval(eval, &mut cx).await
});
let output = match result {
Ok(output) => eval_utils::EvalOutput {
data: output.to_string(),
mismatched_tags: output.sample.edit_output.parser_metrics.mismatched_tags,
tags: output.sample.edit_output.parser_metrics.tags,
outcome_kind: if output.assertion.score < 80 {
eval_utils::OutcomeKind::Failed
} else {
eval_utils::OutcomeKind::Passed
},
},
Err(e) => eval_utils::EvalOutput {
data: format!("{e:?}"),
mismatched_tags: 0,
tags: 0,
outcome_kind: eval_utils::OutcomeKind::Error,
},
};
tx.send(output).unwrap();
}
@@ -1439,22 +1376,6 @@ impl Display for EvalOutput {
}
}
fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usize) {
let passed_count = evaluated_count - failed_count;
let passed_ratio = if evaluated_count == 0 {
0.0
} else {
passed_count as f64 / evaluated_count as f64
};
print!(
"\r\x1b[KEvaluated {}/{} ({:.2}% passed)",
evaluated_count,
iterations,
passed_ratio * 100.0
);
std::io::stdout().flush().unwrap();
}
struct EditAgentTest {
agent: EditAgent,
project: Entity<Project>,
@@ -1550,7 +1471,10 @@ impl EditAgentTest {
})
}
async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result<EvalOutput> {
async fn eval(&self, mut eval: EvalInput, cx: &mut TestAppContext) -> Result<EvalOutput> {
// Make sure the last message in the conversation is cached.
eval.conversation.last_mut().unwrap().cache = true;
let path = self
.project
.read_with(cx, |project, cx| {

View File

@@ -4,6 +4,7 @@ mod create_directory_tool;
mod delete_path_tool;
mod diagnostics_tool;
mod edit_file_tool;
mod failure_message_tool;
mod fetch_tool;
mod find_path_tool;
mod grep_tool;
@@ -12,6 +13,7 @@ mod move_path_tool;
mod now_tool;
mod open_tool;
mod read_file_tool;
mod rewrite_section_tool;
mod terminal_tool;
mod thinking_tool;
mod web_search_tool;
@@ -25,6 +27,7 @@ pub use create_directory_tool::*;
pub use delete_path_tool::*;
pub use diagnostics_tool::*;
pub use edit_file_tool::*;
pub use failure_message_tool::*;
pub use fetch_tool::*;
pub use find_path_tool::*;
pub use grep_tool::*;
@@ -33,6 +36,7 @@ pub use move_path_tool::*;
pub use now_tool::*;
pub use open_tool::*;
pub use read_file_tool::*;
pub use rewrite_section_tool::*;
pub use terminal_tool::*;
pub use thinking_tool::*;
pub use web_search_tool::*;

View File

@@ -0,0 +1,50 @@
//! This tool is intended for use with the inline assistant, not the agent panel.
use std::sync::Arc;
use agent_client_protocol as acp;
use anyhow::Result;
use gpui::{App, SharedString, Task};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::{AgentTool, ToolCallEventStream};
/// Use this tool to provide a message to the user when you're unable to complete a task.
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct FailureMessageInput {
/// A brief message to the user explaining why you're unable to fulfill the request.
pub message: String,
}
pub struct FailureMessageTool;
impl AgentTool for FailureMessageTool {
type Input = FailureMessageInput;
type Output = String;
fn name() -> &'static str {
"failure_message"
}
fn kind() -> acp::ToolKind {
acp::ToolKind::Think
}
fn initial_title(
&self,
_input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
"".into()
}
fn run(
self: Arc<Self>,
_input: Self::Input,
_event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Task<Result<String>> {
unimplemented!()
}
}

View File

@@ -0,0 +1,55 @@
//! This tool is intended for use with the inline assistant, not the agent panel.
use std::sync::Arc;
use agent_client_protocol as acp;
use anyhow::Result;
use gpui::{App, SharedString, Task};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::{AgentTool, ToolCallEventStream};
/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct RewriteSectionInput {
/// A brief description of the edit you have made.
///
/// This is optional - if the edit is simple or obvious, you should leave it empty.
pub description: String,
/// The text to replace the section with.
pub replacement_text: String,
}
pub struct RewriteSectionTool;
impl AgentTool for RewriteSectionTool {
type Input = RewriteSectionInput;
type Output = String;
fn name() -> &'static str {
"rewrite_section"
}
fn kind() -> acp::ToolKind {
acp::ToolKind::Edit
}
fn initial_title(
&self,
_input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
"".into()
}
fn run(
self: Arc<Self>,
_input: Self::Input,
_event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Task<Result<String>> {
unimplemented!()
}
}

View File

@@ -13,7 +13,8 @@ path = "src/agent_ui.rs"
doctest = false
[features]
test-support = ["gpui/test-support", "language/test-support"]
test-support = ["gpui/test-support", "language/test-support", "reqwest_client"]
unit-eval = []
[dependencies]
acp_thread.workspace = true
@@ -47,6 +48,7 @@ fs.workspace = true
futures.workspace = true
fuzzy.workspace = true
gpui.workspace = true
gpui_tokio.workspace = true
html_to_markdown.workspace = true
http_client.workspace = true
indoc.workspace = true
@@ -98,14 +100,17 @@ workspace.workspace = true
zed_actions.workspace = true
image.workspace = true
async-fs.workspace = true
reqwest_client = { workspace = true, optional = true }
[dev-dependencies]
acp_thread = { workspace = true, features = ["test-support"] }
agent = { workspace = true, features = ["test-support"] }
assistant_text_thread = { workspace = true, features = ["test-support"] }
buffer_diff = { workspace = true, features = ["test-support"] }
clock.workspace = true
db = { workspace = true, features = ["test-support"] }
editor = { workspace = true, features = ["test-support"] }
eval_utils.workspace = true
gpui = { workspace = true, "features" = ["test-support"] }
indoc.workspace = true
language = { workspace = true, "features" = ["test-support"] }
@@ -115,5 +120,6 @@ pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
semver.workspace = true
rand.workspace = true
reqwest_client.workspace = true
tree-sitter-md.workspace = true
unindent.workspace = true

View File

@@ -2685,16 +2685,17 @@ impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist {
return;
};
let project = workspace.read(cx).project().downgrade();
let thread_store = panel.read(cx).thread_store().clone();
assistant.assist(
prompt_editor,
self.workspace.clone(),
project,
panel.read(cx).thread_store().clone(),
thread_store,
None,
initial_prompt,
window,
cx,
)
);
})
}

View File

@@ -7,6 +7,8 @@ mod buffer_codegen;
mod completion_provider;
mod context;
mod context_server_configuration;
#[cfg(test)]
mod evals;
mod inline_assistant;
mod inline_prompt_editor;
mod language_model_selector;

View File

@@ -1,10 +1,15 @@
use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
use agent::{
AgentTool as _, FailureMessageInput, FailureMessageTool, RewriteSectionInput,
RewriteSectionTool, SystemPromptTemplate, Template, Templates,
};
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result};
use client::telemetry::Telemetry;
use cloud_llm_client::CompletionIntent;
use collections::HashSet;
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
use feature_flags::{FeatureFlagAppExt as _, InlineAssistantV2FeatureFlag};
use futures::{
SinkExt, Stream, StreamExt, TryStreamExt as _,
channel::mpsc,
@@ -14,12 +19,14 @@ use futures::{
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task};
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
use language_model::{
LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelTextStream, Role, report_assistant_event,
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolUse, Role,
report_assistant_event,
};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
use prompt_store::PromptBuilder;
use prompt_store::{ProjectContext, PromptBuilder};
use rope::Rope;
use smol::future::FutureExt;
use std::{
@@ -214,6 +221,13 @@ impl BufferCodegen {
pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
self.active_alternative().read(cx).last_equal_ranges()
}
pub fn tool_description<'a>(&self, cx: &'a App) -> Option<&'a str> {
self.active_alternative()
.read(cx)
.tool_description
.as_deref()
}
}
impl EventEmitter<CodegenEvent> for BufferCodegen {}
@@ -238,6 +252,7 @@ pub struct CodegenAlternative {
elapsed_time: Option<f64>,
completion: Option<String>,
pub message_id: Option<String>,
pub tool_description: Option<String>,
}
impl EventEmitter<CodegenEvent> for CodegenAlternative {}
@@ -288,14 +303,15 @@ impl CodegenAlternative {
generation: Task::ready(()),
diff: Diff::default(),
telemetry,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
builder,
active,
active: false,
edits: Vec::new(),
line_operations: Vec::new(),
range,
elapsed_time: None,
completion: None,
tool_description: None,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
}
}
@@ -358,20 +374,135 @@ impl CodegenAlternative {
let api_key = model.api_key(cx);
let telemetry_id = model.telemetry_id();
let provider_id = model.provider_id();
let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
if user_prompt.trim().to_lowercase() == "delete" {
async { Ok(LanguageModelTextStream::default()) }.boxed_local()
} else {
let request = self.build_request(&model, user_prompt, context_task, cx)?;
cx.spawn(async move |_, cx| {
Ok(model.stream_completion_text(request.await, cx).await?)
if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
// if false {
// let request = self.build_request_v2(&model, user_prompt, context_task, cx)?;
let request = self.build_request(&model, user_prompt, context_task, cx)?;
let tool_use = cx
.spawn(async move |_, cx| {
Ok(model.stream_completion_tool(request.await, cx).await?)
})
.boxed_local()
};
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
// .boxed_local();
;
self.handle_tool_use(telemetry_id, provider_id.to_string(), api_key, tool_use, cx);
} else {
let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
if user_prompt.trim().to_lowercase() == "delete" {
async { Ok(LanguageModelTextStream::default()) }.boxed_local()
} else {
let request = self.build_request(&model, user_prompt, context_task, cx)?;
cx.spawn(async move |_, cx| {
Ok(model.stream_completion_text(request.await, cx).await?)
})
.boxed_local()
};
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
}
Ok(())
}
fn build_request_v2(
&self,
model: &Arc<dyn LanguageModel>,
user_prompt: String,
context_task: Shared<Task<Option<LoadedContext>>>,
cx: &mut App,
) -> Result<Task<LanguageModelRequest>> {
let buffer = self.buffer.read(cx).snapshot(cx);
let language = buffer.language_at(self.range.start);
let language_name = if let Some(language) = language.as_ref() {
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
None
} else {
Some(language.name())
}
} else {
None
};
let language_name = language_name.as_ref();
let start = buffer.point_to_buffer_offset(self.range.start);
let end = buffer.point_to_buffer_offset(self.range.end);
let (buffer, range) = if let Some((start, end)) = start.zip(end) {
let (start_buffer, start_buffer_offset) = start;
let (end_buffer, end_buffer_offset) = end;
if start_buffer.remote_id() == end_buffer.remote_id() {
(start_buffer.clone(), start_buffer_offset..end_buffer_offset)
} else {
anyhow::bail!("invalid transformation range");
}
} else {
anyhow::bail!("invalid transformation range");
};
let system_prompt = self
.builder
.generate_inline_transformation_prompt_v2(
language_name,
buffer,
range.start.0..range.end.0,
)
.context("generating content prompt")?;
let temperature = AgentSettings::temperature_for_model(model, cx);
let tool_input_format = model.tool_input_format();
Ok(cx.spawn(async move |_cx| {
let mut messages = vec![LanguageModelRequestMessage {
role: Role::System,
content: vec![system_prompt.into()],
cache: false,
reasoning_details: None,
}];
let mut user_message = LanguageModelRequestMessage {
role: Role::User,
content: Vec::new(),
cache: false,
reasoning_details: None,
};
if let Some(context) = context_task.await {
context.add_to_request_message(&mut user_message);
}
user_message.content.push(user_prompt.into());
messages.push(user_message);
dbg!(&messages);
let tools = vec![
LanguageModelRequestTool {
name: RewriteSectionTool::name().to_string(),
description: RewriteSectionTool::description().to_string(),
input_schema: RewriteSectionTool::input_schema(tool_input_format).to_value(),
},
LanguageModelRequestTool {
name: FailureMessageTool::name().to_string(),
description: FailureMessageTool::description().to_string(),
input_schema: FailureMessageTool::input_schema(tool_input_format).to_value(),
},
];
let req = LanguageModelRequest {
thread_id: None,
prompt_id: None,
intent: Some(CompletionIntent::InlineAssist),
mode: None,
tools,
tool_choice: None,
stop: Vec::new(),
temperature,
messages,
thinking_allowed: false,
};
dbg!(&req);
req
}))
}
fn build_request(
&self,
model: &Arc<dyn LanguageModel>,
@@ -379,6 +510,10 @@ impl CodegenAlternative {
context_task: Shared<Task<Option<LoadedContext>>>,
cx: &mut App,
) -> Result<Task<LanguageModelRequest>> {
if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
return self.build_request_v2(model, user_prompt, context_task, cx);
}
let buffer = self.buffer.read(cx).snapshot(cx);
let language = buffer.language_at(self.range.start);
let language_name = if let Some(language) = language.as_ref() {
@@ -415,6 +550,7 @@ impl CodegenAlternative {
range.start.0..range.end.0,
)
.context("generating content prompt")?;
dbg!(&prompt);
let temperature = AgentSettings::temperature_for_model(model, cx);
@@ -508,8 +644,16 @@ impl CodegenAlternative {
let completion = Arc::new(Mutex::new(String::new()));
let completion_clone = completion.clone();
dbg!("AAA 0");
self.generation = cx.spawn(async move |codegen, cx| {
dbg!("AAA 1");
let stream = stream.await;
dbg!("AAA 2");
// use futures::stream::StreamExt;
// let all_chunks: Vec<Result<String, _>> = stream.unwrap().collect().await;
// dbg!(&all_chunks);
let token_usage = stream
.as_ref()
.ok()
@@ -518,16 +662,19 @@ impl CodegenAlternative {
.as_ref()
.ok()
.and_then(|stream| stream.message_id.clone());
dbg!("AAA 3");
let generate = async {
let model_telemetry_id = model_telemetry_id.clone();
let model_provider_id = model_provider_id.clone();
let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
let executor = cx.background_executor().clone();
let message_id = message_id.clone();
dbg!("AAA 4");
let line_based_stream_diff: Task<anyhow::Result<()>> =
cx.background_spawn(async move {
let mut response_latency = None;
let request_start = Instant::now();
dbg!("AAA 5");
let diff = async {
let chunks = StripInvalidSpans::new(
stream?.stream.map_err(|error| error.into()),
@@ -541,11 +688,14 @@ impl CodegenAlternative {
let mut line_indent = None;
let mut first_line = true;
dbg!("AAA 6");
while let Some(chunk) = chunks.next().await {
dbg!("AAA 7");
if response_latency.is_none() {
response_latency = Some(request_start.elapsed());
}
let chunk = chunk?;
dbg!(&chunk);
completion_clone.lock().push_str(&chunk);
let mut lines = chunk.split('\n').peekable();
@@ -719,6 +869,7 @@ impl CodegenAlternative {
output_tokens = usage.output_tokens,
)
}
cx.emit(CodegenEvent::Finished);
cx.notify();
})
@@ -898,6 +1049,126 @@ impl CodegenAlternative {
.ok();
})
}
fn handle_tool_use(
&mut self,
_telemetry_id: String,
_provider_id: String,
_api_key: Option<String>,
tool_use: impl 'static
+ Future<
Output = Result<language_model::LanguageModelToolUse, LanguageModelCompletionError>,
>,
cx: &mut Context<Self>,
) {
self.diff = Diff::default();
self.status = CodegenStatus::Pending;
self.generation = cx.spawn(async move |codegen, cx| {
let tool_use = tool_use.await;
dbg!(&tool_use);
match tool_use {
Ok(tool_use) if tool_use.name.as_ref() == "rewrite_section" => {
eprintln!("Received tool use: {:?}", tool_use);
// Parse the input JSON into RewriteSectionInput
match serde_json::from_value::<RewriteSectionInput>(tool_use.input) {
Ok(input) => {
eprintln!("Description: {}", input.description);
eprintln!("Replacement text length: {}", input.replacement_text.len());
// Store the description if non-empty
let description = if !input.description.trim().is_empty() {
Some(input.description.clone())
} else {
None
};
// Apply the replacement text to the buffer and compute diff
let batch_diff_task = codegen
.update(cx, |this, cx| {
this.tool_description = description;
let range = this.range.clone();
this.apply_edits(
std::iter::once((range, input.replacement_text)),
cx,
);
this.reapply_batch_diff(cx)
})
.ok();
// Wait for the diff computation to complete
if let Some(diff_task) = batch_diff_task {
diff_task.await;
}
let _ = codegen.update(cx, |this, cx| {
this.status = CodegenStatus::Done;
cx.emit(CodegenEvent::Finished);
cx.notify();
});
return;
}
Err(e) => {
eprintln!("Failed to parse RewriteSectionInput: {:?}", e);
let _ = codegen.update(cx, |this, cx| {
this.status = CodegenStatus::Error(e.into());
cx.emit(CodegenEvent::Finished);
cx.notify();
});
return;
}
}
}
Ok(tool_use) if tool_use.name.as_ref() == "failure_message" => {
// Handle failure message tool use
match serde_json::from_value::<FailureMessageInput>(tool_use.input) {
Ok(input) => {
eprintln!("Failure message: {}", input.message);
let _ = codegen.update(cx, |this, cx| {
// Store the failure message as the tool description
this.tool_description = Some(input.message);
this.status = CodegenStatus::Done;
cx.emit(CodegenEvent::Finished);
cx.notify();
});
return;
}
Err(e) => {
eprintln!("Failed to parse FailureMessageInput: {:?}", e);
let _ = codegen.update(cx, |this, cx| {
this.status = CodegenStatus::Error(e.into());
cx.emit(CodegenEvent::Finished);
cx.notify();
});
return;
}
}
}
Ok(tool_use) => {
eprintln!("Unexpected tool {}", tool_use.name);
let _ = codegen.update(cx, |this, cx| {
this.status = CodegenStatus::Done;
cx.emit(CodegenEvent::Finished);
});
return;
}
Err(e) => {
eprintln!("Failed to get tool use: {:?}", e);
let _ = codegen.update(cx, |this, cx| {
this.status = CodegenStatus::Error(e.into());
cx.emit(CodegenEvent::Finished);
cx.notify();
});
return;
}
}
});
cx.notify();
}
}
#[derive(Copy, Clone, Debug)]

View File

@@ -0,0 +1,88 @@
use std::{str::FromStr, sync::Arc};
// use std::sync::Arc;
use crate::inline_assistant::test::run_inline_assistant_test;
use eval_utils::EvalOutput;
use gpui::TestAppContext;
use language_model::{LanguageModelRegistry, SelectedModel};
use rand::{SeedableRng as _, rngs::StdRng};
#[test]
#[cfg_attr(not(feature = "unit-eval"), ignore)]
fn eval_single_cursor_edit() {
eval_utils::eval(
1,
1.0,
0.0,
Arc::new(|tx| {
run_eval(
&EvalInput {
prompt: "Rename this variable to buffer_text".to_string(),
text: indoc::indoc! {"
struct EvalExampleStruct {
text: Strˇing,
prompt: String,
}
"}
.to_string(),
},
tx,
&|_, output| {
EvalOutput::assert(
format!("Failed to rename variable, output: {}", output),
output
== indoc::indoc! {"
struct EvalExampleStruct {
buffer_text: String,
prompt: String,
}
"},
)
},
);
}),
);
}
struct EvalInput {
text: String,
prompt: String,
}
fn run_eval(
input: &EvalInput,
tx: std::sync::mpsc::Sender<eval_utils::EvalOutput>,
judge: &dyn Fn(&EvalInput, &str) -> eval_utils::EvalOutput,
) {
let dispatcher = gpui::TestDispatcher::new(StdRng::from_os_rng());
let mut cx = TestAppContext::build(dispatcher, None);
let buffer_text = run_inline_assistant_test(
input.text.clone(),
input.prompt.clone(),
|cx| {
// Reconfigure to use a real model instead of the fake one
let model_name = std::env::var("ZED_AGENT_MODEL")
.unwrap_or("anthropic/claude-sonnet-4-latest".into());
let selected_model = SelectedModel::from_str(&model_name)
.expect("Invalid model format. Use 'provider/model-id'");
log::info!("Selected model: {selected_model:?}");
cx.update(|_, cx| {
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.select_inline_assistant_model(Some(&selected_model), cx);
});
});
},
|_cx| {
log::info!("Waiting for actual response from the LLM...");
},
&mut cx,
);
let output = judge(input, &buffer_text);
tx.send(output).ok();
}

View File

@@ -32,7 +32,7 @@ use editor::{
},
};
use fs::Fs;
use futures::FutureExt;
use futures::{FutureExt, channel::mpsc};
use gpui::{
App, Context, Entity, Focusable, Global, HighlightStyle, Subscription, Task, UpdateGlobal,
WeakEntity, Window, point,
@@ -102,6 +102,7 @@ pub struct InlineAssistant {
prompt_builder: Arc<PromptBuilder>,
telemetry: Arc<Telemetry>,
fs: Arc<dyn Fs>,
_inline_assistant_completions: Option<mpsc::UnboundedSender<anyhow::Result<InlineAssistId>>>,
}
impl Global for InlineAssistant {}
@@ -123,9 +124,18 @@ impl InlineAssistant {
prompt_builder,
telemetry,
fs,
_inline_assistant_completions: None,
}
}
#[cfg(any(test, feature = "test-support"))]
pub fn set_completion_receiver(
&mut self,
sender: mpsc::UnboundedSender<anyhow::Result<InlineAssistId>>,
) {
self._inline_assistant_completions = Some(sender);
}
pub fn register_workspace(
&mut self,
workspace: &Entity<Workspace>,
@@ -287,7 +297,7 @@ impl InlineAssistant {
action.prompt.clone(),
window,
cx,
)
);
})
}
InlineAssistTarget::Terminal(active_terminal) => {
@@ -301,8 +311,8 @@ impl InlineAssistant {
action.prompt.clone(),
window,
cx,
)
})
);
});
}
};
@@ -377,17 +387,9 @@ impl InlineAssistant {
let mut selections = Vec::<Selection<Point>>::new();
let mut newest_selection = None;
for mut selection in initial_selections {
if selection.end > selection.start {
selection.start.column = 0;
// If the selection ends at the start of the line, we don't want to include it.
if selection.end.column == 0 {
selection.end.row -= 1;
}
selection.end.column = snapshot
.buffer_snapshot()
.line_len(MultiBufferRow(selection.end.row));
} else if let Some(fold) =
snapshot.crease_for_buffer_row(MultiBufferRow(selection.end.row))
if selection.end == selection.start
&& let Some(fold) =
snapshot.crease_for_buffer_row(MultiBufferRow(selection.end.row))
{
selection.start = fold.range().start;
selection.end = fold.range().end;
@@ -414,6 +416,15 @@ impl InlineAssistant {
}
}
}
} else {
selection.start.column = 0;
// If the selection ends at the start of the line, we don't want to include it.
if selection.end.column == 0 && selection.start.row != selection.end.row {
selection.end.row -= 1;
}
selection.end.column = snapshot
.buffer_snapshot()
.line_len(MultiBufferRow(selection.end.row));
}
if let Some(prev_selection) = selections.last_mut()
@@ -534,14 +545,15 @@ impl InlineAssistant {
}
}
let [prompt_block_id, end_block_id] =
self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
let [prompt_block_id, tool_description_block_id, end_block_id] =
self.insert_assist_blocks(&editor, &range, &prompt_editor, cx);
assists.push((
assist_id,
range.clone(),
prompt_editor,
prompt_block_id,
tool_description_block_id,
end_block_id,
));
}
@@ -560,7 +572,15 @@ impl InlineAssistant {
};
let mut assist_group = InlineAssistGroup::new();
for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
for (
assist_id,
range,
prompt_editor,
prompt_block_id,
tool_description_block_id,
end_block_id,
) in assists
{
let codegen = prompt_editor.read(cx).codegen().clone();
self.assists.insert(
@@ -571,6 +591,7 @@ impl InlineAssistant {
editor,
&prompt_editor,
prompt_block_id,
tool_description_block_id,
end_block_id,
range,
codegen,
@@ -598,13 +619,13 @@ impl InlineAssistant {
initial_prompt: Option<String>,
window: &mut Window,
cx: &mut App,
) {
) -> Option<InlineAssistId> {
let snapshot = editor.update(cx, |editor, cx| editor.snapshot(window, cx));
let Some((codegen_ranges, newest_selection)) =
self.codegen_ranges(editor, &snapshot, window, cx)
else {
return;
return None;
};
let assist_to_focus = self.batch_assist(
@@ -624,6 +645,8 @@ impl InlineAssistant {
if let Some(assist_id) = assist_to_focus {
self.focus_assist(assist_id, window, cx);
}
assist_to_focus
}
pub fn suggest_assist(
@@ -677,7 +700,7 @@ impl InlineAssistant {
range: &Range<Anchor>,
prompt_editor: &Entity<PromptEditor<BufferCodegen>>,
cx: &mut App,
) -> [CustomBlockId; 2] {
) -> [CustomBlockId; 3] {
let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
prompt_editor
.editor
@@ -691,6 +714,14 @@ impl InlineAssistant {
render: build_assist_editor_renderer(prompt_editor),
priority: 0,
},
// Placeholder for tool description - will be updated dynamically
BlockProperties {
style: BlockStyle::Flex,
placement: BlockPlacement::Below(range.end),
height: Some(0),
render: Arc::new(|_cx| div().into_any_element()),
priority: 0,
},
BlockProperties {
style: BlockStyle::Sticky,
placement: BlockPlacement::Below(range.end),
@@ -709,7 +740,7 @@ impl InlineAssistant {
editor.update(cx, |editor, cx| {
let block_ids = editor.insert_blocks(assist_blocks, None, cx);
[block_ids[0], block_ids[1]]
[block_ids[0], block_ids[1], block_ids[2]]
})
}
@@ -1101,6 +1132,9 @@ impl InlineAssistant {
let mut to_remove = decorations.removed_line_block_ids;
to_remove.insert(decorations.prompt_block_id);
to_remove.insert(decorations.end_block_id);
if let Some(tool_description_block_id) = decorations.tool_description_block_id {
to_remove.insert(tool_description_block_id);
}
editor.remove_blocks(to_remove, None, cx);
});
@@ -1421,8 +1455,60 @@ impl InlineAssistant {
let old_snapshot = codegen.snapshot(cx);
let old_buffer = codegen.old_buffer(cx);
let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
let tool_description = codegen.tool_description(cx).map(|s| s.to_string());
editor.update(cx, |editor, cx| {
// Update tool description block
if let Some(description) = tool_description {
if let Some(block_id) = decorations.tool_description_block_id {
editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
let new_block_id = editor.insert_blocks(
[BlockProperties {
style: BlockStyle::Flex,
placement: BlockPlacement::Below(assist.range.end),
height: Some(1),
render: Arc::new({
let description = description.clone();
move |cx| {
div()
.w_full()
.py_1()
.px_2()
.bg(cx.theme().colors().editor_background)
.border_y_1()
.border_color(cx.theme().status().info_border)
.child(
Label::new(description.clone())
.color(Color::Muted)
.size(LabelSize::Small),
)
.into_any_element()
}
}),
priority: 0,
}],
None,
cx,
);
decorations.tool_description_block_id = new_block_id.into_iter().next();
}
} else if let Some(block_id) = decorations.tool_description_block_id {
// Hide the block if there's no description
editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
let new_block_id = editor.insert_blocks(
[BlockProperties {
style: BlockStyle::Flex,
placement: BlockPlacement::Below(assist.range.end),
height: Some(0),
render: Arc::new(|_cx| div().into_any_element()),
priority: 0,
}],
None,
cx,
);
decorations.tool_description_block_id = new_block_id.into_iter().next();
}
let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
editor.remove_blocks(old_blocks, None, cx);
@@ -1674,6 +1760,7 @@ impl InlineAssist {
editor: &Entity<Editor>,
prompt_editor: &Entity<PromptEditor<BufferCodegen>>,
prompt_block_id: CustomBlockId,
tool_description_block_id: CustomBlockId,
end_block_id: CustomBlockId,
range: Range<Anchor>,
codegen: Entity<BufferCodegen>,
@@ -1688,7 +1775,8 @@ impl InlineAssist {
decorations: Some(InlineAssistDecorations {
prompt_block_id,
prompt_editor: prompt_editor.clone(),
removed_line_block_ids: HashSet::default(),
removed_line_block_ids: Default::default(),
tool_description_block_id: Some(tool_description_block_id),
end_block_id,
}),
range,
@@ -1740,6 +1828,16 @@ impl InlineAssist {
&& assist.decorations.is_none()
&& let Some(workspace) = assist.workspace.upgrade()
{
#[cfg(any(test, feature = "test-support"))]
if let Some(sender) = &mut this._inline_assistant_completions {
sender
.unbounded_send(Err(anyhow::anyhow!(
"Inline assistant error: {}",
error
)))
.ok();
}
let error = format!("Inline assistant error: {}", error);
workspace.update(cx, |workspace, cx| {
struct InlineAssistantError;
@@ -1750,6 +1848,11 @@ impl InlineAssist {
workspace.show_toast(Toast::new(id, error), cx);
})
} else {
#[cfg(any(test, feature = "test-support"))]
if let Some(sender) = &mut this._inline_assistant_completions {
sender.unbounded_send(Ok(assist_id)).ok();
}
}
if assist.decorations.is_none() {
@@ -1777,6 +1880,7 @@ struct InlineAssistDecorations {
prompt_block_id: CustomBlockId,
prompt_editor: Entity<PromptEditor<BufferCodegen>>,
removed_line_block_ids: HashSet<CustomBlockId>,
tool_description_block_id: Option<CustomBlockId>,
end_block_id: CustomBlockId,
}
@@ -1943,3 +2047,159 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
}
}
}
#[cfg(any(test, feature = "test-support"))]
pub mod test {
use std::sync::Arc;
use agent::HistoryStore;
use assistant_text_thread::TextThreadStore;
use client::{Client, UserStore};
use editor::{Editor, MultiBuffer, MultiBufferOffset};
use fs::FakeFs;
use futures::channel::mpsc;
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
use language::Buffer;
use language_model::LanguageModelRegistry;
use project::Project;
use prompt_store::PromptBuilder;
use smol::stream::StreamExt as _;
use util::test::marked_text_ranges;
use workspace::Workspace;
use crate::InlineAssistant;
pub fn run_inline_assistant_test<SetupF, TestF>(
base_buffer: String,
prompt: String,
setup: SetupF,
test: TestF,
cx: &mut TestAppContext,
) -> String
where
SetupF: FnOnce(&mut gpui::VisualTestContext),
TestF: FnOnce(&mut gpui::VisualTestContext),
{
let fs = FakeFs::new(cx.executor());
let app_state = cx.update(|cx| workspace::AppState::test(cx));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let http = Arc::new(reqwest_client::ReqwestClient::user_agent("agent tests").unwrap());
let client = cx.update(|cx| {
cx.set_http_client(http);
Client::production(cx)
});
let mut inline_assistant =
InlineAssistant::new(fs.clone(), prompt_builder, client.telemetry().clone());
let (tx, mut completion_rx) = mpsc::unbounded();
inline_assistant.set_completion_receiver(tx);
// Initialize settings and client
cx.update(|cx| {
gpui_tokio::init(cx);
settings::init(cx);
client::init(&client, cx);
workspace::init(app_state.clone(), cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_models::init(user_store, client.clone(), cx);
cx.set_global(inline_assistant);
});
let project = cx
.executor()
.block_test(async { Project::test(fs.clone(), [], cx).await });
// Create workspace with window
let (workspace, cx) = cx.add_window_view(|window, cx| {
window.activate_window();
Workspace::new(None, project.clone(), app_state.clone(), window, cx)
});
setup(cx);
let (_editor, buffer) = cx.update(|window, cx| {
let buffer = cx.new(|cx| Buffer::local("", cx));
let multibuffer = cx.new(|cx| MultiBuffer::singleton(buffer.clone(), cx));
let editor = cx.new(|cx| Editor::for_multibuffer(multibuffer, None, window, cx));
editor.update(cx, |editor, cx| {
let (unmarked_text, selection_ranges) = marked_text_ranges(&base_buffer, true);
editor.set_text(unmarked_text, window, cx);
editor.change_selections(Default::default(), window, cx, |s| {
s.select_ranges(
selection_ranges.into_iter().map(|range| {
MultiBufferOffset(range.start)..MultiBufferOffset(range.end)
}),
)
})
});
let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx));
let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
// Add editor to workspace
workspace.update(cx, |workspace, cx| {
workspace.add_item_to_active_pane(Box::new(editor.clone()), None, true, window, cx);
});
// Call assist method
InlineAssistant::update_global(cx, |inline_assistant, cx| {
let assist_id = inline_assistant
.assist(
&editor,
workspace.downgrade(),
project.downgrade(),
history_store, // thread_store
None, // prompt_store
Some(prompt),
window,
cx,
)
.unwrap();
inline_assistant.start_assist(assist_id, window, cx);
});
(editor, buffer)
});
cx.run_until_parked();
test(cx);
cx.executor()
.block_test(async { completion_rx.next().await });
buffer.read_with(cx, |buffer, _| buffer.text())
}
#[allow(unused)]
fn test_inline_assistant(
base_buffer: &'static str,
llm_output: &'static str, // vec<&'static str>
cx: &mut TestAppContext,
) -> String {
run_inline_assistant_test(
base_buffer.to_string(),
"Prompt doesn't matter because we're using a fake model".to_string(),
|cx| {
cx.update(|_, cx| LanguageModelRegistry::test(cx));
},
|cx| {
let fake_model = cx.update(|_, cx| {
LanguageModelRegistry::global(cx)
.update(cx, |registry, _| registry.fake_model())
});
let fake = fake_model.as_fake();
// let fake = fake_model;
fake.send_last_completion_stream_text_chunk(llm_output.to_string());
fake.end_last_completion_stream();
// Run again to process the model's response
cx.run_until_parked();
},
cx,
)
}
}

View File

@@ -0,0 +1,18 @@
[package]
name = "eval_utils"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/eval_utils.rs"
doctest = false
[dependencies]
gpui.workspace = true
serde.workspace = true
smol.workspace = true

View File

@@ -0,0 +1 @@
LICENSE-GPL

View File

@@ -0,0 +1,3 @@
# eval_utils
Utilities for evals of agents.

View File

@@ -0,0 +1,148 @@
//! Utilities for evaluation and benchmarking.
use std::{
collections::HashMap,
io::Write as _,
sync::{Arc, mpsc},
};
fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usize) {
let passed_count = evaluated_count - failed_count;
let passed_ratio = if evaluated_count == 0 {
0.0
} else {
passed_count as f64 / evaluated_count as f64
};
print!(
"\r\x1b[KEvaluated {}/{} ({:.2}% passed)",
evaluated_count,
iterations,
passed_ratio * 100.0
);
std::io::stdout().flush().unwrap();
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum OutcomeKind {
Passed,
Failed,
Error,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct EvalOutput {
pub data: String,
pub mismatched_tags: usize,
pub tags: usize,
pub outcome_kind: OutcomeKind,
}
impl EvalOutput {
// TODO! Evaluate this API in relation to the original evals
pub fn assert(failure_data: String, pass: bool) -> Self {
if pass {
EvalOutput {
data: "".to_string(),
mismatched_tags: 0,
tags: 0,
outcome_kind: OutcomeKind::Passed,
}
} else {
EvalOutput {
data: failure_data,
mismatched_tags: 0,
tags: 0,
outcome_kind: OutcomeKind::Failed,
}
}
}
}
pub fn eval(
iterations: usize,
expected_pass_ratio: f32,
mismatched_tag_threshold: f32,
evalf: Arc<dyn Fn(mpsc::Sender<EvalOutput>) + Send + Sync>,
) {
let mut evaluated_count = 0;
let mut failed_count = 0;
report_progress(evaluated_count, failed_count, iterations);
let (tx, rx) = mpsc::channel();
let executor = gpui::background_executor();
let semaphore = Arc::new(smol::lock::Semaphore::new(32));
// Warm the cache once
evalf(tx.clone());
for _ in 1..iterations {
let tx = tx.clone();
let semaphore = semaphore.clone();
let evalf = evalf.clone();
executor
.spawn(async move {
let _guard = semaphore.acquire().await;
evalf(tx);
})
.detach();
}
drop(tx);
let mut failed_evals = Vec::new();
let mut errored_evals = HashMap::new();
let mut eval_outputs = Vec::new();
let mut cumulative_mismatched_tags = 0usize;
let mut cumulative_tags = 0usize;
while let Ok(output) = rx.recv() {
if matches!(
output.outcome_kind,
OutcomeKind::Passed | OutcomeKind::Failed
) {
cumulative_mismatched_tags += output.mismatched_tags;
cumulative_tags += output.tags;
eval_outputs.push(output.clone());
}
match output.outcome_kind {
OutcomeKind::Passed => {}
OutcomeKind::Failed => {
failed_count += 1;
failed_evals.push(output);
}
OutcomeKind::Error => {
failed_count += 1;
*errored_evals.entry(output.data).or_insert(0) += 1;
}
}
evaluated_count += 1;
report_progress(evaluated_count, failed_count, iterations);
}
let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
println!("Actual pass ratio: {}\n", actual_pass_ratio);
if actual_pass_ratio < expected_pass_ratio {
for (error, count) in errored_evals {
println!("Eval errored {} times. Error: {}", count, error);
}
for failed in failed_evals {
println!("Eval failed");
println!("{}", failed.data);
}
panic!(
"Actual pass ratio: {}\nExpected pass ratio: {}",
actual_pass_ratio, expected_pass_ratio
);
}
let mismatched_tag_ratio = cumulative_mismatched_tags as f32 / cumulative_tags as f32;
if mismatched_tag_ratio > mismatched_tag_threshold {
for eval_output in eval_outputs {
println!("{}", eval_output.data);
}
panic!("Too many mismatched tags: {:?}", cumulative_mismatched_tags);
}
}

View File

@@ -17,3 +17,9 @@ pub struct PanicFeatureFlag;
impl FeatureFlag for PanicFeatureFlag {
const NAME: &'static str = "panic";
}
pub struct InlineAssistantV2FeatureFlag;
impl FeatureFlag for InlineAssistantV2FeatureFlag {
const NAME: &'static str = "inline-assistant-v2";
}

View File

@@ -408,6 +408,7 @@ impl FakeHttpClient {
}
pub fn with_404_response() -> Arc<HttpClientWithUrl> {
log::warn!("Using fake HTTP client with 404 response");
Self::create(|_| async move {
Ok(Response::builder()
.status(404)
@@ -417,6 +418,7 @@ impl FakeHttpClient {
}
pub fn with_200_response() -> Arc<HttpClientWithUrl> {
log::warn!("Using fake HTTP client with 200 response");
Self::create(|_| async move {
Ok(Response::builder()
.status(200)

View File

@@ -670,6 +670,7 @@ pub trait LanguageModel: Send + Sync {
.chain(events.filter_map({
let last_token_usage = last_token_usage.clone();
move |result| {
dbg!(&result);
let last_token_usage = last_token_usage.clone();
async move {
match result {
@@ -707,6 +708,40 @@ pub trait LanguageModel: Send + Sync {
.boxed()
}
fn stream_completion_tool(
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
let future = self.stream_completion(request, cx);
async move {
let events = future.await?;
let mut events = events.fuse();
// Iterate through events until we find a complete ToolUse
while let Some(event) = events.next().await {
match event {
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
if tool_use.is_input_complete =>
{
return Ok(tool_use);
}
Err(err) => {
return Err(err);
}
_ => {}
}
}
// Stream ended without a complete tool use
Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
"Stream ended without receiving a complete tool use"
)))
}
.boxed()
}
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
None
}

View File

@@ -135,6 +135,11 @@ impl LanguageModelRegistry {
fake_provider
}
#[cfg(any(test, feature = "test-support"))]
pub fn fake_model(&self) -> Arc<dyn LanguageModel> {
self.default_model.as_ref().unwrap().model.clone()
}
pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
&mut self,
provider: Arc<T>,

View File

@@ -94,6 +94,16 @@ pub struct ContentPromptContext {
pub diagnostic_errors: Vec<ContentPromptDiagnosticContext>,
}
#[derive(Serialize)]
pub struct ContentPromptContextV2 {
pub content_type: String,
pub language_name: Option<String>,
pub is_truncated: bool,
pub document_content: String,
pub rewrite_section: Option<String>,
pub diagnostic_errors: Vec<ContentPromptDiagnosticContext>,
}
#[derive(Serialize)]
pub struct TerminalAssistantPromptContext {
pub os: String,
@@ -276,6 +286,88 @@ impl PromptBuilder {
Ok(())
}
pub fn generate_inline_transformation_prompt_v2(
&self,
language_name: Option<&LanguageName>,
buffer: BufferSnapshot,
range: Range<usize>,
) -> Result<String, RenderError> {
let content_type = match language_name.as_ref().map(|l| l.as_ref()) {
None | Some("Markdown" | "Plain Text") => "text",
Some(_) => "code",
};
const MAX_CTX: usize = 50000;
let is_insert = range.is_empty();
let mut is_truncated = false;
let before_range = 0..range.start;
let truncated_before = if before_range.len() > MAX_CTX {
is_truncated = true;
let start = buffer.clip_offset(range.start - MAX_CTX, text::Bias::Right);
start..range.start
} else {
before_range
};
let after_range = range.end..buffer.len();
let truncated_after = if after_range.len() > MAX_CTX {
is_truncated = true;
let end = buffer.clip_offset(range.end + MAX_CTX, text::Bias::Left);
range.end..end
} else {
after_range
};
let mut document_content = String::new();
for chunk in buffer.text_for_range(truncated_before) {
document_content.push_str(chunk);
}
if is_insert {
document_content.push_str("<insert_here></insert_here>");
} else {
document_content.push_str("<rewrite_this>\n");
for chunk in buffer.text_for_range(range.clone()) {
document_content.push_str(chunk);
}
document_content.push_str("\n</rewrite_this>");
}
for chunk in buffer.text_for_range(truncated_after) {
document_content.push_str(chunk);
}
let rewrite_section = if !is_insert {
let mut section = String::new();
for chunk in buffer.text_for_range(range.clone()) {
section.push_str(chunk);
}
Some(section)
} else {
None
};
let diagnostics = buffer.diagnostics_in_range::<_, Point>(range, false);
let diagnostic_errors: Vec<ContentPromptDiagnosticContext> = diagnostics
.map(|entry| {
let start = entry.range.start;
ContentPromptDiagnosticContext {
line_number: (start.row + 1) as usize,
error_message: entry.diagnostic.message.clone(),
code_content: buffer.text_for_range(entry.range).collect(),
}
})
.collect();
let context = ContentPromptContextV2 {
content_type: content_type.to_string(),
language_name: language_name.map(|s| s.to_string()),
is_truncated,
document_content,
rewrite_section,
diagnostic_errors,
};
self.handlebars.lock().render("content_prompt_v2", &context)
}
pub fn generate_inline_transformation_prompt(
&self,
user_prompt: String,