Compare commits

...

32 Commits

Author SHA1 Message Date
Richard Feldman
d97b1fcd7b wip
Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Agus Zubiaga <hi@aguz.me>
2025-04-16 15:22:19 -04:00
Richard Feldman
1ca4a011c6 Provide a window handle to Tool::run
Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Agus Zubiaga <hi@aguz.me>
2025-04-16 14:56:52 -04:00
Bennet Bo Fenner
fa4abaf56e Remove unused dependency 2025-04-16 19:02:14 +02:00
Bennet Bo Fenner
f9c729a7b1 Remove OpenAI search provider for now 2025-04-16 18:48:41 +02:00
Bennet Bo Fenner
3c8b7caa2e Only register tool when feature flag is set 2025-04-16 18:45:00 +02:00
Bennet Bo Fenner
4fc96e9453 Wire up web search cloud provider
Co-authored-by: Marshall Bowers <git@maxdeviant.com>
2025-04-16 18:12:55 +02:00
Bennet Bo Fenner
3f7abfbfe4 Merge branch 'main' into websearch-tool 2025-04-16 17:31:48 +02:00
Bennet Bo Fenner
79cf3f5e93 cleanup 2025-04-16 16:28:11 +02:00
Bennet Bo Fenner
d27eec8c58 format 2025-04-16 16:28:05 +02:00
Bennet Bo Fenner
3467b80595 Remove unused dependencies 2025-04-16 07:50:43 +02:00
Bennet Bo Fenner
d6196d72c1 cleanup 2025-04-15 20:03:48 +02:00
Bennet Bo Fenner
303036e333 More 2025-04-15 17:14:27 +02:00
Bennet Bo Fenner
d0634bbf2b Fix merge 2025-04-15 17:13:57 +02:00
Bennet Bo Fenner
714a60e9e9 Merge branch 'main' into websearch-tool 2025-04-15 17:13:05 +02:00
Bennet Bo Fenner
b09eb4b683 Cleanup 2025-04-15 16:26:56 +02:00
Bennet Bo Fenner
8afca164cf Use symlinks for LICENSES 2025-04-15 15:36:58 +02:00
Bennet Bo Fenner
bf2284019a Use symlinks for LICENSES 2025-04-15 15:35:57 +02:00
Bennet Bo Fenner
706be9bc06 Fix merge 2025-04-15 15:33:33 +02:00
Bennet Bo Fenner
b65aeedfef Merge branch 'main' into websearch-tool 2025-04-15 15:33:24 +02:00
Danilo Leal
01ccb732be Adjust the UI 2025-04-14 18:25:05 -03:00
Bennet Bo Fenner
3313348769 WIP UI 2025-04-11 20:18:18 -06:00
Bennet Bo Fenner
bf052b56a4 Complete API
Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-04-11 19:36:40 -06:00
Bennet Bo Fenner
0b17b59305 WIP custom UI
Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-04-11 19:24:06 -06:00
Bennet Bo Fenner
e63a14721e Merge remote-tracking branch 'origin/custom-tool-cards' into websearch-tool 2025-04-11 18:43:54 -06:00
Bennet Bo Fenner
931891414c Do not require confirmation 2025-04-11 16:13:48 -06:00
Bennet Bo Fenner
9292ab94b6 Get it working 2025-04-11 16:11:20 -06:00
Bennet Bo Fenner
5fb549699b WIP 2025-04-11 14:56:19 -06:00
Bennet Bo Fenner
88ea23d7da WIP 2025-04-11 13:15:30 -06:00
Bennet Bo Fenner
df509bbc20 add web search tool 2025-04-11 12:25:44 -06:00
Antonio Scandurra
38fcadf948 Merge remote-tracking branch 'origin/main' into custom-tool-cards 2025-04-09 16:42:30 -06:00
Antonio Scandurra
e5cbac1373 Checkpoint 2025-04-09 15:21:36 -06:00
Antonio Scandurra
53375434cf Lay the groundwork to support rendering custom tool cards 2025-04-09 08:17:35 -06:00
47 changed files with 1116 additions and 112 deletions

42
Cargo.lock generated
View File

