Compare commits
32 Commits
codex
...
test-drive
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c240f876b1 | ||
|
|
b076ff99ef | ||
|
|
36c173e3e2 | ||
|
|
8573b3a84b | ||
|
|
5e70235794 | ||
|
|
9e9192f6a3 | ||
|
|
936972d9b0 | ||
|
|
e9533423db | ||
|
|
ba480295c1 | ||
|
|
9106f4495b | ||
|
|
1feb1296fe | ||
|
|
582a247922 | ||
|
|
c2881a4537 | ||
|
|
b4744750da | ||
|
|
6edc255158 | ||
|
|
a96a1b1339 | ||
|
|
73cee468ed | ||
|
|
1f06615da2 | ||
|
|
c1773f7281 | ||
|
|
6c6b1ba3bc | ||
|
|
a23d9328ce | ||
|
|
5796a2663b | ||
|
|
447eb8e1c9 | ||
|
|
e434117018 | ||
|
|
36271b79b3 | ||
|
|
41644a53cc | ||
|
|
08a9c4af09 | ||
|
|
3187f28405 | ||
|
|
101f3b100f | ||
|
|
39c8b7bf5f | ||
|
|
08b41252f6 | ||
|
|
152bbca238 |
33
Cargo.lock
generated
33
Cargo.lock
generated
@@ -107,6 +107,39 @@ dependencies = [
|
||||
"zstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"assistant_tools",
|
||||
"chrono",
|
||||
"client",
|
||||
"collections",
|
||||
"ctor",
|
||||
"env_logger 0.11.8",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"handlebars 4.5.0",
|
||||
"language_model",
|
||||
"language_models",
|
||||
"parking_lot",
|
||||
"project",
|
||||
"reqwest_client",
|
||||
"rust-embed",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"thiserror 2.0.12",
|
||||
"util",
|
||||
"worktree",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent_settings"
|
||||
version = "0.1.0"
|
||||
|
||||
@@ -4,6 +4,7 @@ members = [
|
||||
"crates/activity_indicator",
|
||||
"crates/agent_ui",
|
||||
"crates/agent",
|
||||
"crates/agent2",
|
||||
"crates/agent_settings",
|
||||
"crates/anthropic",
|
||||
"crates/askpass",
|
||||
|
||||
@@ -111,7 +111,7 @@ mod tests {
|
||||
use assistant_tool::ToolRegistry;
|
||||
use collections::IndexMap;
|
||||
use gpui::SharedString;
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
use gpui::TestAppContext;
|
||||
use http_client::FakeHttpClient;
|
||||
use project::Project;
|
||||
use settings::{Settings, SettingsStore};
|
||||
|
||||
49
crates/agent2/Cargo.toml
Normal file
49
crates/agent2/Cargo.toml
Normal file
@@ -0,0 +1,49 @@
|
||||
[package]
|
||||
name = "agent2"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "GPL-3.0-or-later"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/agent2.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
assistant_tools.workspace = true
|
||||
chrono.workspace = true
|
||||
collections.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
handlebars = { workspace = true, features = ["rust-embed"] }
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
rust-embed.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
thiserror.workspace = true
|
||||
util.workspace = true
|
||||
worktree.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
ctor.workspace = true
|
||||
client = { workspace = true, "features" = ["test-support"] }
|
||||
env_logger.workspace = true
|
||||
fs = { workspace = true, "features" = ["test-support"] }
|
||||
gpui = { workspace = true, "features" = ["test-support"] }
|
||||
gpui_tokio.workspace = true
|
||||
language_model = { workspace = true, "features" = ["test-support"] }
|
||||
project = { workspace = true, "features" = ["test-support"] }
|
||||
reqwest_client.workspace = true
|
||||
settings = { workspace = true, "features" = ["test-support"] }
|
||||
worktree = { workspace = true, "features" = ["test-support"] }
|
||||
1
crates/agent2/LICENSE-GPL
Symbolic link
1
crates/agent2/LICENSE-GPL
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-GPL
|
||||
6
crates/agent2/src/agent2.rs
Normal file
6
crates/agent2/src/agent2.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
mod prompts;
|
||||
mod templates;
|
||||
mod thread;
|
||||
mod tools;
|
||||
|
||||
pub use thread::*;
|
||||
29
crates/agent2/src/prompts.rs
Normal file
29
crates/agent2/src/prompts.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use crate::{
|
||||
templates::{BaseTemplate, Template, Templates, WorktreeData},
|
||||
thread::Prompt,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use gpui::{App, Entity};
|
||||
use project::Project;
|
||||
|
||||
struct BasePrompt {
|
||||
project: Entity<Project>,
|
||||
}
|
||||
|
||||
impl Prompt for BasePrompt {
|
||||
fn render(&self, templates: &Templates, cx: &App) -> Result<String> {
|
||||
BaseTemplate {
|
||||
os: std::env::consts::OS.to_string(),
|
||||
shell: util::get_system_shell(),
|
||||
worktrees: self
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| WorktreeData {
|
||||
root_name: worktree.read(cx).root_name().to_string(),
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
.render(templates)
|
||||
}
|
||||
}
|
||||
57
crates/agent2/src/templates.rs
Normal file
57
crates/agent2/src/templates.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use handlebars::Handlebars;
|
||||
use rust_embed::RustEmbed;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(RustEmbed)]
|
||||
#[folder = "src/templates"]
|
||||
#[include = "*.hbs"]
|
||||
struct Assets;
|
||||
|
||||
pub struct Templates(Handlebars<'static>);
|
||||
|
||||
impl Templates {
|
||||
pub fn new() -> Arc<Self> {
|
||||
let mut handlebars = Handlebars::new();
|
||||
handlebars.register_embed_templates::<Assets>().unwrap();
|
||||
Arc::new(Self(handlebars))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Template: Sized {
|
||||
const TEMPLATE_NAME: &'static str;
|
||||
|
||||
fn render(&self, templates: &Templates) -> Result<String>
|
||||
where
|
||||
Self: Serialize + Sized,
|
||||
{
|
||||
Ok(templates.0.render(Self::TEMPLATE_NAME, self)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct BaseTemplate {
|
||||
pub os: String,
|
||||
pub shell: String,
|
||||
pub worktrees: Vec<WorktreeData>,
|
||||
}
|
||||
|
||||
impl Template for BaseTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "base.hbs";
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct WorktreeData {
|
||||
pub root_name: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct GlobTemplate {
|
||||
pub project_roots: String,
|
||||
}
|
||||
|
||||
impl Template for GlobTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "glob.hbs";
|
||||
}
|
||||
56
crates/agent2/src/templates/base.hbs
Normal file
56
crates/agent2/src/templates/base.hbs
Normal file
@@ -0,0 +1,56 @@
|
||||
You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
|
||||
|
||||
## Communication
|
||||
|
||||
1. Be conversational but professional.
|
||||
2. Refer to the USER in the second person and yourself in the first person.
|
||||
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names.
|
||||
4. NEVER lie or make things up.
|
||||
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
|
||||
|
||||
## Tool Use
|
||||
|
||||
1. Make sure to adhere to the tools schema.
|
||||
2. Provide every required argument.
|
||||
3. DO NOT use tools to access items that are already available in the context section.
|
||||
4. Use only the tools that are currently available.
|
||||
5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off.
|
||||
|
||||
## Searching and Reading
|
||||
|
||||
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
|
||||
|
||||
If appropriate, use tool calls to explore the current project, which contains the following root directories:
|
||||
|
||||
{{#each worktrees}}
|
||||
- `{{root_name}}`
|
||||
{{/each}}
|
||||
|
||||
- When providing paths to tools, the path should always begin with a path that starts with a project root directory listed above.
|
||||
- When looking for symbols in the project, prefer the `grep` tool.
|
||||
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
|
||||
- Bias towards not asking the user for help if you can find the answer yourself.
|
||||
|
||||
## Fixing Diagnostics
|
||||
|
||||
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.
|
||||
2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem.
|
||||
|
||||
## Debugging
|
||||
|
||||
When debugging, only make code changes if you are certain that you can solve the problem.
|
||||
Otherwise, follow debugging best practices:
|
||||
1. Address the root cause instead of the symptoms.
|
||||
2. Add descriptive logging statements and error messages to track variable and code state.
|
||||
3. Add test functions and statements to isolate the problem.
|
||||
|
||||
## Calling External APIs
|
||||
|
||||
1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission.
|
||||
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data.
|
||||
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
|
||||
|
||||
## System Information
|
||||
|
||||
Operating System: {{os}}
|
||||
Default Shell: {{shell}}
|
||||
8
crates/agent2/src/templates/glob.hbs
Normal file
8
crates/agent2/src/templates/glob.hbs
Normal file
@@ -0,0 +1,8 @@
|
||||
Find paths on disk with glob patterns.
|
||||
|
||||
Assume that all glob patterns are matched in a project directory with the following entries.
|
||||
|
||||
{{project_roots}}
|
||||
|
||||
When searching with patterns that begin with literal path components, e.g. `foo/bar/**/*.rs`, be
|
||||
sure to anchor them with one of the directories listed above.
|
||||
420
crates/agent2/src/thread.rs
Normal file
420
crates/agent2/src/thread.rs
Normal file
@@ -0,0 +1,420 @@
|
||||
use crate::templates::Templates;
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{channel::mpsc, future};
|
||||
use gpui::{App, Context, SharedString, Task};
|
||||
use language_model::{
|
||||
CompletionIntent, CompletionMode, LanguageModel, LanguageModelCompletionError,
|
||||
LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||
LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, Role, StopReason,
|
||||
};
|
||||
use schemars::{JsonSchema, Schema};
|
||||
use serde::Deserialize;
|
||||
use smol::stream::StreamExt;
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AgentMessage {
|
||||
pub role: Role,
|
||||
pub content: Vec<MessageContent>,
|
||||
}
|
||||
|
||||
pub type AgentResponseEvent = LanguageModelCompletionEvent;
|
||||
|
||||
pub trait Prompt {
|
||||
fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
|
||||
}
|
||||
|
||||
pub struct Thread {
|
||||
messages: Vec<AgentMessage>,
|
||||
completion_mode: CompletionMode,
|
||||
/// Holds the task that handles agent interaction until the end of the turn.
|
||||
/// Survives across multiple requests as the model performs tool calls and
|
||||
/// we run tools, report their results.
|
||||
running_turn: Option<Task<()>>,
|
||||
system_prompts: Vec<Arc<dyn Prompt>>,
|
||||
tools: BTreeMap<SharedString, Arc<dyn AgentToolErased>>,
|
||||
templates: Arc<Templates>,
|
||||
// project: Entity<Project>,
|
||||
// action_log: Entity<ActionLog>,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
pub fn new(templates: Arc<Templates>) -> Self {
|
||||
Self {
|
||||
messages: Vec::new(),
|
||||
completion_mode: CompletionMode::Normal,
|
||||
system_prompts: Vec::new(),
|
||||
running_turn: None,
|
||||
tools: BTreeMap::default(),
|
||||
templates,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_mode(&mut self, mode: CompletionMode) {
|
||||
self.completion_mode = mode;
|
||||
}
|
||||
|
||||
pub fn messages(&self) -> &[AgentMessage] {
|
||||
&self.messages
|
||||
}
|
||||
|
||||
pub fn add_tool(&mut self, tool: impl AgentTool) {
|
||||
self.tools.insert(tool.name(), tool.erase());
|
||||
}
|
||||
|
||||
pub fn remove_tool(&mut self, name: &str) -> bool {
|
||||
self.tools.remove(name).is_some()
|
||||
}
|
||||
|
||||
/// Sending a message results in the model streaming a response, which could include tool calls.
|
||||
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
|
||||
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
|
||||
pub fn send(
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
content: impl Into<MessageContent>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
|
||||
cx.notify();
|
||||
let (events_tx, events_rx) =
|
||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||
|
||||
let system_message = self.build_system_message(cx);
|
||||
self.messages.extend(system_message);
|
||||
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::User,
|
||||
content: vec![content.into()],
|
||||
});
|
||||
self.running_turn = Some(cx.spawn(async move |thread, cx| {
|
||||
let turn_result = async {
|
||||
// Perform one request, then keep looping if the model makes tool calls.
|
||||
let mut completion_intent = CompletionIntent::UserPrompt;
|
||||
loop {
|
||||
let request = thread.update(cx, |thread, cx| {
|
||||
thread.build_completion_request(completion_intent, cx)
|
||||
})?;
|
||||
|
||||
// println!(
|
||||
// "request: {}",
|
||||
// serde_json::to_string_pretty(&request).unwrap()
|
||||
// );
|
||||
|
||||
// Stream events, appending to messages and collecting up tool uses.
|
||||
let mut events = model.stream_completion(request, cx).await?;
|
||||
let mut tool_uses = Vec::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
tool_uses.extend(thread.handle_streamed_completion_event(
|
||||
event,
|
||||
events_tx.clone(),
|
||||
cx,
|
||||
));
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
Err(error) => {
|
||||
events_tx.unbounded_send(Err(error)).ok();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no tool uses, the turn is done.
|
||||
if tool_uses.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// If there are tool uses, wait for their results to be
|
||||
// computed, then send them together in a single message on
|
||||
// the next loop iteration.
|
||||
let tool_results = future::join_all(tool_uses).await;
|
||||
thread
|
||||
.update(cx, |thread, _cx| {
|
||||
thread.messages.push(AgentMessage {
|
||||
role: Role::User,
|
||||
content: tool_results.into_iter().map(Into::into).collect(),
|
||||
});
|
||||
})
|
||||
.ok();
|
||||
completion_intent = CompletionIntent::ToolResults;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
if let Err(error) = turn_result {
|
||||
events_tx.unbounded_send(Err(error)).ok();
|
||||
}
|
||||
}));
|
||||
events_rx
|
||||
}
|
||||
|
||||
pub fn build_system_message(&mut self, cx: &App) -> Option<AgentMessage> {
|
||||
let mut system_message = AgentMessage {
|
||||
role: Role::System,
|
||||
content: Vec::new(),
|
||||
};
|
||||
|
||||
for prompt in &self.system_prompts {
|
||||
if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
|
||||
system_message
|
||||
.content
|
||||
.push(MessageContent::Text(rendered_prompt));
|
||||
}
|
||||
}
|
||||
|
||||
(!system_message.content.is_empty()).then_some(system_message)
|
||||
}
|
||||
|
||||
/// A helper method that's called on every streamed completion event.
|
||||
/// Returns an optional tool result task, which the main agentic loop in
|
||||
/// send will send back to the model when it resolves.
|
||||
fn handle_streamed_completion_event(
|
||||
&mut self,
|
||||
event: LanguageModelCompletionEvent,
|
||||
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
use LanguageModelCompletionEvent::*;
|
||||
events_tx.unbounded_send(Ok(event.clone())).ok();
|
||||
|
||||
match event {
|
||||
Text(new_text) => self.handle_text_event(new_text, cx),
|
||||
Thinking { text, signature } => {
|
||||
todo!()
|
||||
}
|
||||
ToolUse(tool_use) => {
|
||||
return self.handle_tool_use_event(tool_use, cx);
|
||||
}
|
||||
StartMessage { role, .. } => {
|
||||
self.messages.push(AgentMessage {
|
||||
role,
|
||||
content: Vec::new(),
|
||||
});
|
||||
}
|
||||
UsageUpdate(_) => {}
|
||||
Stop(stop_reason) => self.handle_stop_event(stop_reason),
|
||||
StatusUpdate(_completion_request_status) => {}
|
||||
RedactedThinking { data } => todo!(),
|
||||
ToolUseJsonParseError {
|
||||
id,
|
||||
tool_name,
|
||||
raw_input,
|
||||
json_parse_error,
|
||||
} => todo!(),
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn handle_stop_event(&mut self, stop_reason: StopReason) {
|
||||
match stop_reason {
|
||||
StopReason::EndTurn | StopReason::ToolUse => {}
|
||||
StopReason::MaxTokens => todo!(),
|
||||
StopReason::Refusal => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
|
||||
let last_message = self.last_assistant_message();
|
||||
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
|
||||
text.push_str(&new_text);
|
||||
} else {
|
||||
last_message.content.push(MessageContent::Text(new_text));
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn handle_tool_use_event(
|
||||
&mut self,
|
||||
tool_use: LanguageModelToolUse,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
cx.notify();
|
||||
|
||||
let last_message = self.last_assistant_message();
|
||||
|
||||
// Ensure the last message ends in the current tool use
|
||||
let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
|
||||
if let MessageContent::ToolUse(last_tool_use) = content {
|
||||
if last_tool_use.id == tool_use.id {
|
||||
*last_tool_use = tool_use.clone();
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
if push_new_tool_use {
|
||||
last_message.content.push(tool_use.clone().into());
|
||||
}
|
||||
|
||||
if !tool_use.is_input_complete {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
|
||||
let pending_tool_result = tool.clone().run(tool_use.input, cx);
|
||||
|
||||
Some(cx.foreground_executor().spawn(async move {
|
||||
match pending_tool_result.await {
|
||||
Ok(tool_output) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: false,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
|
||||
output: None,
|
||||
},
|
||||
Err(error) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
|
||||
output: None,
|
||||
},
|
||||
}
|
||||
}))
|
||||
} else {
|
||||
Some(Task::ready(LanguageModelToolResult {
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(format!(
|
||||
"No tool named {} exists",
|
||||
tool_use.name
|
||||
))),
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
output: None,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
/// Guarantees the last message is from the assistant and returns a mutable reference.
|
||||
fn last_assistant_message(&mut self) -> &mut AgentMessage {
|
||||
if self
|
||||
.messages
|
||||
.last()
|
||||
.map_or(true, |m| m.role != Role::Assistant)
|
||||
{
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::Assistant,
|
||||
content: Vec::new(),
|
||||
});
|
||||
}
|
||||
self.messages.last_mut().unwrap()
|
||||
}
|
||||
|
||||
fn build_completion_request(
|
||||
&self,
|
||||
completion_intent: CompletionIntent,
|
||||
cx: &mut App,
|
||||
) -> LanguageModelRequest {
|
||||
LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: Some(completion_intent),
|
||||
mode: Some(self.completion_mode),
|
||||
messages: self.build_request_messages(),
|
||||
tools: self
|
||||
.tools
|
||||
.values()
|
||||
.filter_map(|tool| {
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool.name().to_string(),
|
||||
description: tool.description(cx).to_string(),
|
||||
input_schema: tool
|
||||
.input_schema(LanguageModelToolSchemaFormat::JsonSchema)
|
||||
.log_err()?,
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
tool_choice: None,
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
|
||||
self.messages
|
||||
.iter()
|
||||
.map(|message| LanguageModelRequestMessage {
|
||||
role: message.role,
|
||||
content: message.content.clone(),
|
||||
cache: false,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AgentTool
|
||||
where
|
||||
Self: 'static + Sized,
|
||||
{
|
||||
type Input: for<'de> Deserialize<'de> + JsonSchema;
|
||||
|
||||
fn name(&self) -> SharedString;
|
||||
fn description(&self, _cx: &mut App) -> SharedString {
|
||||
let schema = schemars::schema_for!(Self::Input);
|
||||
SharedString::new(
|
||||
schema
|
||||
.get("description")
|
||||
.and_then(|description| description.as_str())
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the JSON schema that describes the tool's input.
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Schema {
|
||||
assistant_tools::root_schema_for::<Self::Input>(format)
|
||||
}
|
||||
|
||||
/// Runs the tool with the provided input.
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
|
||||
|
||||
fn erase(self) -> Arc<dyn AgentToolErased> {
|
||||
Arc::new(Erased(Arc::new(self)))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Erased<T>(T);
|
||||
|
||||
pub trait AgentToolErased {
|
||||
fn name(&self) -> SharedString;
|
||||
fn description(&self, cx: &mut App) -> SharedString;
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
||||
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
|
||||
}
|
||||
|
||||
impl<T> AgentToolErased for Erased<Arc<T>>
|
||||
where
|
||||
T: AgentTool,
|
||||
{
|
||||
fn name(&self) -> SharedString {
|
||||
self.0.name()
|
||||
}
|
||||
|
||||
fn description(&self, cx: &mut App) -> SharedString {
|
||||
self.0.description(cx)
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
Ok(serde_json::to_value(self.0.input_schema(format))?)
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
|
||||
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
|
||||
match parsed_input {
|
||||
Ok(input) => self.0.clone().run(input, cx),
|
||||
Err(error) => Task::ready(Err(anyhow!(error))),
|
||||
}
|
||||
}
|
||||
}
|
||||
254
crates/agent2/src/thread/tests.rs
Normal file
254
crates/agent2/src/thread/tests.rs
Normal file
@@ -0,0 +1,254 @@
|
||||
use super::*;
|
||||
use client::{proto::language_server_prompt_request, Client, UserStore};
|
||||
use fs::FakeFs;
|
||||
use gpui::{AppContext, Entity, TestAppContext};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelRegistry, MessageContent, StopReason,
|
||||
};
|
||||
use reqwest_client::ReqwestClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol::stream::StreamExt;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
mod test_tools;
|
||||
use test_tools::*;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_echo(cx: &mut TestAppContext) {
|
||||
let AgentTest { model, agent, .. } = setup(cx).await;
|
||||
|
||||
let events = agent
|
||||
.update(cx, |agent, cx| {
|
||||
agent.send(model.clone(), "Testing: Reply with 'Hello'", cx)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
agent.update(cx, |agent, _cx| {
|
||||
assert_eq!(
|
||||
agent.messages.last().unwrap().content,
|
||||
vec![MessageContent::Text("Hello".to_string())]
|
||||
);
|
||||
});
|
||||
assert_eq!(stop_events(events), vec![StopReason::EndTurn]);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||
let AgentTest { model, agent, .. } = setup(cx).await;
|
||||
|
||||
// Test a tool call that's likely to complete *before* streaming stops.
|
||||
let events = agent
|
||||
.update(cx, |agent, cx| {
|
||||
agent.add_tool(EchoTool);
|
||||
agent.send(
|
||||
model.clone(),
|
||||
"Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
assert_eq!(
|
||||
stop_events(events),
|
||||
vec![StopReason::ToolUse, StopReason::EndTurn]
|
||||
);
|
||||
|
||||
// Test a tool calls that's likely to complete *after* streaming stops.
|
||||
let events = agent
|
||||
.update(cx, |agent, cx| {
|
||||
agent.remove_tool(&AgentTool::name(&EchoTool));
|
||||
agent.add_tool(DelayTool);
|
||||
agent.send(
|
||||
model.clone(),
|
||||
"Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
assert_eq!(
|
||||
stop_events(events),
|
||||
vec![StopReason::ToolUse, StopReason::EndTurn]
|
||||
);
|
||||
agent.update(cx, |agent, _cx| {
|
||||
assert!(agent
|
||||
.messages
|
||||
.last()
|
||||
.unwrap()
|
||||
.content
|
||||
.iter()
|
||||
.any(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
text.contains("Ding")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||
let AgentTest { model, agent, .. } = setup(cx).await;
|
||||
|
||||
// Test a tool call that's likely to complete *before* streaming stops.
|
||||
let mut events = agent.update(cx, |agent, cx| {
|
||||
agent.add_tool(WordListTool);
|
||||
agent.send(model.clone(), "Test the word_list tool.", cx)
|
||||
});
|
||||
|
||||
let mut saw_partial_tool_use = false;
|
||||
while let Some(event) = events.next().await {
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event {
|
||||
agent.update(cx, |agent, _cx| {
|
||||
// Look for a tool use in the agent's last message
|
||||
let last_content = agent.messages().last().unwrap().content.last().unwrap();
|
||||
if let MessageContent::ToolUse(last_tool_use) = last_content {
|
||||
assert_eq!(last_tool_use.name.as_ref(), "word_list");
|
||||
if tool_use_event.is_input_complete {
|
||||
last_tool_use
|
||||
.input
|
||||
.get("a")
|
||||
.expect("'a' has streamed because input is now complete");
|
||||
last_tool_use
|
||||
.input
|
||||
.get("g")
|
||||
.expect("'g' has streamed because input is now complete");
|
||||
} else {
|
||||
if !last_tool_use.is_input_complete
|
||||
&& last_tool_use.input.get("g").is_none()
|
||||
{
|
||||
saw_partial_tool_use = true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
panic!("last content should be a tool use");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
saw_partial_tool_use,
|
||||
"should see at least one partially streamed tool use in the history"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||
let AgentTest { model, agent, .. } = setup(cx).await;
|
||||
|
||||
// Test concurrent tool calls with different delay times
|
||||
let events = agent
|
||||
.update(cx, |agent, cx| {
|
||||
agent.add_tool(DelayTool);
|
||||
agent.send(
|
||||
model.clone(),
|
||||
"Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
let stop_reasons = stop_events(events);
|
||||
if stop_reasons.len() == 2 {
|
||||
assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]);
|
||||
} else if stop_reasons.len() == 3 {
|
||||
assert_eq!(
|
||||
stop_reasons,
|
||||
vec![
|
||||
StopReason::ToolUse,
|
||||
StopReason::ToolUse,
|
||||
StopReason::EndTurn
|
||||
]
|
||||
);
|
||||
} else {
|
||||
panic!("Expected either 1 or 2 tool uses followed by end turn");
|
||||
}
|
||||
|
||||
agent.update(cx, |agent, _cx| {
|
||||
let last_message = agent.messages.last().unwrap();
|
||||
let text = last_message
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<String>();
|
||||
|
||||
assert!(text.contains("Ding"));
|
||||
});
|
||||
}
|
||||
|
||||
/// Filters out the stop events for asserting against in tests
|
||||
fn stop_events(
|
||||
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
) -> Vec<StopReason> {
|
||||
result_events
|
||||
.into_iter()
|
||||
.filter_map(|event| match event.unwrap() {
|
||||
LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
struct AgentTest {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
agent: Entity<Thread>,
|
||||
}
|
||||
|
||||
async fn setup(cx: &mut TestAppContext) -> AgentTest {
|
||||
cx.executor().allow_parking();
|
||||
cx.update(settings::init);
|
||||
let fs = FakeFs::new(cx.executor().clone());
|
||||
// let project = Project::test(fs.clone(), [], cx).await;
|
||||
// let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let templates = Templates::new();
|
||||
let agent = cx.new(|_| Thread::new(templates));
|
||||
|
||||
let model = cx
|
||||
.update(|cx| {
|
||||
gpui_tokio::init(cx);
|
||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
client::init_settings(cx);
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
|
||||
let models = LanguageModelRegistry::read_global(cx);
|
||||
let model = models
|
||||
.available_models(cx)
|
||||
.find(|model| model.id().0 == "claude-3-7-sonnet-latest")
|
||||
.unwrap();
|
||||
|
||||
let provider = models.provider(&model.provider_id()).unwrap();
|
||||
let authenticated = provider.authenticate(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
authenticated.await.unwrap();
|
||||
model
|
||||
})
|
||||
})
|
||||
.await;
|
||||
|
||||
AgentTest { model, agent }
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[ctor::ctor]
|
||||
fn init_logger() {
|
||||
if std::env::var("RUST_LOG").is_ok() {
|
||||
env_logger::init();
|
||||
}
|
||||
}
|
||||
83
crates/agent2/src/thread/tests/test_tools.rs
Normal file
83
crates/agent2/src/thread/tests/test_tools.rs
Normal file
@@ -0,0 +1,83 @@
|
||||
use super::*;
|
||||
|
||||
/// A tool that echoes its input
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct EchoToolInput {
|
||||
/// The text to echo.
|
||||
text: String,
|
||||
}
|
||||
|
||||
pub struct EchoTool;
|
||||
|
||||
impl AgentTool for EchoTool {
|
||||
type Input = EchoToolInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"echo".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
|
||||
Task::ready(Ok(input.text))
|
||||
}
|
||||
}
|
||||
|
||||
/// A tool that waits for a specified delay
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct DelayToolInput {
|
||||
/// The delay in milliseconds.
|
||||
ms: u64,
|
||||
}
|
||||
|
||||
pub struct DelayTool;
|
||||
|
||||
impl AgentTool for DelayTool {
|
||||
type Input = DelayToolInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"delay".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
cx.foreground_executor().spawn(async move {
|
||||
smol::Timer::after(Duration::from_millis(input.ms)).await;
|
||||
Ok("Ding".to_string())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A tool that takes an object with map from letters to random words starting with that letter.
|
||||
/// All fiealds are required! Pass a word for every letter!
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct WordListInput {
|
||||
/// Provide a random word that starts with A.
|
||||
a: Option<String>,
|
||||
/// Provide a random word that starts with B.
|
||||
b: Option<String>,
|
||||
/// Provide a random word that starts with C.
|
||||
c: Option<String>,
|
||||
/// Provide a random word that starts with D.
|
||||
d: Option<String>,
|
||||
/// Provide a random word that starts with E.
|
||||
e: Option<String>,
|
||||
/// Provide a random word that starts with F.
|
||||
f: Option<String>,
|
||||
/// Provide a random word that starts with G.
|
||||
g: Option<String>,
|
||||
}
|
||||
|
||||
pub struct WordListTool;
|
||||
|
||||
impl AgentTool for WordListTool {
|
||||
type Input = WordListInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"word_list".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, _input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
|
||||
Task::ready(Ok("ok".to_string()))
|
||||
}
|
||||
}
|
||||
1
crates/agent2/src/tools.rs
Normal file
1
crates/agent2/src/tools.rs
Normal file
@@ -0,0 +1 @@
|
||||
mod glob;
|
||||
76
crates/agent2/src/tools/glob.rs
Normal file
76
crates/agent2/src/tools/glob.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{App, AppContext, Entity, SharedString, Task};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use util::paths::PathMatcher;
|
||||
use worktree::Snapshot as WorktreeSnapshot;
|
||||
|
||||
use crate::{
|
||||
templates::{GlobTemplate, Template, Templates},
|
||||
thread::AgentTool,
|
||||
};
|
||||
|
||||
// Description is dynamic, see `fn description` below
|
||||
#[derive(Deserialize, JsonSchema)]
|
||||
struct GlobInput {
|
||||
/// A POSIX glob pattern
|
||||
glob: SharedString,
|
||||
}
|
||||
|
||||
struct GlobTool {
|
||||
project: Entity<Project>,
|
||||
templates: Arc<Templates>,
|
||||
}
|
||||
|
||||
impl AgentTool for GlobTool {
|
||||
type Input = GlobInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"glob".into()
|
||||
}
|
||||
|
||||
fn description(&self, cx: &mut App) -> SharedString {
|
||||
let project_roots = self
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).root_name().into())
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
GlobTemplate { project_roots }
|
||||
.render(&self.templates)
|
||||
.expect("template failed to render")
|
||||
.into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>> {
|
||||
let path_matcher = match PathMatcher::new([&input.glob]) {
|
||||
Ok(matcher) => matcher,
|
||||
Err(error) => return Task::ready(Err(anyhow!(error))),
|
||||
};
|
||||
|
||||
let snapshots: Vec<WorktreeSnapshot> = self
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).snapshot())
|
||||
.collect();
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let paths = snapshots.iter().flat_map(|snapshot| {
|
||||
let root_name = PathBuf::from(snapshot.root_name());
|
||||
snapshot
|
||||
.entries(false, 0)
|
||||
.map(move |entry| root_name.join(&entry.path))
|
||||
.filter(|path| path_matcher.is_match(&path))
|
||||
});
|
||||
let output = paths
|
||||
.map(|path| format!("{}\n", path.display()))
|
||||
.collect::<String>();
|
||||
Ok(output)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -46,6 +46,7 @@ pub use find_path_tool::FindPathToolInput;
|
||||
pub use grep_tool::{GrepTool, GrepToolInput};
|
||||
pub use open_tool::OpenTool;
|
||||
pub use read_file_tool::{ReadFileTool, ReadFileToolInput};
|
||||
pub use schema::root_schema_for;
|
||||
pub use terminal_tool::TerminalTool;
|
||||
|
||||
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
|
||||
@@ -0,0 +1,328 @@
|
||||
use crate::commit::get_messages;
|
||||
use crate::{GitRemote, Oid};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::AsyncWriteExt;
|
||||
use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Stdio;
|
||||
use std::{ops::Range, path::Path};
|
||||
use text::Rope;
|
||||
use time::OffsetDateTime;
|
||||
use time::UtcOffset;
|
||||
use time::macros::format_description;
|
||||
|
||||
pub use git2 as libgit;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Blame {
|
||||
pub entries: Vec<BlameEntry>,
|
||||
pub messages: HashMap<Oid, String>,
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ParsedCommitMessage {
|
||||
pub message: SharedString,
|
||||
pub permalink: Option<url::Url>,
|
||||
pub pull_request: Option<crate::hosting_provider::PullRequest>,
|
||||
pub remote: Option<GitRemote>,
|
||||
}
|
||||
|
||||
impl Blame {
|
||||
pub async fn for_path(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
content: &Rope,
|
||||
remote_url: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let output = run_git_blame(git_binary, working_directory, path, content).await?;
|
||||
let mut entries = parse_git_blame(&output)?;
|
||||
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
|
||||
|
||||
let mut unique_shas = HashSet::default();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
unique_shas.insert(entry.sha);
|
||||
}
|
||||
|
||||
let shas = unique_shas.into_iter().collect::<Vec<_>>();
|
||||
let messages = get_messages(working_directory, &shas)
|
||||
.await
|
||||
.context("failed to get commit messages")?;
|
||||
|
||||
Ok(Self {
|
||||
entries,
|
||||
messages,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
|
||||
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BlameEntry {
|
||||
pub sha: Oid,
|
||||
|
||||
pub range: Range<u32>,
|
||||
|
||||
pub original_line_number: u32,
|
||||
|
||||
pub author: Option<String>,
|
||||
pub author_mail: Option<String>,
|
||||
pub author_time: Option<i64>,
|
||||
pub author_tz: Option<String>,
|
||||
|
||||
pub committer_name: Option<String>,
|
||||
pub committer_email: Option<String>,
|
||||
pub committer_time: Option<i64>,
|
||||
pub committer_tz: Option<String>,
|
||||
|
||||
pub summary: Option<String>,
|
||||
|
||||
pub previous: Option<String>,
|
||||
pub filename: String,
|
||||
}
|
||||
|
||||
impl BlameEntry {
|
||||
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
|
||||
// entry. The line MUST have this format:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
|
||||
let mut parts = line.split_whitespace();
|
||||
|
||||
let sha = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<Oid>().ok())
|
||||
.ok_or_else(|| anyhow!("failed to parse sha"))?;
|
||||
|
||||
let original_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
|
||||
let final_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let line_count = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let start_line = final_line_number.saturating_sub(1);
|
||||
let end_line = start_line + line_count;
|
||||
let range = start_line..end_line;
|
||||
|
||||
Ok(Self {
|
||||
sha,
|
||||
range,
|
||||
original_line_number,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
|
||||
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
|
||||
let format = format_description!("[offset_hour][offset_minute]");
|
||||
let offset = UtcOffset::parse(author_tz, &format)?;
|
||||
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
|
||||
|
||||
Ok(date_time_utc.to_offset(offset))
|
||||
} else {
|
||||
// Directly return current time in UTC if there's no committer time or timezone
|
||||
Ok(time::OffsetDateTime::now_utc())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parse_git_blame parses the output of `git blame --incremental`, which returns
|
||||
// all the blame-entries for a given path incrementally, as it finds them.
|
||||
//
|
||||
// Each entry *always* starts with:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
//
|
||||
// Each entry *always* ends with:
|
||||
//
|
||||
// filename <whitespace-quoted-filename-goes-here>
|
||||
//
|
||||
// Line numbers are 1-indexed.
|
||||
//
|
||||
// A `git blame --incremental` entry looks like this:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
|
||||
// author Joe Schmoe
|
||||
// author-mail <joe.schmoe@example.com>
|
||||
// author-time 1709741400
|
||||
// author-tz +0100
|
||||
// committer Joe Schmoe
|
||||
// committer-mail <joe.schmoe@example.com>
|
||||
// committer-time 1709741400
|
||||
// committer-tz +0100
|
||||
// summary Joe's cool commit
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// If the entry has the same SHA as an entry that was already printed then no
|
||||
// signature information is printed:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
|
||||
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
|
||||
let mut entries: Vec<BlameEntry> = Vec::new();
|
||||
let mut index: HashMap<Oid, usize> = HashMap::default();
|
||||
|
||||
let mut current_entry: Option<BlameEntry> = None;
|
||||
|
||||
for line in output.lines() {
|
||||
let mut done = false;
|
||||
|
||||
match &mut current_entry {
|
||||
None => {
|
||||
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
|
||||
|
||||
if let Some(existing_entry) = index
|
||||
.get(&new_entry.sha)
|
||||
.and_then(|slot| entries.get(*slot))
|
||||
{
|
||||
new_entry.author.clone_from(&existing_entry.author);
|
||||
new_entry
|
||||
.author_mail
|
||||
.clone_from(&existing_entry.author_mail);
|
||||
new_entry.author_time = existing_entry.author_time;
|
||||
new_entry.author_tz.clone_from(&existing_entry.author_tz);
|
||||
new_entry
|
||||
.committer_name
|
||||
.clone_from(&existing_entry.committer_name);
|
||||
new_entry
|
||||
.committer_email
|
||||
.clone_from(&existing_entry.committer_email);
|
||||
new_entry.committer_time = existing_entry.committer_time;
|
||||
new_entry
|
||||
.committer_tz
|
||||
.clone_from(&existing_entry.committer_tz);
|
||||
new_entry.summary.clone_from(&existing_entry.summary);
|
||||
}
|
||||
|
||||
current_entry.replace(new_entry);
|
||||
}
|
||||
Some(entry) => {
|
||||
let Some((key, value)) = line.split_once(' ') else {
|
||||
continue;
|
||||
};
|
||||
let is_committed = !entry.sha.is_zero();
|
||||
match key {
|
||||
"filename" => {
|
||||
entry.filename = value.into();
|
||||
done = true;
|
||||
}
|
||||
"previous" => entry.previous = Some(value.into()),
|
||||
|
||||
"summary" if is_committed => entry.summary = Some(value.into()),
|
||||
"author" if is_committed => entry.author = Some(value.into()),
|
||||
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
|
||||
"author-time" if is_committed => {
|
||||
entry.author_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
|
||||
|
||||
"committer" if is_committed => entry.committer_name = Some(value.into()),
|
||||
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
|
||||
"committer-time" if is_committed => {
|
||||
entry.committer_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if done {
|
||||
if let Some(entry) = current_entry.take() {
|
||||
index.insert(entry.sha, entries.len());
|
||||
|
||||
// We only want annotations that have a commit.
|
||||
if !entry.sha.is_zero() {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::BlameEntry;
|
||||
use super::parse_git_blame;
|
||||
|
||||
fn read_test_data(filename: &str) -> String {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push(filename);
|
||||
|
||||
std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
|
||||
}
|
||||
|
||||
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push("golden");
|
||||
path.push(format!("{}.json", golden_filename));
|
||||
|
||||
let mut have_json =
|
||||
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
|
||||
// We always want to save with a trailing newline.
|
||||
have_json.push('\n');
|
||||
|
||||
let update = std::env::var("UPDATE_GOLDEN")
|
||||
.map(|val| val.eq_ignore_ascii_case("true"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if update {
|
||||
std::fs::create_dir_all(path.parent().unwrap())
|
||||
.expect("could not create golden test data directory");
|
||||
std::fs::write(&path, have_json).expect("could not write out golden data");
|
||||
} else {
|
||||
let want_json =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| {
|
||||
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
|
||||
}).replace("\r\n", "\n");
|
||||
|
||||
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_not_committed() {
|
||||
let output = read_test_data("blame_incremental_not_committed");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_not_committed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_simple() {
|
||||
let output = read_test_data("blame_incremental_simple");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_simple");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_complex() {
|
||||
let output = read_test_data("blame_incremental_complex");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_complex");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,374 @@
|
||||
use crate::commit::get_messages;
|
||||
use crate::{GitRemote, Oid};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::AsyncWriteExt;
|
||||
use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Stdio;
|
||||
use std::{ops::Range, path::Path};
|
||||
use text::Rope;
|
||||
use time::OffsetDateTime;
|
||||
use time::UtcOffset;
|
||||
use time::macros::format_description;
|
||||
|
||||
pub use git2 as libgit;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Blame {
|
||||
pub entries: Vec<BlameEntry>,
|
||||
pub messages: HashMap<Oid, String>,
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ParsedCommitMessage {
|
||||
pub message: SharedString,
|
||||
pub permalink: Option<url::Url>,
|
||||
pub pull_request: Option<crate::hosting_provider::PullRequest>,
|
||||
pub remote: Option<GitRemote>,
|
||||
}
|
||||
|
||||
impl Blame {
|
||||
pub async fn for_path(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
content: &Rope,
|
||||
remote_url: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let output = run_git_blame(git_binary, working_directory, path, content).await?;
|
||||
let mut entries = parse_git_blame(&output)?;
|
||||
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
|
||||
|
||||
let mut unique_shas = HashSet::default();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
unique_shas.insert(entry.sha);
|
||||
}
|
||||
|
||||
let shas = unique_shas.into_iter().collect::<Vec<_>>();
|
||||
let messages = get_messages(working_directory, &shas)
|
||||
.await
|
||||
.context("failed to get commit messages")?;
|
||||
|
||||
Ok(Self {
|
||||
entries,
|
||||
messages,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
|
||||
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
|
||||
|
||||
async fn run_git_blame(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
contents: &Rope,
|
||||
) -> Result<String> {
|
||||
let mut child = util::command::new_smol_command(git_binary)
|
||||
.current_dir(working_directory)
|
||||
.arg("blame")
|
||||
.arg("--incremental")
|
||||
.arg("--contents")
|
||||
.arg("-")
|
||||
.arg(path.as_os_str())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.as_mut()
|
||||
.context("failed to get pipe to stdin of git blame command")?;
|
||||
|
||||
for chunk in contents.chunks() {
|
||||
stdin.write_all(chunk.as_bytes()).await?;
|
||||
}
|
||||
stdin.flush().await?;
|
||||
|
||||
let output = child
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let trimmed = stderr.trim();
|
||||
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
|
||||
return Ok(String::new());
|
||||
}
|
||||
return Err(anyhow!("git blame process failed: {}", stderr));
|
||||
}
|
||||
|
||||
Ok(String::from_utf8(output.stdout)?)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BlameEntry {
|
||||
pub sha: Oid,
|
||||
|
||||
pub range: Range<u32>,
|
||||
|
||||
pub original_line_number: u32,
|
||||
|
||||
pub author: Option<String>,
|
||||
pub author_mail: Option<String>,
|
||||
pub author_time: Option<i64>,
|
||||
pub author_tz: Option<String>,
|
||||
|
||||
pub committer_name: Option<String>,
|
||||
pub committer_email: Option<String>,
|
||||
pub committer_time: Option<i64>,
|
||||
pub committer_tz: Option<String>,
|
||||
|
||||
pub summary: Option<String>,
|
||||
|
||||
pub previous: Option<String>,
|
||||
pub filename: String,
|
||||
}
|
||||
|
||||
impl BlameEntry {
|
||||
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
|
||||
// entry. The line MUST have this format:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
|
||||
let mut parts = line.split_whitespace();
|
||||
|
||||
let sha = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<Oid>().ok())
|
||||
.ok_or_else(|| anyhow!("failed to parse sha"))?;
|
||||
|
||||
let original_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
|
||||
let final_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let line_count = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let start_line = final_line_number.saturating_sub(1);
|
||||
let end_line = start_line + line_count;
|
||||
let range = start_line..end_line;
|
||||
|
||||
Ok(Self {
|
||||
sha,
|
||||
range,
|
||||
original_line_number,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
|
||||
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
|
||||
let format = format_description!("[offset_hour][offset_minute]");
|
||||
let offset = UtcOffset::parse(author_tz, &format)?;
|
||||
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
|
||||
|
||||
Ok(date_time_utc.to_offset(offset))
|
||||
} else {
|
||||
// Directly return current time in UTC if there's no committer time or timezone
|
||||
Ok(time::OffsetDateTime::now_utc())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parse_git_blame parses the output of `git blame --incremental`, which returns
|
||||
// all the blame-entries for a given path incrementally, as it finds them.
|
||||
//
|
||||
// Each entry *always* starts with:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
//
|
||||
// Each entry *always* ends with:
|
||||
//
|
||||
// filename <whitespace-quoted-filename-goes-here>
|
||||
//
|
||||
// Line numbers are 1-indexed.
|
||||
//
|
||||
// A `git blame --incremental` entry looks like this:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
|
||||
// author Joe Schmoe
|
||||
// author-mail <joe.schmoe@example.com>
|
||||
// author-time 1709741400
|
||||
// author-tz +0100
|
||||
// committer Joe Schmoe
|
||||
// committer-mail <joe.schmoe@example.com>
|
||||
// committer-time 1709741400
|
||||
// committer-tz +0100
|
||||
// summary Joe's cool commit
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// If the entry has the same SHA as an entry that was already printed then no
|
||||
// signature information is printed:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
|
||||
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
|
||||
let mut entries: Vec<BlameEntry> = Vec::new();
|
||||
let mut index: HashMap<Oid, usize> = HashMap::default();
|
||||
|
||||
let mut current_entry: Option<BlameEntry> = None;
|
||||
|
||||
for line in output.lines() {
|
||||
let mut done = false;
|
||||
|
||||
match &mut current_entry {
|
||||
None => {
|
||||
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
|
||||
|
||||
if let Some(existing_entry) = index
|
||||
.get(&new_entry.sha)
|
||||
.and_then(|slot| entries.get(*slot))
|
||||
{
|
||||
new_entry.author.clone_from(&existing_entry.author);
|
||||
new_entry
|
||||
.author_mail
|
||||
.clone_from(&existing_entry.author_mail);
|
||||
new_entry.author_time = existing_entry.author_time;
|
||||
new_entry.author_tz.clone_from(&existing_entry.author_tz);
|
||||
new_entry
|
||||
.committer_name
|
||||
.clone_from(&existing_entry.committer_name);
|
||||
new_entry
|
||||
.committer_email
|
||||
.clone_from(&existing_entry.committer_email);
|
||||
new_entry.committer_time = existing_entry.committer_time;
|
||||
new_entry
|
||||
.committer_tz
|
||||
.clone_from(&existing_entry.committer_tz);
|
||||
new_entry.summary.clone_from(&existing_entry.summary);
|
||||
}
|
||||
|
||||
current_entry.replace(new_entry);
|
||||
}
|
||||
Some(entry) => {
|
||||
let Some((key, value)) = line.split_once(' ') else {
|
||||
continue;
|
||||
};
|
||||
let is_committed = !entry.sha.is_zero();
|
||||
match key {
|
||||
"filename" => {
|
||||
entry.filename = value.into();
|
||||
done = true;
|
||||
}
|
||||
"previous" => entry.previous = Some(value.into()),
|
||||
|
||||
"summary" if is_committed => entry.summary = Some(value.into()),
|
||||
"author" if is_committed => entry.author = Some(value.into()),
|
||||
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
|
||||
"author-time" if is_committed => {
|
||||
entry.author_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
|
||||
|
||||
"committer" if is_committed => entry.committer_name = Some(value.into()),
|
||||
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
|
||||
"committer-time" if is_committed => {
|
||||
entry.committer_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if done {
|
||||
if let Some(entry) = current_entry.take() {
|
||||
index.insert(entry.sha, entries.len());
|
||||
|
||||
// We only want annotations that have a commit.
|
||||
if !entry.sha.is_zero() {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::BlameEntry;
|
||||
use super::parse_git_blame;
|
||||
|
||||
fn read_test_data(filename: &str) -> String {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push(filename);
|
||||
|
||||
std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
|
||||
}
|
||||
|
||||
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push("golden");
|
||||
path.push(format!("{}.json", golden_filename));
|
||||
|
||||
let mut have_json =
|
||||
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
|
||||
// We always want to save with a trailing newline.
|
||||
have_json.push('\n');
|
||||
|
||||
let update = std::env::var("UPDATE_GOLDEN")
|
||||
.map(|val| val.eq_ignore_ascii_case("true"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if update {
|
||||
std::fs::create_dir_all(path.parent().unwrap())
|
||||
.expect("could not create golden test data directory");
|
||||
std::fs::write(&path, have_json).expect("could not write out golden data");
|
||||
} else {
|
||||
let want_json =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| {
|
||||
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
|
||||
}).replace("\r\n", "\n");
|
||||
|
||||
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_not_committed() {
|
||||
let output = read_test_data("blame_incremental_not_committed");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_not_committed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_simple() {
|
||||
let output = read_test_data("blame_incremental_simple");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_simple");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_complex() {
|
||||
let output = read_test_data("blame_incremental_complex");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_complex");
|
||||
}
|
||||
}
|
||||
@@ -314,7 +314,7 @@ impl Tool for GrepTool {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assistant_tool::Tool;
|
||||
use gpui::{AppContext, TestAppContext, UpdateGlobal};
|
||||
use gpui::{TestAppContext, UpdateGlobal};
|
||||
use language::{Language, LanguageConfig, LanguageMatcher};
|
||||
use language_model::fake_provider::FakeLanguageModel;
|
||||
use project::{FakeFs, Project, WorktreeSettings};
|
||||
|
||||
@@ -226,7 +226,7 @@ impl Tool for ListDirectoryTool {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assistant_tool::Tool;
|
||||
use gpui::{AppContext, TestAppContext, UpdateGlobal};
|
||||
use gpui::{TestAppContext, UpdateGlobal};
|
||||
use indoc::indoc;
|
||||
use language_model::fake_provider::FakeLanguageModel;
|
||||
use project::{FakeFs, Project, WorktreeSettings};
|
||||
|
||||
@@ -289,7 +289,7 @@ impl Tool for ReadFileTool {
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use gpui::{AppContext, TestAppContext, UpdateGlobal};
|
||||
use gpui::{TestAppContext, UpdateGlobal};
|
||||
use language::{Language, LanguageConfig, LanguageMatcher};
|
||||
use language_model::fake_provider::FakeLanguageModel;
|
||||
use project::{FakeFs, Project, WorktreeSettings};
|
||||
|
||||
@@ -22,7 +22,7 @@ fn schema_to_json(
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
|
||||
pub fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
|
||||
let mut generator = match format {
|
||||
LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(),
|
||||
// TODO: Gemini docs mention using a subset of OpenAPI 3, so this may benefit from using
|
||||
|
||||
47
crates/assistant_tools/src/templates/edit_agent.hbs
Normal file
47
crates/assistant_tools/src/templates/edit_agent.hbs
Normal file
@@ -0,0 +1,47 @@
|
||||
You are an expert text editor. Taking the following file as an input:
|
||||
|
||||
```{{path}}
|
||||
{{file_content}}
|
||||
```
|
||||
|
||||
Produce a series of edits following the given user instructions:
|
||||
|
||||
<user_instructions>
|
||||
{{instructions}}
|
||||
</user_instructions>
|
||||
|
||||
Your response must be a series of edits in the following format:
|
||||
|
||||
<edits>
|
||||
<old_text>
|
||||
OLD TEXT 1 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 1 HERE
|
||||
</new_text>
|
||||
|
||||
<old_text>
|
||||
OLD TEXT 2 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 2 HERE
|
||||
</new_text>
|
||||
|
||||
<old_text>
|
||||
OLD TEXT 3 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 3 HERE
|
||||
</new_text>
|
||||
</edits>
|
||||
|
||||
Rules for editing:
|
||||
|
||||
- `old_text` represents full lines (including indentation) in the input file that will be replaced with `new_text`
|
||||
- It is crucial that `old_text` is unique and unambiguous.
|
||||
- Always include enough context around the lines you want to replace in `old_text` such that it's impossible to mistake them for other lines.
|
||||
- If you want to replace all occurrences, repeat the same `old_text`/`new_text` pair multiple times and I will apply them sequentially, one occurrence at a time.
|
||||
- Don't explain why you made a change, just report the edits.
|
||||
- Make sure you follow the instructions carefully and thoroughly, avoid doing *less* or *more* than instructed.
|
||||
|
||||
<edits>
|
||||
@@ -6,10 +6,7 @@ use debugger_ui::debugger_panel::DebugPanel;
|
||||
use extension::ExtensionHostProxy;
|
||||
use fs::{FakeFs, Fs as _, RemoveOptions};
|
||||
use futures::StreamExt as _;
|
||||
use gpui::{
|
||||
AppContext as _, BackgroundExecutor, SemanticVersion, TestAppContext, UpdateGlobal as _,
|
||||
VisualContext,
|
||||
};
|
||||
use gpui::{BackgroundExecutor, SemanticVersion, TestAppContext, UpdateGlobal as _, VisualContext};
|
||||
use http_client::BlockedHttpClient;
|
||||
use language::{
|
||||
FakeLspAdapter, Language, LanguageConfig, LanguageMatcher, LanguageRegistry,
|
||||
|
||||
@@ -1307,7 +1307,7 @@ pub mod tests {
|
||||
use crate::scroll::ScrollAmount;
|
||||
use crate::{ExcerptRange, scroll::Autoscroll, test::editor_lsp_test_context::rust_lang};
|
||||
use futures::StreamExt;
|
||||
use gpui::{AppContext as _, Context, SemanticVersion, TestAppContext, WindowHandle};
|
||||
use gpui::{Context, SemanticVersion, TestAppContext, WindowHandle};
|
||||
use itertools::Itertools as _;
|
||||
use language::{Capability, FakeLspAdapter, language_settings::AllLanguageSettingsContent};
|
||||
use language::{Language, LanguageConfig, LanguageMatcher};
|
||||
|
||||
@@ -626,7 +626,7 @@ mod jsx_tag_autoclose_tests {
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use gpui::{AppContext as _, TestAppContext};
|
||||
use gpui::TestAppContext;
|
||||
use language::language_settings::JsxTagAutoCloseSettings;
|
||||
use languages::language;
|
||||
use multi_buffer::ExcerptRange;
|
||||
|
||||
3
crates/eval/src/examples/edit.rs
Normal file
3
crates/eval/src/examples/edit.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod basic;
|
||||
|
||||
pub use basic::*;
|
||||
104
crates/eval/src/examples/edit/basic.rs
Normal file
104
crates/eval/src/examples/edit/basic.rs
Normal file
@@ -0,0 +1,104 @@
|
||||
use std::{collections::HashSet, path::Path, sync::Arc};
|
||||
|
||||
use anyhow::Result;
|
||||
use assistant_tools::{CreateFileToolInput, EditFileToolInput, ReadFileToolInput};
|
||||
use async_trait::async_trait;
|
||||
use buffer_diff::DiffHunkStatus;
|
||||
use collections::HashMap;
|
||||
|
||||
use crate::example::{
|
||||
Example, ExampleContext, ExampleMetadata, FileEditHunk, FileEdits, JudgeAssertion,
|
||||
LanguageServer,
|
||||
};
|
||||
|
||||
pub struct EditBasic;
|
||||
|
||||
#[async_trait(?Send)]
|
||||
impl Example for EditBasic {
|
||||
fn meta(&self) -> ExampleMetadata {
|
||||
ExampleMetadata {
|
||||
name: "edit_basic".to_string(),
|
||||
url: "https://github.com/zed-industries/zed.git".to_string(),
|
||||
revision: "58604fba86ebbffaa01f7c6834253e33bcd38c0f".to_string(),
|
||||
language_server: None,
|
||||
max_assertions: None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
|
||||
cx.push_user_message(format!(
|
||||
r#"
|
||||
Read the `crates/git/src/blame.rs` file and delete `run_git_blame`. Just that
|
||||
one function, not its usages.
|
||||
|
||||
IMPORTANT: You are only allowed to use the `read_file` and `edit_file` tools!
|
||||
"#
|
||||
));
|
||||
|
||||
let response = cx.run_to_end().await?;
|
||||
// let expected_edits = HashMap::from_iter([(
|
||||
// Arc::from(Path::new("crates/git/src/blame.rs")),
|
||||
// FileEdits {
|
||||
// hunks: vec![
|
||||
// FileEditHunk {
|
||||
// base_text: " unique_shas.insert(entry.sha);\n".into(),
|
||||
// text: " unique_shas.insert(entry.git_sha);\n".into(),
|
||||
// status: DiffHunkStatus::modified_none(),
|
||||
// },
|
||||
// FileEditHunk {
|
||||
// base_text: " pub sha: Oid,\n".into(),
|
||||
// text: " pub git_sha: Oid,\n".into(),
|
||||
// status: DiffHunkStatus::modified_none(),
|
||||
// },
|
||||
// FileEditHunk {
|
||||
// base_text: " let sha = parts\n".into(),
|
||||
// text: " let git_sha = parts\n".into(),
|
||||
// status: DiffHunkStatus::modified_none(),
|
||||
// },
|
||||
// FileEditHunk {
|
||||
// base_text:
|
||||
// " .ok_or_else(|| anyhow!(\"failed to parse sha\"))?;\n"
|
||||
// .into(),
|
||||
// text:
|
||||
// " .ok_or_else(|| anyhow!(\"failed to parse git_sha\"))?;\n"
|
||||
// .into(),
|
||||
// status: DiffHunkStatus::modified_none(),
|
||||
// },
|
||||
// FileEditHunk {
|
||||
// base_text: " sha,\n".into(),
|
||||
// text: " git_sha,\n".into(),
|
||||
// status: DiffHunkStatus::modified_none(),
|
||||
// },
|
||||
// FileEditHunk {
|
||||
// base_text: " .get(&new_entry.sha)\n".into(),
|
||||
// text: " .get(&new_entry.git_sha)\n".into(),
|
||||
// status: DiffHunkStatus::modified_none(),
|
||||
// },
|
||||
// FileEditHunk {
|
||||
// base_text: " let is_committed = !entry.sha.is_zero();\n"
|
||||
// .into(),
|
||||
// text: " let is_committed = !entry.git_sha.is_zero();\n"
|
||||
// .into(),
|
||||
// status: DiffHunkStatus::modified_none(),
|
||||
// },
|
||||
// FileEditHunk {
|
||||
// base_text: " index.insert(entry.sha, entries.len());\n"
|
||||
// .into(),
|
||||
// text: " index.insert(entry.git_sha, entries.len());\n"
|
||||
// .into(),
|
||||
// status: DiffHunkStatus::modified_none(),
|
||||
// },
|
||||
// FileEditHunk {
|
||||
// base_text: " if !entry.sha.is_zero() {\n".into(),
|
||||
// text: " if !entry.git_sha.is_zero() {\n".into(),
|
||||
// status: DiffHunkStatus::modified_none(),
|
||||
// },
|
||||
// ],
|
||||
// },
|
||||
// )]);
|
||||
// let actual_edits = cx.edits();
|
||||
// cx.assert_eq(&actual_edits, &expected_edits, "edits don't match")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,7 @@ use collections::{BTreeMap, HashSet};
|
||||
use extension::ExtensionHostProxy;
|
||||
use fs::{FakeFs, Fs, RealFs};
|
||||
use futures::{AsyncReadExt, StreamExt, io::BufReader};
|
||||
use gpui::{AppContext as _, SemanticVersion, TestAppContext};
|
||||
use gpui::{SemanticVersion, TestAppContext};
|
||||
use http_client::{FakeHttpClient, Response};
|
||||
use language::{BinaryStatus, LanguageMatcher, LanguageRegistry};
|
||||
use lsp::LanguageServerName;
|
||||
|
||||
@@ -1069,7 +1069,7 @@ impl App {
|
||||
}
|
||||
}
|
||||
|
||||
/// Obtains a reference to the executor, which can be used to spawn futures.
|
||||
/// Obtains a reference to the background executor, which can be used to spawn futures.
|
||||
pub fn background_executor(&self) -> &BackgroundExecutor {
|
||||
&self.background_executor
|
||||
}
|
||||
|
||||
@@ -178,7 +178,14 @@ impl TestAppContext {
|
||||
&self.foreground_executor
|
||||
}
|
||||
|
||||
fn new<T: 'static>(&mut self, build_entity: impl FnOnce(&mut Context<T>) -> T) -> Entity<T> {
|
||||
/// Builds an entity that is owned by the application.
|
||||
///
|
||||
/// The given function will be invoked with a [`Context`] and must return an object representing the entity. An
|
||||
/// [`Entity`] handle will be returned, which can be used to access the entity in a context.
|
||||
pub fn new<T: 'static>(
|
||||
&mut self,
|
||||
build_entity: impl FnOnce(&mut Context<T>) -> T,
|
||||
) -> Entity<T> {
|
||||
let mut cx = self.app.borrow_mut();
|
||||
cx.new(build_entity)
|
||||
}
|
||||
|
||||
@@ -95,6 +95,13 @@ where
|
||||
.spawn(self.log_tracked_err(*location))
|
||||
.detach();
|
||||
}
|
||||
|
||||
/// Convert a Task<Result<T, E>> to a Task<()> that logs all errors.
|
||||
pub fn log_err_in_task(self, cx: &App) -> Task<Option<T>> {
|
||||
let location = core::panic::Location::caller();
|
||||
cx.foreground_executor()
|
||||
.spawn(async move { self.log_tracked_err(*location).await })
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Future for Task<T> {
|
||||
|
||||
@@ -91,6 +91,7 @@ pub enum LanguageModelCompletionEvent {
|
||||
},
|
||||
StartMessage {
|
||||
message_id: String,
|
||||
role: Role,
|
||||
},
|
||||
UsageUpdate(TokenUsage),
|
||||
}
|
||||
@@ -500,7 +501,7 @@ pub trait LanguageModel: Send + Sync {
|
||||
|
||||
if let Some(first_event) = events.next().await {
|
||||
match first_event {
|
||||
Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
|
||||
Ok(LanguageModelCompletionEvent::StartMessage { message_id: id, .. }) => {
|
||||
message_id = Some(id.clone());
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Text(text)) => {
|
||||
|
||||
@@ -12,7 +12,7 @@ use gpui::{
|
||||
use image::codecs::png::PngEncoder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::ResultExt;
|
||||
use zed_llm_client::{CompletionIntent, CompletionMode};
|
||||
pub use zed_llm_client::{CompletionIntent, CompletionMode};
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||
pub struct LanguageModelImage {
|
||||
@@ -344,6 +344,24 @@ impl From<&str> for MessageContent {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LanguageModelToolUse> for MessageContent {
|
||||
fn from(value: LanguageModelToolUse) -> Self {
|
||||
MessageContent::ToolUse(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LanguageModelImage> for MessageContent {
|
||||
fn from(value: LanguageModelImage) -> Self {
|
||||
MessageContent::Image(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LanguageModelToolResult> for MessageContent {
|
||||
fn from(value: LanguageModelToolResult) -> Self {
|
||||
MessageContent::ToolResult(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
|
||||
pub struct LanguageModelRequestMessage {
|
||||
pub role: Role,
|
||||
|
||||
@@ -36,6 +36,29 @@ impl Role {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<anthropic::Role> for Role {
|
||||
fn from(role: anthropic::Role) -> Self {
|
||||
match role {
|
||||
anthropic::Role::User => Role::User,
|
||||
anthropic::Role::Assistant => Role::Assistant,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Role> for anthropic::Role {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(role: Role) -> Result<Self, Self::Error> {
|
||||
match role {
|
||||
Role::User => Ok(anthropic::Role::User),
|
||||
Role::Assistant => Ok(anthropic::Role::Assistant),
|
||||
Role::System => Err(anyhow::anyhow!(
|
||||
"System role is not supported in anthropic API"
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Role {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
|
||||
@@ -833,6 +833,7 @@ impl AnthropicEventMapper {
|
||||
))),
|
||||
Ok(LanguageModelCompletionEvent::StartMessage {
|
||||
message_id: message.id,
|
||||
role: message.role.into(),
|
||||
}),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::lsp_log::LogMenuItem;
|
||||
|
||||
use super::*;
|
||||
use futures::StreamExt;
|
||||
use gpui::{AppContext as _, SemanticVersion, TestAppContext, VisualTestContext};
|
||||
use gpui::{SemanticVersion, TestAppContext, VisualTestContext};
|
||||
use language::{FakeLspAdapter, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
|
||||
use lsp::LanguageServerName;
|
||||
use lsp_log::LogKind;
|
||||
|
||||
@@ -18,7 +18,7 @@ pub(super) fn bash_task_context() -> ContextProviderWithTasks {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use gpui::{AppContext as _, BorrowAppContext, Context, TestAppContext};
|
||||
use gpui::{BorrowAppContext, Context, TestAppContext};
|
||||
use language::{AutoindentMode, Buffer, language_settings::AllLanguageSettings};
|
||||
use settings::SettingsStore;
|
||||
use std::num::NonZeroU32;
|
||||
|
||||
@@ -347,7 +347,7 @@ async fn get_cached_server_binary(container_dir: PathBuf) -> Option<LanguageServ
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use gpui::{AppContext as _, BorrowAppContext, TestAppContext};
|
||||
use gpui::{BorrowAppContext, TestAppContext};
|
||||
use language::{AutoindentMode, Buffer, language_settings::AllLanguageSettings};
|
||||
use settings::SettingsStore;
|
||||
use std::num::NonZeroU32;
|
||||
|
||||
@@ -167,7 +167,7 @@ async fn get_cached_server_binary(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use gpui::{AppContext as _, TestAppContext};
|
||||
use gpui::TestAppContext;
|
||||
use unindent::Unindent;
|
||||
|
||||
#[gpui::test]
|
||||
|
||||
@@ -1286,7 +1286,7 @@ impl LspAdapter for PyLspAdapter {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use gpui::{AppContext as _, BorrowAppContext, Context, TestAppContext};
|
||||
use gpui::{BorrowAppContext, Context, TestAppContext};
|
||||
use language::{AutoindentMode, Buffer, language_settings::AllLanguageSettings};
|
||||
use settings::SettingsStore;
|
||||
use std::num::NonZeroU32;
|
||||
|
||||
@@ -1010,7 +1010,7 @@ async fn handle_symlink(src_dir: PathBuf, dest_dir: PathBuf) -> Result<()> {
|
||||
mod tests {
|
||||
use std::path::Path;
|
||||
|
||||
use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
|
||||
use gpui::{BackgroundExecutor, TestAppContext};
|
||||
use language::language_settings;
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
|
||||
@@ -593,7 +593,7 @@ mod tests {
|
||||
project_settings::ProjectSettings,
|
||||
};
|
||||
use context_server::test::create_fake_transport;
|
||||
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
|
||||
use gpui::{TestAppContext, UpdateGlobal as _};
|
||||
use serde_json::json;
|
||||
use std::{cell::RefCell, rc::Rc};
|
||||
use util::path;
|
||||
|
||||
Reference in New Issue
Block a user