@@ -702,8 +702,11 @@ version = "0.1.0"
dependencies = [
"anyhow",
"assistant_tool",
"buffer_diff",
"chrono",
"collections",
"editor",
"feature_flags",
"futures 0.3.31",
"gpui",
"html_to_markdown",
@@ -711,6 +714,7 @@ dependencies = [
"itertools 0.14.0",
"language",
"language_model",
"multi_buffer",
"open",
"project",
"rand 0.8.5",
@@ -721,9 +725,11 @@ dependencies = [
"ui",
"unindent",
"util",
"web_search",
"workspace",
"workspace-hack",
"worktree",
"zed_llm_client",
]
[[package]]
@@ -16609,6 +16615,36 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "web_search"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"gpui",
"serde",
"workspace-hack",
"zed_llm_client",
]
[[package]]
name = "web_search_providers"
version = "0.1.0"
dependencies = [
"anyhow",
"client",
"feature_flags",
"futures 0.3.31",
"gpui",
"http_client",
"language_model",
"serde",
"serde_json",
"web_search",
"workspace-hack",
"zed_llm_client",
]
[[package]]
name = "webpki-root-certs"
version = "0.26.8"
@@ -18287,6 +18323,8 @@ dependencies = [
"uuid",
"vim",
"vim_mode_setting",
"web_search",
"web_search_providers",
"welcome",
"windows 0.61.1",
"winresource",
@@ -18351,9 +18389,9 @@ dependencies = [
[[package]]
name = "zed_llm_client"
version = "0.4.2"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d28a5d6bdb0f40acf5261c39cabbf65a13b55ba4b86d9beb5b8b1c484373f1a"
checksum = "57a5e1b5b3ace3fb55292a4c14036723bb8a01fac4aeaa3c2b63b51228412f94"
dependencies = [
"serde",
"serde_json",

View File

@@ -165,6 +165,8 @@ members = [
"crates/util_macros",
"crates/vim",
"crates/vim_mode_setting",
"crates/web_search",
"crates/web_search_providers",
"crates/welcome",
"crates/workspace",
"crates/worktree",
@@ -370,6 +372,8 @@ util = { path = "crates/util" }
util_macros = { path = "crates/util_macros" }
vim = { path = "crates/vim" }
vim_mode_setting = { path = "crates/vim_mode_setting" }
web_search = { path = "crates/web_search" }
web_search_providers = { path = "crates/web_search_providers" }
welcome = { path = "crates/welcome" }
workspace = { path = "crates/workspace" }
worktree = { path = "crates/worktree" }
@@ -601,7 +605,7 @@ wasmtime-wasi = "29"
which = "6.0.0"
wit-component = "0.221"
workspace-hack = "0.1.0"
zed_llm_client = "0.4.2"
zed_llm_client = "0.5.0"
zstd = "0.11"
metal = "0.29"

View File

@@ -652,7 +652,8 @@
"path_search": true,
"read_file": true,
"regex_search": true,
"thinking": true
"thinking": true,
"web_search": true
}
},
"write": {
@@ -678,7 +679,8 @@
"regex_search": true,
"rename": true,
"symbol_info": true,
"thinking": true
"thinking": true,
"web_search": true
}
}
},

View File

@@ -5,11 +5,12 @@ use crate::thread::{
ThreadEvent, ThreadFeedback,
};
use crate::thread_store::{RulesLoadingError, ThreadStore};
use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus};
use crate::tool_use::{PendingToolUseStatus, ToolUse};
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
use crate::{AssistantPanel, OpenActiveThreadAsMarkdown};
use anyhow::Context as _;
use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
use assistant_tool::ToolUseStatus;
use collections::{HashMap, HashSet};
use editor::scroll::Autoscroll;
use editor::{Editor, EditorElement, EditorStyle, MultiBuffer};
@@ -766,10 +767,11 @@ impl ActiveThread {
self.thread.read(cx).summary_or_default()
}
pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
pub fn cancel_last_completion(&mut self, window: &mut Window, cx: &mut App) -> bool {
self.last_error.take();
self.thread
.update(cx, |thread, cx| thread.cancel_last_completion(cx))
self.thread.update(cx, |thread, cx| {
thread.cancel_last_completion(Some(window.window_handle()), cx)
})
}
pub fn last_error(&self) -> Option<ThreadError> {
@@ -943,8 +945,8 @@ impl ActiveThread {
&tool_use.input,
self.thread
.read(cx)
.tool_result(&tool_use.id)
.map(|result| result.content.clone().into())
.output_for_tool(&tool_use.id)
.map(|output| output.clone().into())
.unwrap_or("".into()),
cx,
);
@@ -1142,7 +1144,7 @@ impl ActiveThread {
fn confirm_editing_message(
&mut self,
_: &menu::Confirm,
_: &mut Window,
window: &mut Window,
cx: &mut Context<Self>,
) {
let Some((message_id, state)) = self.editing_message.take() else {
@@ -1171,7 +1173,12 @@ impl ActiveThread {
}
self.thread.update(cx, |thread, cx| {
thread.send_to_model(model.model, RequestKind::Chat, cx)
thread.send_to_model(
model.model,
RequestKind::Chat,
Some(window.window_handle()),
cx,
)
});
cx.notify();
}
@@ -2279,12 +2286,15 @@ impl ActiveThread {
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement + use<> {
if let Some(card) = self.thread.read(cx).card_for_tool(&tool_use.id) {
return card.render(&tool_use.status, window, cx);
}
let is_open = self
.expanded_tool_uses
.get(&tool_use.id)
.copied()
.unwrap_or_default();
let is_status_finished = matches!(&tool_use.status, ToolUseStatus::Finished(_));
let fs = self
@@ -2375,6 +2385,7 @@ impl ActiveThread {
open_markdown_link(text, workspace.clone(), window, cx);
}
})
.into_any_element()
}),
)),
),
@@ -2431,6 +2442,7 @@ impl ActiveThread {
open_markdown_link(text, workspace.clone(), window, cx);
}
})
.into_any_element()
})),
),
),
@@ -2761,7 +2773,7 @@ impl ActiveThread {
)
})
}
})
}).into_any_element()
}
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
@@ -2825,7 +2837,7 @@ impl ActiveThread {
&mut self,
tool_use_id: LanguageModelToolUseId,
_: &ClickEvent,
_window: &mut Window,
window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some(PendingToolUseStatus::NeedsConfirmation(c)) = self
@@ -2841,6 +2853,7 @@ impl ActiveThread {
c.input.clone(),
&c.messages,
c.tool.clone(),
Some(window.window_handle()),
cx,
);
});
@@ -2852,11 +2865,12 @@ impl ActiveThread {
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
_: &ClickEvent,
_window: &mut Window,
window: &mut Window,
cx: &mut Context<Self>,
) {
let window_handle = window.window_handle();
self.thread.update(cx, |thread, cx| {
thread.deny_tool_use(tool_use_id, tool_name, cx);
thread.deny_tool_use(tool_use_id, tool_name, Some(window_handle), cx);
});
}

View File

@@ -337,14 +337,9 @@ impl AssistantPanel {
&self.thread_store
}
fn cancel(
&mut self,
_: &editor::actions::Cancel,
_window: &mut Window,
cx: &mut Context<Self>,
) {
fn cancel(&mut self, _: &editor::actions::Cancel, window: &mut Window, cx: &mut Context<Self>) {
self.thread
.update(cx, |thread, cx| thread.cancel_last_completion(cx));
.update(cx, |thread, cx| thread.cancel_last_completion(window, cx));
}
fn new_thread(&mut self, action: &NewThread, window: &mut Window, cx: &mut Context<Self>) {

View File

@@ -263,6 +263,7 @@ impl MessageEditor {
let context_store = self.context_store.clone();
let git_store = self.project.read(cx).git_store().clone();
let checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
let window_handle = window.window_handle();
cx.spawn(async move |this, cx| {
let checkpoint = checkpoint.await.ok();
@@ -297,7 +298,7 @@ impl MessageEditor {
// Send to model after summaries are done
thread
.update(cx, |thread, cx| {
thread.send_to_model(model, request_kind, cx);
thread.send_to_model(model, request_kind, Some(window_handle), cx);
})
.log_err();
})
@@ -305,9 +306,9 @@ impl MessageEditor {
}
fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let cancelled = self
.thread
.update(cx, |thread, cx| thread.cancel_last_completion(cx));
let cancelled = self.thread.update(cx, |thread, cx| {
thread.cancel_last_completion(Some(window.window_handle()), cx)
});
if cancelled {
self.set_editor_is_expanded(false, cx);

View File

@@ -6,14 +6,16 @@ use std::time::Instant;
use anyhow::{Context as _, Result, anyhow};
use assistant_settings::AssistantSettings;
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap};
use feature_flags::{self, FeatureFlagAppExt};
use futures::future::Shared;
use futures::{FutureExt, StreamExt as _};
use git::repository::DiffType;
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use gpui::{
AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
@@ -631,6 +633,14 @@ impl Thread {
self.tool_use.tool_result(id)
}
pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
Some(&self.tool_use.tool_result(id)?.content)
}
pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
self.tool_use.tool_result_card(id).cloned()
}
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_use.message_has_tool_results(message_id)
}
@@ -839,6 +849,7 @@ impl Thread {
&mut self,
model: Arc<dyn LanguageModel>,
request_kind: RequestKind,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
let mut request = self.to_completion_request(request_kind, cx);
@@ -865,7 +876,7 @@ impl Thread {
};
}
self.stream_completion(request, model, cx);
self.stream_completion(request, model, window, cx);
}
pub fn used_tools_since_last_user_message(&self) -> bool {
@@ -1011,6 +1022,7 @@ impl Thread {
&mut self,
request: LanguageModelRequest,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
let pending_completion_id = post_inc(&mut self.completion_count);
@@ -1138,7 +1150,7 @@ impl Thread {
match result.as_ref() {
Ok(stop_reason) => match stop_reason {
StopReason::ToolUse => {
let tool_uses = thread.use_pending_tools(cx);
let tool_uses = thread.use_pending_tools(window, cx);
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
}
StopReason::EndTurn => {}
@@ -1183,7 +1195,7 @@ impl Thread {
}));
}
thread.cancel_last_completion(cx);
thread.cancel_last_completion(window, cx);
}
}
cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
@@ -1349,7 +1361,11 @@ impl Thread {
)
}
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
pub fn use_pending_tools(
&mut self,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) -> Vec<PendingToolUse> {
self.auto_capture_telemetry(cx);
let request = self.to_completion_request(RequestKind::Chat, cx);
let messages = Arc::new(request.messages);
@@ -1381,6 +1397,7 @@ impl Thread {
tool_use.input.clone(),
&messages,
tool,
window,
cx,
);
}
@@ -1397,9 +1414,10 @@ impl Thread {
input: serde_json::Value,
messages: &[LanguageModelRequestMessage],
tool: Arc<dyn Tool>,
window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>,
) {
let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
self.tool_use
.run_pending_tool(tool_use_id, ui_text.into(), task);
}
@@ -1410,6 +1428,7 @@ impl Thread {
messages: &[LanguageModelRequestMessage],
input: serde_json::Value,
tool: Arc<dyn Tool>,
window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>,
) -> Task<()> {
let tool_name: Arc<str> = tool.name().into();
@@ -1422,10 +1441,17 @@ impl Thread {
messages,
self.project.clone(),
self.action_log.clone(),
window,
cx,
)
};
// Store the card separately if it exists
if let Some(card) = tool_result.card.clone() {
self.tool_use
.insert_tool_result_card(tool_use_id.clone(), card);
}
cx.spawn({
async move |thread: WeakEntity<Thread>, cx| {
let output = tool_result.output.await;
@@ -1438,7 +1464,7 @@ impl Thread {
output,
cx,
);
thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
})
.ok();
}
@@ -1450,6 +1476,7 @@ impl Thread {
tool_use_id: LanguageModelToolUseId,
pending_tool_use: Option<PendingToolUse>,
canceled: bool,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
if self.all_tools_finished() {
@@ -1457,7 +1484,7 @@ impl Thread {
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
self.attach_tool_results(cx);
if !canceled {
self.send_to_model(model, RequestKind::Chat, cx);
self.send_to_model(model, RequestKind::Chat, window, cx);
}
}
}
@@ -1484,7 +1511,11 @@ impl Thread {
/// Cancels the last pending completion, if there are any pending.
///
/// Returns whether a completion was canceled.
pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
pub fn cancel_last_completion(
&mut self,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) -> bool {
let canceled = if self.pending_completions.pop().is_some() {
true
} else {
@@ -1495,6 +1526,7 @@ impl Thread {
pending_tool_use.id.clone(),
Some(pending_tool_use),
true,
window,
cx,
);
}
@@ -1918,6 +1950,7 @@ impl Thread {
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
let err = Err(anyhow::anyhow!(
@@ -1926,7 +1959,7 @@ impl Thread {
self.tool_use
.insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
self.tool_finished(tool_use_id.clone(), None, true, cx);
self.tool_finished(tool_use_id.clone(), None, true, window, cx);
}
}

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use anyhow::Result;
use assistant_tool::{Tool, ToolWorkingSet};
use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
use collections::HashMap;
use futures::FutureExt as _;
use futures::future::Shared;
@@ -27,26 +27,7 @@ pub struct ToolUse {
pub needs_confirmation: bool,
}
#[derive(Debug, Clone)]
pub enum ToolUseStatus {
NeedsConfirmation,
Pending,
Running,
Finished(SharedString),
Error(SharedString),
}
impl ToolUseStatus {
pub fn text(&self) -> SharedString {
match self {
ToolUseStatus::NeedsConfirmation => "".into(),
ToolUseStatus::Pending => "".into(),
ToolUseStatus::Running => "".into(),
ToolUseStatus::Finished(out) => out.clone(),
ToolUseStatus::Error(out) => out.clone(),
}
}
}
pub const USING_TOOL_MARKER: &str = "<using_tool>";
pub struct ToolUseState {
tools: Entity<ToolWorkingSet>,
@@ -54,10 +35,9 @@ pub struct ToolUseState {
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
}
pub const USING_TOOL_MARKER: &str = "<using_tool>";
impl ToolUseState {
pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
Self {
@@ -66,6 +46,7 @@ impl ToolUseState {
tool_uses_by_user_message: HashMap::default(),
tool_results: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(),
tool_result_cards: HashMap::default(),
}
}
@@ -257,6 +238,18 @@ impl ToolUseState {
self.tool_results.get(tool_use_id)
}
pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
self.tool_result_cards.get(tool_use_id)
}
pub fn insert_tool_result_card(
&mut self,
tool_use_id: LanguageModelToolUseId,
card: AnyToolCard,
) {
self.tool_result_cards.insert(tool_use_id, card);
}
pub fn request_tool_use(
&mut self,
assistant_message_id: MessageId,

View File

@@ -9,6 +9,11 @@ use std::fmt::Formatter;
use std::sync::Arc;
use anyhow::Result;
use gpui::AnyElement;
use gpui::AnyWindowHandle;
use gpui::Context;
use gpui::IntoElement;
use gpui::Window;
use gpui::{App, Entity, SharedString, Task};
use icons::IconName;
use language_model::LanguageModelRequestMessage;
@@ -24,16 +29,87 @@ pub fn init(cx: &mut App) {
ToolRegistry::default_global(cx);
}
/// The result of running a tool
#[derive(Debug, Clone)]
pub enum ToolUseStatus {
NeedsConfirmation,
Pending,
Running,
Finished(SharedString),
Error(SharedString),
}
impl ToolUseStatus {
pub fn text(&self) -> SharedString {
match self {
ToolUseStatus::NeedsConfirmation => "".into(),
ToolUseStatus::Pending => "".into(),
ToolUseStatus::Running => "".into(),
ToolUseStatus::Finished(out) => out.clone(),
ToolUseStatus::Error(out) => out.clone(),
}
}
}
/// The result of running a tool, containing both the asynchronous output
/// and an optional card view that can be rendered immediately.
pub struct ToolResult {
/// The asynchronous task that will eventually resolve to the tool's output
pub output: Task<Result<String>>,
/// An optional view to present the output of the tool.
pub card: Option<AnyToolCard>,
}
pub trait ToolCard: 'static + Sized {
fn render(
&mut self,
status: &ToolUseStatus,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement;
}
#[derive(Clone)]
pub struct AnyToolCard {
entity: gpui::AnyEntity,
render: fn(
entity: gpui::AnyEntity,
status: &ToolUseStatus,
window: &mut Window,
cx: &mut App,
) -> AnyElement,
}
impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
fn from(entity: Entity<T>) -> Self {
fn downcast_render<T: ToolCard>(
entity: gpui::AnyEntity,
status: &ToolUseStatus,
window: &mut Window,
cx: &mut App,
) -> AnyElement {
let entity = entity.downcast::<T>().unwrap();
entity.update(cx, |entity, cx| {
entity.render(status, window, cx).into_any_element()
})
}
Self {
entity: entity.into(),
render: downcast_render::<T>,
}
}
}
impl AnyToolCard {
pub fn render(&self, status: &ToolUseStatus, window: &mut Window, cx: &mut App) -> AnyElement {
(self.render)(self.entity.clone(), status, window, cx)
}
}
impl From<Task<Result<String>>> for ToolResult {
/// Convert from a task to a ToolResult
/// Convert from a task to a ToolResult with no card
fn from(output: Task<Result<String>>) -> Self {
Self { output }
Self { output, card: None }
}
}
@@ -80,6 +156,7 @@ pub trait Tool: 'static + Send + Sync {
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult;
}

View File

@@ -14,8 +14,11 @@ path = "src/assistant_tools.rs"
[dependencies]
anyhow.workspace = true
assistant_tool.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
collections.workspace = true
editor.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
html_to_markdown.workspace = true
@@ -23,6 +26,8 @@ http_client.workspace = true
itertools.workspace = true
language.workspace = true
language_model.workspace = true
multi_buffer.workspace = true
open = { workspace = true }
project.workspace = true
regex.workspace = true
schemars.workspace = true
@@ -30,9 +35,11 @@ serde.workspace = true
serde_json.workspace = true
ui.workspace = true
util.workspace = true
worktree.workspace = true
open = { workspace = true }
web_search.workspace = true
workspace-hack.workspace = true
worktree.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
collections = { workspace = true, features = ["test-support"] }

View File

@@ -22,14 +22,17 @@ mod schema;
mod symbol_info_tool;
mod terminal_tool;
mod thinking_tool;
mod web_search_tool;
use std::sync::Arc;
use assistant_tool::ToolRegistry;
use copy_path_tool::CopyPathTool;
use feature_flags::FeatureFlagAppExt;
use gpui::App;
use http_client::HttpClientWithUrl;
use move_path_tool::MovePathTool;
use web_search_tool::WebSearchTool;
use crate::batch_tool::BatchTool;
use crate::code_action_tool::CodeActionTool;
@@ -56,28 +59,39 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
assistant_tool::init(cx);
let registry = ToolRegistry::global(cx);
registry.register_tool(TerminalTool);
registry.register_tool(BatchTool);
registry.register_tool(CreateDirectoryTool);
registry.register_tool(CreateFileTool);
registry.register_tool(CopyPathTool);
registry.register_tool(DeletePathTool);
registry.register_tool(FindReplaceFileTool);
registry.register_tool(SymbolInfoTool);
registry.register_tool(CodeActionTool);
registry.register_tool(MovePathTool);
registry.register_tool(DiagnosticsTool);
registry.register_tool(ListDirectoryTool);
registry.register_tool(NowTool);
registry.register_tool(OpenTool);
registry.register_tool(CodeSymbolsTool);
registry.register_tool(ContentsTool);
registry.register_tool(CopyPathTool);
registry.register_tool(CreateDirectoryTool);
registry.register_tool(CreateFileTool);
registry.register_tool(DeletePathTool);
registry.register_tool(DiagnosticsTool);
registry.register_tool(FetchTool::new(http_client));
registry.register_tool(FindReplaceFileTool);
registry.register_tool(ListDirectoryTool);
registry.register_tool(MovePathTool);
registry.register_tool(NowTool);
registry.register_tool(OpenTool);
registry.register_tool(PathSearchTool);
registry.register_tool(ReadFileTool);
registry.register_tool(RegexSearchTool);
registry.register_tool(RenameTool);
registry.register_tool(SymbolInfoTool);
registry.register_tool(TerminalTool);
registry.register_tool(ThinkingTool);
registry.register_tool(FetchTool::new(http_client));
cx.observe_flag::<feature_flags::ZedProWebSearchTool, _>({
move |is_enabled, cx| {
if is_enabled {
ToolRegistry::global(cx).register_tool(WebSearchTool);
} else {
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
}
}
})
.detach();
}
#[cfg(test)]

View File

@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult, ToolWorkingSet};
use futures::future::join_all;
use gpui::{App, AppContext, Entity, Task};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -218,6 +218,7 @@ impl Tool for BatchTool {
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<BatchToolInput>(input) {
@@ -258,7 +259,16 @@ impl Tool for BatchTool {
let action_log = action_log.clone();
let messages = messages.clone();
let tool_result = cx
.update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
.update(|cx| {
tool.run(
invocation.input,
&messages,
project,
action_log,
window.clone(),
cx,
)
})
.map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
tasks.push(tool_result.output);

View File

@@ -1,6 +1,6 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language::{self, Anchor, Buffer, ToPointUtf16};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{self, LspAction, Project};
@@ -140,6 +140,7 @@ impl Tool for CodeActionTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CodeActionToolInput>(input) {

View File

@@ -6,7 +6,7 @@ use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use collections::IndexMap;
use gpui::{App, AsyncApp, Entity, Task};
use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
use language::{OutlineItem, ParseStatus, Point};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{Project, Symbol};
@@ -128,6 +128,7 @@ impl Tool for CodeSymbolsTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CodeSymbolsInput>(input) {

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use itertools::Itertools;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -102,6 +102,7 @@ impl Tool for ContentsTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<ContentsToolInput>(input) {

View File

@@ -1,6 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::AnyWindowHandle;
use gpui::{App, AppContext, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
@@ -76,6 +77,7 @@ impl Tool for CopyPathTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CopyPathToolInput>(input) {

View File

@@ -1,6 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::AnyWindowHandle;
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
@@ -67,6 +68,7 @@ impl Tool for CreateDirectoryTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CreateDirectoryToolInput>(input) {

View File

@@ -1,6 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::AnyWindowHandle;
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
@@ -72,6 +73,7 @@ impl Tool for CreateFileTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CreateFileToolInput>(input) {

View File

@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::{SinkExt, StreamExt, channel::mpsc};
use gpui::{App, AppContext, Entity, Task};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{Project, ProjectPath};
use schemars::JsonSchema;
@@ -62,6 +62,7 @@ impl Tool for DeletePathTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let path_str = match serde_json::from_value::<DeletePathToolInput>(input) {

View File

@@ -1,7 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language::{DiagnosticSeverity, OffsetRangeExt};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -82,6 +82,7 @@ impl Tool for DiagnosticsTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
match serde_json::from_value::<DiagnosticsToolInput>(input)

View File

@@ -6,7 +6,7 @@ use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow, bail};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::AsyncReadExt as _;
use gpui::{App, AppContext as _, Entity, Task};
use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task};
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
use http_client::{AsyncBody, HttpClientWithUrl};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -145,6 +145,7 @@ impl Tool for FetchTool {
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<FetchToolInput>(input) {

View File

@@ -1,13 +1,25 @@
use crate::{replace::replace_with_flexible_indent, schema::json_schema_for};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{Editor, MultiBuffer, PathKey};
use gpui::{
AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, IntoElement, Task, Window,
};
use language::{
self, Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt as _, Rope,
TextBuffer,
};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc};
use ui::IconName;
use std::{
path::{Path, PathBuf},
sync::Arc,
};
use ui::{Tooltip, prelude::*};
use util::ResultExt;
use crate::replace::replace_exact;
@@ -132,6 +144,274 @@ pub struct FindReplaceFileToolInput {
pub replace: String,
}
pub struct FindReplaceFileToolCard {
path: PathBuf,
description: String,
editor: Entity<Editor>,
multibuffer: Entity<MultiBuffer>,
project: Entity<Project>,
diff_task: Option<Task<Result<()>>>,
}
impl FindReplaceFileToolCard {
fn new(
path: PathBuf,
description: String,
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Self {
let multibuffer = cx.new(|_| MultiBuffer::new(Capability::ReadOnly));
let editor = cx.new(|cx| {
let mut editor =
Editor::for_multibuffer(multibuffer.clone(), Some(project.clone()), window, cx);
editor.disable_inline_diagnostics();
editor.set_expand_all_diff_hunks(cx);
editor
});
Self {
path,
description,
project,
editor,
multibuffer,
diff_task: None,
}
}
fn set_diff(
&mut self,
path: Arc<Path>,
old_text: String,
new_text: String,
cx: &mut Context<Self>,
) {
let language_registry = self.project.read(cx).languages().clone();
self.diff_task = Some(cx.spawn(async move |this, cx| {
let buffer = build_buffer(new_text, path.clone(), &language_registry, cx).await?;
let buffer_diff = build_buffer_diff(old_text, &buffer, &language_registry, cx).await?;
this.update(cx, |this, cx| {
this.multibuffer.update(cx, |multibuffer, cx| {
let snapshot = buffer.read(cx).snapshot();
let diff = buffer_diff.read(cx);
let diff_hunk_ranges = diff
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
.collect::<Vec<_>>();
let _is_newly_added = multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&buffer, cx),
buffer,
diff_hunk_ranges,
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
multibuffer.add_diff(buffer_diff, cx);
});
cx.notify();
})
}));
}
}
async fn build_buffer(
mut text: String,
path: Arc<Path>,
language_registry: &Arc<language::LanguageRegistry>,
cx: &mut AsyncApp,
) -> Result<Entity<Buffer>> {
let line_ending = LineEnding::detect(&text);
LineEnding::normalize(&mut text);
let text = Rope::from(text);
let language = cx
.update(|_cx| language_registry.language_for_file_path(&path))?
.await
.ok();
let buffer = cx.new(|cx| {
let buffer = TextBuffer::new_normalized(
0,
cx.entity_id().as_non_zero_u64().into(),
line_ending,
text,
);
let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite);
buffer.set_language(language, cx);
buffer
})?;
Ok(buffer)
}
async fn build_buffer_diff(
mut old_text: String,
buffer: &Entity<Buffer>,
language_registry: &Arc<LanguageRegistry>,
cx: &mut AsyncApp,
) -> Result<Entity<BufferDiff>> {
LineEnding::normalize(&mut old_text);
let buffer = cx.update(|cx| buffer.read(cx).snapshot())?;
let base_buffer = cx
.update(|cx| {
Buffer::build_snapshot(
old_text.clone().into(),
buffer.language().cloned(),
Some(language_registry.clone()),
cx,
)
})?
.await;
let diff_snapshot = cx
.update(|cx| {
BufferDiffSnapshot::new_with_base_buffer(
buffer.text.clone(),
Some(old_text.into()),
base_buffer,
cx,
)
})?
.await;
cx.new(|cx| {
let mut diff = BufferDiff::new(&buffer.text, cx);
diff.set_snapshot(diff_snapshot, &buffer.text, cx);
diff
})
}
impl ToolCard for FindReplaceFileToolCard {
fn render(
&mut self,
status: &ToolUseStatus,
_window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
let header = h_flex()
.id("tool-label-container")
.gap_1p5()
.max_w_full()
.overflow_x_scroll()
.child(
Icon::new(IconName::Pencil)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(Label::new("Edit ").size(LabelSize::Small))
.child(
div()
.size(px(3.))
.rounded_full()
.bg(cx.theme().colors().text),
)
.child(Label::new(self.path.display().to_string()).size(LabelSize::Small))
.into_any_element();
let header2 = h_flex()
.id("code-block-header-label")
.w_full()
.max_w_full()
.px_1()
.gap_0p5()
.cursor_pointer()
.rounded_sm()
.hover(|item| item.bg(cx.theme().colors().element_hover.opacity(0.5)))
.tooltip(Tooltip::text("Jump to File"));
// todo!
// .child(
// h_flex()
// .gap_0p5()
// .children(
// file_icons::FileIcons::get_icon(&path_range.path, cx)
// .map(Icon::from_path)
// .map(|icon| icon.color(Color::Muted).size(IconSize::XSmall)),
// )
// .child(content)
// .child(
// Icon::new(IconName::ArrowUpRight)
// .size(IconSize::XSmall)
// .color(Color::Ignored),
// ),
// )
// .on_click({
// let path_range = path_range.clone();
// move |_, window, cx| {
// workspace
// .update(cx, {
// |workspace, cx| {
// if let Some(project_path) = workspace
// .project()
// .read(cx)
// .find_project_path(&path_range.path, cx)
// {
// let target = path_range.range.as_ref().map(|range| {
// Point::new(
// // Line number is 1-based
// range.start.line.saturating_sub(1),
// range.start.col.unwrap_or(0),
// )
// });
// let open_task =
// workspace.open_path(project_path, None, true, window, cx);
// window
// .spawn(cx, async move |cx| {
// let item = open_task.await?;
// if let Some(target) = target {
// if let Some(active_editor) =
// item.downcast::<Editor>()
// {
// active_editor
// .downgrade()
// .update_in(cx, |editor, window, cx| {
// editor.go_to_singleton_buffer_point(
// target, window, cx,
// );
// })
// .log_err();
// }
// }
// anyhow::Ok(())
// })
// .detach_and_log_err(cx);
// }
// }
// })
// .ok();
// }
// })
// .into_any_element();
let content = match status {
ToolUseStatus::NeedsConfirmation | ToolUseStatus::Pending | ToolUseStatus::Running => {
div()
// .child(Label::new(&self.description).size(LabelSize::Small))
.into_any_element()
}
ToolUseStatus::Finished(str) => {
dbg!(&str);
self.editor.clone().into_any_element()
}
ToolUseStatus::Error(error) => div()
.child(
Label::new(error.to_string())
.color(Color::Error)
.size(LabelSize::Small),
)
.into_any_element(),
};
v_flex()
.my_2()
.border_1()
.border_color(cx.theme().colors().border)
.rounded_sm()
.gap_1()
.child(header)
.child(content)
}
}
pub struct FindReplaceFileTool;
impl Tool for FindReplaceFileTool {
@@ -168,14 +448,32 @@ impl Tool for FindReplaceFileTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<FindReplaceFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let card = window.and_then(|window| {
window
.update(cx, |_, window, cx| {
cx.new(|cx| {
FindReplaceFileToolCard::new(
input.path.clone(),
input.display_description.clone(),
project.clone(),
window,
cx,
)
})
})
.ok()
});
cx.spawn(async move |cx: &mut AsyncApp| {
let output = cx.spawn({
let card = card.clone();
async move |cx: &mut AsyncApp| {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&input.path, cx)
@@ -183,7 +481,7 @@ impl Tool for FindReplaceFileTool {
})??;
let buffer = project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.update(cx, |project, cx| project.open_buffer(project_path.clone(), cx))?
.await?;
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
@@ -255,14 +553,29 @@ impl Tool for FindReplaceFileTool {
project.save_buffer(buffer, cx)
})?.await?;
let diff_str = cx.background_spawn(async move {
let new_text = snapshot.text();
language::unified_diff(&old_text, &new_text)
let new_text = snapshot.text();
let diff_str = cx.background_spawn({
// todo! probably don't need this
let old_text = old_text.clone();
let new_text = new_text.clone();
async move {
language::unified_diff(&old_text, &new_text)
}
}).await;
if let Some(card) = card {
card.update(cx, |card, cx| {
card.set_diff(project_path.path.clone(), old_text, new_text, cx);
}).log_err();
}
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input.path.display(), diff_str))
}});
}).into()
ToolResult {
output,
card: card.map(|card| card.into()),
}
}
}

View File

@@ -1,7 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -76,6 +76,7 @@ impl Tool for ListDirectoryTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<ListDirectoryToolInput>(input) {

View File

@@ -1,7 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AppContext, Entity, Task};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -89,6 +89,7 @@ impl Tool for MovePathTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<MovePathToolInput>(input) {

View File

@@ -4,7 +4,7 @@ use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use chrono::{Local, Utc};
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -59,6 +59,7 @@ impl Tool for NowTool {
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
_cx: &mut App,
) -> ToolResult {
let input: NowToolInput = match serde_json::from_value(input) {

View File

@@ -1,7 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AppContext, Entity, Task};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -52,6 +52,7 @@ impl Tool for OpenTool {
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input: OpenToolInput = match serde_json::from_value(input) {

View File

@@ -1,7 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AppContext, Entity, Task};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -70,6 +70,7 @@ impl Tool for PathSearchTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let (offset, glob) = match serde_json::from_value::<PathSearchToolInput>(input) {

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use itertools::Itertools;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -87,6 +87,7 @@ impl Tool for ReadFileTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<ReadFileToolInput>(input) {

View File

@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::StreamExt;
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language::OffsetRangeExt;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{
@@ -91,6 +91,7 @@ impl Tool for RegexSearchTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
const CONTEXT_LINES: u32 = 2;

View File

@@ -1,6 +1,6 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language::{self, Buffer, ToPointUtf16};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -87,6 +87,7 @@ impl Tool for RenameTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<RenameToolInput>(input) {

View File

@@ -1,6 +1,6 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AsyncApp, Entity, Task};
use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
use language::{self, Anchor, Buffer, BufferSnapshot, Location, Point, ToPoint, ToPointUtf16};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -121,6 +121,7 @@ impl Tool for SymbolInfoTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<SymbolInfoToolInput>(input) {

View File

@@ -3,7 +3,7 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::io::BufReader;
use futures::{AsyncBufReadExt, AsyncReadExt, FutureExt};
use gpui::{App, AppContext, Entity, Task};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -78,6 +78,7 @@ impl Tool for TerminalTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input: TerminalToolInput = match serde_json::from_value(input) {

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -50,6 +50,7 @@ impl Tool for ThinkingTool {
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
_cx: &mut App,
) -> ToolResult {
// This tool just "thinks out loud" and doesn't perform any actions.

View File

@@ -0,0 +1,214 @@
use std::{sync::Arc, time::Duration};
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
use futures::{FutureExt, TryFutureExt};
use gpui::{
Animation, AnimationExt, AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task,
Window, pulsating_between,
};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use ui::{IconName, Tooltip, prelude::*};
use web_search::WebSearchRegistry;
use zed_llm_client::WebSearchResponse;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct WebSearchToolInput {
/// The search term or question to query on the web.
query: String,
}
pub struct WebSearchTool;
impl Tool for WebSearchTool {
fn name(&self) -> String {
"web_search".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
"Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
}
fn icon(&self) -> IconName {
IconName::Globe
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<WebSearchToolInput>(format)
}
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Web Search".to_string()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<WebSearchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
return Task::ready(Err(anyhow!("Web search is not available."))).into();
};
let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
let output = cx.background_spawn({
let search_task = search_task.clone();
async move {
let response = search_task.await.map_err(|err| anyhow!(err))?;
serde_json::to_string(&response).context("Failed to serialize search results")
}
});
ToolResult {
output,
card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
}
}
}
struct WebSearchToolCard {
response: Option<Result<WebSearchResponse>>,
_task: Task<()>,
}
impl WebSearchToolCard {
fn new(
search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
cx: &mut Context<Self>,
) -> Self {
let _task = cx.spawn(async move |this, cx| {
let response = search_task.await.map_err(|err| anyhow!(err));
this.update(cx, |this, cx| {
this.response = Some(response);
cx.notify();
})
.ok();
});
Self {
response: None,
_task,
}
}
}
impl ToolCard for WebSearchToolCard {
fn render(
&mut self,
_status: &ToolUseStatus,
_window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
let header = h_flex()
.id("tool-label-container")
.gap_1p5()
.max_w_full()
.overflow_x_scroll()
.child(
Icon::new(IconName::Globe)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(match self.response.as_ref() {
Some(Ok(response)) => {
let text: SharedString = if response.citations.len() == 1 {
"1 result".into()
} else {
format!("{} results", response.citations.len()).into()
};
h_flex()
.gap_1p5()
.child(Label::new("Searched the Web").size(LabelSize::Small))
.child(
div()
.size(px(3.))
.rounded_full()
.bg(cx.theme().colors().text),
)
.child(Label::new(text).size(LabelSize::Small))
.into_any_element()
}
Some(Err(error)) => div()
.id("web-search-error")
.child(Label::new("Web Search failed").size(LabelSize::Small))
.tooltip(Tooltip::text(error.to_string()))
.into_any_element(),
None => Label::new("Searching the Web…")
.size(LabelSize::Small)
.with_animation(
"web-search-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| label.alpha(delta),
)
.into_any_element(),
})
.into_any();
let content =
self.response.as_ref().and_then(|response| match response {
Ok(response) => {
Some(
v_flex()
.ml_1p5()
.pl_1p5()
.border_l_1()
.border_color(cx.theme().colors().border_variant)
.gap_1()
.children(response.citations.iter().enumerate().map(
|(index, citation)| {
let title = citation.title.clone();
let url = citation.url.clone();
Button::new(("citation", index), title)
.label_size(LabelSize::Small)
.color(Color::Muted)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::XSmall)
.icon_position(IconPosition::End)
.truncate(true)
.tooltip({
let url = url.clone();
move |window, cx| {
Tooltip::with_meta(
"Citation Link",
None,
url.clone(),
window,
cx,
)
}
})
.on_click({
let url = url.clone();
move |_, _, cx| cx.open_url(&url)
})
},
))
.into_any(),
)
}
Err(_) => None,
});
v_flex().my_2().gap_1().child(header).children(content)
}
}

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use anyhow::{Result, anyhow, bail};
use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource};
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use icons::IconName;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -77,6 +77,7 @@ impl Tool for ContextServerTool {
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {

View File

@@ -84,6 +84,11 @@ impl FeatureFlag for ZedPro {
const NAME: &'static str = "zed-pro";
}
pub struct ZedProWebSearchTool {}
impl FeatureFlag for ZedProWebSearchTool {
const NAME: &'static str = "zed-pro-web-search-tool";
}
pub struct NotebookFeatureFlag;
impl FeatureFlag for NotebookFeatureFlag {

View File

@@ -160,7 +160,11 @@ impl Render for Tooltip {
}),
)
.when_some(self.meta.clone(), |this, meta| {
this.child(Label::new(meta).size(LabelSize::Small).color(Color::Muted))
this.child(
div()
.max_w_72()
.child(Label::new(meta).size(LabelSize::Small).color(Color::Muted)),
)
})
})
}

View File

@@ -0,0 +1,20 @@
[package]
name = "web_search"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/web_search.rs"
[dependencies]
anyhow.workspace = true
collections.workspace = true
gpui.workspace = true
serde.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true

View File

@@ -0,0 +1 @@
../../LICENSE-GPL

View File

@@ -0,0 +1,64 @@
use anyhow::Result;
use collections::HashMap;
use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
use std::sync::Arc;
use zed_llm_client::WebSearchResponse;
pub fn init(cx: &mut App) {
let registry = cx.new(|_cx| WebSearchRegistry::default());
cx.set_global(GlobalWebSearchRegistry(registry));
}
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
pub struct WebSearchProviderId(pub SharedString);
pub trait WebSearchProvider {
fn id(&self) -> WebSearchProviderId;
fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>>;
}
struct GlobalWebSearchRegistry(Entity<WebSearchRegistry>);
impl Global for GlobalWebSearchRegistry {}
#[derive(Default)]
pub struct WebSearchRegistry {
providers: HashMap<WebSearchProviderId, Arc<dyn WebSearchProvider>>,
active_provider: Option<Arc<dyn WebSearchProvider>>,
}
impl WebSearchRegistry {
pub fn global(cx: &App) -> Entity<Self> {
cx.global::<GlobalWebSearchRegistry>().0.clone()
}
pub fn read_global(cx: &App) -> &Self {
cx.global::<GlobalWebSearchRegistry>().0.read(cx)
}
pub fn providers(&self) -> impl Iterator<Item = &Arc<dyn WebSearchProvider>> {
self.providers.values()
}
pub fn active_provider(&self) -> Option<Arc<dyn WebSearchProvider>> {
self.active_provider.clone()
}
pub fn set_active_provider(&mut self, provider: Arc<dyn WebSearchProvider>) {
self.active_provider = Some(provider.clone());
self.providers.insert(provider.id(), provider);
}
pub fn register_provider<T: WebSearchProvider + 'static>(
&mut self,
provider: T,
_cx: &mut Context<Self>,
) {
let id = provider.id();
let provider = Arc::new(provider);
self.providers.insert(id.clone(), provider.clone());
if self.active_provider.is_none() {
self.active_provider = Some(provider);
}
}
}

View File

@@ -0,0 +1,26 @@
[package]
name = "web_search_providers"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/web_search_providers.rs"
[dependencies]
anyhow.workspace = true
client.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
http_client.workspace = true
language_model.workspace = true
serde.workspace = true
serde_json.workspace = true
web_search.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true

View File

@@ -0,0 +1 @@
../../LICENSE-GPL

View File

@@ -0,0 +1,103 @@
use std::sync::Arc;
use anyhow::{Context as _, Result, anyhow};
use client::Client;
use futures::AsyncReadExt as _;
use gpui::{App, AppContext, Context, Entity, Subscription, Task};
use http_client::{HttpClient, Method};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use web_search::{WebSearchProvider, WebSearchProviderId};
use zed_llm_client::{WebSearchBody, WebSearchResponse};
pub struct CloudWebSearchProvider {
state: Entity<State>,
}
impl CloudWebSearchProvider {
pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
let state = cx.new(|cx| State::new(client, cx));
Self { state }
}
}
pub struct State {
client: Arc<Client>,
llm_api_token: LlmApiToken,
_llm_token_subscription: Subscription,
}
impl State {
pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
Self {
client,
llm_api_token: LlmApiToken::default(),
_llm_token_subscription: cx.subscribe(
&refresh_llm_token_listener,
|this, _, _event, cx| {
let client = this.client.clone();
let llm_api_token = this.llm_api_token.clone();
cx.spawn(async move |_this, _cx| {
llm_api_token.refresh(&client).await?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
},
),
}
}
}
impl WebSearchProvider for CloudWebSearchProvider {
fn id(&self) -> WebSearchProviderId {
WebSearchProviderId("zed.dev".into())
}
fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
let state = self.state.read(cx);
let client = state.client.clone();
let llm_api_token = state.llm_api_token.clone();
let body = WebSearchBody { query };
cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
}
}
async fn perform_web_search(
client: Arc<Client>,
llm_api_token: LlmApiToken,
body: WebSearchBody,
) -> Result<WebSearchResponse> {
let http_client = &client.http_client();
let token = llm_api_token.acquire(&client).await?;
let request_builder = http_client::Request::builder().method(Method::POST);
let request_builder = if let Ok(web_search_url) = std::env::var("ZED_WEB_SEARCH_URL") {
request_builder.uri(web_search_url)
} else {
request_builder.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
};
let request = request_builder
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
.body(serde_json::to_string(&body)?.into())?;
let mut response = http_client
.send(request)
.await
.context("failed to send web search request")?;
if response.status().is_success() {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Ok(serde_json::from_str(&body)?);
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"error performing web search.\nStatus: {:?}\nBody: {body}",
response.status(),
));
}
}

View File

@@ -0,0 +1,35 @@
mod cloud;
use client::Client;
use feature_flags::{FeatureFlagAppExt, ZedProWebSearchTool};
use gpui::{App, Context};
use std::sync::Arc;
use web_search::WebSearchRegistry;
pub fn init(client: Arc<Client>, cx: &mut App) {
let registry = WebSearchRegistry::global(cx);
registry.update(cx, |registry, cx| {
register_web_search_providers(registry, client, cx);
});
}
fn register_web_search_providers(
_registry: &mut WebSearchRegistry,
client: Arc<Client>,
cx: &mut Context<WebSearchRegistry>,
) {
cx.observe_flag::<ZedProWebSearchTool, _>({
let client = client.clone();
move |is_enabled, cx| {
if is_enabled {
WebSearchRegistry::global(cx).update(cx, |registry, cx| {
registry.register_provider(
cloud::CloudWebSearchProvider::new(client.clone(), cx),
cx,
);
});
}
}
})
.detach();
}

View File

@@ -133,6 +133,8 @@ util.workspace = true
uuid.workspace = true
vim.workspace = true
vim_mode_setting.workspace = true
web_search.workspace = true
web_search_providers.workspace = true
welcome.workspace = true
workspace.workspace = true
zed_actions.workspace = true

View File

@@ -490,6 +490,8 @@ fn main() {
app_state.fs.clone(),
cx,
);
web_search::init(cx);
web_search_providers::init(app_state.client.clone(), cx);
snippet_provider::init(cx);
inline_completion_registry::init(
app_state.client.clone(),

View File

@@ -4258,6 +4258,8 @@ mod tests {
app_state.fs.clone(),
cx,
);
web_search::init(cx);
web_search_providers::init(app_state.client.clone(), cx);
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx);
assistant::init(
app_state.fs.clone(),