Compare commits
73 Commits
resource-l
...
merge-conf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2aa1255ddb | ||
|
|
a6e2e0d24a | ||
|
|
9be44517cb | ||
|
|
389d24d7e5 | ||
|
|
d36304963e | ||
|
|
389d382f42 | ||
|
|
bd61eb0889 | ||
|
|
0d71351b02 | ||
|
|
b06fe288f3 | ||
|
|
4a35498829 | ||
|
|
fd0ffb737f | ||
|
|
e52f148304 | ||
|
|
cb0bc463f1 | ||
|
|
9a375f1419 | ||
|
|
4238e640fa | ||
|
|
0b9c9f5f2d | ||
|
|
2da80e4641 | ||
|
|
d9a94a5496 | ||
|
|
a7442d8880 | ||
|
|
6c1f19571a | ||
|
|
23cd5b59b2 | ||
|
|
f4b0332f78 | ||
|
|
abde7306e3 | ||
|
|
2b3dbe8815 | ||
|
|
7f1a5c6ad7 | ||
|
|
6307105976 | ||
|
|
8d63312eca | ||
|
|
81474a3de0 | ||
|
|
db497ac867 | ||
|
|
8ff2e3e195 | ||
|
|
96093aa465 | ||
|
|
dc87f4b32e | ||
|
|
1957e1f642 | ||
|
|
d78bd8f1d7 | ||
|
|
32975c4208 | ||
|
|
658d56bd72 | ||
|
|
13a2c53381 | ||
|
|
cd234e28ce | ||
|
|
b564b1d5d0 | ||
|
|
48ae02c1ca | ||
|
|
255bb0a3f8 | ||
|
|
628b1058be | ||
|
|
7167f193c0 | ||
|
|
7ff0f1525e | ||
|
|
7df8e05ad9 | ||
|
|
d030bb6281 | ||
|
|
b62f959528 | ||
|
|
3a04657730 | ||
|
|
42b7dbeaee | ||
|
|
bfbb18476f | ||
|
|
978b75bba9 | ||
|
|
1f20d5bf54 | ||
|
|
9de04ce215 | ||
|
|
d8fc53608e | ||
|
|
39c19abdfd | ||
|
|
b105028c05 | ||
|
|
d2162446d0 | ||
|
|
360d4db87c | ||
|
|
44953375cc | ||
|
|
2444321756 | ||
|
|
13bf45dd4a | ||
|
|
b61b71405d | ||
|
|
cc5eb24066 | ||
|
|
52a9101970 | ||
|
|
1a798830cb | ||
|
|
481e3e5092 | ||
|
|
b35e69692d | ||
|
|
add67bde43 | ||
|
|
fa3d0aaed4 | ||
|
|
094e878ccf | ||
|
|
54d4665100 | ||
|
|
2c84e33b7b | ||
|
|
bb6ea22944 |
35
.github/ISSUE_TEMPLATE/07_bug_windows_alpha.yml
vendored
Normal file
35
.github/ISSUE_TEMPLATE/07_bug_windows_alpha.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
name: Bug Report (Windows Alpha)
|
||||
description: Zed Windows Alpha Related Bugs
|
||||
type: "Bug"
|
||||
labels: ["windows"]
|
||||
title: "Windows Alpha: <a short description of the Windows bug>"
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Summary
|
||||
description: Describe the bug with a one-line summary, and provide detailed reproduction steps
|
||||
value: |
|
||||
<!-- Please insert a one-line summary of the issue below -->
|
||||
SUMMARY_SENTENCE_HERE
|
||||
|
||||
### Description
|
||||
<!-- Describe with sufficient detail to reproduce from a clean Zed install. -->
|
||||
Steps to trigger the problem:
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
**Expected Behavior**:
|
||||
**Actual Behavior**:
|
||||
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: environment
|
||||
attributes:
|
||||
label: Zed Version and System Specs
|
||||
description: 'Open Zed, and in the command palette select "zed: copy system specs into clipboard"'
|
||||
placeholder: |
|
||||
Output of "zed: copy system specs into clipboard"
|
||||
validations:
|
||||
required: true
|
||||
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -718,7 +718,7 @@ jobs:
|
||||
timeout-minutes: 60
|
||||
runs-on: github-8vcpu-ubuntu-2404
|
||||
if: |
|
||||
( startsWith(github.ref, 'refs/tags/v')
|
||||
false && ( startsWith(github.ref, 'refs/tags/v')
|
||||
|| contains(github.event.pull_request.labels.*.name, 'run-bundling') )
|
||||
needs: [linux_tests]
|
||||
name: Build Zed on FreeBSD
|
||||
|
||||
52
Cargo.lock
generated
52
Cargo.lock
generated
@@ -7,20 +7,23 @@ name = "acp_thread"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"action_log",
|
||||
"agent",
|
||||
"agent-client-protocol",
|
||||
"anyhow",
|
||||
"buffer_diff",
|
||||
"collections",
|
||||
"editor",
|
||||
"env_logger 0.11.8",
|
||||
"file_icons",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"indoc",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
"language_model",
|
||||
"markdown",
|
||||
"parking_lot",
|
||||
"project",
|
||||
"prompt_store",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -29,7 +32,10 @@ dependencies = [
|
||||
"tempfile",
|
||||
"terminal",
|
||||
"ui",
|
||||
"url",
|
||||
"util",
|
||||
"uuid",
|
||||
"watch",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@@ -196,6 +202,7 @@ dependencies = [
|
||||
"clock",
|
||||
"cloud_llm_client",
|
||||
"collections",
|
||||
"context_server",
|
||||
"ctor",
|
||||
"editor",
|
||||
"env_logger 0.11.8",
|
||||
@@ -204,6 +211,8 @@ dependencies = [
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"handlebars 4.5.0",
|
||||
"html_to_markdown",
|
||||
"http_client",
|
||||
"indoc",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
@@ -227,6 +236,7 @@ dependencies = [
|
||||
"task",
|
||||
"tempfile",
|
||||
"terminal",
|
||||
"text",
|
||||
"theme",
|
||||
"tree-sitter-rust",
|
||||
"ui",
|
||||
@@ -6440,6 +6450,7 @@ dependencies = [
|
||||
"log",
|
||||
"parking_lot",
|
||||
"pretty_assertions",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"rope",
|
||||
"schemars",
|
||||
@@ -11144,14 +11155,13 @@ dependencies = [
|
||||
"ai_onboarding",
|
||||
"anyhow",
|
||||
"client",
|
||||
"command_palette_hooks",
|
||||
"component",
|
||||
"db",
|
||||
"documented",
|
||||
"editor",
|
||||
"feature_flags",
|
||||
"fs",
|
||||
"fuzzy",
|
||||
"git",
|
||||
"gpui",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
@@ -11163,6 +11173,7 @@ dependencies = [
|
||||
"schemars",
|
||||
"serde",
|
||||
"settings",
|
||||
"telemetry",
|
||||
"theme",
|
||||
"ui",
|
||||
"util",
|
||||
@@ -11238,6 +11249,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.31",
|
||||
"http_client",
|
||||
"log",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -18023,6 +18035,7 @@ dependencies = [
|
||||
"command_palette_hooks",
|
||||
"db",
|
||||
"editor",
|
||||
"env_logger 0.11.8",
|
||||
"futures 0.3.31",
|
||||
"git_ui",
|
||||
"gpui",
|
||||
@@ -18876,33 +18889,6 @@ version = "0.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
|
||||
|
||||
[[package]]
|
||||
name = "welcome"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"client",
|
||||
"component",
|
||||
"db",
|
||||
"documented",
|
||||
"editor",
|
||||
"fuzzy",
|
||||
"gpui",
|
||||
"install_cli",
|
||||
"language",
|
||||
"picker",
|
||||
"project",
|
||||
"serde",
|
||||
"settings",
|
||||
"telemetry",
|
||||
"ui",
|
||||
"util",
|
||||
"vim_mode_setting",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
"zed_actions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "which"
|
||||
version = "4.4.2"
|
||||
@@ -20517,7 +20503,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed"
|
||||
version = "0.200.0"
|
||||
version = "0.201.0"
|
||||
dependencies = [
|
||||
"activity_indicator",
|
||||
"agent",
|
||||
@@ -20657,7 +20643,6 @@ dependencies = [
|
||||
"watch",
|
||||
"web_search",
|
||||
"web_search_providers",
|
||||
"welcome",
|
||||
"windows 0.61.1",
|
||||
"winresource",
|
||||
"workspace",
|
||||
@@ -20681,7 +20666,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed_emmet"
|
||||
version = "0.0.4"
|
||||
version = "0.0.6"
|
||||
dependencies = [
|
||||
"zed_extension_api 0.1.0",
|
||||
]
|
||||
@@ -20920,6 +20905,7 @@ dependencies = [
|
||||
"menu",
|
||||
"postage",
|
||||
"project",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"release_channel",
|
||||
"reqwest_client",
|
||||
|
||||
@@ -185,7 +185,6 @@ members = [
|
||||
"crates/watch",
|
||||
"crates/web_search",
|
||||
"crates/web_search_providers",
|
||||
"crates/welcome",
|
||||
"crates/workspace",
|
||||
"crates/worktree",
|
||||
"crates/x_ai",
|
||||
@@ -412,7 +411,6 @@ vim_mode_setting = { path = "crates/vim_mode_setting" }
|
||||
watch = { path = "crates/watch" }
|
||||
web_search = { path = "crates/web_search" }
|
||||
web_search_providers = { path = "crates/web_search_providers" }
|
||||
welcome = { path = "crates/welcome" }
|
||||
workspace = { path = "crates/workspace" }
|
||||
worktree = { path = "crates/worktree" }
|
||||
x_ai = { path = "crates/x_ai" }
|
||||
@@ -566,6 +564,7 @@ reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "951c77
|
||||
"socks",
|
||||
"stream",
|
||||
] }
|
||||
rodio = { version = "0.21.1", default-features = false }
|
||||
rsa = "0.9.6"
|
||||
runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [
|
||||
"async-dispatcher-runtime",
|
||||
@@ -714,6 +713,7 @@ features = [
|
||||
"Win32_System_LibraryLoader",
|
||||
"Win32_System_Memory",
|
||||
"Win32_System_Ole",
|
||||
"Win32_System_Performance",
|
||||
"Win32_System_Pipes",
|
||||
"Win32_System_SystemInformation",
|
||||
"Win32_System_SystemServices",
|
||||
|
||||
BIN
assets/fonts/ibm-plex-sans/IBMPlexSans-Bold.ttf
Normal file
BIN
assets/fonts/ibm-plex-sans/IBMPlexSans-Bold.ttf
Normal file
Binary file not shown.
BIN
assets/fonts/ibm-plex-sans/IBMPlexSans-BoldItalic.ttf
Normal file
BIN
assets/fonts/ibm-plex-sans/IBMPlexSans-BoldItalic.ttf
Normal file
Binary file not shown.
BIN
assets/fonts/ibm-plex-sans/IBMPlexSans-Italic.ttf
Normal file
BIN
assets/fonts/ibm-plex-sans/IBMPlexSans-Italic.ttf
Normal file
Binary file not shown.
BIN
assets/fonts/ibm-plex-sans/IBMPlexSans-Regular.ttf
Normal file
BIN
assets/fonts/ibm-plex-sans/IBMPlexSans-Regular.ttf
Normal file
Binary file not shown.
BIN
assets/fonts/lilex/Lilex-Bold.ttf
Normal file
BIN
assets/fonts/lilex/Lilex-Bold.ttf
Normal file
Binary file not shown.
BIN
assets/fonts/lilex/Lilex-BoldItalic.ttf
Normal file
BIN
assets/fonts/lilex/Lilex-BoldItalic.ttf
Normal file
Binary file not shown.
BIN
assets/fonts/lilex/Lilex-Italic.ttf
Normal file
BIN
assets/fonts/lilex/Lilex-Italic.ttf
Normal file
Binary file not shown.
BIN
assets/fonts/lilex/Lilex-Regular.ttf
Normal file
BIN
assets/fonts/lilex/Lilex-Regular.ttf
Normal file
Binary file not shown.
@@ -1,8 +1,9 @@
|
||||
Copyright © 2017 IBM Corp. with Reserved Font Name "Plex"
|
||||
Copyright 2019 The Lilex Project Authors (https://github.com/mishamyrt/Lilex)
|
||||
|
||||
This Font Software is licensed under the SIL Open Font License, Version 1.1.
|
||||
This license is copied below, and is also available with a FAQ at:
|
||||
http://scripts.sil.org/OFL
|
||||
https://scripts.sil.org/OFL
|
||||
|
||||
|
||||
-----------------------------------------------------------
|
||||
SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007
|
||||
@@ -89,4 +90,4 @@ COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
||||
INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL
|
||||
DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM
|
||||
OTHER DEALINGS IN THE FONT SOFTWARE.
|
||||
OTHER DEALINGS IN THE FONT SOFTWARE.
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -239,6 +239,7 @@
|
||||
"ctrl-shift-a": "agent::ToggleContextPicker",
|
||||
"ctrl-shift-j": "agent::ToggleNavigationMenu",
|
||||
"ctrl-shift-i": "agent::ToggleOptionsMenu",
|
||||
"ctrl-alt-shift-n": "agent::ToggleNewThreadMenu",
|
||||
"shift-alt-escape": "agent::ExpandMessageEditor",
|
||||
"ctrl->": "assistant::QuoteSelection",
|
||||
"ctrl-alt-e": "agent::RemoveAllContext",
|
||||
@@ -330,8 +331,6 @@
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"enter": "agent::Chat",
|
||||
"up": "agent::PreviousHistoryMessage",
|
||||
"down": "agent::NextHistoryMessage",
|
||||
"shift-ctrl-r": "agent::OpenAgentDiff",
|
||||
"ctrl-shift-y": "agent::KeepAll",
|
||||
"ctrl-shift-n": "agent::RejectAll"
|
||||
|
||||
@@ -279,6 +279,7 @@
|
||||
"cmd-shift-a": "agent::ToggleContextPicker",
|
||||
"cmd-shift-j": "agent::ToggleNavigationMenu",
|
||||
"cmd-shift-i": "agent::ToggleOptionsMenu",
|
||||
"cmd-alt-shift-n": "agent::ToggleNewThreadMenu",
|
||||
"shift-alt-escape": "agent::ExpandMessageEditor",
|
||||
"cmd->": "assistant::QuoteSelection",
|
||||
"cmd-alt-e": "agent::RemoveAllContext",
|
||||
@@ -382,8 +383,6 @@
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"enter": "agent::Chat",
|
||||
"up": "agent::PreviousHistoryMessage",
|
||||
"down": "agent::NextHistoryMessage",
|
||||
"shift-ctrl-r": "agent::OpenAgentDiff",
|
||||
"cmd-shift-y": "agent::KeepAll",
|
||||
"cmd-shift-n": "agent::RejectAll"
|
||||
|
||||
@@ -333,10 +333,14 @@
|
||||
"ctrl-x ctrl-c": "editor::ShowEditPrediction", // zed specific
|
||||
"ctrl-x ctrl-l": "editor::ToggleCodeActions", // zed specific
|
||||
"ctrl-x ctrl-z": "editor::Cancel",
|
||||
"ctrl-x ctrl-e": "vim::LineDown",
|
||||
"ctrl-x ctrl-y": "vim::LineUp",
|
||||
"ctrl-w": "editor::DeleteToPreviousWordStart",
|
||||
"ctrl-u": "editor::DeleteToBeginningOfLine",
|
||||
"ctrl-t": "vim::Indent",
|
||||
"ctrl-d": "vim::Outdent",
|
||||
"ctrl-y": "vim::InsertFromAbove",
|
||||
"ctrl-e": "vim::InsertFromBelow",
|
||||
"ctrl-k": ["vim::PushDigraph", {}],
|
||||
"ctrl-v": ["vim::PushLiteral", {}],
|
||||
"ctrl-shift-v": "editor::Paste", // note: this is *very* similar to ctrl-v in vim, but ctrl-shift-v on linux is the typical shortcut for paste when ctrl-v is already in use.
|
||||
|
||||
@@ -28,7 +28,9 @@
|
||||
"edit_prediction_provider": "zed"
|
||||
},
|
||||
// The name of a font to use for rendering text in the editor
|
||||
"buffer_font_family": "Zed Plex Mono",
|
||||
// ".ZedMono" currently aliases to Lilex
|
||||
// but this may change in the future.
|
||||
"buffer_font_family": ".ZedMono",
|
||||
// Set the buffer text's font fallbacks, this will be merged with
|
||||
// the platform's default fallbacks.
|
||||
"buffer_font_fallbacks": null,
|
||||
@@ -54,7 +56,9 @@
|
||||
"buffer_line_height": "comfortable",
|
||||
// The name of a font to use for rendering text in the UI
|
||||
// You can set this to ".SystemUIFont" to use the system font
|
||||
"ui_font_family": "Zed Plex Sans",
|
||||
// ".ZedSans" currently aliases to "IBM Plex Sans", but this may
|
||||
// change in the future
|
||||
"ui_font_family": ".ZedSans",
|
||||
// Set the UI's font fallbacks, this will be merged with the platform's
|
||||
// default font fallbacks.
|
||||
"ui_font_fallbacks": null,
|
||||
@@ -82,10 +86,10 @@
|
||||
// Layout mode of the bottom dock. Defaults to "contained"
|
||||
// choices: contained, full, left_aligned, right_aligned
|
||||
"bottom_dock_layout": "contained",
|
||||
// The direction that you want to split panes horizontally. Defaults to "up"
|
||||
"pane_split_direction_horizontal": "up",
|
||||
// The direction that you want to split panes vertically. Defaults to "left"
|
||||
"pane_split_direction_vertical": "left",
|
||||
// The direction that you want to split panes horizontally. Defaults to "down"
|
||||
"pane_split_direction_horizontal": "down",
|
||||
// The direction that you want to split panes vertically. Defaults to "right"
|
||||
"pane_split_direction_vertical": "right",
|
||||
// Centered layout related settings.
|
||||
"centered_layout": {
|
||||
// The relative width of the left padding of the central pane from the
|
||||
@@ -1402,7 +1406,7 @@
|
||||
// "font_size": 15,
|
||||
// Set the terminal's font family. If this option is not included,
|
||||
// the terminal will default to matching the buffer's font family.
|
||||
// "font_family": "Zed Plex Mono",
|
||||
// "font_family": ".ZedMono",
|
||||
// Set the terminal's font fallbacks. If this option is not included,
|
||||
// the terminal will default to matching the buffer's font fallbacks.
|
||||
// This will be merged with the platform's default font fallbacks
|
||||
|
||||
@@ -18,23 +18,29 @@ test-support = ["gpui/test-support", "project/test-support"]
|
||||
[dependencies]
|
||||
action_log.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
agent.workspace = true
|
||||
anyhow.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
collections.workspace = true
|
||||
editor.workspace = true
|
||||
file_icons.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
markdown.workspace = true
|
||||
project.workspace = true
|
||||
prompt_store.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
terminal.workspace = true
|
||||
ui.workspace = true
|
||||
url.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
watch.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,18 +1,78 @@
|
||||
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||
|
||||
use crate::AcpThread;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::Result;
|
||||
use gpui::{AsyncApp, Entity, Task};
|
||||
use language_model::LanguageModel;
|
||||
use collections::IndexMap;
|
||||
use gpui::{AsyncApp, Entity, SharedString, Task};
|
||||
use project::Project;
|
||||
use ui::App;
|
||||
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||
use ui::{App, IconName};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AcpThread;
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct UserMessageId(Arc<str>);
|
||||
|
||||
impl UserMessageId {
|
||||
pub fn new() -> Self {
|
||||
Self(Uuid::new_v4().to_string().into())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AgentConnection {
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<AcpThread>>>;
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod];
|
||||
|
||||
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
user_message_id: Option<UserMessageId>,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>>;
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
||||
|
||||
fn session_editor(
|
||||
&self,
|
||||
_session_id: &acp::SessionId,
|
||||
_cx: &mut App,
|
||||
) -> Option<Rc<dyn AgentSessionEditor>> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
|
||||
///
|
||||
/// If the agent does not support model selection, returns [None].
|
||||
/// This allows sharing the selector in UI components.
|
||||
fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AgentSessionEditor {
|
||||
fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthRequired;
|
||||
|
||||
impl Error for AuthRequired {}
|
||||
impl fmt::Display for AuthRequired {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "AuthRequired")
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for agents that support listing, selecting, and querying language models.
|
||||
///
|
||||
/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
|
||||
pub trait ModelSelector: 'static {
|
||||
pub trait AgentModelSelector: 'static {
|
||||
/// Lists all available language models for this agent.
|
||||
///
|
||||
/// # Parameters
|
||||
@@ -20,7 +80,7 @@ pub trait ModelSelector: 'static {
|
||||
///
|
||||
/// # Returns
|
||||
/// A task resolving to the list of models or an error (e.g., if no models are configured).
|
||||
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>>;
|
||||
fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
|
||||
|
||||
/// Selects a model for a specific session (thread).
|
||||
///
|
||||
@@ -37,8 +97,8 @@ pub trait ModelSelector: 'static {
|
||||
fn select_model(
|
||||
&self,
|
||||
session_id: acp::SessionId,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut AsyncApp,
|
||||
model_id: AgentModelId,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<()>>;
|
||||
|
||||
/// Retrieves the currently selected model for a specific session (thread).
|
||||
@@ -52,42 +112,51 @@ pub trait ModelSelector: 'static {
|
||||
fn selected_model(
|
||||
&self,
|
||||
session_id: &acp::SessionId,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Arc<dyn LanguageModel>>>;
|
||||
cx: &mut App,
|
||||
) -> Task<Result<AgentModelInfo>>;
|
||||
|
||||
/// Whenever the model list is updated the receiver will be notified.
|
||||
fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
|
||||
}
|
||||
|
||||
pub trait AgentConnection {
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<AcpThread>>>;
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct AgentModelId(pub SharedString);
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod];
|
||||
impl std::ops::Deref for AgentModelId {
|
||||
type Target = SharedString;
|
||||
|
||||
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App)
|
||||
-> Task<Result<acp::PromptResponse>>;
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
||||
|
||||
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
|
||||
///
|
||||
/// If the agent does not support model selection, returns [None].
|
||||
/// This allows sharing the selector in UI components.
|
||||
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
|
||||
None // Default impl for agents that don't support it
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthRequired;
|
||||
|
||||
impl Error for AuthRequired {}
|
||||
impl fmt::Display for AuthRequired {
|
||||
impl fmt::Display for AgentModelId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "AuthRequired")
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct AgentModelInfo {
|
||||
pub id: AgentModelId,
|
||||
pub name: SharedString,
|
||||
pub icon: Option<IconName>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct AgentModelGroupName(pub SharedString);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AgentModelList {
|
||||
Flat(Vec<AgentModelInfo>),
|
||||
Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
|
||||
}
|
||||
|
||||
impl AgentModelList {
|
||||
pub fn is_empty(&self) -> bool {
|
||||
match self {
|
||||
AgentModelList::Flat(models) => models.is_empty(),
|
||||
AgentModelList::Grouped(groups) => groups.is_empty(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
451
crates/acp_thread/src/mention.rs
Normal file
451
crates/acp_thread/src/mention.rs
Normal file
@@ -0,0 +1,451 @@
|
||||
use agent::ThreadId;
|
||||
use anyhow::{Context as _, Result, bail};
|
||||
use file_icons::FileIcons;
|
||||
use prompt_store::{PromptId, UserPromptId};
|
||||
use std::{
|
||||
fmt,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
use ui::{App, IconName, SharedString};
|
||||
use url::Url;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum MentionUri {
|
||||
File {
|
||||
abs_path: PathBuf,
|
||||
is_directory: bool,
|
||||
},
|
||||
Symbol {
|
||||
path: PathBuf,
|
||||
name: String,
|
||||
line_range: Range<u32>,
|
||||
},
|
||||
Thread {
|
||||
id: ThreadId,
|
||||
name: String,
|
||||
},
|
||||
TextThread {
|
||||
path: PathBuf,
|
||||
name: String,
|
||||
},
|
||||
Rule {
|
||||
id: PromptId,
|
||||
name: String,
|
||||
},
|
||||
Selection {
|
||||
path: PathBuf,
|
||||
line_range: Range<u32>,
|
||||
},
|
||||
Fetch {
|
||||
url: Url,
|
||||
},
|
||||
}
|
||||
|
||||
impl MentionUri {
|
||||
pub fn parse(input: &str) -> Result<Self> {
|
||||
let url = url::Url::parse(input)?;
|
||||
let path = url.path();
|
||||
match url.scheme() {
|
||||
"file" => {
|
||||
if let Some(fragment) = url.fragment() {
|
||||
let range = fragment
|
||||
.strip_prefix("L")
|
||||
.context("Line range must start with \"L\"")?;
|
||||
let (start, end) = range
|
||||
.split_once(":")
|
||||
.context("Line range must use colon as separator")?;
|
||||
let line_range = start
|
||||
.parse::<u32>()
|
||||
.context("Parsing line range start")?
|
||||
.checked_sub(1)
|
||||
.context("Line numbers should be 1-based")?
|
||||
..end
|
||||
.parse::<u32>()
|
||||
.context("Parsing line range end")?
|
||||
.checked_sub(1)
|
||||
.context("Line numbers should be 1-based")?;
|
||||
if let Some(name) = single_query_param(&url, "symbol")? {
|
||||
Ok(Self::Symbol {
|
||||
name,
|
||||
path: path.into(),
|
||||
line_range,
|
||||
})
|
||||
} else {
|
||||
Ok(Self::Selection {
|
||||
path: path.into(),
|
||||
line_range,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
let file_path =
|
||||
PathBuf::from(format!("{}{}", url.host_str().unwrap_or(""), path));
|
||||
let is_directory = input.ends_with("/");
|
||||
|
||||
Ok(Self::File {
|
||||
abs_path: file_path,
|
||||
is_directory,
|
||||
})
|
||||
}
|
||||
}
|
||||
"zed" => {
|
||||
if let Some(thread_id) = path.strip_prefix("/agent/thread/") {
|
||||
let name = single_query_param(&url, "name")?.context("Missing thread name")?;
|
||||
Ok(Self::Thread {
|
||||
id: thread_id.into(),
|
||||
name,
|
||||
})
|
||||
} else if let Some(path) = path.strip_prefix("/agent/text-thread/") {
|
||||
let name = single_query_param(&url, "name")?.context("Missing thread name")?;
|
||||
Ok(Self::TextThread {
|
||||
path: path.into(),
|
||||
name,
|
||||
})
|
||||
} else if let Some(rule_id) = path.strip_prefix("/agent/rule/") {
|
||||
let name = single_query_param(&url, "name")?.context("Missing rule name")?;
|
||||
let rule_id = UserPromptId(rule_id.parse()?);
|
||||
Ok(Self::Rule {
|
||||
id: rule_id.into(),
|
||||
name,
|
||||
})
|
||||
} else {
|
||||
bail!("invalid zed url: {:?}", input);
|
||||
}
|
||||
}
|
||||
"http" | "https" => Ok(MentionUri::Fetch { url }),
|
||||
other => bail!("unrecognized scheme {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name(&self) -> String {
|
||||
match self {
|
||||
MentionUri::File { abs_path, .. } => abs_path
|
||||
.file_name()
|
||||
.unwrap_or_default()
|
||||
.to_string_lossy()
|
||||
.into_owned(),
|
||||
MentionUri::Symbol { name, .. } => name.clone(),
|
||||
MentionUri::Thread { name, .. } => name.clone(),
|
||||
MentionUri::TextThread { name, .. } => name.clone(),
|
||||
MentionUri::Rule { name, .. } => name.clone(),
|
||||
MentionUri::Selection {
|
||||
path, line_range, ..
|
||||
} => selection_name(path, line_range),
|
||||
MentionUri::Fetch { url } => url.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn icon_path(&self, cx: &mut App) -> SharedString {
|
||||
match self {
|
||||
MentionUri::File {
|
||||
abs_path,
|
||||
is_directory,
|
||||
} => {
|
||||
if *is_directory {
|
||||
FileIcons::get_folder_icon(false, cx)
|
||||
.unwrap_or_else(|| IconName::Folder.path().into())
|
||||
} else {
|
||||
FileIcons::get_icon(&abs_path, cx)
|
||||
.unwrap_or_else(|| IconName::File.path().into())
|
||||
}
|
||||
}
|
||||
MentionUri::Symbol { .. } => IconName::Code.path().into(),
|
||||
MentionUri::Thread { .. } => IconName::Thread.path().into(),
|
||||
MentionUri::TextThread { .. } => IconName::Thread.path().into(),
|
||||
MentionUri::Rule { .. } => IconName::Reader.path().into(),
|
||||
MentionUri::Selection { .. } => IconName::Reader.path().into(),
|
||||
MentionUri::Fetch { .. } => IconName::ToolWeb.path().into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_link<'a>(&'a self) -> MentionLink<'a> {
|
||||
MentionLink(self)
|
||||
}
|
||||
|
||||
pub fn to_uri(&self) -> Url {
|
||||
match self {
|
||||
MentionUri::File {
|
||||
abs_path,
|
||||
is_directory,
|
||||
} => {
|
||||
let mut url = Url::parse("file:///").unwrap();
|
||||
let mut path = abs_path.to_string_lossy().to_string();
|
||||
if *is_directory && !path.ends_with("/") {
|
||||
path.push_str("/");
|
||||
}
|
||||
url.set_path(&path);
|
||||
url
|
||||
}
|
||||
MentionUri::Symbol {
|
||||
path,
|
||||
name,
|
||||
line_range,
|
||||
} => {
|
||||
let mut url = Url::parse("file:///").unwrap();
|
||||
url.set_path(&path.to_string_lossy());
|
||||
url.query_pairs_mut().append_pair("symbol", name);
|
||||
url.set_fragment(Some(&format!(
|
||||
"L{}:{}",
|
||||
line_range.start + 1,
|
||||
line_range.end + 1
|
||||
)));
|
||||
url
|
||||
}
|
||||
MentionUri::Selection { path, line_range } => {
|
||||
let mut url = Url::parse("file:///").unwrap();
|
||||
url.set_path(&path.to_string_lossy());
|
||||
url.set_fragment(Some(&format!(
|
||||
"L{}:{}",
|
||||
line_range.start + 1,
|
||||
line_range.end + 1
|
||||
)));
|
||||
url
|
||||
}
|
||||
MentionUri::Thread { name, id } => {
|
||||
let mut url = Url::parse("zed:///").unwrap();
|
||||
url.set_path(&format!("/agent/thread/{id}"));
|
||||
url.query_pairs_mut().append_pair("name", name);
|
||||
url
|
||||
}
|
||||
MentionUri::TextThread { path, name } => {
|
||||
let mut url = Url::parse("zed:///").unwrap();
|
||||
url.set_path(&format!("/agent/text-thread/{}", path.to_string_lossy()));
|
||||
url.query_pairs_mut().append_pair("name", name);
|
||||
url
|
||||
}
|
||||
MentionUri::Rule { name, id } => {
|
||||
let mut url = Url::parse("zed:///").unwrap();
|
||||
url.set_path(&format!("/agent/rule/{id}"));
|
||||
url.query_pairs_mut().append_pair("name", name);
|
||||
url
|
||||
}
|
||||
MentionUri::Fetch { url } => url.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MentionLink<'a>(&'a MentionUri);
|
||||
|
||||
impl fmt::Display for MentionLink<'_> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "[@{}]({})", self.0.name(), self.0.to_uri())
|
||||
}
|
||||
}
|
||||
|
||||
fn single_query_param(url: &Url, name: &'static str) -> Result<Option<String>> {
|
||||
let pairs = url.query_pairs().collect::<Vec<_>>();
|
||||
match pairs.as_slice() {
|
||||
[] => Ok(None),
|
||||
[(k, v)] => {
|
||||
if k != name {
|
||||
bail!("invalid query parameter")
|
||||
}
|
||||
|
||||
Ok(Some(v.to_string()))
|
||||
}
|
||||
_ => bail!("too many query pairs"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn selection_name(path: &Path, line_range: &Range<u32>) -> String {
|
||||
format!(
|
||||
"{} ({}:{})",
|
||||
path.file_name().unwrap_or_default().display(),
|
||||
line_range.start + 1,
|
||||
line_range.end + 1
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_file_uri() {
|
||||
let file_uri = "file:///path/to/file.rs";
|
||||
let parsed = MentionUri::parse(file_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::File {
|
||||
abs_path,
|
||||
is_directory,
|
||||
} => {
|
||||
assert_eq!(abs_path.to_str().unwrap(), "/path/to/file.rs");
|
||||
assert!(!is_directory);
|
||||
}
|
||||
_ => panic!("Expected File variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri().to_string(), file_uri);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_directory_uri() {
|
||||
let file_uri = "file:///path/to/dir/";
|
||||
let parsed = MentionUri::parse(file_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::File {
|
||||
abs_path,
|
||||
is_directory,
|
||||
} => {
|
||||
assert_eq!(abs_path.to_str().unwrap(), "/path/to/dir");
|
||||
assert!(is_directory);
|
||||
}
|
||||
_ => panic!("Expected File variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri().to_string(), file_uri);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_directory_uri_with_slash() {
|
||||
let uri = MentionUri::File {
|
||||
abs_path: PathBuf::from("/path/to/dir/"),
|
||||
is_directory: true,
|
||||
};
|
||||
assert_eq!(uri.to_uri().to_string(), "file:///path/to/dir/");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_directory_uri_without_slash() {
|
||||
let uri = MentionUri::File {
|
||||
abs_path: PathBuf::from("/path/to/dir"),
|
||||
is_directory: true,
|
||||
};
|
||||
assert_eq!(uri.to_uri().to_string(), "file:///path/to/dir/");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_symbol_uri() {
|
||||
let symbol_uri = "file:///path/to/file.rs?symbol=MySymbol#L10:20";
|
||||
let parsed = MentionUri::parse(symbol_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::Symbol {
|
||||
path,
|
||||
name,
|
||||
line_range,
|
||||
} => {
|
||||
assert_eq!(path.to_str().unwrap(), "/path/to/file.rs");
|
||||
assert_eq!(name, "MySymbol");
|
||||
assert_eq!(line_range.start, 9);
|
||||
assert_eq!(line_range.end, 19);
|
||||
}
|
||||
_ => panic!("Expected Symbol variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri().to_string(), symbol_uri);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_selection_uri() {
|
||||
let selection_uri = "file:///path/to/file.rs#L5:15";
|
||||
let parsed = MentionUri::parse(selection_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::Selection { path, line_range } => {
|
||||
assert_eq!(path.to_str().unwrap(), "/path/to/file.rs");
|
||||
assert_eq!(line_range.start, 4);
|
||||
assert_eq!(line_range.end, 14);
|
||||
}
|
||||
_ => panic!("Expected Selection variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri().to_string(), selection_uri);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_thread_uri() {
|
||||
let thread_uri = "zed:///agent/thread/session123?name=Thread+name";
|
||||
let parsed = MentionUri::parse(thread_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::Thread {
|
||||
id: thread_id,
|
||||
name,
|
||||
} => {
|
||||
assert_eq!(thread_id.to_string(), "session123");
|
||||
assert_eq!(name, "Thread name");
|
||||
}
|
||||
_ => panic!("Expected Thread variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri().to_string(), thread_uri);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_rule_uri() {
|
||||
let rule_uri = "zed:///agent/rule/d8694ff2-90d5-4b6f-be33-33c1763acd52?name=Some+rule";
|
||||
let parsed = MentionUri::parse(rule_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::Rule { id, name } => {
|
||||
assert_eq!(id.to_string(), "d8694ff2-90d5-4b6f-be33-33c1763acd52");
|
||||
assert_eq!(name, "Some rule");
|
||||
}
|
||||
_ => panic!("Expected Rule variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri().to_string(), rule_uri);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_fetch_http_uri() {
|
||||
let http_uri = "http://example.com/path?query=value#fragment";
|
||||
let parsed = MentionUri::parse(http_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::Fetch { url } => {
|
||||
assert_eq!(url.to_string(), http_uri);
|
||||
}
|
||||
_ => panic!("Expected Fetch variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri().to_string(), http_uri);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_fetch_https_uri() {
|
||||
let https_uri = "https://example.com/api/endpoint";
|
||||
let parsed = MentionUri::parse(https_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::Fetch { url } => {
|
||||
assert_eq!(url.to_string(), https_uri);
|
||||
}
|
||||
_ => panic!("Expected Fetch variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri().to_string(), https_uri);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_scheme() {
|
||||
assert!(MentionUri::parse("ftp://example.com").is_err());
|
||||
assert!(MentionUri::parse("ssh://example.com").is_err());
|
||||
assert!(MentionUri::parse("unknown://example.com").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_zed_path() {
|
||||
assert!(MentionUri::parse("zed:///invalid/path").is_err());
|
||||
assert!(MentionUri::parse("zed:///agent/unknown/test").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_line_range_format() {
|
||||
// Missing L prefix
|
||||
assert!(MentionUri::parse("file:///path/to/file.rs#10:20").is_err());
|
||||
|
||||
// Missing colon separator
|
||||
assert!(MentionUri::parse("file:///path/to/file.rs#L1020").is_err());
|
||||
|
||||
// Invalid numbers
|
||||
assert!(MentionUri::parse("file:///path/to/file.rs#L10:abc").is_err());
|
||||
assert!(MentionUri::parse("file:///path/to/file.rs#Labc:20").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_query_parameters() {
|
||||
// Invalid query parameter name
|
||||
assert!(MentionUri::parse("file:///path/to/file.rs#L10:20?invalid=test").is_err());
|
||||
|
||||
// Too many query parameters
|
||||
assert!(
|
||||
MentionUri::parse("file:///path/to/file.rs#L10:20?symbol=test&another=param").is_err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_based_line_numbers() {
|
||||
// Test that 0-based line numbers are rejected (should be 1-based)
|
||||
assert!(MentionUri::parse("file:///path/to/file.rs#L0:10").is_err());
|
||||
assert!(MentionUri::parse("file:///path/to/file.rs#L1:0").is_err());
|
||||
assert!(MentionUri::parse("file:///path/to/file.rs#L0:0").is_err());
|
||||
}
|
||||
}
|
||||
@@ -29,8 +29,14 @@ impl Terminal {
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
Self {
|
||||
command: cx
|
||||
.new(|cx| Markdown::new(command.into(), Some(language_registry.clone()), None, cx)),
|
||||
command: cx.new(|cx| {
|
||||
Markdown::new(
|
||||
format!("```\n{}\n```", command).into(),
|
||||
Some(language_registry.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
}),
|
||||
working_dir,
|
||||
terminal,
|
||||
started_at: Instant::now(),
|
||||
|
||||
@@ -17,8 +17,6 @@ use util::{
|
||||
pub struct ActionLog {
|
||||
/// Buffers that we want to notify the model about when they change.
|
||||
tracked_buffers: BTreeMap<Entity<Buffer>, TrackedBuffer>,
|
||||
/// Has the model edited a file since it last checked diagnostics?
|
||||
edited_since_project_diagnostics_check: bool,
|
||||
/// The project this action log is associated with
|
||||
project: Entity<Project>,
|
||||
}
|
||||
@@ -28,7 +26,6 @@ impl ActionLog {
|
||||
pub fn new(project: Entity<Project>) -> Self {
|
||||
Self {
|
||||
tracked_buffers: BTreeMap::default(),
|
||||
edited_since_project_diagnostics_check: false,
|
||||
project,
|
||||
}
|
||||
}
|
||||
@@ -37,16 +34,6 @@ impl ActionLog {
|
||||
&self.project
|
||||
}
|
||||
|
||||
/// Notifies a diagnostics check
|
||||
pub fn checked_project_diagnostics(&mut self) {
|
||||
self.edited_since_project_diagnostics_check = false;
|
||||
}
|
||||
|
||||
/// Returns true if any files have been edited since the last project diagnostics check
|
||||
pub fn has_edited_files_since_project_diagnostics_check(&self) -> bool {
|
||||
self.edited_since_project_diagnostics_check
|
||||
}
|
||||
|
||||
pub fn latest_snapshot(&self, buffer: &Entity<Buffer>) -> Option<text::BufferSnapshot> {
|
||||
Some(self.tracked_buffers.get(buffer)?.snapshot.clone())
|
||||
}
|
||||
@@ -543,14 +530,11 @@ impl ActionLog {
|
||||
|
||||
/// Mark a buffer as created by agent, so we can refresh it in the context
|
||||
pub fn buffer_created(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.edited_since_project_diagnostics_check = true;
|
||||
self.track_buffer_internal(buffer.clone(), true, cx);
|
||||
}
|
||||
|
||||
/// Mark a buffer as edited by agent, so we can refresh it in the context
|
||||
pub fn buffer_edited(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.edited_since_project_diagnostics_check = true;
|
||||
|
||||
let tracked_buffer = self.track_buffer_internal(buffer.clone(), false, cx);
|
||||
if let TrackedBufferStatus::Deleted = tracked_buffer.status {
|
||||
tracked_buffer.status = TrackedBufferStatus::Modified;
|
||||
|
||||
@@ -716,18 +716,10 @@ impl ActivityIndicator {
|
||||
})),
|
||||
tooltip_message: Some(Self::version_tooltip_message(&version)),
|
||||
}),
|
||||
AutoUpdateStatus::Updated {
|
||||
binary_path,
|
||||
version,
|
||||
} => Some(Content {
|
||||
AutoUpdateStatus::Updated { version } => Some(Content {
|
||||
icon: None,
|
||||
message: "Click to restart and update Zed".to_string(),
|
||||
on_click: Some(Arc::new({
|
||||
let reload = workspace::Reload {
|
||||
binary_path: Some(binary_path.clone()),
|
||||
};
|
||||
move |_, _, cx| workspace::reload(&reload, cx)
|
||||
})),
|
||||
on_click: Some(Arc::new(move |_, _, cx| workspace::reload(cx))),
|
||||
tooltip_message: Some(Self::version_tooltip_message(&version)),
|
||||
}),
|
||||
AutoUpdateStatus::Errored => Some(Content {
|
||||
|
||||
@@ -2268,6 +2268,15 @@ impl Thread {
|
||||
max_attempts: 3,
|
||||
})
|
||||
}
|
||||
Other(err)
|
||||
if err.is::<PaymentRequiredError>()
|
||||
|| err.is::<ModelRequestLimitReachedError>() =>
|
||||
{
|
||||
// Retrying won't help for Payment Required or Model Request Limit errors (where
|
||||
// the user must upgrade to usage-based billing to get more requests, or else wait
|
||||
// for a significant amount of time for the request limit to reset).
|
||||
None
|
||||
}
|
||||
// Conservatively assume that any other errors are non-retryable
|
||||
HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
|
||||
delay: BASE_RETRY_DELAY,
|
||||
|
||||
@@ -23,10 +23,13 @@ assistant_tools.workspace = true
|
||||
chrono.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
context_server.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
handlebars = { workspace = true, features = ["rust-embed"] }
|
||||
html_to_markdown.workspace = true
|
||||
http_client.workspace = true
|
||||
indoc.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
@@ -46,6 +49,7 @@ settings.workspace = true
|
||||
smol.workspace = true
|
||||
task.workspace = true
|
||||
terminal.workspace = true
|
||||
text.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
@@ -58,6 +62,7 @@ workspace-hack.workspace = true
|
||||
ctor.workspace = true
|
||||
client = { workspace = true, "features" = ["test-support"] }
|
||||
clock = { workspace = true, "features" = ["test-support"] }
|
||||
context_server = { workspace = true, "features" = ["test-support"] }
|
||||
editor = { workspace = true, "features" = ["test-support"] }
|
||||
env_logger.workspace = true
|
||||
fs = { workspace = true, "features" = ["test-support"] }
|
||||
|
||||
@@ -1,21 +1,26 @@
|
||||
use crate::{AgentResponseEvent, Thread, templates::Templates};
|
||||
use crate::{
|
||||
CopyPathTool, CreateDirectoryTool, EditFileTool, FindPathTool, GrepTool, ListDirectoryTool,
|
||||
MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool,
|
||||
ToolCallAuthorization, WebSearchTool,
|
||||
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool,
|
||||
FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool,
|
||||
ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent,
|
||||
WebSearchTool,
|
||||
};
|
||||
use acp_thread::ModelSelector;
|
||||
use acp_thread::AgentModelSelector;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashSet, IndexMap};
|
||||
use fs::Fs;
|
||||
use futures::{StreamExt, future};
|
||||
use gpui::{
|
||||
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
|
||||
};
|
||||
use language_model::{LanguageModel, LanguageModelRegistry};
|
||||
use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
|
||||
use project::{Project, ProjectItem, ProjectPath, Worktree};
|
||||
use prompt_store::{
|
||||
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
|
||||
};
|
||||
use settings::update_settings_file;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
@@ -48,6 +53,104 @@ struct Session {
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
pub struct LanguageModels {
|
||||
/// Access language model by ID
|
||||
models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
|
||||
/// Cached list for returning language model information
|
||||
model_list: acp_thread::AgentModelList,
|
||||
refresh_models_rx: watch::Receiver<()>,
|
||||
refresh_models_tx: watch::Sender<()>,
|
||||
}
|
||||
|
||||
impl LanguageModels {
|
||||
fn new(cx: &App) -> Self {
|
||||
let (refresh_models_tx, refresh_models_rx) = watch::channel(());
|
||||
let mut this = Self {
|
||||
models: HashMap::default(),
|
||||
model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
|
||||
refresh_models_rx,
|
||||
refresh_models_tx,
|
||||
};
|
||||
this.refresh_list(cx);
|
||||
this
|
||||
}
|
||||
|
||||
fn refresh_list(&mut self, cx: &App) {
|
||||
let providers = LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.providers()
|
||||
.into_iter()
|
||||
.filter(|provider| provider.is_authenticated(cx))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut language_model_list = IndexMap::default();
|
||||
let mut recommended_models = HashSet::default();
|
||||
|
||||
let mut recommended = Vec::new();
|
||||
for provider in &providers {
|
||||
for model in provider.recommended_models(cx) {
|
||||
recommended_models.insert(model.id());
|
||||
recommended.push(Self::map_language_model_to_info(&model, &provider));
|
||||
}
|
||||
}
|
||||
if !recommended.is_empty() {
|
||||
language_model_list.insert(
|
||||
acp_thread::AgentModelGroupName("Recommended".into()),
|
||||
recommended,
|
||||
);
|
||||
}
|
||||
|
||||
let mut models = HashMap::default();
|
||||
for provider in providers {
|
||||
let mut provider_models = Vec::new();
|
||||
for model in provider.provided_models(cx) {
|
||||
let model_info = Self::map_language_model_to_info(&model, &provider);
|
||||
let model_id = model_info.id.clone();
|
||||
if !recommended_models.contains(&model.id()) {
|
||||
provider_models.push(model_info);
|
||||
}
|
||||
models.insert(model_id, model);
|
||||
}
|
||||
if !provider_models.is_empty() {
|
||||
language_model_list.insert(
|
||||
acp_thread::AgentModelGroupName(provider.name().0.clone()),
|
||||
provider_models,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
self.models = models;
|
||||
self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
|
||||
self.refresh_models_tx.send(()).ok();
|
||||
}
|
||||
|
||||
fn watch(&self) -> watch::Receiver<()> {
|
||||
self.refresh_models_rx.clone()
|
||||
}
|
||||
|
||||
pub fn model_from_id(
|
||||
&self,
|
||||
model_id: &acp_thread::AgentModelId,
|
||||
) -> Option<Arc<dyn LanguageModel>> {
|
||||
self.models.get(model_id).cloned()
|
||||
}
|
||||
|
||||
fn map_language_model_to_info(
|
||||
model: &Arc<dyn LanguageModel>,
|
||||
provider: &Arc<dyn LanguageModelProvider>,
|
||||
) -> acp_thread::AgentModelInfo {
|
||||
acp_thread::AgentModelInfo {
|
||||
id: Self::model_id(model),
|
||||
name: model.name().0,
|
||||
icon: Some(provider.icon()),
|
||||
}
|
||||
}
|
||||
|
||||
fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
|
||||
acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NativeAgent {
|
||||
/// Session ID -> Session mapping
|
||||
sessions: HashMap<acp::SessionId, Session>,
|
||||
@@ -55,10 +158,14 @@ pub struct NativeAgent {
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
project_context_needs_refresh: watch::Sender<()>,
|
||||
_maintain_project_context: Task<Result<()>>,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
/// Shared templates for all threads
|
||||
templates: Arc<Templates>,
|
||||
/// Cached model information
|
||||
models: LanguageModels,
|
||||
project: Entity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
fs: Arc<dyn Fs>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
@@ -67,6 +174,7 @@ impl NativeAgent {
|
||||
project: Entity<Project>,
|
||||
templates: Arc<Templates>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<NativeAgent>> {
|
||||
log::info!("Creating new NativeAgent");
|
||||
@@ -76,7 +184,13 @@ impl NativeAgent {
|
||||
.await;
|
||||
|
||||
cx.new(|cx| {
|
||||
let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
|
||||
let mut subscriptions = vec![
|
||||
cx.subscribe(&project, Self::handle_project_event),
|
||||
cx.subscribe(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
Self::handle_models_updated_event,
|
||||
),
|
||||
];
|
||||
if let Some(prompt_store) = prompt_store.as_ref() {
|
||||
subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
|
||||
}
|
||||
@@ -90,14 +204,23 @@ impl NativeAgent {
|
||||
_maintain_project_context: cx.spawn(async move |this, cx| {
|
||||
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
|
||||
}),
|
||||
context_server_registry: cx.new(|cx| {
|
||||
ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
|
||||
}),
|
||||
templates,
|
||||
models: LanguageModels::new(cx),
|
||||
project,
|
||||
prompt_store,
|
||||
fs,
|
||||
_subscriptions: subscriptions,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn models(&self) -> &LanguageModels {
|
||||
&self.models
|
||||
}
|
||||
|
||||
async fn maintain_project_context(
|
||||
this: WeakEntity<Self>,
|
||||
mut needs_refresh: watch::Receiver<()>,
|
||||
@@ -293,75 +416,104 @@ impl NativeAgent {
|
||||
) {
|
||||
self.project_context_needs_refresh.send(()).ok();
|
||||
}
|
||||
|
||||
fn handle_models_updated_event(
|
||||
&mut self,
|
||||
_registry: Entity<LanguageModelRegistry>,
|
||||
_event: &language_model::Event,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.models.refresh_list(cx);
|
||||
for session in self.sessions.values_mut() {
|
||||
session.thread.update(cx, |thread, _| {
|
||||
let model_id = LanguageModels::model_id(&thread.selected_model);
|
||||
if let Some(model) = self.models.model_from_id(&model_id) {
|
||||
thread.selected_model = model.clone();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper struct that implements the AgentConnection trait
|
||||
#[derive(Clone)]
|
||||
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
|
||||
|
||||
impl ModelSelector for NativeAgentConnection {
|
||||
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
|
||||
impl AgentModelSelector for NativeAgentConnection {
|
||||
fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
|
||||
log::debug!("NativeAgentConnection::list_models called");
|
||||
cx.spawn(async move |cx| {
|
||||
cx.update(|cx| {
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
let models = registry.available_models(cx).collect::<Vec<_>>();
|
||||
log::info!("Found {} available models", models.len());
|
||||
if models.is_empty() {
|
||||
Err(anyhow::anyhow!("No models available"))
|
||||
} else {
|
||||
Ok(models)
|
||||
}
|
||||
})?
|
||||
let list = self.0.read(cx).models.model_list.clone();
|
||||
Task::ready(if list.is_empty() {
|
||||
Err(anyhow::anyhow!("No models available"))
|
||||
} else {
|
||||
Ok(list)
|
||||
})
|
||||
}
|
||||
|
||||
fn select_model(
|
||||
&self,
|
||||
session_id: acp::SessionId,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut AsyncApp,
|
||||
model_id: acp_thread::AgentModelId,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<()>> {
|
||||
log::info!(
|
||||
"Setting model for session {}: {:?}",
|
||||
session_id,
|
||||
model.name()
|
||||
);
|
||||
let agent = self.0.clone();
|
||||
log::info!("Setting model for session {}: {}", session_id, model_id);
|
||||
let Some(thread) = self
|
||||
.0
|
||||
.read(cx)
|
||||
.sessions
|
||||
.get(&session_id)
|
||||
.map(|session| session.thread.clone())
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("Session not found")));
|
||||
};
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
agent.update(cx, |agent, cx| {
|
||||
if let Some(session) = agent.sessions.get(&session_id) {
|
||||
session.thread.update(cx, |thread, _cx| {
|
||||
thread.selected_model = model;
|
||||
});
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!("Session not found"))
|
||||
}
|
||||
})?
|
||||
})
|
||||
let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
|
||||
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
|
||||
};
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.selected_model = model.clone();
|
||||
});
|
||||
|
||||
update_settings_file::<AgentSettings>(
|
||||
self.0.read(cx).fs.clone(),
|
||||
cx,
|
||||
move |settings, _cx| {
|
||||
settings.set_model(model);
|
||||
},
|
||||
);
|
||||
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn selected_model(
|
||||
&self,
|
||||
session_id: &acp::SessionId,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Arc<dyn LanguageModel>>> {
|
||||
let agent = self.0.clone();
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp_thread::AgentModelInfo>> {
|
||||
let session_id = session_id.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let thread = agent
|
||||
.read_with(cx, |agent, _| {
|
||||
agent
|
||||
.sessions
|
||||
.get(&session_id)
|
||||
.map(|session| session.thread.clone())
|
||||
})?
|
||||
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
|
||||
let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
|
||||
Ok(selected)
|
||||
})
|
||||
|
||||
let Some(thread) = self
|
||||
.0
|
||||
.read(cx)
|
||||
.sessions
|
||||
.get(&session_id)
|
||||
.map(|session| session.thread.clone())
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("Session not found")));
|
||||
};
|
||||
let model = thread.read(cx).selected_model.clone();
|
||||
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("Provider not found")));
|
||||
};
|
||||
Task::ready(Ok(LanguageModels::map_language_model_to_info(
|
||||
&model, &provider,
|
||||
)))
|
||||
}
|
||||
|
||||
fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
|
||||
self.0.read(cx).models.watch()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -385,7 +537,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
// Create AcpThread
|
||||
let acp_thread = cx.update(|cx| {
|
||||
cx.new(|cx| {
|
||||
acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
|
||||
acp_thread::AcpThread::new(
|
||||
"agent2",
|
||||
self.clone(),
|
||||
project.clone(),
|
||||
session_id.clone(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})?;
|
||||
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
|
||||
@@ -403,28 +561,37 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
|
||||
let default_model = registry
|
||||
.default_model()
|
||||
.map(|configured| {
|
||||
log::info!(
|
||||
"Using configured default model: {:?} from provider: {:?}",
|
||||
configured.model.name(),
|
||||
configured.provider.name()
|
||||
);
|
||||
configured.model
|
||||
.and_then(|default_model| {
|
||||
agent
|
||||
.models
|
||||
.model_from_id(&LanguageModels::model_id(&default_model.model))
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
log::warn!("No default model configured in settings");
|
||||
anyhow!("No default model configured. Please configure a default model in settings.")
|
||||
anyhow!(
|
||||
"No default model. Please configure a default model in settings."
|
||||
)
|
||||
})?;
|
||||
|
||||
let thread = cx.new(|cx| {
|
||||
let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
|
||||
let mut thread = Thread::new(
|
||||
project.clone(),
|
||||
agent.project_context.clone(),
|
||||
agent.context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
agent.templates.clone(),
|
||||
default_model,
|
||||
cx,
|
||||
);
|
||||
thread.add_tool(CreateDirectoryTool::new(project.clone()));
|
||||
thread.add_tool(CopyPathTool::new(project.clone()));
|
||||
thread.add_tool(DiagnosticsTool::new(project.clone()));
|
||||
thread.add_tool(MovePathTool::new(project.clone()));
|
||||
thread.add_tool(ListDirectoryTool::new(project.clone()));
|
||||
thread.add_tool(OpenTool::new(project.clone()));
|
||||
thread.add_tool(ThinkingTool);
|
||||
thread.add_tool(FindPathTool::new(project.clone()));
|
||||
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
|
||||
thread.add_tool(GrepTool::new(project.clone()));
|
||||
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
|
||||
thread.add_tool(EditFileTool::new(cx.entity()));
|
||||
@@ -448,7 +615,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
acp_thread: acp_thread.downgrade(),
|
||||
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
||||
this.sessions.remove(acp_thread.session_id());
|
||||
})
|
||||
}),
|
||||
},
|
||||
);
|
||||
})?;
|
||||
@@ -465,15 +632,17 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
|
||||
Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
|
||||
fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
|
||||
Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
|
||||
}
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
id: Option<acp_thread::UserMessageId>,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
let id = id.expect("UserMessageId is required");
|
||||
let session_id = params.session_id.clone();
|
||||
let agent = self.0.clone();
|
||||
log::info!("Received prompt request for session: {}", session_id);
|
||||
@@ -494,10 +663,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
})?;
|
||||
log::debug!("Found session for: {}", session_id);
|
||||
|
||||
// Convert prompt to message
|
||||
let message = convert_prompt_to_message(params.prompt);
|
||||
log::info!("Converted prompt to message: {} chars", message.len());
|
||||
log::debug!("Message content: {}", message);
|
||||
let content: Vec<UserMessageContent> = params
|
||||
.prompt
|
||||
.into_iter()
|
||||
.map(Into::into)
|
||||
.collect::<Vec<_>>();
|
||||
log::info!("Converted prompt to message: {} chars", content.len());
|
||||
log::debug!("Message id: {:?}", id);
|
||||
log::debug!("Message content: {:?}", content);
|
||||
|
||||
// Get model using the ModelSelector capability (always available for agent2)
|
||||
// Get the selected model from the thread directly
|
||||
@@ -505,7 +678,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
|
||||
// Send to thread
|
||||
log::info!("Sending message to thread with model: {:?}", model.name());
|
||||
let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
|
||||
let mut response_stream =
|
||||
thread.update(cx, |thread, cx| thread.send(id, content, cx))?;
|
||||
|
||||
// Handle response stream and forward to session.acp_thread
|
||||
while let Some(result) = response_stream.next().await {
|
||||
@@ -599,44 +773,33 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn session_editor(
|
||||
&self,
|
||||
session_id: &agent_client_protocol::SessionId,
|
||||
cx: &mut App,
|
||||
) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
|
||||
self.0.update(cx, |agent, _cx| {
|
||||
agent
|
||||
.sessions
|
||||
.get(session_id)
|
||||
.map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert ACP content blocks to a message string
|
||||
fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
|
||||
log::debug!("Converting {} content blocks to message", blocks.len());
|
||||
let mut message = String::new();
|
||||
struct NativeAgentSessionEditor(Entity<Thread>);
|
||||
|
||||
for block in blocks {
|
||||
match block {
|
||||
acp::ContentBlock::Text(text) => {
|
||||
log::trace!("Processing text block: {} chars", text.text.len());
|
||||
message.push_str(&text.text);
|
||||
}
|
||||
acp::ContentBlock::ResourceLink(link) => {
|
||||
log::trace!("Processing resource link: {}", link.uri);
|
||||
message.push_str(&format!(" @{} ", link.uri));
|
||||
}
|
||||
acp::ContentBlock::Image(_) => {
|
||||
log::trace!("Processing image block");
|
||||
message.push_str(" [image] ");
|
||||
}
|
||||
acp::ContentBlock::Audio(_) => {
|
||||
log::trace!("Processing audio block");
|
||||
message.push_str(" [audio] ");
|
||||
}
|
||||
acp::ContentBlock::Resource(resource) => {
|
||||
log::trace!("Processing resource block: {:?}", resource.resource);
|
||||
message.push_str(&format!(" [resource: {:?}] ", resource.resource));
|
||||
}
|
||||
}
|
||||
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
|
||||
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
|
||||
}
|
||||
|
||||
message
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
|
||||
use fs::FakeFs;
|
||||
use gpui::TestAppContext;
|
||||
use serde_json::json;
|
||||
@@ -654,9 +817,15 @@ mod tests {
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
let agent = NativeAgent::new(
|
||||
project.clone(),
|
||||
Templates::new(),
|
||||
None,
|
||||
fs.clone(),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
agent.read_with(cx, |agent, _| {
|
||||
assert_eq!(agent.project_context.borrow().worktrees, vec![])
|
||||
});
|
||||
@@ -697,13 +866,131 @@ mod tests {
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_listing_models(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/", json!({ "a": {} })).await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
let connection = NativeAgentConnection(
|
||||
NativeAgent::new(
|
||||
project.clone(),
|
||||
Templates::new(),
|
||||
None,
|
||||
fs.clone(),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
|
||||
|
||||
let acp_thread::AgentModelList::Grouped(models) = models else {
|
||||
panic!("Unexpected model group");
|
||||
};
|
||||
assert_eq!(
|
||||
models,
|
||||
IndexMap::from_iter([(
|
||||
AgentModelGroupName("Fake".into()),
|
||||
vec![AgentModelInfo {
|
||||
id: AgentModelId("fake/fake".into()),
|
||||
name: "Fake".into(),
|
||||
icon: Some(ui::IconName::ZedAssistant),
|
||||
}]
|
||||
)])
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.create_dir(paths::settings_file().parent().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
fs.insert_file(
|
||||
paths::settings_file(),
|
||||
json!({
|
||||
"agent": {
|
||||
"default_model": {
|
||||
"provider": "foo",
|
||||
"model": "bar"
|
||||
}
|
||||
}
|
||||
})
|
||||
.to_string()
|
||||
.into_bytes(),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
|
||||
// Create the agent and connection
|
||||
let agent = NativeAgent::new(
|
||||
project.clone(),
|
||||
Templates::new(),
|
||||
None,
|
||||
fs.clone(),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let connection = NativeAgentConnection(agent.clone());
|
||||
|
||||
// Create a thread/session
|
||||
let acp_thread = cx
|
||||
.update(|cx| {
|
||||
Rc::new(connection.clone()).new_thread(
|
||||
project.clone(),
|
||||
Path::new("/a"),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
|
||||
|
||||
// Select a model
|
||||
let model_id = AgentModelId("fake/fake".into());
|
||||
cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify the thread has the selected model
|
||||
agent.read_with(cx, |agent, _| {
|
||||
let session = agent.sessions.get(&session_id).unwrap();
|
||||
session.thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.selected_model.id().0, "fake");
|
||||
});
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
// Verify settings file was updated
|
||||
let settings_content = fs.load(paths::settings_file()).await.unwrap();
|
||||
let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
|
||||
|
||||
// Check that the agent settings contain the selected model
|
||||
assert_eq!(
|
||||
settings_json["agent"]["default_model"]["model"],
|
||||
json!("fake")
|
||||
);
|
||||
assert_eq!(
|
||||
settings_json["agent"]["default_model"]["provider"],
|
||||
json!("fake")
|
||||
);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
env_logger::try_init().ok();
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
agent_settings::init(cx);
|
||||
language::init(cx);
|
||||
LanguageModelRegistry::test(cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use std::{path::Path, rc::Rc, sync::Arc};
|
||||
|
||||
use agent_servers::AgentServer;
|
||||
use anyhow::Result;
|
||||
use fs::Fs;
|
||||
use gpui::{App, Entity, Task};
|
||||
use project::Project;
|
||||
use prompt_store::PromptStore;
|
||||
@@ -10,7 +10,15 @@ use prompt_store::PromptStore;
|
||||
use crate::{NativeAgent, NativeAgentConnection, templates::Templates};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NativeAgentServer;
|
||||
pub struct NativeAgentServer {
|
||||
fs: Arc<dyn Fs>,
|
||||
}
|
||||
|
||||
impl NativeAgentServer {
|
||||
pub fn new(fs: Arc<dyn Fs>) -> Self {
|
||||
Self { fs }
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentServer for NativeAgentServer {
|
||||
fn name(&self) -> &'static str {
|
||||
@@ -41,6 +49,7 @@ impl AgentServer for NativeAgentServer {
|
||||
_root_dir
|
||||
);
|
||||
let project = project.clone();
|
||||
let fs = self.fs.clone();
|
||||
let prompt_store = PromptStore::global(cx);
|
||||
cx.spawn(async move |cx| {
|
||||
log::debug!("Creating templates for native agent");
|
||||
@@ -48,7 +57,7 @@ impl AgentServer for NativeAgentServer {
|
||||
let prompt_store = prompt_store.await?;
|
||||
|
||||
log::debug!("Creating native agent entity");
|
||||
let agent = NativeAgent::new(project, templates, Some(prompt_store), cx).await?;
|
||||
let agent = NativeAgent::new(project, templates, Some(prompt_store), fs, cx).await?;
|
||||
|
||||
// Create the connection wrapper
|
||||
let connection = NativeAgentConnection(agent);
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use super::*;
|
||||
use acp_thread::AgentConnection;
|
||||
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use agent_settings::AgentProfileId;
|
||||
use anyhow::Result;
|
||||
use client::{Client, UserStore};
|
||||
use fs::{FakeFs, Fs};
|
||||
@@ -12,8 +13,8 @@ use gpui::{
|
||||
use indoc::indoc;
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
|
||||
StopReason, fake_provider::FakeLanguageModel,
|
||||
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason,
|
||||
fake_provider::FakeLanguageModel,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_store::ProjectContext;
|
||||
@@ -36,15 +37,19 @@ async fn test_echo(cx: &mut TestAppContext) {
|
||||
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send("Testing: Reply with 'Hello'", cx)
|
||||
thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.messages().last().unwrap().content,
|
||||
vec![MessageContent::Text("Hello".to_string())]
|
||||
);
|
||||
thread.last_message().unwrap().to_markdown(),
|
||||
indoc! {"
|
||||
## Assistant
|
||||
|
||||
Hello
|
||||
"}
|
||||
)
|
||||
});
|
||||
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
||||
}
|
||||
@@ -57,12 +62,13 @@ async fn test_thinking(cx: &mut TestAppContext) {
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(
|
||||
indoc! {"
|
||||
UserMessageId::new(),
|
||||
[indoc! {"
|
||||
Testing:
|
||||
|
||||
Generate a thinking step where you just think the word 'Think',
|
||||
and have your final answer be 'Hello'
|
||||
"},
|
||||
"}],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -70,9 +76,10 @@ async fn test_thinking(cx: &mut TestAppContext) {
|
||||
.await;
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.messages().last().unwrap().to_markdown(),
|
||||
thread.last_message().unwrap().to_markdown(),
|
||||
indoc! {"
|
||||
## assistant
|
||||
## Assistant
|
||||
|
||||
<think>Think</think>
|
||||
Hello
|
||||
"}
|
||||
@@ -93,7 +100,9 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
|
||||
|
||||
project_context.borrow_mut().shell = "test-shell".into();
|
||||
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
||||
thread.update(cx, |thread, cx| thread.send("abc", cx));
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
let mut pending_completions = fake_model.pending_completions();
|
||||
assert_eq!(
|
||||
@@ -130,7 +139,8 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||
.update(cx, |thread, cx| {
|
||||
thread.add_tool(EchoTool);
|
||||
thread.send(
|
||||
"Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
|
||||
UserMessageId::new(),
|
||||
["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -144,7 +154,11 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||
thread.remove_tool(&AgentTool::name(&EchoTool));
|
||||
thread.add_tool(DelayTool);
|
||||
thread.send(
|
||||
"Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
|
||||
UserMessageId::new(),
|
||||
[
|
||||
"Now call the delay tool with 200ms.",
|
||||
"When the timer goes off, then you echo the output of the tool.",
|
||||
],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -154,18 +168,21 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert!(
|
||||
thread
|
||||
.messages()
|
||||
.last()
|
||||
.last_message()
|
||||
.unwrap()
|
||||
.as_agent_message()
|
||||
.unwrap()
|
||||
.content
|
||||
.iter()
|
||||
.any(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
if let AgentMessageContent::Text(text) = content {
|
||||
text.contains("Ding")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
}),
|
||||
"{}",
|
||||
thread.to_markdown()
|
||||
);
|
||||
});
|
||||
}
|
||||
@@ -178,7 +195,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||
// Test a tool call that's likely to complete *before* streaming stops.
|
||||
let mut events = thread.update(cx, |thread, cx| {
|
||||
thread.add_tool(WordListTool);
|
||||
thread.send("Test the word_list tool.", cx)
|
||||
thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
|
||||
});
|
||||
|
||||
let mut saw_partial_tool_use = false;
|
||||
@@ -186,8 +203,10 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
|
||||
thread.update(cx, |thread, _cx| {
|
||||
// Look for a tool use in the thread's last message
|
||||
let last_content = thread.messages().last().unwrap().content.last().unwrap();
|
||||
if let MessageContent::ToolUse(last_tool_use) = last_content {
|
||||
let message = thread.last_message().unwrap();
|
||||
let agent_message = message.as_agent_message().unwrap();
|
||||
let last_content = agent_message.content.last().unwrap();
|
||||
if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
|
||||
assert_eq!(last_tool_use.name.as_ref(), "word_list");
|
||||
if tool_call.status == acp::ToolCallStatus::Pending {
|
||||
if !last_tool_use.is_input_complete
|
||||
@@ -225,7 +244,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||
|
||||
let mut events = thread.update(cx, |thread, cx| {
|
||||
thread.add_tool(ToolRequiringPermission);
|
||||
thread.send("abc", cx)
|
||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
@@ -269,14 +288,14 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||
assert_eq!(
|
||||
message.content,
|
||||
vec![
|
||||
MessageContent::ToolResult(LanguageModelToolResult {
|
||||
language_model::MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
|
||||
tool_name: ToolRequiringPermission.name().into(),
|
||||
is_error: false,
|
||||
content: "Allowed".into(),
|
||||
output: Some("Allowed".into())
|
||||
}),
|
||||
MessageContent::ToolResult(LanguageModelToolResult {
|
||||
language_model::MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
|
||||
tool_name: ToolRequiringPermission.name().into(),
|
||||
is_error: true,
|
||||
@@ -309,13 +328,15 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||
let message = completion.messages.last().unwrap();
|
||||
assert_eq!(
|
||||
message.content,
|
||||
vec![MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
|
||||
tool_name: ToolRequiringPermission.name().into(),
|
||||
is_error: false,
|
||||
content: "Allowed".into(),
|
||||
output: Some("Allowed".into())
|
||||
})]
|
||||
vec![language_model::MessageContent::ToolResult(
|
||||
LanguageModelToolResult {
|
||||
tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
|
||||
tool_name: ToolRequiringPermission.name().into(),
|
||||
is_error: false,
|
||||
content: "Allowed".into(),
|
||||
output: Some("Allowed".into())
|
||||
}
|
||||
)]
|
||||
);
|
||||
|
||||
// Simulate a final tool call, ensuring we don't trigger authorization.
|
||||
@@ -334,13 +355,15 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||
let message = completion.messages.last().unwrap();
|
||||
assert_eq!(
|
||||
message.content,
|
||||
vec![MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: "tool_id_4".into(),
|
||||
tool_name: ToolRequiringPermission.name().into(),
|
||||
is_error: false,
|
||||
content: "Allowed".into(),
|
||||
output: Some("Allowed".into())
|
||||
})]
|
||||
vec![language_model::MessageContent::ToolResult(
|
||||
LanguageModelToolResult {
|
||||
tool_use_id: "tool_id_4".into(),
|
||||
tool_name: ToolRequiringPermission.name().into(),
|
||||
is_error: false,
|
||||
content: "Allowed".into(),
|
||||
output: Some("Allowed".into())
|
||||
}
|
||||
)]
|
||||
);
|
||||
}
|
||||
|
||||
@@ -349,7 +372,9 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx));
|
||||
let mut events = thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
@@ -441,7 +466,12 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||
.update(cx, |thread, cx| {
|
||||
thread.add_tool(DelayTool);
|
||||
thread.send(
|
||||
"Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
|
||||
UserMessageId::new(),
|
||||
[
|
||||
"Call the delay tool twice in the same message.",
|
||||
"Once with 100ms. Once with 300ms.",
|
||||
"When both timers are complete, describe the outputs.",
|
||||
],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -452,12 +482,13 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||
assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
let last_message = thread.messages().last().unwrap();
|
||||
let text = last_message
|
||||
let last_message = thread.last_message().unwrap();
|
||||
let agent_message = last_message.as_agent_message().unwrap();
|
||||
let text = agent_message
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
if let AgentMessageContent::Text(text) = content {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
@@ -469,6 +500,82 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_profiles(cx: &mut TestAppContext) {
|
||||
let ThreadTest {
|
||||
model, thread, fs, ..
|
||||
} = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.add_tool(DelayTool);
|
||||
thread.add_tool(EchoTool);
|
||||
thread.add_tool(InfiniteTool);
|
||||
});
|
||||
|
||||
// Override profiles and wait for settings to be loaded.
|
||||
fs.insert_file(
|
||||
paths::settings_file(),
|
||||
json!({
|
||||
"agent": {
|
||||
"profiles": {
|
||||
"test-1": {
|
||||
"name": "Test Profile 1",
|
||||
"tools": {
|
||||
EchoTool.name(): true,
|
||||
DelayTool.name(): true,
|
||||
}
|
||||
},
|
||||
"test-2": {
|
||||
"name": "Test Profile 2",
|
||||
"tools": {
|
||||
InfiniteTool.name(): true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.to_string()
|
||||
.into_bytes(),
|
||||
)
|
||||
.await;
|
||||
cx.run_until_parked();
|
||||
|
||||
// Test that test-1 profile (default) has echo and delay tools
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_profile(AgentProfileId("test-1".into()));
|
||||
thread.send(UserMessageId::new(), ["test"], cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
|
||||
let mut pending_completions = fake_model.pending_completions();
|
||||
assert_eq!(pending_completions.len(), 1);
|
||||
let completion = pending_completions.pop().unwrap();
|
||||
let tool_names: Vec<String> = completion
|
||||
.tools
|
||||
.iter()
|
||||
.map(|tool| tool.name.clone())
|
||||
.collect();
|
||||
assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
|
||||
fake_model.end_last_completion_stream();
|
||||
|
||||
// Switch to test-2 profile, and verify that it has only the infinite tool.
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_profile(AgentProfileId("test-2".into()));
|
||||
thread.send(UserMessageId::new(), ["test2"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
let mut pending_completions = fake_model.pending_completions();
|
||||
assert_eq!(pending_completions.len(), 1);
|
||||
let completion = pending_completions.pop().unwrap();
|
||||
let tool_names: Vec<String> = completion
|
||||
.tools
|
||||
.iter()
|
||||
.map(|tool| tool.name.clone())
|
||||
.collect();
|
||||
assert_eq!(tool_names, vec![InfiniteTool.name()]);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "can't run on CI yet"]
|
||||
async fn test_cancellation(cx: &mut TestAppContext) {
|
||||
@@ -478,7 +585,8 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||
thread.add_tool(InfiniteTool);
|
||||
thread.add_tool(EchoTool);
|
||||
thread.send(
|
||||
"Call the echo tool and then call the infinite tool, then explain their output",
|
||||
UserMessageId::new(),
|
||||
["Call the echo tool, then call the infinite tool, then explain their output"],
|
||||
cx,
|
||||
)
|
||||
});
|
||||
@@ -523,14 +631,20 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||
// Ensure we can still send a new message after cancellation.
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send("Testing: reply with 'Hello' then stop.", cx)
|
||||
thread.send(
|
||||
UserMessageId::new(),
|
||||
["Testing: reply with 'Hello' then stop."],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
thread.update(cx, |thread, _cx| {
|
||||
let message = thread.last_message().unwrap();
|
||||
let agent_message = message.as_agent_message().unwrap();
|
||||
assert_eq!(
|
||||
thread.messages().last().unwrap().content,
|
||||
vec![MessageContent::Text("Hello".to_string())]
|
||||
agent_message.content,
|
||||
vec![AgentMessageContent::Text("Hello".to_string())]
|
||||
);
|
||||
});
|
||||
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
||||
@@ -541,13 +655,16 @@ async fn test_refusal(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let events = thread.update(cx, |thread, cx| thread.send("Hello", cx));
|
||||
let events = thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["Hello"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## user
|
||||
## User
|
||||
|
||||
Hello
|
||||
"}
|
||||
);
|
||||
@@ -559,9 +676,12 @@ async fn test_refusal(cx: &mut TestAppContext) {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## user
|
||||
## User
|
||||
|
||||
Hello
|
||||
## assistant
|
||||
|
||||
## Assistant
|
||||
|
||||
Hey!
|
||||
"}
|
||||
);
|
||||
@@ -577,6 +697,85 @@ async fn test_refusal(cx: &mut TestAppContext) {
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_truncate(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let message_id = UserMessageId::new();
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(message_id.clone(), ["Hello"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
Hello
|
||||
"}
|
||||
);
|
||||
});
|
||||
|
||||
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
Hello
|
||||
|
||||
## Assistant
|
||||
|
||||
Hey!
|
||||
"}
|
||||
);
|
||||
});
|
||||
|
||||
thread
|
||||
.update(cx, |thread, _cx| thread.truncate(message_id))
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.to_markdown(), "");
|
||||
});
|
||||
|
||||
// Ensure we can still send a new message after truncation.
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["Hi"], cx)
|
||||
});
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
Hi
|
||||
"}
|
||||
);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
fake_model.send_last_completion_stream_text_chunk("Ahoy!");
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
Hi
|
||||
|
||||
## Assistant
|
||||
|
||||
Ahoy!
|
||||
"}
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
cx.update(settings::init);
|
||||
@@ -595,19 +794,26 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
Project::init_settings(cx);
|
||||
LanguageModelRegistry::test(cx);
|
||||
agent_settings::init(cx);
|
||||
});
|
||||
cx.executor().forbid_parking();
|
||||
|
||||
// Create a project for new_thread
|
||||
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
|
||||
fake_fs.insert_tree(path!("/test"), json!({})).await;
|
||||
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
|
||||
let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
|
||||
let cwd = Path::new("/test");
|
||||
|
||||
// Create agent and connection
|
||||
let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
let agent = NativeAgent::new(
|
||||
project.clone(),
|
||||
templates.clone(),
|
||||
None,
|
||||
fake_fs.clone(),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let connection = NativeAgentConnection(agent.clone());
|
||||
|
||||
// Test model_selector returns Some
|
||||
@@ -620,22 +826,22 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
|
||||
// Test list_models
|
||||
let listed_models = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.list_models(&mut async_cx)
|
||||
})
|
||||
.update(|cx| selector.list_models(cx))
|
||||
.await
|
||||
.expect("list_models should succeed");
|
||||
let AgentModelList::Grouped(listed_models) = listed_models else {
|
||||
panic!("Unexpected model list type");
|
||||
};
|
||||
assert!(!listed_models.is_empty(), "should have at least one model");
|
||||
assert_eq!(listed_models[0].id().0, "fake");
|
||||
assert_eq!(
|
||||
listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
|
||||
"fake/fake"
|
||||
);
|
||||
|
||||
// Create a thread using new_thread
|
||||
let connection_rc = Rc::new(connection.clone());
|
||||
let acp_thread = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
connection_rc.new_thread(project, cwd, &mut async_cx)
|
||||
})
|
||||
.update(|cx| connection_rc.new_thread(project, cwd, &mut cx.to_async()))
|
||||
.await
|
||||
.expect("new_thread should succeed");
|
||||
|
||||
@@ -644,12 +850,12 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
|
||||
// Test selected_model returns the default
|
||||
let model = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.selected_model(&session_id, &mut async_cx)
|
||||
})
|
||||
.update(|cx| selector.selected_model(&session_id, cx))
|
||||
.await
|
||||
.expect("selected_model should succeed");
|
||||
let model = cx
|
||||
.update(|cx| agent.read(cx).models().model_from_id(&model.id))
|
||||
.unwrap();
|
||||
let model = model.as_fake();
|
||||
assert_eq!(model.id().0, "fake", "should return default model");
|
||||
|
||||
@@ -683,6 +889,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
connection.prompt(
|
||||
Some(acp_thread::UserMessageId::new()),
|
||||
acp::PromptRequest {
|
||||
session_id: session_id.clone(),
|
||||
prompt: vec!["ghi".into()],
|
||||
@@ -705,7 +912,9 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
||||
thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx));
|
||||
let mut events = thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["Think"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
|
||||
// Simulate streaming partial input.
|
||||
@@ -790,6 +999,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
||||
id: acp::ToolCallId("1".into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::Completed),
|
||||
raw_output: Some("Finished thinking.".into()),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
@@ -813,6 +1023,7 @@ struct ThreadTest {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
thread: Entity<Thread>,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
fs: Arc<FakeFs>,
|
||||
}
|
||||
|
||||
enum TestModel {
|
||||
@@ -835,30 +1046,57 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.background_executor.clone());
|
||||
fs.create_dir(paths::settings_file().parent().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
fs.insert_file(
|
||||
paths::settings_file(),
|
||||
json!({
|
||||
"agent": {
|
||||
"default_profile": "test-profile",
|
||||
"profiles": {
|
||||
"test-profile": {
|
||||
"name": "Test Profile",
|
||||
"tools": {
|
||||
EchoTool.name(): true,
|
||||
DelayTool.name(): true,
|
||||
WordListTool.name(): true,
|
||||
ToolRequiringPermission.name(): true,
|
||||
InfiniteTool.name(): true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.to_string()
|
||||
.into_bytes(),
|
||||
)
|
||||
.await;
|
||||
|
||||
cx.update(|cx| {
|
||||
settings::init(cx);
|
||||
watch_settings(fs.clone(), cx);
|
||||
Project::init_settings(cx);
|
||||
agent_settings::init(cx);
|
||||
gpui_tokio::init(cx);
|
||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
client::init_settings(cx);
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
|
||||
watch_settings(fs.clone(), cx);
|
||||
});
|
||||
|
||||
let templates = Templates::new();
|
||||
|
||||
fs.insert_tree(path!("/test"), json!({})).await;
|
||||
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
|
||||
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
|
||||
|
||||
let model = cx
|
||||
.update(|cx| {
|
||||
gpui_tokio::init(cx);
|
||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
client::init_settings(cx);
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
|
||||
if let TestModel::Fake = model {
|
||||
Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
|
||||
} else {
|
||||
@@ -881,20 +1119,25 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
||||
.await;
|
||||
|
||||
let project_context = Rc::new(RefCell::new(ProjectContext::default()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
project_context.clone(),
|
||||
context_server_registry,
|
||||
action_log,
|
||||
templates,
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
ThreadTest {
|
||||
model,
|
||||
thread,
|
||||
project_context,
|
||||
fs,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,10 @@
|
||||
mod context_server_registry;
|
||||
mod copy_path_tool;
|
||||
mod create_directory_tool;
|
||||
mod delete_path_tool;
|
||||
mod diagnostics_tool;
|
||||
mod edit_file_tool;
|
||||
mod fetch_tool;
|
||||
mod find_path_tool;
|
||||
mod grep_tool;
|
||||
mod list_directory_tool;
|
||||
@@ -13,10 +16,13 @@ mod terminal_tool;
|
||||
mod thinking_tool;
|
||||
mod web_search_tool;
|
||||
|
||||
pub use context_server_registry::*;
|
||||
pub use copy_path_tool::*;
|
||||
pub use create_directory_tool::*;
|
||||
pub use delete_path_tool::*;
|
||||
pub use diagnostics_tool::*;
|
||||
pub use edit_file_tool::*;
|
||||
pub use fetch_tool::*;
|
||||
pub use find_path_tool::*;
|
||||
pub use grep_tool::*;
|
||||
pub use list_directory_tool::*;
|
||||
|
||||
231
crates/agent2/src/tools/context_server_registry.rs
Normal file
231
crates/agent2/src/tools/context_server_registry.rs
Normal file
@@ -0,0 +1,231 @@
|
||||
use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
|
||||
use agent_client_protocol::ToolKind;
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use context_server::ContextServerId;
|
||||
use gpui::{App, Context, Entity, SharedString, Task};
|
||||
use project::context_server_store::{ContextServerStatus, ContextServerStore};
|
||||
use std::sync::Arc;
|
||||
use util::ResultExt;
|
||||
|
||||
pub struct ContextServerRegistry {
|
||||
server_store: Entity<ContextServerStore>,
|
||||
registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
|
||||
_subscription: gpui::Subscription,
|
||||
}
|
||||
|
||||
struct RegisteredContextServer {
|
||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
load_tools: Task<Result<()>>,
|
||||
}
|
||||
|
||||
impl ContextServerRegistry {
|
||||
pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
|
||||
let mut this = Self {
|
||||
server_store: server_store.clone(),
|
||||
registered_servers: HashMap::default(),
|
||||
_subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
|
||||
};
|
||||
for server in server_store.read(cx).running_servers() {
|
||||
this.reload_tools_for_server(server.id(), cx);
|
||||
}
|
||||
this
|
||||
}
|
||||
|
||||
pub fn servers(
|
||||
&self,
|
||||
) -> impl Iterator<
|
||||
Item = (
|
||||
&ContextServerId,
|
||||
&BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
),
|
||||
> {
|
||||
self.registered_servers
|
||||
.iter()
|
||||
.map(|(id, server)| (id, &server.tools))
|
||||
}
|
||||
|
||||
fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
|
||||
let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
|
||||
return;
|
||||
};
|
||||
let Some(client) = server.client() else {
|
||||
return;
|
||||
};
|
||||
if !client.capable(context_server::protocol::ServerCapability::Tools) {
|
||||
return;
|
||||
}
|
||||
|
||||
let registered_server =
|
||||
self.registered_servers
|
||||
.entry(server_id.clone())
|
||||
.or_insert(RegisteredContextServer {
|
||||
tools: BTreeMap::default(),
|
||||
load_tools: Task::ready(Ok(())),
|
||||
});
|
||||
registered_server.load_tools = cx.spawn(async move |this, cx| {
|
||||
let response = client
|
||||
.request::<context_server::types::requests::ListTools>(())
|
||||
.await;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
|
||||
return;
|
||||
};
|
||||
|
||||
registered_server.tools.clear();
|
||||
if let Some(response) = response.log_err() {
|
||||
for tool in response.tools {
|
||||
let tool = Arc::new(ContextServerTool::new(
|
||||
this.server_store.clone(),
|
||||
server.id(),
|
||||
tool,
|
||||
));
|
||||
registered_server.tools.insert(tool.name(), tool);
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn handle_context_server_store_event(
|
||||
&mut self,
|
||||
_: Entity<ContextServerStore>,
|
||||
event: &project::context_server_store::Event,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match event {
|
||||
project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
|
||||
match status {
|
||||
ContextServerStatus::Starting => {}
|
||||
ContextServerStatus::Running => {
|
||||
self.reload_tools_for_server(server_id.clone(), cx);
|
||||
}
|
||||
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
|
||||
self.registered_servers.remove(&server_id);
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ContextServerTool {
|
||||
store: Entity<ContextServerStore>,
|
||||
server_id: ContextServerId,
|
||||
tool: context_server::types::Tool,
|
||||
}
|
||||
|
||||
impl ContextServerTool {
|
||||
fn new(
|
||||
store: Entity<ContextServerStore>,
|
||||
server_id: ContextServerId,
|
||||
tool: context_server::types::Tool,
|
||||
) -> Self {
|
||||
Self {
|
||||
store,
|
||||
server_id,
|
||||
tool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AnyAgentTool for ContextServerTool {
|
||||
fn name(&self) -> SharedString {
|
||||
self.tool.name.clone().into()
|
||||
}
|
||||
|
||||
fn description(&self) -> SharedString {
|
||||
self.tool.description.clone().unwrap_or_default().into()
|
||||
}
|
||||
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Other
|
||||
}
|
||||
|
||||
fn initial_title(&self, _input: serde_json::Value) -> SharedString {
|
||||
format!("Run MCP tool `{}`", self.tool.name).into()
|
||||
}
|
||||
|
||||
fn input_schema(
|
||||
&self,
|
||||
format: language_model::LanguageModelToolSchemaFormat,
|
||||
) -> Result<serde_json::Value> {
|
||||
let mut schema = self.tool.input_schema.clone();
|
||||
assistant_tool::adapt_schema_to_format(&mut schema, format)?;
|
||||
Ok(match schema {
|
||||
serde_json::Value::Null => {
|
||||
serde_json::json!({ "type": "object", "properties": [] })
|
||||
}
|
||||
serde_json::Value::Object(map) if map.is_empty() => {
|
||||
serde_json::json!({ "type": "object", "properties": [] })
|
||||
}
|
||||
_ => schema,
|
||||
})
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<AgentToolOutput>> {
|
||||
let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
|
||||
return Task::ready(Err(anyhow!("Context server not found")));
|
||||
};
|
||||
let tool_name = self.tool.name.clone();
|
||||
let server_clone = server.clone();
|
||||
let input_clone = input.clone();
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
let Some(protocol) = server_clone.client() else {
|
||||
bail!("Context server not initialized");
|
||||
};
|
||||
|
||||
let arguments = if let serde_json::Value::Object(map) = input_clone {
|
||||
Some(map.into_iter().collect())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
log::trace!(
|
||||
"Running tool: {} with arguments: {:?}",
|
||||
tool_name,
|
||||
arguments
|
||||
);
|
||||
let response = protocol
|
||||
.request::<context_server::types::requests::CallTool>(
|
||||
context_server::types::CallToolParams {
|
||||
name: tool_name,
|
||||
arguments,
|
||||
meta: None,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut result = String::new();
|
||||
for content in response.content {
|
||||
match content {
|
||||
context_server::types::ToolResponseContent::Text { text } => {
|
||||
result.push_str(&text);
|
||||
}
|
||||
context_server::types::ToolResponseContent::Image { .. } => {
|
||||
log::warn!("Ignoring image content from tool response");
|
||||
}
|
||||
context_server::types::ToolResponseContent::Audio { .. } => {
|
||||
log::warn!("Ignoring audio content from tool response");
|
||||
}
|
||||
context_server::types::ToolResponseContent::Resource { .. } => {
|
||||
log::warn!("Ignoring resource content from tool response");
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(AgentToolOutput {
|
||||
raw_output: result.clone().into(),
|
||||
llm_output: result.into(),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
163
crates/agent2/src/tools/diagnostics_tool.rs
Normal file
163
crates/agent2/src/tools/diagnostics_tool.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
use crate::{AgentTool, ToolCallEventStream};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Result, anyhow};
|
||||
use gpui::{App, Entity, Task};
|
||||
use language::{DiagnosticSeverity, OffsetRangeExt};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fmt::Write, path::Path, sync::Arc};
|
||||
use ui::SharedString;
|
||||
use util::markdown::MarkdownInlineCode;
|
||||
|
||||
/// Get errors and warnings for the project or a specific file.
|
||||
///
|
||||
/// This tool can be invoked after a series of edits to determine if further edits are necessary, or if the user asks to fix errors or warnings in their codebase.
|
||||
///
|
||||
/// When a path is provided, shows all diagnostics for that specific file.
|
||||
/// When no path is provided, shows a summary of error and warning counts for all files in the project.
|
||||
///
|
||||
/// <example>
|
||||
/// To get diagnostics for a specific file:
|
||||
/// {
|
||||
/// "path": "src/main.rs"
|
||||
/// }
|
||||
///
|
||||
/// To get a project-wide diagnostic summary:
|
||||
/// {}
|
||||
/// </example>
|
||||
///
|
||||
/// <guidelines>
|
||||
/// - If you think you can fix a diagnostic, make 1-2 attempts and then give up.
|
||||
/// - Don't remove code you've generated just because you can't fix an error. The user can help you fix it.
|
||||
/// </guidelines>
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct DiagnosticsToolInput {
|
||||
/// The path to get diagnostics for. If not provided, returns a project-wide summary.
|
||||
///
|
||||
/// This path should never be absolute, and the first component
|
||||
/// of the path should always be a root directory in a project.
|
||||
///
|
||||
/// <example>
|
||||
/// If the project has the following root directories:
|
||||
///
|
||||
/// - lorem
|
||||
/// - ipsum
|
||||
///
|
||||
/// If you wanna access diagnostics for `dolor.txt` in `ipsum`, you should use the path `ipsum/dolor.txt`.
|
||||
/// </example>
|
||||
pub path: Option<String>,
|
||||
}
|
||||
|
||||
pub struct DiagnosticsTool {
|
||||
project: Entity<Project>,
|
||||
}
|
||||
|
||||
impl DiagnosticsTool {
|
||||
pub fn new(project: Entity<Project>) -> Self {
|
||||
Self { project }
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentTool for DiagnosticsTool {
|
||||
type Input = DiagnosticsToolInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"diagnostics".into()
|
||||
}
|
||||
|
||||
fn kind(&self) -> acp::ToolKind {
|
||||
acp::ToolKind::Read
|
||||
}
|
||||
|
||||
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
|
||||
if let Some(path) = input.ok().and_then(|input| match input.path {
|
||||
Some(path) if !path.is_empty() => Some(path),
|
||||
_ => None,
|
||||
}) {
|
||||
format!("Check diagnostics for {}", MarkdownInlineCode(&path)).into()
|
||||
} else {
|
||||
"Check project diagnostics".into()
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
_event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Self::Output>> {
|
||||
match input.path {
|
||||
Some(path) if !path.is_empty() => {
|
||||
let Some(project_path) = self.project.read(cx).find_project_path(&path, cx) else {
|
||||
return Task::ready(Err(anyhow!("Could not find path {path} in project",)));
|
||||
};
|
||||
|
||||
let buffer = self
|
||||
.project
|
||||
.update(cx, |project, cx| project.open_buffer(project_path, cx));
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let mut output = String::new();
|
||||
let buffer = buffer.await?;
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
|
||||
for (_, group) in snapshot.diagnostic_groups(None) {
|
||||
let entry = &group.entries[group.primary_ix];
|
||||
let range = entry.range.to_point(&snapshot);
|
||||
let severity = match entry.diagnostic.severity {
|
||||
DiagnosticSeverity::ERROR => "error",
|
||||
DiagnosticSeverity::WARNING => "warning",
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
writeln!(
|
||||
output,
|
||||
"{} at line {}: {}",
|
||||
severity,
|
||||
range.start.row + 1,
|
||||
entry.diagnostic.message
|
||||
)?;
|
||||
}
|
||||
|
||||
if output.is_empty() {
|
||||
Ok("File doesn't have errors or warnings!".to_string())
|
||||
} else {
|
||||
Ok(output)
|
||||
}
|
||||
})
|
||||
}
|
||||
_ => {
|
||||
let project = self.project.read(cx);
|
||||
let mut output = String::new();
|
||||
let mut has_diagnostics = false;
|
||||
|
||||
for (project_path, _, summary) in project.diagnostic_summaries(true, cx) {
|
||||
if summary.error_count > 0 || summary.warning_count > 0 {
|
||||
let Some(worktree) = project.worktree_for_id(project_path.worktree_id, cx)
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
has_diagnostics = true;
|
||||
output.push_str(&format!(
|
||||
"{}: {} error(s), {} warning(s)\n",
|
||||
Path::new(worktree.read(cx).root_name())
|
||||
.join(project_path.path)
|
||||
.display(),
|
||||
summary.error_count,
|
||||
summary.warning_count
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if has_diagnostics {
|
||||
Task::ready(Ok(output))
|
||||
} else {
|
||||
Task::ready(Ok("No errors or warnings found in the project.".into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,13 @@
|
||||
use crate::{AgentTool, Thread, ToolCallEventStream};
|
||||
use acp_thread::Diff;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::HashSet;
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
||||
use indoc::formatdoc;
|
||||
use language::ToPoint;
|
||||
use language::language_settings::{self, FormatOnSave};
|
||||
use language_model::LanguageModelToolResultContent;
|
||||
use paths;
|
||||
@@ -225,6 +226,16 @@ impl AgentTool for EditFileTool {
|
||||
Ok(path) => path,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))),
|
||||
};
|
||||
let abs_path = project.read(cx).absolute_path(&project_path, cx);
|
||||
if let Some(abs_path) = abs_path.clone() {
|
||||
event_stream.update_fields(ToolCallUpdateFields {
|
||||
locations: Some(vec![acp::ToolCallLocation {
|
||||
path: abs_path,
|
||||
line: None,
|
||||
}]),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
let request = self.thread.update(cx, |thread, cx| {
|
||||
thread.build_completion_request(CompletionIntent::ToolResults, cx)
|
||||
@@ -283,13 +294,38 @@ impl AgentTool for EditFileTool {
|
||||
|
||||
let mut hallucinated_old_text = false;
|
||||
let mut ambiguous_ranges = Vec::new();
|
||||
let mut emitted_location = false;
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
EditAgentOutputEvent::Edited => {},
|
||||
EditAgentOutputEvent::Edited(range) => {
|
||||
if !emitted_location {
|
||||
let line = buffer.update(cx, |buffer, _cx| {
|
||||
range.start.to_point(&buffer.snapshot()).row
|
||||
}).ok();
|
||||
if let Some(abs_path) = abs_path.clone() {
|
||||
event_stream.update_fields(ToolCallUpdateFields {
|
||||
locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
emitted_location = true;
|
||||
}
|
||||
},
|
||||
EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
|
||||
EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
|
||||
EditAgentOutputEvent::ResolvingEditRange(range) => {
|
||||
diff.update(cx, |card, cx| card.reveal_range(range, cx))?;
|
||||
diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
|
||||
// if !emitted_location {
|
||||
// let line = buffer.update(cx, |buffer, _cx| {
|
||||
// range.start.to_point(&buffer.snapshot()).row
|
||||
// }).ok();
|
||||
// if let Some(abs_path) = abs_path.clone() {
|
||||
// event_stream.update_fields(ToolCallUpdateFields {
|
||||
// locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
|
||||
// ..Default::default()
|
||||
// });
|
||||
// }
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -454,9 +490,8 @@ fn resolve_path(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::Templates;
|
||||
|
||||
use super::*;
|
||||
use crate::{ContextServerRegistry, Templates};
|
||||
use action_log::ActionLog;
|
||||
use client::TelemetrySettings;
|
||||
use fs::Fs;
|
||||
@@ -475,9 +510,20 @@ mod tests {
|
||||
fs.insert_tree("/root", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread =
|
||||
cx.new(|_| Thread::new(project, Rc::default(), action_log, Templates::new(), model));
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log,
|
||||
Templates::new(),
|
||||
model,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = EditFileToolInput {
|
||||
@@ -661,14 +707,18 @@ mod tests {
|
||||
});
|
||||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
@@ -792,15 +842,19 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
@@ -914,15 +968,19 @@ mod tests {
|
||||
init_test(cx);
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
@@ -1041,15 +1099,19 @@ mod tests {
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/project", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
@@ -1148,14 +1210,18 @@ mod tests {
|
||||
.await;
|
||||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
@@ -1225,14 +1291,18 @@ mod tests {
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
@@ -1305,14 +1375,18 @@ mod tests {
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
@@ -1382,14 +1456,18 @@ mod tests {
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
|
||||
155
crates/agent2/src/tools/fetch_tool.rs
Normal file
155
crates/agent2/src/tools/fetch_tool.rs
Normal file
@@ -0,0 +1,155 @@
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use std::{borrow::Cow, cell::RefCell};
|
||||
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Context as _, Result, bail};
|
||||
use futures::AsyncReadExt as _;
|
||||
use gpui::{App, AppContext as _, Task};
|
||||
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
|
||||
use http_client::{AsyncBody, HttpClientWithUrl};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ui::SharedString;
|
||||
use util::markdown::MarkdownEscaped;
|
||||
|
||||
use crate::{AgentTool, ToolCallEventStream};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
|
||||
enum ContentType {
|
||||
Html,
|
||||
Plaintext,
|
||||
Json,
|
||||
}
|
||||
|
||||
/// Fetches a URL and returns the content as Markdown.
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct FetchToolInput {
|
||||
/// The URL to fetch.
|
||||
url: String,
|
||||
}
|
||||
|
||||
pub struct FetchTool {
|
||||
http_client: Arc<HttpClientWithUrl>,
|
||||
}
|
||||
|
||||
impl FetchTool {
|
||||
pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
|
||||
Self { http_client }
|
||||
}
|
||||
|
||||
async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
|
||||
let url = if !url.starts_with("https://") && !url.starts_with("http://") {
|
||||
Cow::Owned(format!("https://{url}"))
|
||||
} else {
|
||||
Cow::Borrowed(url)
|
||||
};
|
||||
|
||||
let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
|
||||
|
||||
let mut body = Vec::new();
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_end(&mut body)
|
||||
.await
|
||||
.context("error reading response body")?;
|
||||
|
||||
if response.status().is_client_error() {
|
||||
let text = String::from_utf8_lossy(body.as_slice());
|
||||
bail!(
|
||||
"status error {}, response: {text:?}",
|
||||
response.status().as_u16()
|
||||
);
|
||||
}
|
||||
|
||||
let Some(content_type) = response.headers().get("content-type") else {
|
||||
bail!("missing Content-Type header");
|
||||
};
|
||||
let content_type = content_type
|
||||
.to_str()
|
||||
.context("invalid Content-Type header")?;
|
||||
|
||||
let content_type = if content_type.starts_with("text/plain") {
|
||||
ContentType::Plaintext
|
||||
} else if content_type.starts_with("application/json") {
|
||||
ContentType::Json
|
||||
} else {
|
||||
ContentType::Html
|
||||
};
|
||||
|
||||
match content_type {
|
||||
ContentType::Html => {
|
||||
let mut handlers: Vec<TagHandler> = vec![
|
||||
Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
|
||||
Rc::new(RefCell::new(markdown::ParagraphHandler)),
|
||||
Rc::new(RefCell::new(markdown::HeadingHandler)),
|
||||
Rc::new(RefCell::new(markdown::ListHandler)),
|
||||
Rc::new(RefCell::new(markdown::TableHandler::new())),
|
||||
Rc::new(RefCell::new(markdown::StyledTextHandler)),
|
||||
];
|
||||
if url.contains("wikipedia.org") {
|
||||
use html_to_markdown::structure::wikipedia;
|
||||
|
||||
handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
|
||||
handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
|
||||
handlers.push(Rc::new(
|
||||
RefCell::new(wikipedia::WikipediaCodeHandler::new()),
|
||||
));
|
||||
} else {
|
||||
handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
|
||||
}
|
||||
|
||||
convert_html_to_markdown(&body[..], &mut handlers)
|
||||
}
|
||||
ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
|
||||
ContentType::Json => {
|
||||
let json: serde_json::Value = serde_json::from_slice(&body)?;
|
||||
|
||||
Ok(format!(
|
||||
"```json\n{}\n```",
|
||||
serde_json::to_string_pretty(&json)?
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentTool for FetchTool {
|
||||
type Input = FetchToolInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"fetch".into()
|
||||
}
|
||||
|
||||
fn kind(&self) -> acp::ToolKind {
|
||||
acp::ToolKind::Fetch
|
||||
}
|
||||
|
||||
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
|
||||
match input {
|
||||
Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)).into(),
|
||||
Err(_) => "Fetch URL".into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
_event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Self::Output>> {
|
||||
let text = cx.background_spawn({
|
||||
let http_client = self.http_client.clone();
|
||||
async move { Self::build_message(http_client, &input.url).await }
|
||||
});
|
||||
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let text = text.await?;
|
||||
if text.trim().is_empty() {
|
||||
bail!("no textual content found");
|
||||
}
|
||||
Ok(text)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -139,9 +139,6 @@ impl AgentTool for FindPathTool {
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
raw_output: Some(serde_json::json!({
|
||||
"paths": &matches,
|
||||
})),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ impl AgentTool for GrepTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
_event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Self::Output>> {
|
||||
const CONTEXT_LINES: u32 = 2;
|
||||
@@ -282,33 +282,22 @@ impl AgentTool for GrepTool {
|
||||
}
|
||||
}
|
||||
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
content: Some(vec![output.clone().into()]),
|
||||
..Default::default()
|
||||
});
|
||||
matches_found += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let output = if matches_found == 0 {
|
||||
"No matches found".to_string()
|
||||
if matches_found == 0 {
|
||||
Ok("No matches found".into())
|
||||
} else if has_more_matches {
|
||||
format!(
|
||||
Ok(format!(
|
||||
"Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
|
||||
input.offset + 1,
|
||||
input.offset + matches_found,
|
||||
input.offset + RESULTS_PER_PAGE,
|
||||
)
|
||||
))
|
||||
} else {
|
||||
format!("Found {matches_found} matches:\n{output}")
|
||||
};
|
||||
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
content: Some(vec![output.clone().into()]),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
Ok(output)
|
||||
Ok(format!("Found {matches_found} matches:\n{output}"))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,20 +47,13 @@ impl AgentTool for NowTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
_event_stream: ToolCallEventStream,
|
||||
_cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
let now = match input.timezone {
|
||||
Timezone::Utc => Utc::now().to_rfc3339(),
|
||||
Timezone::Local => Local::now().to_rfc3339(),
|
||||
};
|
||||
let content = format!("The current datetime is {now}.");
|
||||
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
content: Some(vec![content.clone().into()]),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
Task::ready(Ok(content))
|
||||
Task::ready(Ok(format!("The current datetime is {now}.")))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use agent_client_protocol::{self as acp, ToolCallUpdateFields};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::outline;
|
||||
use gpui::{App, Entity, SharedString, Task};
|
||||
use indoc::formatdoc;
|
||||
use language::{Anchor, Point};
|
||||
use language::Point;
|
||||
use language_model::{LanguageModelImage, LanguageModelToolResultContent};
|
||||
use project::{AgentLocation, ImageItem, Project, WorktreeSettings, image_store};
|
||||
use schemars::JsonSchema;
|
||||
@@ -97,7 +97,7 @@ impl AgentTool for ReadFileTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
_event_stream: ToolCallEventStream,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<LanguageModelToolResultContent>> {
|
||||
let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else {
|
||||
@@ -166,7 +166,9 @@ impl AgentTool for ReadFileTool {
|
||||
cx.spawn(async move |cx| {
|
||||
let buffer = cx
|
||||
.update(|cx| {
|
||||
project.update(cx, |project, cx| project.open_buffer(project_path, cx))
|
||||
project.update(cx, |project, cx| {
|
||||
project.open_buffer(project_path.clone(), cx)
|
||||
})
|
||||
})?
|
||||
.await?;
|
||||
if buffer.read_with(cx, |buffer, _| {
|
||||
@@ -178,19 +180,10 @@ impl AgentTool for ReadFileTool {
|
||||
anyhow::bail!("{file_path} not found");
|
||||
}
|
||||
|
||||
project.update(cx, |project, cx| {
|
||||
project.set_agent_location(
|
||||
Some(AgentLocation {
|
||||
buffer: buffer.downgrade(),
|
||||
position: Anchor::MIN,
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
})?;
|
||||
let mut anchor = None;
|
||||
|
||||
// Check if specific line ranges are provided
|
||||
if input.start_line.is_some() || input.end_line.is_some() {
|
||||
let mut anchor = None;
|
||||
let result = if input.start_line.is_some() || input.end_line.is_some() {
|
||||
let result = buffer.read_with(cx, |buffer, _cx| {
|
||||
let text = buffer.text();
|
||||
// .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0.
|
||||
@@ -214,18 +207,6 @@ impl AgentTool for ReadFileTool {
|
||||
log.buffer_read(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
if let Some(anchor) = anchor {
|
||||
project.update(cx, |project, cx| {
|
||||
project.set_agent_location(
|
||||
Some(AgentLocation {
|
||||
buffer: buffer.downgrade(),
|
||||
position: anchor,
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(result.into())
|
||||
} else {
|
||||
// No line ranges specified, so check file size to see if it's too big.
|
||||
@@ -236,7 +217,7 @@ impl AgentTool for ReadFileTool {
|
||||
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
|
||||
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_read(buffer, cx);
|
||||
log.buffer_read(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
Ok(result.into())
|
||||
@@ -244,7 +225,8 @@ impl AgentTool for ReadFileTool {
|
||||
// File is too big, so return the outline
|
||||
// and a suggestion to read again with line numbers.
|
||||
let outline =
|
||||
outline::file_outline(project, file_path, action_log, None, cx).await?;
|
||||
outline::file_outline(project.clone(), file_path, action_log, None, cx)
|
||||
.await?;
|
||||
Ok(formatdoc! {"
|
||||
This file was too big to read all at once.
|
||||
|
||||
@@ -261,7 +243,28 @@ impl AgentTool for ReadFileTool {
|
||||
}
|
||||
.into())
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
project.update(cx, |project, cx| {
|
||||
if let Some(abs_path) = project.absolute_path(&project_path, cx) {
|
||||
project.set_agent_location(
|
||||
Some(AgentLocation {
|
||||
buffer: buffer.downgrade(),
|
||||
position: anchor.unwrap_or(text::Anchor::MIN),
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
event_stream.update_fields(ToolCallUpdateFields {
|
||||
locations: Some(vec![acp::ToolCallLocation {
|
||||
path: abs_path,
|
||||
line: input.start_line.map(|line| line.saturating_sub(1)),
|
||||
}]),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
})?;
|
||||
|
||||
result
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,9 @@ use agent_client_protocol as acp;
|
||||
use anyhow::{Result, anyhow};
|
||||
use cloud_llm_client::WebSearchResponse;
|
||||
use gpui::{App, AppContext, Task};
|
||||
use language_model::LanguageModelToolResultContent;
|
||||
use language_model::{
|
||||
LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ui::prelude::*;
|
||||
@@ -50,6 +52,11 @@ impl AgentTool for WebSearchTool {
|
||||
"Searching the Web".into()
|
||||
}
|
||||
|
||||
/// We currently only support Zed Cloud as a provider.
|
||||
fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
|
||||
provider == &ZED_CLOUD_PROVIDER_ID
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
|
||||
@@ -467,6 +467,7 @@ impl AgentConnection for AcpConnection {
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
_id: Option<acp_thread::UserMessageId>,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
|
||||
@@ -171,6 +171,7 @@ impl AgentConnection for AcpConnection {
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
_id: Option<acp_thread::UserMessageId>,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
|
||||
@@ -210,6 +210,7 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
_id: Option<acp_thread::UserMessageId>,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
@@ -423,7 +424,7 @@ impl ClaudeAgentSession {
|
||||
if !turn_state.borrow().is_cancelled() {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.push_user_content_block(text.into(), cx)
|
||||
thread.push_user_content_block(None, text.into(), cx)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
||||
@@ -48,6 +48,20 @@ pub struct AgentProfileSettings {
|
||||
pub context_servers: IndexMap<Arc<str>, ContextServerPreset>,
|
||||
}
|
||||
|
||||
impl AgentProfileSettings {
|
||||
pub fn is_tool_enabled(&self, tool_name: &str) -> bool {
|
||||
self.tools.get(tool_name) == Some(&true)
|
||||
}
|
||||
|
||||
pub fn is_context_server_tool_enabled(&self, server_id: &str, tool_name: &str) -> bool {
|
||||
self.enable_all_context_servers
|
||||
|| self
|
||||
.context_servers
|
||||
.get(server_id)
|
||||
.map_or(false, |preset| preset.tools.get(tool_name) == Some(&true))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ContextServerPreset {
|
||||
pub tools: IndexMap<Arc<str>, bool>,
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
mod completion_provider;
|
||||
mod message_history;
|
||||
mod message_editor;
|
||||
mod model_selector;
|
||||
mod model_selector_popover;
|
||||
mod thread_view;
|
||||
|
||||
pub use message_history::MessageHistory;
|
||||
pub use model_selector::AcpModelSelector;
|
||||
pub use model_selector_popover::AcpModelSelectorPopover;
|
||||
pub use thread_view::AcpThreadView;
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
use std::ops::Range;
|
||||
use std::path::Path;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
use anyhow::Result;
|
||||
use acp_thread::MentionUri;
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use editor::display_map::CreaseId;
|
||||
use editor::{CompletionProvider, Editor, ExcerptId};
|
||||
use file_icons::FileIcons;
|
||||
use futures::future::try_join_all;
|
||||
use gpui::{App, Entity, Task, WeakEntity};
|
||||
use language::{Buffer, CodeLabel, HighlightId};
|
||||
use lsp::CompletionContext;
|
||||
use parking_lot::Mutex;
|
||||
use project::{Completion, CompletionIntent, CompletionResponse, ProjectPath, WorktreeId};
|
||||
use project::{Completion, CompletionIntent, CompletionResponse, Project, ProjectPath, WorktreeId};
|
||||
use rope::Point;
|
||||
use text::{Anchor, ToPoint};
|
||||
use ui::prelude::*;
|
||||
@@ -23,21 +25,67 @@ use crate::context_picker::file_context_picker::{extract_file_name_and_directory
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct MentionSet {
|
||||
paths_by_crease_id: HashMap<CreaseId, ProjectPath>,
|
||||
paths_by_crease_id: HashMap<CreaseId, MentionUri>,
|
||||
}
|
||||
|
||||
impl MentionSet {
|
||||
pub fn insert(&mut self, crease_id: CreaseId, path: ProjectPath) {
|
||||
self.paths_by_crease_id.insert(crease_id, path);
|
||||
}
|
||||
|
||||
pub fn path_for_crease_id(&self, crease_id: CreaseId) -> Option<ProjectPath> {
|
||||
self.paths_by_crease_id.get(&crease_id).cloned()
|
||||
pub fn insert(&mut self, crease_id: CreaseId, path: PathBuf) {
|
||||
self.paths_by_crease_id
|
||||
.insert(crease_id, MentionUri::File(path));
|
||||
}
|
||||
|
||||
pub fn drain(&mut self) -> impl Iterator<Item = CreaseId> {
|
||||
self.paths_by_crease_id.drain().map(|(id, _)| id)
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.paths_by_crease_id.clear();
|
||||
}
|
||||
|
||||
pub fn contents(
|
||||
&self,
|
||||
project: Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<HashMap<CreaseId, Mention>>> {
|
||||
let contents = self
|
||||
.paths_by_crease_id
|
||||
.iter()
|
||||
.map(|(crease_id, uri)| match uri {
|
||||
MentionUri::File(path) => {
|
||||
let crease_id = *crease_id;
|
||||
let uri = uri.clone();
|
||||
let path = path.to_path_buf();
|
||||
let buffer_task = project.update(cx, |project, cx| {
|
||||
let path = project
|
||||
.find_project_path(path, cx)
|
||||
.context("Failed to find project path")?;
|
||||
anyhow::Ok(project.open_buffer(path, cx))
|
||||
});
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let buffer = buffer_task?.await?;
|
||||
let content = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
|
||||
|
||||
anyhow::Ok((crease_id, Mention { uri, content }))
|
||||
})
|
||||
}
|
||||
_ => {
|
||||
// TODO
|
||||
unimplemented!()
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
let contents = try_join_all(contents).await?.into_iter().collect();
|
||||
anyhow::Ok(contents)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Mention {
|
||||
pub uri: MentionUri,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
pub struct ContextPickerCompletionProvider {
|
||||
@@ -68,6 +116,7 @@ impl ContextPickerCompletionProvider {
|
||||
source_range: Range<Anchor>,
|
||||
editor: Entity<Editor>,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
project: Entity<Project>,
|
||||
cx: &App,
|
||||
) -> Completion {
|
||||
let (file_name, directory) =
|
||||
@@ -112,6 +161,7 @@ impl ContextPickerCompletionProvider {
|
||||
new_text_len - 1,
|
||||
editor,
|
||||
mention_set,
|
||||
project,
|
||||
)),
|
||||
}
|
||||
}
|
||||
@@ -159,6 +209,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
return Task::ready(Ok(Vec::new()));
|
||||
};
|
||||
|
||||
let project = workspace.read(cx).project().clone();
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let source_range = snapshot.anchor_before(state.source_range.start)
|
||||
..snapshot.anchor_after(state.source_range.end);
|
||||
@@ -195,6 +246,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
source_range.clone(),
|
||||
editor.clone(),
|
||||
mention_set.clone(),
|
||||
project.clone(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -254,6 +306,7 @@ fn confirm_completion_callback(
|
||||
content_len: usize,
|
||||
editor: Entity<Editor>,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
project: Entity<Project>,
|
||||
) -> Arc<dyn Fn(CompletionIntent, &mut Window, &mut App) -> bool + Send + Sync> {
|
||||
Arc::new(move |_, window, cx| {
|
||||
let crease_text = crease_text.clone();
|
||||
@@ -261,6 +314,7 @@ fn confirm_completion_callback(
|
||||
let editor = editor.clone();
|
||||
let project_path = project_path.clone();
|
||||
let mention_set = mention_set.clone();
|
||||
let project = project.clone();
|
||||
window.defer(cx, move |window, cx| {
|
||||
let crease_id = crate::context_picker::insert_crease_for_mention(
|
||||
excerpt_id,
|
||||
@@ -272,8 +326,13 @@ fn confirm_completion_callback(
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
let Some(path) = project.read(cx).absolute_path(&project_path, cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
if let Some(crease_id) = crease_id {
|
||||
mention_set.lock().insert(crease_id, project_path);
|
||||
mention_set.lock().insert(crease_id, path);
|
||||
}
|
||||
});
|
||||
false
|
||||
|
||||
456
crates/agent_ui/src/acp/message_editor.rs
Normal file
456
crates/agent_ui/src/acp/message_editor.rs
Normal file
@@ -0,0 +1,456 @@
|
||||
use crate::acp::completion_provider::ContextPickerCompletionProvider;
|
||||
use crate::acp::completion_provider::MentionSet;
|
||||
use acp_thread::MentionUri;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use collections::HashSet;
|
||||
use editor::{
|
||||
AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorMode,
|
||||
EditorStyle, MultiBuffer,
|
||||
};
|
||||
use file_icons::FileIcons;
|
||||
use gpui::{
|
||||
AppContext, Context, Entity, EventEmitter, FocusHandle, Focusable, Task, TextStyle, WeakEntity,
|
||||
};
|
||||
use language::Buffer;
|
||||
use language::Language;
|
||||
use parking_lot::Mutex;
|
||||
use project::{CompletionIntent, Project};
|
||||
use settings::Settings;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{
|
||||
ActiveTheme, App, IconName, InteractiveElement, IntoElement, ParentElement, Render,
|
||||
SharedString, Styled, TextSize, Window, div,
|
||||
};
|
||||
use util::ResultExt;
|
||||
use workspace::Workspace;
|
||||
use zed_actions::agent::Chat;
|
||||
|
||||
pub struct MessageEditor {
|
||||
editor: Entity<Editor>,
|
||||
project: Entity<Project>,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
}
|
||||
|
||||
pub enum MessageEditorEvent {
|
||||
Send,
|
||||
Cancel,
|
||||
}
|
||||
|
||||
impl EventEmitter<MessageEditorEvent> for MessageEditor {}
|
||||
|
||||
impl MessageEditor {
|
||||
pub fn new(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
mode: EditorMode,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let language = Language::new(
|
||||
language::LanguageConfig {
|
||||
completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']),
|
||||
..Default::default()
|
||||
},
|
||||
None,
|
||||
);
|
||||
|
||||
let mention_set = Arc::new(Mutex::new(MentionSet::default()));
|
||||
let editor = cx.new(|cx| {
|
||||
let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
|
||||
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
|
||||
let mut editor = Editor::new(mode, buffer, None, window, cx);
|
||||
editor.set_placeholder_text("Message the agent - @ to include files", cx);
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_soft_wrap();
|
||||
editor.set_use_modal_editing(true);
|
||||
editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new(
|
||||
mention_set.clone(),
|
||||
workspace,
|
||||
cx.weak_entity(),
|
||||
))));
|
||||
editor.set_context_menu_options(ContextMenuOptions {
|
||||
min_entries_visible: 12,
|
||||
max_entries_visible: 12,
|
||||
placement: Some(ContextMenuPlacement::Above),
|
||||
});
|
||||
editor
|
||||
});
|
||||
|
||||
Self {
|
||||
editor,
|
||||
project,
|
||||
mention_set,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self, cx: &App) -> bool {
|
||||
self.editor.read(cx).is_empty(cx)
|
||||
}
|
||||
|
||||
pub fn contents(&self, cx: &mut Context<Self>) -> Task<Result<Vec<acp::ContentBlock>>> {
|
||||
let contents = self.mention_set.lock().contents(self.project.clone(), cx);
|
||||
let editor = self.editor.clone();
|
||||
|
||||
cx.spawn(async move |_, cx| {
|
||||
let contents = contents.await?;
|
||||
|
||||
editor.update(cx, |editor, cx| {
|
||||
let mut ix = 0;
|
||||
let mut chunks: Vec<acp::ContentBlock> = Vec::new();
|
||||
let text = editor.text(cx);
|
||||
editor.display_map.update(cx, |map, cx| {
|
||||
let snapshot = map.snapshot(cx);
|
||||
for (crease_id, crease) in snapshot.crease_snapshot.creases() {
|
||||
// Skip creases that have been edited out of the message buffer.
|
||||
if !crease.range().start.is_valid(&snapshot.buffer_snapshot) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(mention) = contents.get(&crease_id) {
|
||||
let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot);
|
||||
if crease_range.start > ix {
|
||||
chunks.push(text[ix..crease_range.start].into());
|
||||
}
|
||||
chunks.push(acp::ContentBlock::Resource(acp::EmbeddedResource {
|
||||
annotations: None,
|
||||
resource: acp::EmbeddedResourceResource::TextResourceContents(
|
||||
acp::TextResourceContents {
|
||||
mime_type: None,
|
||||
text: mention.content.clone(),
|
||||
uri: mention.uri.to_uri(),
|
||||
},
|
||||
),
|
||||
}));
|
||||
ix = crease_range.end;
|
||||
}
|
||||
}
|
||||
|
||||
if ix < text.len() {
|
||||
let last_chunk = text[ix..].trim_end();
|
||||
if !last_chunk.is_empty() {
|
||||
chunks.push(last_chunk.into());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
chunks
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn clear(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
editor.clear(window, cx);
|
||||
editor.remove_creases(self.mention_set.lock().drain(), cx)
|
||||
});
|
||||
}
|
||||
|
||||
fn chat(&mut self, _: &Chat, _: &mut Window, cx: &mut Context<Self>) {
|
||||
cx.emit(MessageEditorEvent::Send)
|
||||
}
|
||||
|
||||
fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
|
||||
cx.emit(MessageEditorEvent::Cancel)
|
||||
}
|
||||
|
||||
pub fn insert_dragged_files(
|
||||
&self,
|
||||
paths: Vec<project::ProjectPath>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let buffer = self.editor.read(cx).buffer().clone();
|
||||
let Some((&excerpt_id, _, _)) = buffer.read(cx).snapshot(cx).as_singleton() else {
|
||||
return;
|
||||
};
|
||||
let Some(buffer) = buffer.read(cx).as_singleton() else {
|
||||
return;
|
||||
};
|
||||
for path in paths {
|
||||
let Some(entry) = self.project.read(cx).entry_for_path(&path, cx) else {
|
||||
continue;
|
||||
};
|
||||
let Some(abs_path) = self.project.read(cx).absolute_path(&path, cx) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let anchor = buffer.update(cx, |buffer, _cx| buffer.anchor_before(buffer.len()));
|
||||
let path_prefix = abs_path
|
||||
.file_name()
|
||||
.unwrap_or(path.path.as_os_str())
|
||||
.display()
|
||||
.to_string();
|
||||
let completion = ContextPickerCompletionProvider::completion_for_path(
|
||||
path,
|
||||
&path_prefix,
|
||||
false,
|
||||
entry.is_dir(),
|
||||
excerpt_id,
|
||||
anchor..anchor,
|
||||
self.editor.clone(),
|
||||
self.mention_set.clone(),
|
||||
self.project.clone(),
|
||||
cx,
|
||||
);
|
||||
|
||||
self.editor.update(cx, |message_editor, cx| {
|
||||
message_editor.edit(
|
||||
[(
|
||||
multi_buffer::Anchor::max()..multi_buffer::Anchor::max(),
|
||||
completion.new_text,
|
||||
)],
|
||||
cx,
|
||||
);
|
||||
});
|
||||
if let Some(confirm) = completion.confirm.clone() {
|
||||
confirm(CompletionIntent::Complete, window, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_mode(&mut self, mode: EditorMode, cx: &mut Context<Self>) {
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
editor.set_mode(mode);
|
||||
cx.notify()
|
||||
});
|
||||
}
|
||||
|
||||
pub fn set_message(
|
||||
&mut self,
|
||||
message: &[acp::ContentBlock],
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let mut text = String::new();
|
||||
let mut mentions = Vec::new();
|
||||
|
||||
for chunk in message {
|
||||
match chunk {
|
||||
acp::ContentBlock::Text(text_content) => {
|
||||
text.push_str(&text_content.text);
|
||||
}
|
||||
acp::ContentBlock::Resource(acp::EmbeddedResource {
|
||||
resource: acp::EmbeddedResourceResource::TextResourceContents(resource),
|
||||
..
|
||||
}) => {
|
||||
if let Some(mention) = MentionUri::parse(&resource.uri).log_err() {
|
||||
let project_path = self
|
||||
.project
|
||||
.read(cx)
|
||||
.project_path_for_absolute_path(&abs_path, cx);
|
||||
let start = text.len();
|
||||
write!(text, "{}", mention.as_link());
|
||||
let end = text.len();
|
||||
mentions.push((start..end, project_path, filename));
|
||||
}
|
||||
}
|
||||
acp::ContentBlock::Image(_)
|
||||
| acp::ContentBlock::Audio(_)
|
||||
| acp::ContentBlock::Resource(_)
|
||||
| acp::ContentBlock::ResourceLink(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
let snapshot = self.editor.update(cx, |editor, cx| {
|
||||
editor.set_text(text, window, cx);
|
||||
editor.buffer().read(cx).snapshot(cx)
|
||||
});
|
||||
|
||||
self.mention_set.lock().clear();
|
||||
for (range, project_path, filename) in mentions {
|
||||
let crease_icon_path = if project_path.path.is_dir() {
|
||||
FileIcons::get_folder_icon(false, cx)
|
||||
.unwrap_or_else(|| IconName::Folder.path().into())
|
||||
} else {
|
||||
FileIcons::get_icon(Path::new(project_path.path.as_ref()), cx)
|
||||
.unwrap_or_else(|| IconName::File.path().into())
|
||||
};
|
||||
|
||||
let anchor = snapshot.anchor_before(range.start);
|
||||
if let Some(project_path) = self.project.read(cx).absolute_path(&project_path, cx) {
|
||||
let crease_id = crate::context_picker::insert_crease_for_mention(
|
||||
anchor.excerpt_id,
|
||||
anchor.text_anchor,
|
||||
range.end - range.start,
|
||||
filename,
|
||||
crease_icon_path,
|
||||
self.editor.clone(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
if let Some(crease_id) = crease_id {
|
||||
self.mention_set.lock().insert(crease_id, project_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn set_text(&mut self, text: &str, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
editor.set_text(text, window, cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Focusable for MessageEditor {
|
||||
fn focus_handle(&self, cx: &App) -> FocusHandle {
|
||||
self.editor.focus_handle(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for MessageEditor {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
div()
|
||||
.key_context("MessageEditor")
|
||||
.on_action(cx.listener(Self::chat))
|
||||
.on_action(cx.listener(Self::cancel))
|
||||
.flex_1()
|
||||
.child({
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let font_size = TextSize::Small
|
||||
.rems(cx)
|
||||
.to_pixels(settings.agent_font_size(cx));
|
||||
let line_height = settings.buffer_line_height.value() * font_size;
|
||||
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.buffer_font.family.clone(),
|
||||
font_fallbacks: settings.buffer_font.fallbacks.clone(),
|
||||
font_features: settings.buffer_font.features.clone(),
|
||||
font_size: font_size.into(),
|
||||
line_height: line_height.into(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
EditorElement::new(
|
||||
&self.editor,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
syntax: cx.theme().syntax().clone(),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::Path;
|
||||
|
||||
use agent_client_protocol as acp;
|
||||
use editor::EditorMode;
|
||||
use fs::FakeFs;
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
use lsp::{CompletionContext, CompletionTriggerKind};
|
||||
use project::{CompletionIntent, Project};
|
||||
use serde_json::json;
|
||||
use util::path;
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::acp::{message_editor::MessageEditor, thread_view::tests::init_test};
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_at_mention_removal(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/project", json!({"file": ""})).await;
|
||||
let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
|
||||
|
||||
let (workspace, cx) =
|
||||
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
|
||||
let message_editor = cx.update(|window, cx| {
|
||||
cx.new(|cx| {
|
||||
MessageEditor::new(
|
||||
workspace.downgrade(),
|
||||
project.clone(),
|
||||
EditorMode::AutoHeight {
|
||||
min_lines: 1,
|
||||
max_lines: None,
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
});
|
||||
let editor = message_editor.update(cx, |message_editor, _| message_editor.editor.clone());
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let excerpt_id = editor.update(cx, |editor, cx| {
|
||||
editor
|
||||
.buffer()
|
||||
.read(cx)
|
||||
.excerpt_ids()
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap()
|
||||
});
|
||||
let completions = editor.update_in(cx, |editor, window, cx| {
|
||||
editor.set_text("Hello @", window, cx);
|
||||
let buffer = editor.buffer().read(cx).as_singleton().unwrap();
|
||||
let completion_provider = editor.completion_provider().unwrap();
|
||||
completion_provider.completions(
|
||||
excerpt_id,
|
||||
&buffer,
|
||||
text::Anchor::MAX,
|
||||
CompletionContext {
|
||||
trigger_kind: CompletionTriggerKind::TRIGGER_CHARACTER,
|
||||
trigger_character: Some("@".into()),
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let [_, completion]: [_; 2] = completions
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.flat_map(|response| response.completions)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
editor.update_in(cx, |editor, window, cx| {
|
||||
let snapshot = editor.buffer().read(cx).snapshot(cx);
|
||||
let start = snapshot
|
||||
.anchor_in_excerpt(excerpt_id, completion.replace_range.start)
|
||||
.unwrap();
|
||||
let end = snapshot
|
||||
.anchor_in_excerpt(excerpt_id, completion.replace_range.end)
|
||||
.unwrap();
|
||||
editor.edit([(start..end, completion.new_text)], cx);
|
||||
(completion.confirm.unwrap())(CompletionIntent::Complete, window, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
// Backspace over the inserted crease (and the following space).
|
||||
editor.update_in(cx, |editor, window, cx| {
|
||||
editor.backspace(&Default::default(), window, cx);
|
||||
editor.backspace(&Default::default(), window, cx);
|
||||
});
|
||||
|
||||
let content = message_editor
|
||||
.update_in(cx, |message_editor, _window, cx| {
|
||||
message_editor.contents(cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// We don't send a resource link for the deleted crease.
|
||||
pretty_assertions::assert_matches!(content.as_slice(), [acp::ContentBlock::Text { .. }]);
|
||||
}
|
||||
}
|
||||
@@ -1,92 +0,0 @@
|
||||
pub struct MessageHistory<T> {
|
||||
items: Vec<T>,
|
||||
current: Option<usize>,
|
||||
}
|
||||
|
||||
impl<T> Default for MessageHistory<T> {
|
||||
fn default() -> Self {
|
||||
MessageHistory {
|
||||
items: Vec::new(),
|
||||
current: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> MessageHistory<T> {
|
||||
pub fn push(&mut self, message: T) {
|
||||
self.current.take();
|
||||
self.items.push(message);
|
||||
}
|
||||
|
||||
pub fn reset_position(&mut self) {
|
||||
self.current.take();
|
||||
}
|
||||
|
||||
pub fn prev(&mut self) -> Option<&T> {
|
||||
if self.items.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let new_ix = self
|
||||
.current
|
||||
.get_or_insert(self.items.len())
|
||||
.saturating_sub(1);
|
||||
|
||||
self.current = Some(new_ix);
|
||||
self.items.get(new_ix)
|
||||
}
|
||||
|
||||
pub fn next(&mut self) -> Option<&T> {
|
||||
let current = self.current.as_mut()?;
|
||||
*current += 1;
|
||||
|
||||
self.items.get(*current).or_else(|| {
|
||||
self.current.take();
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn items(&self) -> &[T] {
|
||||
&self.items
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_prev_next() {
|
||||
let mut history = MessageHistory::default();
|
||||
|
||||
// Test empty history
|
||||
assert_eq!(history.prev(), None);
|
||||
assert_eq!(history.next(), None);
|
||||
|
||||
// Add some messages
|
||||
history.push("first");
|
||||
history.push("second");
|
||||
history.push("third");
|
||||
|
||||
// Test prev navigation
|
||||
assert_eq!(history.prev(), Some(&"third"));
|
||||
assert_eq!(history.prev(), Some(&"second"));
|
||||
assert_eq!(history.prev(), Some(&"first"));
|
||||
assert_eq!(history.prev(), Some(&"first"));
|
||||
|
||||
assert_eq!(history.next(), Some(&"second"));
|
||||
|
||||
// Test mixed navigation
|
||||
history.push("fourth");
|
||||
assert_eq!(history.prev(), Some(&"fourth"));
|
||||
assert_eq!(history.prev(), Some(&"third"));
|
||||
assert_eq!(history.next(), Some(&"fourth"));
|
||||
assert_eq!(history.next(), None);
|
||||
|
||||
// Test that push resets navigation
|
||||
history.prev();
|
||||
history.prev();
|
||||
history.push("fifth");
|
||||
assert_eq!(history.prev(), Some(&"fifth"));
|
||||
}
|
||||
}
|
||||
472
crates/agent_ui/src/acp/model_selector.rs
Normal file
472
crates/agent_ui/src/acp/model_selector.rs
Normal file
@@ -0,0 +1,472 @@
|
||||
use std::{cmp::Reverse, rc::Rc, sync::Arc};
|
||||
|
||||
use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use collections::IndexMap;
|
||||
use futures::FutureExt;
|
||||
use fuzzy::{StringMatchCandidate, match_strings};
|
||||
use gpui::{Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity};
|
||||
use ordered_float::OrderedFloat;
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use ui::{
|
||||
AnyElement, App, Context, IntoElement, ListItem, ListItemSpacing, SharedString, Window,
|
||||
prelude::*, rems,
|
||||
};
|
||||
use util::ResultExt;
|
||||
|
||||
pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
|
||||
|
||||
pub fn acp_model_selector(
|
||||
session_id: acp::SessionId,
|
||||
selector: Rc<dyn AgentModelSelector>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<AcpModelSelector>,
|
||||
) -> AcpModelSelector {
|
||||
let delegate = AcpModelPickerDelegate::new(session_id, selector, window, cx);
|
||||
Picker::list(delegate, window, cx)
|
||||
.show_scrollbar(true)
|
||||
.width(rems(20.))
|
||||
.max_height(Some(rems(20.).into()))
|
||||
}
|
||||
|
||||
enum AcpModelPickerEntry {
|
||||
Separator(SharedString),
|
||||
Model(AgentModelInfo),
|
||||
}
|
||||
|
||||
pub struct AcpModelPickerDelegate {
|
||||
session_id: acp::SessionId,
|
||||
selector: Rc<dyn AgentModelSelector>,
|
||||
filtered_entries: Vec<AcpModelPickerEntry>,
|
||||
models: Option<AgentModelList>,
|
||||
selected_index: usize,
|
||||
selected_model: Option<AgentModelInfo>,
|
||||
_refresh_models_task: Task<()>,
|
||||
}
|
||||
|
||||
impl AcpModelPickerDelegate {
|
||||
fn new(
|
||||
session_id: acp::SessionId,
|
||||
selector: Rc<dyn AgentModelSelector>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<AcpModelSelector>,
|
||||
) -> Self {
|
||||
let mut rx = selector.watch(cx);
|
||||
let refresh_models_task = cx.spawn_in(window, {
|
||||
let session_id = session_id.clone();
|
||||
async move |this, cx| {
|
||||
async fn refresh(
|
||||
this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
|
||||
session_id: &acp::SessionId,
|
||||
cx: &mut AsyncWindowContext,
|
||||
) -> Result<()> {
|
||||
let (models_task, selected_model_task) = this.update(cx, |this, cx| {
|
||||
(
|
||||
this.delegate.selector.list_models(cx),
|
||||
this.delegate.selector.selected_model(session_id, cx),
|
||||
)
|
||||
})?;
|
||||
|
||||
let (models, selected_model) = futures::join!(models_task, selected_model_task);
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.delegate.models = models.ok();
|
||||
this.delegate.selected_model = selected_model.ok();
|
||||
this.delegate.update_matches(this.query(cx), window, cx)
|
||||
})?
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
refresh(&this, &session_id, cx).await.log_err();
|
||||
while let Ok(()) = rx.recv().await {
|
||||
refresh(&this, &session_id, cx).await.log_err();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
session_id,
|
||||
selector,
|
||||
filtered_entries: Vec::new(),
|
||||
models: None,
|
||||
selected_model: None,
|
||||
selected_index: 0,
|
||||
_refresh_models_task: refresh_models_task,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn active_model(&self) -> Option<&AgentModelInfo> {
|
||||
self.selected_model.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl PickerDelegate for AcpModelPickerDelegate {
|
||||
type ListItem = AnyElement;
|
||||
|
||||
fn match_count(&self) -> usize {
|
||||
self.filtered_entries.len()
|
||||
}
|
||||
|
||||
fn selected_index(&self) -> usize {
|
||||
self.selected_index
|
||||
}
|
||||
|
||||
fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn can_select(
|
||||
&mut self,
|
||||
ix: usize,
|
||||
_window: &mut Window,
|
||||
_cx: &mut Context<Picker<Self>>,
|
||||
) -> bool {
|
||||
match self.filtered_entries.get(ix) {
|
||||
Some(AcpModelPickerEntry::Model(_)) => true,
|
||||
Some(AcpModelPickerEntry::Separator(_)) | None => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
|
||||
"Select a model…".into()
|
||||
}
|
||||
|
||||
fn update_matches(
|
||||
&mut self,
|
||||
query: String,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Task<()> {
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let filtered_models = match this
|
||||
.read_with(cx, |this, cx| {
|
||||
this.delegate.models.clone().map(move |models| {
|
||||
fuzzy_search(models, query, cx.background_executor().clone())
|
||||
})
|
||||
})
|
||||
.ok()
|
||||
.flatten()
|
||||
{
|
||||
Some(task) => task.await,
|
||||
None => AgentModelList::Flat(vec![]),
|
||||
};
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.delegate.filtered_entries =
|
||||
info_list_to_picker_entries(filtered_models).collect();
|
||||
// Finds the currently selected model in the list
|
||||
let new_index = this
|
||||
.delegate
|
||||
.selected_model
|
||||
.as_ref()
|
||||
.and_then(|selected| {
|
||||
this.delegate.filtered_entries.iter().position(|entry| {
|
||||
if let AcpModelPickerEntry::Model(model_info) = entry {
|
||||
model_info.id == selected.id
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
})
|
||||
.unwrap_or(0);
|
||||
this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
}
|
||||
|
||||
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
if let Some(AcpModelPickerEntry::Model(model_info)) =
|
||||
self.filtered_entries.get(self.selected_index)
|
||||
{
|
||||
self.selector
|
||||
.select_model(self.session_id.clone(), model_info.id.clone(), cx)
|
||||
.detach_and_log_err(cx);
|
||||
self.selected_model = Some(model_info.clone());
|
||||
let current_index = self.selected_index;
|
||||
self.set_selected_index(current_index, window, cx);
|
||||
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
}
|
||||
|
||||
fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
|
||||
fn render_match(
|
||||
&self,
|
||||
ix: usize,
|
||||
selected: bool,
|
||||
_: &mut Window,
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Option<Self::ListItem> {
|
||||
match self.filtered_entries.get(ix)? {
|
||||
AcpModelPickerEntry::Separator(title) => Some(
|
||||
div()
|
||||
.px_2()
|
||||
.pb_1()
|
||||
.when(ix > 1, |this| {
|
||||
this.mt_1()
|
||||
.pt_2()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
})
|
||||
.child(
|
||||
Label::new(title)
|
||||
.size(LabelSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.into_any_element(),
|
||||
),
|
||||
AcpModelPickerEntry::Model(model_info) => {
|
||||
let is_selected = Some(model_info) == self.selected_model.as_ref();
|
||||
|
||||
let model_icon_color = if is_selected {
|
||||
Color::Accent
|
||||
} else {
|
||||
Color::Muted
|
||||
};
|
||||
|
||||
Some(
|
||||
ListItem::new(ix)
|
||||
.inset(true)
|
||||
.spacing(ListItemSpacing::Sparse)
|
||||
.toggle_state(selected)
|
||||
.start_slot::<Icon>(model_info.icon.map(|icon| {
|
||||
Icon::new(icon)
|
||||
.color(model_icon_color)
|
||||
.size(IconSize::Small)
|
||||
}))
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.pl_0p5()
|
||||
.gap_1p5()
|
||||
.w(px(240.))
|
||||
.child(Label::new(model_info.name.clone()).truncate()),
|
||||
)
|
||||
.end_slot(div().pr_3().when(is_selected, |this| {
|
||||
this.child(
|
||||
Icon::new(IconName::Check)
|
||||
.color(Color::Accent)
|
||||
.size(IconSize::Small),
|
||||
)
|
||||
}))
|
||||
.into_any_element(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn render_footer(
|
||||
&self,
|
||||
_: &mut Window,
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Option<gpui::AnyElement> {
|
||||
Some(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.p_1()
|
||||
.gap_4()
|
||||
.justify_between()
|
||||
.child(
|
||||
Button::new("configure", "Configure")
|
||||
.icon(IconName::Settings)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.icon_position(IconPosition::Start)
|
||||
.on_click(|_, window, cx| {
|
||||
window.dispatch_action(
|
||||
zed_actions::agent::OpenSettings.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
.into_any(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn info_list_to_picker_entries(
|
||||
model_list: AgentModelList,
|
||||
) -> impl Iterator<Item = AcpModelPickerEntry> {
|
||||
match model_list {
|
||||
AgentModelList::Flat(list) => {
|
||||
itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
|
||||
}
|
||||
AgentModelList::Grouped(index_map) => {
|
||||
itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
|
||||
std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
|
||||
.chain(models.into_iter().map(AcpModelPickerEntry::Model))
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn fuzzy_search(
|
||||
model_list: AgentModelList,
|
||||
query: String,
|
||||
executor: BackgroundExecutor,
|
||||
) -> AgentModelList {
|
||||
async fn fuzzy_search_list(
|
||||
model_list: Vec<AgentModelInfo>,
|
||||
query: &str,
|
||||
executor: BackgroundExecutor,
|
||||
) -> Vec<AgentModelInfo> {
|
||||
let candidates = model_list
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(ix, model)| {
|
||||
StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let mut matches = match_strings(
|
||||
&candidates,
|
||||
&query,
|
||||
false,
|
||||
true,
|
||||
100,
|
||||
&Default::default(),
|
||||
executor,
|
||||
)
|
||||
.await;
|
||||
|
||||
matches.sort_unstable_by_key(|mat| {
|
||||
let candidate = &candidates[mat.candidate_id];
|
||||
(Reverse(OrderedFloat(mat.score)), candidate.id)
|
||||
});
|
||||
|
||||
matches
|
||||
.into_iter()
|
||||
.map(|mat| model_list[mat.candidate_id].clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
match model_list {
|
||||
AgentModelList::Flat(model_list) => {
|
||||
AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
|
||||
}
|
||||
AgentModelList::Grouped(index_map) => {
|
||||
let groups =
|
||||
futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
|
||||
fuzzy_search_list(models, &query, executor.clone())
|
||||
.map(|results| (group_name, results))
|
||||
}))
|
||||
.await;
|
||||
AgentModelList::Grouped(IndexMap::from_iter(
|
||||
groups
|
||||
.into_iter()
|
||||
.filter(|(_, results)| !results.is_empty()),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use gpui::TestAppContext;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
|
||||
AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
|
||||
|(group, models)| {
|
||||
(
|
||||
acp_thread::AgentModelGroupName(group.to_string().into()),
|
||||
models
|
||||
.into_iter()
|
||||
.map(|model| acp_thread::AgentModelInfo {
|
||||
id: acp_thread::AgentModelId(model.to_string().into()),
|
||||
name: model.to_string().into(),
|
||||
icon: None,
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
},
|
||||
)))
|
||||
}
|
||||
|
||||
fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
|
||||
let AgentModelList::Grouped(groups) = result else {
|
||||
panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
groups.len(),
|
||||
expected.len(),
|
||||
"Number of groups doesn't match"
|
||||
);
|
||||
|
||||
for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
|
||||
let (actual_group, actual_models) = groups.get_index(i).unwrap();
|
||||
assert_eq!(
|
||||
actual_group.0.as_ref(),
|
||||
*expected_group,
|
||||
"Group at position {} doesn't match expected group",
|
||||
i
|
||||
);
|
||||
assert_eq!(
|
||||
actual_models.len(),
|
||||
expected_models.len(),
|
||||
"Number of models in group {} doesn't match",
|
||||
expected_group
|
||||
);
|
||||
|
||||
for (j, expected_model_name) in expected_models.iter().enumerate() {
|
||||
assert_eq!(
|
||||
actual_models[j].name, *expected_model_name,
|
||||
"Model at position {} in group {} doesn't match expected model",
|
||||
j, expected_group
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_fuzzy_match(cx: &mut TestAppContext) {
|
||||
let models = create_model_list(vec![
|
||||
(
|
||||
"zed",
|
||||
vec![
|
||||
"Claude 3.7 Sonnet",
|
||||
"Claude 3.7 Sonnet Thinking",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-nano",
|
||||
],
|
||||
),
|
||||
("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
|
||||
("ollama", vec!["mistral", "deepseek"]),
|
||||
]);
|
||||
|
||||
// Results should preserve models order whenever possible.
|
||||
// In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
|
||||
// similarity scores, but `zed/gpt-4.1` was higher in the models list,
|
||||
// so it should appear first in the results.
|
||||
let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
|
||||
assert_models_eq(
|
||||
results,
|
||||
vec![
|
||||
("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
|
||||
("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
|
||||
],
|
||||
);
|
||||
|
||||
// Fuzzy search
|
||||
let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
|
||||
assert_models_eq(
|
||||
results,
|
||||
vec![
|
||||
("zed", vec!["gpt-4.1-nano"]),
|
||||
("openai", vec!["gpt-4.1-nano"]),
|
||||
],
|
||||
);
|
||||
}
|
||||
}
|
||||
85
crates/agent_ui/src/acp/model_selector_popover.rs
Normal file
85
crates/agent_ui/src/acp/model_selector_popover.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
use std::rc::Rc;
|
||||
|
||||
use acp_thread::AgentModelSelector;
|
||||
use agent_client_protocol as acp;
|
||||
use gpui::{Entity, FocusHandle};
|
||||
use picker::popover_menu::PickerPopoverMenu;
|
||||
use ui::{
|
||||
ButtonLike, Context, IntoElement, PopoverMenuHandle, SharedString, Tooltip, Window, prelude::*,
|
||||
};
|
||||
use zed_actions::agent::ToggleModelSelector;
|
||||
|
||||
use crate::acp::{AcpModelSelector, model_selector::acp_model_selector};
|
||||
|
||||
pub struct AcpModelSelectorPopover {
|
||||
selector: Entity<AcpModelSelector>,
|
||||
menu_handle: PopoverMenuHandle<AcpModelSelector>,
|
||||
focus_handle: FocusHandle,
|
||||
}
|
||||
|
||||
impl AcpModelSelectorPopover {
|
||||
pub(crate) fn new(
|
||||
session_id: acp::SessionId,
|
||||
selector: Rc<dyn AgentModelSelector>,
|
||||
menu_handle: PopoverMenuHandle<AcpModelSelector>,
|
||||
focus_handle: FocusHandle,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
Self {
|
||||
selector: cx.new(move |cx| acp_model_selector(session_id, selector, window, cx)),
|
||||
menu_handle,
|
||||
focus_handle,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn toggle(&self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.menu_handle.toggle(window, cx);
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AcpModelSelectorPopover {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let model = self.selector.read(cx).delegate.active_model();
|
||||
let model_name = model
|
||||
.as_ref()
|
||||
.map(|model| model.name.clone())
|
||||
.unwrap_or_else(|| SharedString::from("Select a Model"));
|
||||
|
||||
let model_icon = model.as_ref().and_then(|model| model.icon);
|
||||
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
|
||||
PickerPopoverMenu::new(
|
||||
self.selector.clone(),
|
||||
ButtonLike::new("active-model")
|
||||
.when_some(model_icon, |this, icon| {
|
||||
this.child(Icon::new(icon).color(Color::Muted).size(IconSize::XSmall))
|
||||
})
|
||||
.child(
|
||||
Label::new(model_name)
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small)
|
||||
.ml_0p5(),
|
||||
)
|
||||
.child(
|
||||
Icon::new(IconName::ChevronDown)
|
||||
.color(Color::Muted)
|
||||
.size(IconSize::XSmall),
|
||||
),
|
||||
move |window, cx| {
|
||||
Tooltip::for_action_in(
|
||||
"Change Model",
|
||||
&ToggleModelSelector,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
},
|
||||
gpui::Corner::BottomRight,
|
||||
cx,
|
||||
)
|
||||
.with_handle(self.menu_handle.clone())
|
||||
.render(window, cx)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1521,7 +1521,8 @@ impl AgentDiff {
|
||||
self.update_reviewing_editors(workspace, window, cx);
|
||||
}
|
||||
}
|
||||
AcpThreadEvent::Stopped
|
||||
AcpThreadEvent::EntriesRemoved(_)
|
||||
| AcpThreadEvent::Stopped
|
||||
| AcpThreadEvent::ToolAuthorizationRequired
|
||||
| AcpThreadEvent::Error
|
||||
| AcpThreadEvent::ServerExited(_) => {}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -64,6 +64,8 @@ actions!(
|
||||
NewTextThread,
|
||||
/// Toggles the context picker interface for adding files, symbols, or other context.
|
||||
ToggleContextPicker,
|
||||
/// Toggles the menu to create new agent threads.
|
||||
ToggleNewThreadMenu,
|
||||
/// Toggles the navigation menu for switching between threads and views.
|
||||
ToggleNavigationMenu,
|
||||
/// Toggles the options menu for agent settings and preferences.
|
||||
@@ -155,11 +157,11 @@ enum ExternalAgent {
|
||||
}
|
||||
|
||||
impl ExternalAgent {
|
||||
pub fn server(&self) -> Rc<dyn agent_servers::AgentServer> {
|
||||
pub fn server(&self, fs: Arc<dyn fs::Fs>) -> Rc<dyn agent_servers::AgentServer> {
|
||||
match self {
|
||||
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
|
||||
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
|
||||
ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer),
|
||||
ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer::new(fs)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ mod agent_notification;
|
||||
mod burn_mode_tooltip;
|
||||
mod context_pill;
|
||||
mod end_trial_upsell;
|
||||
mod new_thread_button;
|
||||
// mod new_thread_button;
|
||||
mod onboarding_modal;
|
||||
pub mod preview;
|
||||
|
||||
@@ -10,5 +10,5 @@ pub use agent_notification::*;
|
||||
pub use burn_mode_tooltip::*;
|
||||
pub use context_pill::*;
|
||||
pub use end_trial_upsell::*;
|
||||
pub use new_thread_button::*;
|
||||
// pub use new_thread_button::*;
|
||||
pub use onboarding_modal::*;
|
||||
|
||||
@@ -11,7 +11,7 @@ pub struct NewThreadButton {
|
||||
}
|
||||
|
||||
impl NewThreadButton {
|
||||
pub fn new(id: impl Into<ElementId>, label: impl Into<SharedString>, icon: IconName) -> Self {
|
||||
fn new(id: impl Into<ElementId>, label: impl Into<SharedString>, icon: IconName) -> Self {
|
||||
Self {
|
||||
id: id.into(),
|
||||
label: label.into(),
|
||||
@@ -21,12 +21,12 @@ impl NewThreadButton {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn keybinding(mut self, keybinding: Option<ui::KeyBinding>) -> Self {
|
||||
fn keybinding(mut self, keybinding: Option<ui::KeyBinding>) -> Self {
|
||||
self.keybinding = keybinding;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn on_click<F>(mut self, handler: F) -> Self
|
||||
fn on_click<F>(mut self, handler: F) -> Self
|
||||
where
|
||||
F: Fn(&mut Window, &mut App) + 'static,
|
||||
{
|
||||
|
||||
@@ -58,9 +58,7 @@ impl Assets {
|
||||
pub fn load_test_fonts(&self, cx: &App) {
|
||||
cx.text_system()
|
||||
.add_fonts(vec![
|
||||
self.load("fonts/plex-mono/ZedPlexMono-Regular.ttf")
|
||||
.unwrap()
|
||||
.unwrap(),
|
||||
self.load("fonts/lilex/Lilex-Regular.ttf").unwrap().unwrap(),
|
||||
])
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
@@ -86,7 +86,7 @@ impl Tool for DiagnosticsTool {
|
||||
input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
@@ -159,10 +159,6 @@ impl Tool for DiagnosticsTool {
|
||||
}
|
||||
}
|
||||
|
||||
action_log.update(cx, |action_log, _cx| {
|
||||
action_log.checked_project_diagnostics();
|
||||
});
|
||||
|
||||
if has_diagnostics {
|
||||
Task::ready(Ok(output.into())).into()
|
||||
} else {
|
||||
|
||||
@@ -65,7 +65,7 @@ pub enum EditAgentOutputEvent {
|
||||
ResolvingEditRange(Range<Anchor>),
|
||||
UnresolvedEditRange,
|
||||
AmbiguousEditRange(Vec<Range<usize>>),
|
||||
Edited,
|
||||
Edited(Range<Anchor>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
|
||||
@@ -178,7 +178,9 @@ impl EditAgent {
|
||||
)
|
||||
});
|
||||
output_events_tx
|
||||
.unbounded_send(EditAgentOutputEvent::Edited)
|
||||
.unbounded_send(EditAgentOutputEvent::Edited(
|
||||
language::Anchor::MIN..language::Anchor::MAX,
|
||||
))
|
||||
.ok();
|
||||
})?;
|
||||
|
||||
@@ -200,7 +202,9 @@ impl EditAgent {
|
||||
});
|
||||
})?;
|
||||
output_events_tx
|
||||
.unbounded_send(EditAgentOutputEvent::Edited)
|
||||
.unbounded_send(EditAgentOutputEvent::Edited(
|
||||
language::Anchor::MIN..language::Anchor::MAX,
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
@@ -336,8 +340,8 @@ impl EditAgent {
|
||||
// Edit the buffer and report edits to the action log as part of the
|
||||
// same effect cycle, otherwise the edit will be reported as if the
|
||||
// user made it.
|
||||
cx.update(|cx| {
|
||||
let max_edit_end = buffer.update(cx, |buffer, cx| {
|
||||
let (min_edit_start, max_edit_end) = cx.update(|cx| {
|
||||
let (min_edit_start, max_edit_end) = buffer.update(cx, |buffer, cx| {
|
||||
buffer.edit(edits.iter().cloned(), None, cx);
|
||||
let max_edit_end = buffer
|
||||
.summaries_for_anchors::<Point, _>(
|
||||
@@ -345,7 +349,16 @@ impl EditAgent {
|
||||
)
|
||||
.max()
|
||||
.unwrap();
|
||||
buffer.anchor_before(max_edit_end)
|
||||
let min_edit_start = buffer
|
||||
.summaries_for_anchors::<Point, _>(
|
||||
edits.iter().map(|(range, _)| &range.start),
|
||||
)
|
||||
.min()
|
||||
.unwrap();
|
||||
(
|
||||
buffer.anchor_after(min_edit_start),
|
||||
buffer.anchor_before(max_edit_end),
|
||||
)
|
||||
});
|
||||
self.action_log
|
||||
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
|
||||
@@ -358,9 +371,10 @@ impl EditAgent {
|
||||
cx,
|
||||
);
|
||||
});
|
||||
(min_edit_start, max_edit_end)
|
||||
})?;
|
||||
output_events
|
||||
.unbounded_send(EditAgentOutputEvent::Edited)
|
||||
.unbounded_send(EditAgentOutputEvent::Edited(min_edit_start..max_edit_end))
|
||||
.ok();
|
||||
}
|
||||
|
||||
@@ -755,6 +769,7 @@ mod tests {
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use language_model::fake_provider::FakeLanguageModel;
|
||||
use pretty_assertions::assert_matches;
|
||||
use project::{AgentLocation, Project};
|
||||
use rand::prelude::*;
|
||||
use rand::rngs::StdRng;
|
||||
@@ -992,7 +1007,10 @@ mod tests {
|
||||
|
||||
model.send_last_completion_stream_text_chunk("<new_text>abX");
|
||||
cx.run_until_parked();
|
||||
assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
|
||||
assert_matches!(
|
||||
drain_events(&mut events).as_slice(),
|
||||
[EditAgentOutputEvent::Edited(_)]
|
||||
);
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
"abXc\ndef\nghi\njkl"
|
||||
@@ -1007,7 +1025,10 @@ mod tests {
|
||||
|
||||
model.send_last_completion_stream_text_chunk("cY");
|
||||
cx.run_until_parked();
|
||||
assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
|
||||
assert_matches!(
|
||||
drain_events(&mut events).as_slice(),
|
||||
[EditAgentOutputEvent::Edited { .. }]
|
||||
);
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
"abXcY\ndef\nghi\njkl"
|
||||
@@ -1118,9 +1139,9 @@ mod tests {
|
||||
|
||||
model.send_last_completion_stream_text_chunk("GHI</new_text>");
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
vec![EditAgentOutputEvent::Edited]
|
||||
assert_matches!(
|
||||
drain_events(&mut events).as_slice(),
|
||||
[EditAgentOutputEvent::Edited { .. }]
|
||||
);
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
@@ -1165,9 +1186,9 @@ mod tests {
|
||||
);
|
||||
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
vec![EditAgentOutputEvent::Edited]
|
||||
assert_matches!(
|
||||
drain_events(&mut events).as_slice(),
|
||||
[EditAgentOutputEvent::Edited(_)]
|
||||
);
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
@@ -1183,9 +1204,9 @@ mod tests {
|
||||
|
||||
chunks_tx.unbounded_send("```\njkl\n").unwrap();
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
vec![EditAgentOutputEvent::Edited]
|
||||
assert_matches!(
|
||||
drain_events(&mut events).as_slice(),
|
||||
[EditAgentOutputEvent::Edited { .. }]
|
||||
);
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
@@ -1201,9 +1222,9 @@ mod tests {
|
||||
|
||||
chunks_tx.unbounded_send("mno\n").unwrap();
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
vec![EditAgentOutputEvent::Edited]
|
||||
assert_matches!(
|
||||
drain_events(&mut events).as_slice(),
|
||||
[EditAgentOutputEvent::Edited { .. }]
|
||||
);
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
@@ -1219,9 +1240,9 @@ mod tests {
|
||||
|
||||
chunks_tx.unbounded_send("pqr\n```").unwrap();
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
vec![EditAgentOutputEvent::Edited]
|
||||
assert_matches!(
|
||||
drain_events(&mut events).as_slice(),
|
||||
[EditAgentOutputEvent::Edited(_)],
|
||||
);
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
|
||||
@@ -307,7 +307,7 @@ impl Tool for EditFileTool {
|
||||
let mut ambiguous_ranges = Vec::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
EditAgentOutputEvent::Edited => {
|
||||
EditAgentOutputEvent::Edited { .. } => {
|
||||
if let Some(card) = card_clone.as_ref() {
|
||||
card.update(cx, |card, cx| card.update_diff(cx))?;
|
||||
}
|
||||
|
||||
@@ -18,6 +18,6 @@ collections.workspace = true
|
||||
derive_more.workspace = true
|
||||
gpui.workspace = true
|
||||
parking_lot.workspace = true
|
||||
rodio = { version = "0.21.1", default-features = false, features = ["wav", "playback", "tracing"] }
|
||||
rodio = { workspace = true, features = ["wav", "playback", "tracing"] }
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
@@ -59,16 +59,9 @@ pub enum VersionCheckType {
|
||||
pub enum AutoUpdateStatus {
|
||||
Idle,
|
||||
Checking,
|
||||
Downloading {
|
||||
version: VersionCheckType,
|
||||
},
|
||||
Installing {
|
||||
version: VersionCheckType,
|
||||
},
|
||||
Updated {
|
||||
binary_path: PathBuf,
|
||||
version: VersionCheckType,
|
||||
},
|
||||
Downloading { version: VersionCheckType },
|
||||
Installing { version: VersionCheckType },
|
||||
Updated { version: VersionCheckType },
|
||||
Errored,
|
||||
}
|
||||
|
||||
@@ -83,6 +76,7 @@ pub struct AutoUpdater {
|
||||
current_version: SemanticVersion,
|
||||
http_client: Arc<HttpClientWithUrl>,
|
||||
pending_poll: Option<Task<Option<()>>>,
|
||||
quit_subscription: Option<gpui::Subscription>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
@@ -164,7 +158,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
AutoUpdateSetting::register(cx);
|
||||
|
||||
cx.observe_new(|workspace: &mut Workspace, _window, _cx| {
|
||||
workspace.register_action(|_, action: &Check, window, cx| check(action, window, cx));
|
||||
workspace.register_action(|_, action, window, cx| check(action, window, cx));
|
||||
|
||||
workspace.register_action(|_, action, _, cx| {
|
||||
view_release_notes(action, cx);
|
||||
@@ -174,7 +168,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
|
||||
let version = release_channel::AppVersion::global(cx);
|
||||
let auto_updater = cx.new(|cx| {
|
||||
let updater = AutoUpdater::new(version, http_client);
|
||||
let updater = AutoUpdater::new(version, http_client, cx);
|
||||
|
||||
let poll_for_updates = ReleaseChannel::try_global(cx)
|
||||
.map(|channel| channel.poll_for_updates())
|
||||
@@ -321,12 +315,34 @@ impl AutoUpdater {
|
||||
cx.default_global::<GlobalAutoUpdate>().0.clone()
|
||||
}
|
||||
|
||||
fn new(current_version: SemanticVersion, http_client: Arc<HttpClientWithUrl>) -> Self {
|
||||
fn new(
|
||||
current_version: SemanticVersion,
|
||||
http_client: Arc<HttpClientWithUrl>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
// On windows, executable files cannot be overwritten while they are
|
||||
// running, so we must wait to overwrite the application until quitting
|
||||
// or restarting. When quitting the app, we spawn the auto update helper
|
||||
// to finish the auto update process after Zed exits. When restarting
|
||||
// the app after an update, we use `set_restart_path` to run the auto
|
||||
// update helper instead of the app, so that it can overwrite the app
|
||||
// and then spawn the new binary.
|
||||
let quit_subscription = Some(cx.on_app_quit(|_, _| async move {
|
||||
#[cfg(target_os = "windows")]
|
||||
finalize_auto_update_on_quit();
|
||||
}));
|
||||
|
||||
cx.on_app_restart(|this, _| {
|
||||
this.quit_subscription.take();
|
||||
})
|
||||
.detach();
|
||||
|
||||
Self {
|
||||
status: AutoUpdateStatus::Idle,
|
||||
current_version,
|
||||
http_client,
|
||||
pending_poll: None,
|
||||
quit_subscription,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -536,6 +552,8 @@ impl AutoUpdater {
|
||||
)
|
||||
})?;
|
||||
|
||||
Self::check_dependencies()?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.status = AutoUpdateStatus::Checking;
|
||||
cx.notify();
|
||||
@@ -582,13 +600,15 @@ impl AutoUpdater {
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
let binary_path = Self::binary_path(installer_dir, target_path, &cx).await?;
|
||||
let new_binary_path = Self::install_release(installer_dir, target_path, &cx).await?;
|
||||
if let Some(new_binary_path) = new_binary_path {
|
||||
cx.update(|cx| cx.set_restart_path(new_binary_path))?;
|
||||
}
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.set_should_show_update_notification(true, cx)
|
||||
.detach_and_log_err(cx);
|
||||
this.status = AutoUpdateStatus::Updated {
|
||||
binary_path,
|
||||
version: newer_version,
|
||||
};
|
||||
cx.notify();
|
||||
@@ -639,6 +659,15 @@ impl AutoUpdater {
|
||||
}
|
||||
}
|
||||
|
||||
fn check_dependencies() -> Result<()> {
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
anyhow::ensure!(
|
||||
which::which("rsync").is_ok(),
|
||||
"Aborting. Could not find rsync which is required for auto-updates."
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn target_path(installer_dir: &InstallerDir) -> Result<PathBuf> {
|
||||
let filename = match OS {
|
||||
"macos" => anyhow::Ok("Zed.dmg"),
|
||||
@@ -647,20 +676,14 @@ impl AutoUpdater {
|
||||
unsupported_os => anyhow::bail!("not supported: {unsupported_os}"),
|
||||
}?;
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
anyhow::ensure!(
|
||||
which::which("rsync").is_ok(),
|
||||
"Aborting. Could not find rsync which is required for auto-updates."
|
||||
);
|
||||
|
||||
Ok(installer_dir.path().join(filename))
|
||||
}
|
||||
|
||||
async fn binary_path(
|
||||
async fn install_release(
|
||||
installer_dir: InstallerDir,
|
||||
target_path: PathBuf,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<PathBuf> {
|
||||
) -> Result<Option<PathBuf>> {
|
||||
match OS {
|
||||
"macos" => install_release_macos(&installer_dir, target_path, cx).await,
|
||||
"linux" => install_release_linux(&installer_dir, target_path, cx).await,
|
||||
@@ -801,7 +824,7 @@ async fn install_release_linux(
|
||||
temp_dir: &InstallerDir,
|
||||
downloaded_tar_gz: PathBuf,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<PathBuf> {
|
||||
) -> Result<Option<PathBuf>> {
|
||||
let channel = cx.update(|cx| ReleaseChannel::global(cx).dev_name())?;
|
||||
let home_dir = PathBuf::from(env::var("HOME").context("no HOME env var set")?);
|
||||
let running_app_path = cx.update(|cx| cx.app_path())??;
|
||||
@@ -861,14 +884,14 @@ async fn install_release_linux(
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
|
||||
Ok(to.join(expected_suffix))
|
||||
Ok(Some(to.join(expected_suffix)))
|
||||
}
|
||||
|
||||
async fn install_release_macos(
|
||||
temp_dir: &InstallerDir,
|
||||
downloaded_dmg: PathBuf,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<PathBuf> {
|
||||
) -> Result<Option<PathBuf>> {
|
||||
let running_app_path = cx.update(|cx| cx.app_path())??;
|
||||
let running_app_filename = running_app_path
|
||||
.file_name()
|
||||
@@ -910,10 +933,10 @@ async fn install_release_macos(
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
|
||||
Ok(running_app_path)
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn install_release_windows(downloaded_installer: PathBuf) -> Result<PathBuf> {
|
||||
async fn install_release_windows(downloaded_installer: PathBuf) -> Result<Option<PathBuf>> {
|
||||
let output = Command::new(downloaded_installer)
|
||||
.arg("/verysilent")
|
||||
.arg("/update=true")
|
||||
@@ -926,29 +949,36 @@ async fn install_release_windows(downloaded_installer: PathBuf) -> Result<PathBu
|
||||
"failed to start installer: {:?}",
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
Ok(std::env::current_exe()?)
|
||||
// We return the path to the update helper program, because it will
|
||||
// perform the final steps of the update process, copying the new binary,
|
||||
// deleting the old one, and launching the new binary.
|
||||
let helper_path = std::env::current_exe()?
|
||||
.parent()
|
||||
.context("No parent dir for Zed.exe")?
|
||||
.join("tools\\auto_update_helper.exe");
|
||||
Ok(Some(helper_path))
|
||||
}
|
||||
|
||||
pub fn check_pending_installation() -> bool {
|
||||
pub fn finalize_auto_update_on_quit() {
|
||||
let Some(installer_path) = std::env::current_exe()
|
||||
.ok()
|
||||
.and_then(|p| p.parent().map(|p| p.join("updates")))
|
||||
else {
|
||||
return false;
|
||||
return;
|
||||
};
|
||||
|
||||
// The installer will create a flag file after it finishes updating
|
||||
let flag_file = installer_path.join("versions.txt");
|
||||
if flag_file.exists() {
|
||||
if let Some(helper) = installer_path
|
||||
if flag_file.exists()
|
||||
&& let Some(helper) = installer_path
|
||||
.parent()
|
||||
.map(|p| p.join("tools\\auto_update_helper.exe"))
|
||||
{
|
||||
let _ = std::process::Command::new(helper).spawn();
|
||||
return true;
|
||||
}
|
||||
{
|
||||
let mut command = std::process::Command::new(helper);
|
||||
command.arg("--launch");
|
||||
command.arg("false");
|
||||
let _ = command.spawn();
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -1002,7 +1032,6 @@ mod tests {
|
||||
let app_commit_sha = Ok(Some("a".to_string()));
|
||||
let installed_version = SemanticVersion::new(1, 0, 0);
|
||||
let status = AutoUpdateStatus::Updated {
|
||||
binary_path: PathBuf::new(),
|
||||
version: VersionCheckType::Semantic(SemanticVersion::new(1, 0, 1)),
|
||||
};
|
||||
let fetched_version = SemanticVersion::new(1, 0, 1);
|
||||
@@ -1024,7 +1053,6 @@ mod tests {
|
||||
let app_commit_sha = Ok(Some("a".to_string()));
|
||||
let installed_version = SemanticVersion::new(1, 0, 0);
|
||||
let status = AutoUpdateStatus::Updated {
|
||||
binary_path: PathBuf::new(),
|
||||
version: VersionCheckType::Semantic(SemanticVersion::new(1, 0, 1)),
|
||||
};
|
||||
let fetched_version = SemanticVersion::new(1, 0, 2);
|
||||
@@ -1090,7 +1118,6 @@ mod tests {
|
||||
let app_commit_sha = Ok(Some("a".to_string()));
|
||||
let installed_version = SemanticVersion::new(1, 0, 0);
|
||||
let status = AutoUpdateStatus::Updated {
|
||||
binary_path: PathBuf::new(),
|
||||
version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())),
|
||||
};
|
||||
let fetched_sha = "b".to_string();
|
||||
@@ -1112,7 +1139,6 @@ mod tests {
|
||||
let app_commit_sha = Ok(Some("a".to_string()));
|
||||
let installed_version = SemanticVersion::new(1, 0, 0);
|
||||
let status = AutoUpdateStatus::Updated {
|
||||
binary_path: PathBuf::new(),
|
||||
version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())),
|
||||
};
|
||||
let fetched_sha = "c".to_string();
|
||||
@@ -1160,7 +1186,6 @@ mod tests {
|
||||
let app_commit_sha = Ok(None);
|
||||
let installed_version = SemanticVersion::new(1, 0, 0);
|
||||
let status = AutoUpdateStatus::Updated {
|
||||
binary_path: PathBuf::new(),
|
||||
version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())),
|
||||
};
|
||||
let fetched_sha = "b".to_string();
|
||||
@@ -1183,7 +1208,6 @@ mod tests {
|
||||
let app_commit_sha = Ok(None);
|
||||
let installed_version = SemanticVersion::new(1, 0, 0);
|
||||
let status = AutoUpdateStatus::Updated {
|
||||
binary_path: PathBuf::new(),
|
||||
version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())),
|
||||
};
|
||||
let fetched_sha = "c".to_string();
|
||||
|
||||
@@ -37,6 +37,11 @@ mod windows_impl {
|
||||
pub(crate) const WM_JOB_UPDATED: u32 = WM_USER + 1;
|
||||
pub(crate) const WM_TERMINATE: u32 = WM_USER + 2;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Args {
|
||||
launch: Option<bool>,
|
||||
}
|
||||
|
||||
pub(crate) fn run() -> Result<()> {
|
||||
let helper_dir = std::env::current_exe()?
|
||||
.parent()
|
||||
@@ -51,8 +56,9 @@ mod windows_impl {
|
||||
log::info!("======= Starting Zed update =======");
|
||||
let (tx, rx) = std::sync::mpsc::channel();
|
||||
let hwnd = create_dialog_window(rx)?.0 as isize;
|
||||
let args = parse_args();
|
||||
std::thread::spawn(move || {
|
||||
let result = perform_update(app_dir.as_path(), Some(hwnd));
|
||||
let result = perform_update(app_dir.as_path(), Some(hwnd), args.launch.unwrap_or(true));
|
||||
tx.send(result).ok();
|
||||
unsafe { PostMessageW(Some(HWND(hwnd as _)), WM_TERMINATE, WPARAM(0), LPARAM(0)) }.ok();
|
||||
});
|
||||
@@ -77,6 +83,41 @@ mod windows_impl {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_args() -> Args {
|
||||
let mut result = Args { launch: None };
|
||||
if let Some(candidate) = std::env::args().nth(1) {
|
||||
parse_single_arg(&candidate, &mut result);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn parse_single_arg(arg: &str, result: &mut Args) {
|
||||
let Some((key, value)) = arg.strip_prefix("--").and_then(|arg| arg.split_once('=')) else {
|
||||
log::error!(
|
||||
"Invalid argument format: '{}'. Expected format: --key=value",
|
||||
arg
|
||||
);
|
||||
return;
|
||||
};
|
||||
|
||||
match key {
|
||||
"launch" => parse_launch_arg(value, &mut result.launch),
|
||||
_ => log::error!("Unknown argument: --{}", key),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_launch_arg(value: &str, arg: &mut Option<bool>) {
|
||||
match value {
|
||||
"true" => *arg = Some(true),
|
||||
"false" => *arg = Some(false),
|
||||
_ => log::error!(
|
||||
"Invalid value for --launch: '{}'. Expected 'true' or 'false'",
|
||||
value
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn show_error(mut content: String) {
|
||||
if content.len() > 600 {
|
||||
content.truncate(600);
|
||||
@@ -91,4 +132,47 @@ mod windows_impl {
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::windows_impl::{Args, parse_launch_arg, parse_single_arg};
|
||||
|
||||
#[test]
|
||||
fn test_parse_launch_arg() {
|
||||
let mut arg = None;
|
||||
parse_launch_arg("true", &mut arg);
|
||||
assert_eq!(arg, Some(true));
|
||||
|
||||
let mut arg = None;
|
||||
parse_launch_arg("false", &mut arg);
|
||||
assert_eq!(arg, Some(false));
|
||||
|
||||
let mut arg = None;
|
||||
parse_launch_arg("invalid", &mut arg);
|
||||
assert_eq!(arg, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_single_arg() {
|
||||
let mut args = Args { launch: None };
|
||||
parse_single_arg("--launch=true", &mut args);
|
||||
assert_eq!(args.launch, Some(true));
|
||||
|
||||
let mut args = Args { launch: None };
|
||||
parse_single_arg("--launch=false", &mut args);
|
||||
assert_eq!(args.launch, Some(false));
|
||||
|
||||
let mut args = Args { launch: None };
|
||||
parse_single_arg("--launch=invalid", &mut args);
|
||||
assert_eq!(args.launch, None);
|
||||
|
||||
let mut args = Args { launch: None };
|
||||
parse_single_arg("--launch", &mut args);
|
||||
assert_eq!(args.launch, None);
|
||||
|
||||
let mut args = Args { launch: None };
|
||||
parse_single_arg("--unknown", &mut args);
|
||||
assert_eq!(args.launch, None);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ pub(crate) fn create_dialog_window(receiver: Receiver<Result<()>>) -> Result<HWN
|
||||
let hwnd = CreateWindowExW(
|
||||
WS_EX_TOPMOST,
|
||||
class_name,
|
||||
windows::core::w!("Zed Editor"),
|
||||
windows::core::w!("Zed"),
|
||||
WS_VISIBLE | WS_POPUP | WS_CAPTION,
|
||||
rect.right / 2 - width / 2,
|
||||
rect.bottom / 2 - height / 2,
|
||||
@@ -171,7 +171,7 @@ unsafe extern "system" fn wnd_proc(
|
||||
&HSTRING::from(font_name),
|
||||
);
|
||||
let temp = SelectObject(hdc, font.into());
|
||||
let string = HSTRING::from("Zed Editor is updating...");
|
||||
let string = HSTRING::from("Updating Zed...");
|
||||
return_if_failed!(TextOutW(hdc, 20, 15, &string).ok());
|
||||
return_if_failed!(DeleteObject(temp).ok());
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ pub(crate) const JOBS: [Job; 2] = [
|
||||
},
|
||||
];
|
||||
|
||||
pub(crate) fn perform_update(app_dir: &Path, hwnd: Option<isize>) -> Result<()> {
|
||||
pub(crate) fn perform_update(app_dir: &Path, hwnd: Option<isize>, launch: bool) -> Result<()> {
|
||||
let hwnd = hwnd.map(|ptr| HWND(ptr as _));
|
||||
|
||||
for job in JOBS.iter() {
|
||||
@@ -145,9 +145,11 @@ pub(crate) fn perform_update(app_dir: &Path, hwnd: Option<isize>) -> Result<()>
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = std::process::Command::new(app_dir.join("Zed.exe"))
|
||||
.creation_flags(CREATE_NEW_PROCESS_GROUP.0)
|
||||
.spawn();
|
||||
if launch {
|
||||
let _ = std::process::Command::new(app_dir.join("Zed.exe"))
|
||||
.creation_flags(CREATE_NEW_PROCESS_GROUP.0)
|
||||
.spawn();
|
||||
}
|
||||
log::info!("Update completed successfully");
|
||||
Ok(())
|
||||
}
|
||||
@@ -159,11 +161,11 @@ mod test {
|
||||
#[test]
|
||||
fn test_perform_update() {
|
||||
let app_dir = std::path::Path::new("C:/");
|
||||
assert!(perform_update(app_dir, None).is_ok());
|
||||
assert!(perform_update(app_dir, None, false).is_ok());
|
||||
|
||||
// Simulate a timeout
|
||||
unsafe { std::env::set_var("ZED_AUTO_UPDATE", "err") };
|
||||
let ret = perform_update(app_dir, None);
|
||||
let ret = perform_update(app_dir, None, false);
|
||||
assert!(ret.is_err_and(|e| e.to_string().as_str() == "Timed out"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -957,17 +957,14 @@ mod mac_os {
|
||||
) -> Result<()> {
|
||||
use anyhow::bail;
|
||||
|
||||
let app_id_prompt = format!("id of app \"{}\"", channel.display_name());
|
||||
let app_id_output = Command::new("osascript")
|
||||
let app_path_prompt = format!(
|
||||
"POSIX path of (path to application \"{}\")",
|
||||
channel.display_name()
|
||||
);
|
||||
let app_path_output = Command::new("osascript")
|
||||
.arg("-e")
|
||||
.arg(&app_id_prompt)
|
||||
.arg(&app_path_prompt)
|
||||
.output()?;
|
||||
if !app_id_output.status.success() {
|
||||
bail!("Could not determine app id for {}", channel.display_name());
|
||||
}
|
||||
let app_name = String::from_utf8(app_id_output.stdout)?.trim().to_owned();
|
||||
let app_path_prompt = format!("kMDItemCFBundleIdentifier == '{app_name}'");
|
||||
let app_path_output = Command::new("mdfind").arg(app_path_prompt).output()?;
|
||||
if !app_path_output.status.success() {
|
||||
bail!(
|
||||
"Could not determine app path for {}",
|
||||
|
||||
@@ -340,22 +340,35 @@ impl Telemetry {
|
||||
}
|
||||
|
||||
pub fn log_edit_event(self: &Arc<Self>, environment: &'static str, is_via_ssh: bool) {
|
||||
static LAST_EVENT_TIME: Mutex<Option<Instant>> = Mutex::new(None);
|
||||
|
||||
let mut state = self.state.lock();
|
||||
let period_data = state.event_coalescer.log_event(environment);
|
||||
drop(state);
|
||||
|
||||
if let Some((start, end, environment)) = period_data {
|
||||
let duration = end
|
||||
.saturating_duration_since(start)
|
||||
.min(Duration::from_secs(60 * 60 * 24))
|
||||
.as_millis() as i64;
|
||||
if let Some(mut last_event) = LAST_EVENT_TIME.try_lock() {
|
||||
let current_time = std::time::Instant::now();
|
||||
let last_time = last_event.get_or_insert(current_time);
|
||||
|
||||
telemetry::event!(
|
||||
"Editor Edited",
|
||||
duration = duration,
|
||||
environment = environment,
|
||||
is_via_ssh = is_via_ssh
|
||||
);
|
||||
if current_time.duration_since(*last_time) > Duration::from_secs(60 * 10) {
|
||||
*last_time = current_time;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some((start, end, environment)) = period_data {
|
||||
let duration = end
|
||||
.saturating_duration_since(start)
|
||||
.min(Duration::from_secs(60 * 60 * 24))
|
||||
.as_millis() as i64;
|
||||
|
||||
telemetry::event!(
|
||||
"Editor Edited",
|
||||
duration = duration,
|
||||
environment = environment,
|
||||
is_via_ssh = is_via_ssh
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ use language::{
|
||||
point_from_lsp, point_to_lsp,
|
||||
};
|
||||
use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName};
|
||||
use node_runtime::NodeRuntime;
|
||||
use node_runtime::{NodeRuntime, VersionCheck};
|
||||
use parking_lot::Mutex;
|
||||
use project::DisableAiSettings;
|
||||
use request::StatusNotification;
|
||||
@@ -1169,9 +1169,8 @@ async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::
|
||||
const SERVER_PATH: &str =
|
||||
"node_modules/@github/copilot-language-server/dist/language-server.js";
|
||||
|
||||
let latest_version = node_runtime
|
||||
.npm_package_latest_version(PACKAGE_NAME)
|
||||
.await?;
|
||||
// pinning it: https://github.com/zed-industries/zed/issues/36093
|
||||
const PINNED_VERSION: &str = "1.354";
|
||||
let server_path = paths::copilot_dir().join(SERVER_PATH);
|
||||
|
||||
fs.create_dir(paths::copilot_dir()).await?;
|
||||
@@ -1181,12 +1180,13 @@ async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::
|
||||
PACKAGE_NAME,
|
||||
&server_path,
|
||||
paths::copilot_dir(),
|
||||
&latest_version,
|
||||
&PINNED_VERSION,
|
||||
VersionCheck::VersionMismatch,
|
||||
)
|
||||
.await;
|
||||
if should_install {
|
||||
node_runtime
|
||||
.npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &latest_version)])
|
||||
.npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &PINNED_VERSION)])
|
||||
.await?;
|
||||
}
|
||||
|
||||
|
||||
@@ -2290,8 +2290,6 @@ mod tests {
|
||||
fn test_blocks_on_wrapped_lines(cx: &mut gpui::TestAppContext) {
|
||||
cx.update(init_test);
|
||||
|
||||
let _font_id = cx.text_system().font_id(&font("Helvetica")).unwrap();
|
||||
|
||||
let text = "one two three\nfour five six\nseven eight";
|
||||
|
||||
let buffer = cx.update(|cx| MultiBuffer::build_simple(text, cx));
|
||||
|
||||
@@ -1223,7 +1223,7 @@ mod tests {
|
||||
let tab_size = NonZeroU32::new(rng.gen_range(1..=4)).unwrap();
|
||||
|
||||
let font = test_font();
|
||||
let _font_id = text_system.font_id(&font);
|
||||
let _font_id = text_system.resolve_font(&font);
|
||||
let font_size = px(14.0);
|
||||
|
||||
log::info!("Tab size: {}", tab_size);
|
||||
|
||||
@@ -250,6 +250,24 @@ pub type RenderDiffHunkControlsFn = Arc<
|
||||
) -> AnyElement,
|
||||
>;
|
||||
|
||||
enum ReportEditorEvent {
|
||||
Saved { auto_saved: bool },
|
||||
EditorOpened,
|
||||
ZetaTosClicked,
|
||||
Closed,
|
||||
}
|
||||
|
||||
impl ReportEditorEvent {
|
||||
pub fn event_type(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Saved { .. } => "Editor Saved",
|
||||
Self::EditorOpened => "Editor Opened",
|
||||
Self::ZetaTosClicked => "Edit Prediction Provider ToS Clicked",
|
||||
Self::Closed => "Editor Closed",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct InlineValueCache {
|
||||
enabled: bool,
|
||||
inlays: Vec<InlayId>,
|
||||
@@ -2325,7 +2343,7 @@ impl Editor {
|
||||
}
|
||||
|
||||
if editor.mode.is_full() {
|
||||
editor.report_editor_event("Editor Opened", None, cx);
|
||||
editor.report_editor_event(ReportEditorEvent::EditorOpened, None, cx);
|
||||
}
|
||||
|
||||
editor
|
||||
@@ -9124,7 +9142,7 @@ impl Editor {
|
||||
.on_mouse_down(MouseButton::Left, |_, window, _| window.prevent_default())
|
||||
.on_click(cx.listener(|this, _event, window, cx| {
|
||||
cx.stop_propagation();
|
||||
this.report_editor_event("Edit Prediction Provider ToS Clicked", None, cx);
|
||||
this.report_editor_event(ReportEditorEvent::ZetaTosClicked, None, cx);
|
||||
window.dispatch_action(
|
||||
zed_actions::OpenZedPredictOnboarding.boxed_clone(),
|
||||
cx,
|
||||
@@ -20547,7 +20565,7 @@ impl Editor {
|
||||
|
||||
fn report_editor_event(
|
||||
&self,
|
||||
event_type: &'static str,
|
||||
reported_event: ReportEditorEvent,
|
||||
file_extension: Option<String>,
|
||||
cx: &App,
|
||||
) {
|
||||
@@ -20581,15 +20599,30 @@ impl Editor {
|
||||
.show_edit_predictions;
|
||||
|
||||
let project = project.read(cx);
|
||||
telemetry::event!(
|
||||
event_type,
|
||||
file_extension,
|
||||
vim_mode,
|
||||
copilot_enabled,
|
||||
copilot_enabled_for_language,
|
||||
edit_predictions_provider,
|
||||
is_via_ssh = project.is_via_ssh(),
|
||||
);
|
||||
let event_type = reported_event.event_type();
|
||||
|
||||
if let ReportEditorEvent::Saved { auto_saved } = reported_event {
|
||||
telemetry::event!(
|
||||
event_type,
|
||||
type = if auto_saved {"autosave"} else {"manual"},
|
||||
file_extension,
|
||||
vim_mode,
|
||||
copilot_enabled,
|
||||
copilot_enabled_for_language,
|
||||
edit_predictions_provider,
|
||||
is_via_ssh = project.is_via_ssh(),
|
||||
);
|
||||
} else {
|
||||
telemetry::event!(
|
||||
event_type,
|
||||
file_extension,
|
||||
vim_mode,
|
||||
copilot_enabled,
|
||||
copilot_enabled_for_language,
|
||||
edit_predictions_provider,
|
||||
is_via_ssh = project.is_via_ssh(),
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
/// Copy the highlighted chunks to the clipboard as JSON. The format is an array of lines,
|
||||
|
||||
@@ -22456,7 +22456,7 @@ async fn test_invisible_worktree_servers(cx: &mut TestAppContext) {
|
||||
);
|
||||
|
||||
cx.update(|_, cx| {
|
||||
workspace::reload(&workspace::Reload::default(), cx);
|
||||
workspace::reload(cx);
|
||||
});
|
||||
assert_language_servers_count(
|
||||
1,
|
||||
|
||||
@@ -3011,7 +3011,7 @@ impl EditorElement {
|
||||
.icon_color(Color::Custom(cx.theme().colors().editor_line_number))
|
||||
.selected_icon_color(Color::Custom(cx.theme().colors().editor_foreground))
|
||||
.icon_size(IconSize::Custom(rems(editor_font_size / window.rem_size())))
|
||||
.width(width.into())
|
||||
.width(width)
|
||||
.on_click(move |_, window, cx| {
|
||||
editor.update(cx, |editor, cx| {
|
||||
editor.expand_excerpt(excerpt_id, direction, window, cx);
|
||||
@@ -3627,7 +3627,7 @@ impl EditorElement {
|
||||
ButtonLike::new("toggle-buffer-fold")
|
||||
.style(ui::ButtonStyle::Transparent)
|
||||
.height(px(28.).into())
|
||||
.width(px(28.).into())
|
||||
.width(px(28.))
|
||||
.children(toggle_chevron_icon)
|
||||
.tooltip({
|
||||
let focus_handle = focus_handle.clone();
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::{
|
||||
Anchor, Autoscroll, Editor, EditorEvent, EditorSettings, ExcerptId, ExcerptRange, FormatTarget,
|
||||
MultiBuffer, MultiBufferSnapshot, NavigationData, SearchWithinRange, SelectionEffects,
|
||||
ToPoint as _,
|
||||
MultiBuffer, MultiBufferSnapshot, NavigationData, ReportEditorEvent, SearchWithinRange,
|
||||
SelectionEffects, ToPoint as _,
|
||||
display_map::HighlightKey,
|
||||
editor_settings::SeedQuerySetting,
|
||||
persistence::{DB, SerializedEditor},
|
||||
@@ -776,6 +776,10 @@ impl Item for Editor {
|
||||
}
|
||||
}
|
||||
|
||||
fn on_removed(&self, cx: &App) {
|
||||
self.report_editor_event(ReportEditorEvent::Closed, None, cx);
|
||||
}
|
||||
|
||||
fn deactivated(&mut self, _: &mut Window, cx: &mut Context<Self>) {
|
||||
let selection = self.selections.newest_anchor();
|
||||
self.push_to_nav_history(selection.head(), None, true, false, cx);
|
||||
@@ -815,9 +819,9 @@ impl Item for Editor {
|
||||
) -> Task<Result<()>> {
|
||||
// Add meta data tracking # of auto saves
|
||||
if options.autosave {
|
||||
self.report_editor_event("Editor Autosaved", None, cx);
|
||||
self.report_editor_event(ReportEditorEvent::Saved { auto_saved: true }, None, cx);
|
||||
} else {
|
||||
self.report_editor_event("Editor Saved", None, cx);
|
||||
self.report_editor_event(ReportEditorEvent::Saved { auto_saved: false }, None, cx);
|
||||
}
|
||||
|
||||
let buffers = self.buffer().clone().read(cx).all_buffers();
|
||||
@@ -896,7 +900,11 @@ impl Item for Editor {
|
||||
.path
|
||||
.extension()
|
||||
.map(|a| a.to_string_lossy().to_string());
|
||||
self.report_editor_event("Editor Saved", file_extension, cx);
|
||||
self.report_editor_event(
|
||||
ReportEditorEvent::Saved { auto_saved: false },
|
||||
file_extension,
|
||||
cx,
|
||||
);
|
||||
|
||||
project.update(cx, |project, cx| project.save_buffer_as(buffer, path, cx))
|
||||
}
|
||||
@@ -997,12 +1005,16 @@ impl Item for Editor {
|
||||
) {
|
||||
self.workspace = Some((workspace.weak_handle(), workspace.database_id()));
|
||||
if let Some(workspace) = &workspace.weak_handle().upgrade() {
|
||||
cx.subscribe(&workspace, |editor, _, event: &workspace::Event, _cx| {
|
||||
if matches!(event, workspace::Event::ModalOpened) {
|
||||
editor.mouse_context_menu.take();
|
||||
editor.inline_blame_popover.take();
|
||||
}
|
||||
})
|
||||
cx.subscribe(
|
||||
&workspace,
|
||||
|editor, _, event: &workspace::Event, _cx| match event {
|
||||
workspace::Event::ModalOpened => {
|
||||
editor.mouse_context_menu.take();
|
||||
editor.inline_blame_popover.take();
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
)
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ pub fn marked_display_snapshot(
|
||||
let (unmarked_text, markers) = marked_text_offsets(text);
|
||||
|
||||
let font = Font {
|
||||
family: "Zed Plex Mono".into(),
|
||||
family: ".ZedMono".into(),
|
||||
features: FontFeatures::default(),
|
||||
fallbacks: None,
|
||||
weight: FontWeight::default(),
|
||||
|
||||
@@ -1118,15 +1118,17 @@ impl ExtensionStore {
|
||||
extensions_to_unload.len() - reload_count
|
||||
);
|
||||
|
||||
for extension_id in &extensions_to_load {
|
||||
if let Some(extension) = new_index.extensions.get(extension_id) {
|
||||
telemetry::event!(
|
||||
"Extension Loaded",
|
||||
extension_id,
|
||||
version = extension.manifest.version
|
||||
);
|
||||
}
|
||||
}
|
||||
let extension_ids = extensions_to_load
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
Some((
|
||||
id.clone(),
|
||||
new_index.extensions.get(id)?.manifest.version.clone(),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
telemetry::event!("Extensions Loaded", id_and_versions = extension_ids);
|
||||
|
||||
let themes_to_remove = old_index
|
||||
.themes
|
||||
|
||||
@@ -33,13 +33,23 @@ impl FileIcons {
|
||||
// TODO: Associate a type with the languages and have the file's language
|
||||
// override these associations
|
||||
|
||||
// check if file name is in suffixes
|
||||
// e.g. catch file named `eslint.config.js` instead of `.eslint.config.js`
|
||||
if let Some(typ) = path.file_name().and_then(|typ| typ.to_str()) {
|
||||
if let Some(mut typ) = path.file_name().and_then(|typ| typ.to_str()) {
|
||||
// check if file name is in suffixes
|
||||
// e.g. catch file named `eslint.config.js` instead of `.eslint.config.js`
|
||||
let maybe_path = get_icon_from_suffix(typ);
|
||||
if maybe_path.is_some() {
|
||||
return maybe_path;
|
||||
}
|
||||
|
||||
// check if suffix based on first dot is in suffixes
|
||||
// e.g. consider `module.js` as suffix to angular's module file named `auth.module.js`
|
||||
while let Some((_, suffix)) = typ.split_once('.') {
|
||||
let maybe_path = get_icon_from_suffix(suffix);
|
||||
if maybe_path.is_some() {
|
||||
return maybe_path;
|
||||
}
|
||||
typ = suffix;
|
||||
}
|
||||
}
|
||||
|
||||
// primary case: check if the files extension or the hidden file name
|
||||
|
||||
@@ -51,6 +51,7 @@ ashpd.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
git = { workspace = true, features = ["test-support"] }
|
||||
|
||||
[features]
|
||||
test-support = ["gpui/test-support", "git/test-support"]
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use crate::{FakeFs, Fs};
|
||||
use crate::{FakeFs, FakeFsEntry, Fs};
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::future::{self, BoxFuture, join_all};
|
||||
use git::{
|
||||
Oid,
|
||||
blame::Blame,
|
||||
repository::{
|
||||
AskPassDelegate, Branch, CommitDetails, CommitOptions, FetchOptions, GitRepository,
|
||||
@@ -10,8 +11,9 @@ use git::{
|
||||
},
|
||||
status::{FileStatus, GitStatus, StatusCode, TrackedStatus, UnmergedStatus},
|
||||
};
|
||||
use gpui::{AsyncApp, BackgroundExecutor, SharedString};
|
||||
use gpui::{AsyncApp, BackgroundExecutor, SharedString, Task};
|
||||
use ignore::gitignore::GitignoreBuilder;
|
||||
use parking_lot::Mutex;
|
||||
use rope::Rope;
|
||||
use smol::future::FutureExt as _;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
@@ -19,6 +21,7 @@ use std::{path::PathBuf, sync::Arc};
|
||||
#[derive(Clone)]
|
||||
pub struct FakeGitRepository {
|
||||
pub(crate) fs: Arc<FakeFs>,
|
||||
pub(crate) checkpoints: Arc<Mutex<HashMap<Oid, FakeFsEntry>>>,
|
||||
pub(crate) executor: BackgroundExecutor,
|
||||
pub(crate) dot_git_path: PathBuf,
|
||||
pub(crate) repository_dir_path: PathBuf,
|
||||
@@ -183,7 +186,7 @@ impl GitRepository for FakeGitRepository {
|
||||
async move { None }.boxed()
|
||||
}
|
||||
|
||||
fn status(&self, path_prefixes: &[RepoPath]) -> BoxFuture<'_, Result<GitStatus>> {
|
||||
fn status(&self, path_prefixes: &[RepoPath]) -> Task<Result<GitStatus>> {
|
||||
let workdir_path = self.dot_git_path.parent().unwrap();
|
||||
|
||||
// Load gitignores
|
||||
@@ -311,7 +314,10 @@ impl GitRepository for FakeGitRepository {
|
||||
entries: entries.into(),
|
||||
})
|
||||
});
|
||||
async move { result? }.boxed()
|
||||
Task::ready(match result {
|
||||
Ok(result) => result,
|
||||
Err(e) => Err(e),
|
||||
})
|
||||
}
|
||||
|
||||
fn branches(&self) -> BoxFuture<'_, Result<Vec<Branch>>> {
|
||||
@@ -466,22 +472,57 @@ impl GitRepository for FakeGitRepository {
|
||||
}
|
||||
|
||||
fn checkpoint(&self) -> BoxFuture<'static, Result<GitRepositoryCheckpoint>> {
|
||||
unimplemented!()
|
||||
let executor = self.executor.clone();
|
||||
let fs = self.fs.clone();
|
||||
let checkpoints = self.checkpoints.clone();
|
||||
let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf();
|
||||
async move {
|
||||
executor.simulate_random_delay().await;
|
||||
let oid = Oid::random(&mut executor.rng());
|
||||
let entry = fs.entry(&repository_dir_path)?;
|
||||
checkpoints.lock().insert(oid, entry);
|
||||
Ok(GitRepositoryCheckpoint { commit_sha: oid })
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn restore_checkpoint(
|
||||
&self,
|
||||
_checkpoint: GitRepositoryCheckpoint,
|
||||
) -> BoxFuture<'_, Result<()>> {
|
||||
unimplemented!()
|
||||
fn restore_checkpoint(&self, checkpoint: GitRepositoryCheckpoint) -> BoxFuture<'_, Result<()>> {
|
||||
let executor = self.executor.clone();
|
||||
let fs = self.fs.clone();
|
||||
let checkpoints = self.checkpoints.clone();
|
||||
let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf();
|
||||
async move {
|
||||
executor.simulate_random_delay().await;
|
||||
let checkpoints = checkpoints.lock();
|
||||
let entry = checkpoints
|
||||
.get(&checkpoint.commit_sha)
|
||||
.context(format!("invalid checkpoint: {}", checkpoint.commit_sha))?;
|
||||
fs.insert_entry(&repository_dir_path, entry.clone())?;
|
||||
Ok(())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn compare_checkpoints(
|
||||
&self,
|
||||
_left: GitRepositoryCheckpoint,
|
||||
_right: GitRepositoryCheckpoint,
|
||||
left: GitRepositoryCheckpoint,
|
||||
right: GitRepositoryCheckpoint,
|
||||
) -> BoxFuture<'_, Result<bool>> {
|
||||
unimplemented!()
|
||||
let executor = self.executor.clone();
|
||||
let checkpoints = self.checkpoints.clone();
|
||||
async move {
|
||||
executor.simulate_random_delay().await;
|
||||
let checkpoints = checkpoints.lock();
|
||||
let left = checkpoints
|
||||
.get(&left.commit_sha)
|
||||
.context(format!("invalid left checkpoint: {}", left.commit_sha))?;
|
||||
let right = checkpoints
|
||||
.get(&right.commit_sha)
|
||||
.context(format!("invalid right checkpoint: {}", right.commit_sha))?;
|
||||
|
||||
Ok(left == right)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn diff_checkpoints(
|
||||
@@ -496,3 +537,63 @@ impl GitRepository for FakeGitRepository {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{FakeFs, Fs};
|
||||
use gpui::BackgroundExecutor;
|
||||
use serde_json::json;
|
||||
use std::path::Path;
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_checkpoints(executor: BackgroundExecutor) {
|
||||
let fs = FakeFs::new(executor);
|
||||
fs.insert_tree(
|
||||
path!("/"),
|
||||
json!({
|
||||
"bar": {
|
||||
"baz": "qux"
|
||||
},
|
||||
"foo": {
|
||||
".git": {},
|
||||
"a": "lorem",
|
||||
"b": "ipsum",
|
||||
},
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
fs.with_git_state(Path::new("/foo/.git"), true, |_git| {})
|
||||
.unwrap();
|
||||
let repository = fs.open_repo(Path::new("/foo/.git")).unwrap();
|
||||
|
||||
let checkpoint_1 = repository.checkpoint().await.unwrap();
|
||||
fs.write(Path::new("/foo/b"), b"IPSUM").await.unwrap();
|
||||
fs.write(Path::new("/foo/c"), b"dolor").await.unwrap();
|
||||
let checkpoint_2 = repository.checkpoint().await.unwrap();
|
||||
let checkpoint_3 = repository.checkpoint().await.unwrap();
|
||||
|
||||
assert!(
|
||||
repository
|
||||
.compare_checkpoints(checkpoint_2.clone(), checkpoint_3.clone())
|
||||
.await
|
||||
.unwrap()
|
||||
);
|
||||
assert!(
|
||||
!repository
|
||||
.compare_checkpoints(checkpoint_1.clone(), checkpoint_2.clone())
|
||||
.await
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
repository.restore_checkpoint(checkpoint_1).await.unwrap();
|
||||
assert_eq!(
|
||||
fs.files_with_contents(Path::new("")),
|
||||
[
|
||||
(Path::new("/bar/baz").into(), b"qux".into()),
|
||||
(Path::new("/foo/a").into(), b"lorem".into()),
|
||||
(Path::new("/foo/b").into(), b"ipsum".into())
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -924,7 +924,7 @@ pub struct FakeFs {
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
struct FakeFsState {
|
||||
root: Arc<Mutex<FakeFsEntry>>,
|
||||
root: FakeFsEntry,
|
||||
next_inode: u64,
|
||||
next_mtime: SystemTime,
|
||||
git_event_tx: smol::channel::Sender<PathBuf>,
|
||||
@@ -939,7 +939,7 @@ struct FakeFsState {
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
enum FakeFsEntry {
|
||||
File {
|
||||
inode: u64,
|
||||
@@ -953,7 +953,7 @@ enum FakeFsEntry {
|
||||
inode: u64,
|
||||
mtime: MTime,
|
||||
len: u64,
|
||||
entries: BTreeMap<String, Arc<Mutex<FakeFsEntry>>>,
|
||||
entries: BTreeMap<String, FakeFsEntry>,
|
||||
git_repo_state: Option<Arc<Mutex<FakeGitRepositoryState>>>,
|
||||
},
|
||||
Symlink {
|
||||
@@ -961,6 +961,67 @@ enum FakeFsEntry {
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
impl PartialEq for FakeFsEntry {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(
|
||||
Self::File {
|
||||
inode: l_inode,
|
||||
mtime: l_mtime,
|
||||
len: l_len,
|
||||
content: l_content,
|
||||
git_dir_path: l_git_dir_path,
|
||||
},
|
||||
Self::File {
|
||||
inode: r_inode,
|
||||
mtime: r_mtime,
|
||||
len: r_len,
|
||||
content: r_content,
|
||||
git_dir_path: r_git_dir_path,
|
||||
},
|
||||
) => {
|
||||
l_inode == r_inode
|
||||
&& l_mtime == r_mtime
|
||||
&& l_len == r_len
|
||||
&& l_content == r_content
|
||||
&& l_git_dir_path == r_git_dir_path
|
||||
}
|
||||
(
|
||||
Self::Dir {
|
||||
inode: l_inode,
|
||||
mtime: l_mtime,
|
||||
len: l_len,
|
||||
entries: l_entries,
|
||||
git_repo_state: l_git_repo_state,
|
||||
},
|
||||
Self::Dir {
|
||||
inode: r_inode,
|
||||
mtime: r_mtime,
|
||||
len: r_len,
|
||||
entries: r_entries,
|
||||
git_repo_state: r_git_repo_state,
|
||||
},
|
||||
) => {
|
||||
let same_repo_state = match (l_git_repo_state.as_ref(), r_git_repo_state.as_ref()) {
|
||||
(Some(l), Some(r)) => Arc::ptr_eq(l, r),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
};
|
||||
l_inode == r_inode
|
||||
&& l_mtime == r_mtime
|
||||
&& l_len == r_len
|
||||
&& l_entries == r_entries
|
||||
&& same_repo_state
|
||||
}
|
||||
(Self::Symlink { target: l_target }, Self::Symlink { target: r_target }) => {
|
||||
l_target == r_target
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
impl FakeFsState {
|
||||
fn get_and_increment_mtime(&mut self) -> MTime {
|
||||
@@ -975,25 +1036,9 @@ impl FakeFsState {
|
||||
inode
|
||||
}
|
||||
|
||||
fn read_path(&self, target: &Path) -> Result<Arc<Mutex<FakeFsEntry>>> {
|
||||
Ok(self
|
||||
.try_read_path(target, true)
|
||||
.ok_or_else(|| {
|
||||
anyhow!(io::Error::new(
|
||||
io::ErrorKind::NotFound,
|
||||
format!("not found: {target:?}")
|
||||
))
|
||||
})?
|
||||
.0)
|
||||
}
|
||||
|
||||
fn try_read_path(
|
||||
&self,
|
||||
target: &Path,
|
||||
follow_symlink: bool,
|
||||
) -> Option<(Arc<Mutex<FakeFsEntry>>, PathBuf)> {
|
||||
let mut path = target.to_path_buf();
|
||||
fn canonicalize(&self, target: &Path, follow_symlink: bool) -> Option<PathBuf> {
|
||||
let mut canonical_path = PathBuf::new();
|
||||
let mut path = target.to_path_buf();
|
||||
let mut entry_stack = Vec::new();
|
||||
'outer: loop {
|
||||
let mut path_components = path.components().peekable();
|
||||
@@ -1003,7 +1048,7 @@ impl FakeFsState {
|
||||
Component::Prefix(prefix_component) => prefix = Some(prefix_component),
|
||||
Component::RootDir => {
|
||||
entry_stack.clear();
|
||||
entry_stack.push(self.root.clone());
|
||||
entry_stack.push(&self.root);
|
||||
canonical_path.clear();
|
||||
match prefix {
|
||||
Some(prefix_component) => {
|
||||
@@ -1020,20 +1065,18 @@ impl FakeFsState {
|
||||
canonical_path.pop();
|
||||
}
|
||||
Component::Normal(name) => {
|
||||
let current_entry = entry_stack.last().cloned()?;
|
||||
let current_entry = current_entry.lock();
|
||||
if let FakeFsEntry::Dir { entries, .. } = &*current_entry {
|
||||
let entry = entries.get(name.to_str().unwrap()).cloned()?;
|
||||
let current_entry = *entry_stack.last()?;
|
||||
if let FakeFsEntry::Dir { entries, .. } = current_entry {
|
||||
let entry = entries.get(name.to_str().unwrap())?;
|
||||
if path_components.peek().is_some() || follow_symlink {
|
||||
let entry = entry.lock();
|
||||
if let FakeFsEntry::Symlink { target, .. } = &*entry {
|
||||
if let FakeFsEntry::Symlink { target, .. } = entry {
|
||||
let mut target = target.clone();
|
||||
target.extend(path_components);
|
||||
path = target;
|
||||
continue 'outer;
|
||||
}
|
||||
}
|
||||
entry_stack.push(entry.clone());
|
||||
entry_stack.push(entry);
|
||||
canonical_path = canonical_path.join(name);
|
||||
} else {
|
||||
return None;
|
||||
@@ -1043,19 +1086,72 @@ impl FakeFsState {
|
||||
}
|
||||
break;
|
||||
}
|
||||
Some((entry_stack.pop()?, canonical_path))
|
||||
|
||||
if entry_stack.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(canonical_path)
|
||||
}
|
||||
}
|
||||
|
||||
fn write_path<Fn, T>(&self, path: &Path, callback: Fn) -> Result<T>
|
||||
fn try_entry(
|
||||
&mut self,
|
||||
target: &Path,
|
||||
follow_symlink: bool,
|
||||
) -> Option<(&mut FakeFsEntry, PathBuf)> {
|
||||
let canonical_path = self.canonicalize(target, follow_symlink)?;
|
||||
|
||||
let mut components = canonical_path.components();
|
||||
let Some(Component::RootDir) = components.next() else {
|
||||
panic!(
|
||||
"the path {:?} was not canonicalized properly {:?}",
|
||||
target, canonical_path
|
||||
)
|
||||
};
|
||||
|
||||
let mut entry = &mut self.root;
|
||||
for component in components {
|
||||
match component {
|
||||
Component::Normal(name) => {
|
||||
if let FakeFsEntry::Dir { entries, .. } = entry {
|
||||
entry = entries.get_mut(name.to_str().unwrap())?;
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
panic!(
|
||||
"the path {:?} was not canonicalized properly {:?}",
|
||||
target, canonical_path
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some((entry, canonical_path))
|
||||
}
|
||||
|
||||
fn entry(&mut self, target: &Path) -> Result<&mut FakeFsEntry> {
|
||||
Ok(self
|
||||
.try_entry(target, true)
|
||||
.ok_or_else(|| {
|
||||
anyhow!(io::Error::new(
|
||||
io::ErrorKind::NotFound,
|
||||
format!("not found: {target:?}")
|
||||
))
|
||||
})?
|
||||
.0)
|
||||
}
|
||||
|
||||
fn write_path<Fn, T>(&mut self, path: &Path, callback: Fn) -> Result<T>
|
||||
where
|
||||
Fn: FnOnce(btree_map::Entry<String, Arc<Mutex<FakeFsEntry>>>) -> Result<T>,
|
||||
Fn: FnOnce(btree_map::Entry<String, FakeFsEntry>) -> Result<T>,
|
||||
{
|
||||
let path = normalize_path(path);
|
||||
let filename = path.file_name().context("cannot overwrite the root")?;
|
||||
let parent_path = path.parent().unwrap();
|
||||
|
||||
let parent = self.read_path(parent_path)?;
|
||||
let mut parent = parent.lock();
|
||||
let parent = self.entry(parent_path)?;
|
||||
let new_entry = parent
|
||||
.dir_entries(parent_path)?
|
||||
.entry(filename.to_str().unwrap().into());
|
||||
@@ -1105,13 +1201,13 @@ impl FakeFs {
|
||||
this: this.clone(),
|
||||
executor: executor.clone(),
|
||||
state: Arc::new(Mutex::new(FakeFsState {
|
||||
root: Arc::new(Mutex::new(FakeFsEntry::Dir {
|
||||
root: FakeFsEntry::Dir {
|
||||
inode: 0,
|
||||
mtime: MTime(UNIX_EPOCH),
|
||||
len: 0,
|
||||
entries: Default::default(),
|
||||
git_repo_state: None,
|
||||
})),
|
||||
},
|
||||
git_event_tx: tx,
|
||||
next_mtime: UNIX_EPOCH + Self::SYSTEMTIME_INTERVAL,
|
||||
next_inode: 1,
|
||||
@@ -1161,15 +1257,15 @@ impl FakeFs {
|
||||
.write_path(path, move |entry| {
|
||||
match entry {
|
||||
btree_map::Entry::Vacant(e) => {
|
||||
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
|
||||
e.insert(FakeFsEntry::File {
|
||||
inode: new_inode,
|
||||
mtime: new_mtime,
|
||||
content: Vec::new(),
|
||||
len: 0,
|
||||
git_dir_path: None,
|
||||
})));
|
||||
});
|
||||
}
|
||||
btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut().lock() {
|
||||
btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut() {
|
||||
FakeFsEntry::File { mtime, .. } => *mtime = new_mtime,
|
||||
FakeFsEntry::Dir { mtime, .. } => *mtime = new_mtime,
|
||||
FakeFsEntry::Symlink { .. } => {}
|
||||
@@ -1188,7 +1284,7 @@ impl FakeFs {
|
||||
pub async fn insert_symlink(&self, path: impl AsRef<Path>, target: PathBuf) {
|
||||
let mut state = self.state.lock();
|
||||
let path = path.as_ref();
|
||||
let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target }));
|
||||
let file = FakeFsEntry::Symlink { target };
|
||||
state
|
||||
.write_path(path.as_ref(), move |e| match e {
|
||||
btree_map::Entry::Vacant(e) => {
|
||||
@@ -1221,13 +1317,13 @@ impl FakeFs {
|
||||
match entry {
|
||||
btree_map::Entry::Vacant(e) => {
|
||||
kind = Some(PathEventKind::Created);
|
||||
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
|
||||
e.insert(FakeFsEntry::File {
|
||||
inode: new_inode,
|
||||
mtime: new_mtime,
|
||||
len: new_len,
|
||||
content: new_content,
|
||||
git_dir_path: None,
|
||||
})));
|
||||
});
|
||||
}
|
||||
btree_map::Entry::Occupied(mut e) => {
|
||||
kind = Some(PathEventKind::Changed);
|
||||
@@ -1237,7 +1333,7 @@ impl FakeFs {
|
||||
len,
|
||||
content,
|
||||
..
|
||||
} = &mut *e.get_mut().lock()
|
||||
} = e.get_mut()
|
||||
{
|
||||
*mtime = new_mtime;
|
||||
*content = new_content;
|
||||
@@ -1259,9 +1355,8 @@ impl FakeFs {
|
||||
pub fn read_file_sync(&self, path: impl AsRef<Path>) -> Result<Vec<u8>> {
|
||||
let path = path.as_ref();
|
||||
let path = normalize_path(path);
|
||||
let state = self.state.lock();
|
||||
let entry = state.read_path(&path)?;
|
||||
let entry = entry.lock();
|
||||
let mut state = self.state.lock();
|
||||
let entry = state.entry(&path)?;
|
||||
entry.file_content(&path).cloned()
|
||||
}
|
||||
|
||||
@@ -1269,9 +1364,8 @@ impl FakeFs {
|
||||
let path = path.as_ref();
|
||||
let path = normalize_path(path);
|
||||
self.simulate_random_delay().await;
|
||||
let state = self.state.lock();
|
||||
let entry = state.read_path(&path)?;
|
||||
let entry = entry.lock();
|
||||
let mut state = self.state.lock();
|
||||
let entry = state.entry(&path)?;
|
||||
entry.file_content(&path).cloned()
|
||||
}
|
||||
|
||||
@@ -1292,6 +1386,25 @@ impl FakeFs {
|
||||
self.state.lock().flush_events(count);
|
||||
}
|
||||
|
||||
pub(crate) fn entry(&self, target: &Path) -> Result<FakeFsEntry> {
|
||||
self.state.lock().entry(target).cloned()
|
||||
}
|
||||
|
||||
pub(crate) fn insert_entry(&self, target: &Path, new_entry: FakeFsEntry) -> Result<()> {
|
||||
let mut state = self.state.lock();
|
||||
state.write_path(target, |entry| {
|
||||
match entry {
|
||||
btree_map::Entry::Vacant(vacant_entry) => {
|
||||
vacant_entry.insert(new_entry);
|
||||
}
|
||||
btree_map::Entry::Occupied(mut occupied_entry) => {
|
||||
occupied_entry.insert(new_entry);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn insert_tree<'a>(
|
||||
&'a self,
|
||||
@@ -1361,20 +1474,19 @@ impl FakeFs {
|
||||
F: FnOnce(&mut FakeGitRepositoryState, &Path, &Path) -> T,
|
||||
{
|
||||
let mut state = self.state.lock();
|
||||
let entry = state.read_path(dot_git).context("open .git")?;
|
||||
let mut entry = entry.lock();
|
||||
let git_event_tx = state.git_event_tx.clone();
|
||||
let entry = state.entry(dot_git).context("open .git")?;
|
||||
|
||||
if let FakeFsEntry::Dir { git_repo_state, .. } = &mut *entry {
|
||||
if let FakeFsEntry::Dir { git_repo_state, .. } = entry {
|
||||
let repo_state = git_repo_state.get_or_insert_with(|| {
|
||||
log::debug!("insert git state for {dot_git:?}");
|
||||
Arc::new(Mutex::new(FakeGitRepositoryState::new(
|
||||
state.git_event_tx.clone(),
|
||||
)))
|
||||
Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx)))
|
||||
});
|
||||
let mut repo_state = repo_state.lock();
|
||||
|
||||
let result = f(&mut repo_state, dot_git, dot_git);
|
||||
|
||||
drop(repo_state);
|
||||
if emit_git_event {
|
||||
state.emit_event([(dot_git, None)]);
|
||||
}
|
||||
@@ -1398,21 +1510,20 @@ impl FakeFs {
|
||||
}
|
||||
}
|
||||
.clone();
|
||||
drop(entry);
|
||||
let Some((git_dir_entry, canonical_path)) = state.try_read_path(&path, true) else {
|
||||
let Some((git_dir_entry, canonical_path)) = state.try_entry(&path, true) else {
|
||||
anyhow::bail!("pointed-to git dir {path:?} not found")
|
||||
};
|
||||
let FakeFsEntry::Dir {
|
||||
git_repo_state,
|
||||
entries,
|
||||
..
|
||||
} = &mut *git_dir_entry.lock()
|
||||
} = git_dir_entry
|
||||
else {
|
||||
anyhow::bail!("gitfile points to a non-directory")
|
||||
};
|
||||
let common_dir = if let Some(child) = entries.get("commondir") {
|
||||
Path::new(
|
||||
std::str::from_utf8(child.lock().file_content("commondir".as_ref())?)
|
||||
std::str::from_utf8(child.file_content("commondir".as_ref())?)
|
||||
.context("commondir content")?,
|
||||
)
|
||||
.to_owned()
|
||||
@@ -1420,15 +1531,14 @@ impl FakeFs {
|
||||
canonical_path.clone()
|
||||
};
|
||||
let repo_state = git_repo_state.get_or_insert_with(|| {
|
||||
Arc::new(Mutex::new(FakeGitRepositoryState::new(
|
||||
state.git_event_tx.clone(),
|
||||
)))
|
||||
Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx)))
|
||||
});
|
||||
let mut repo_state = repo_state.lock();
|
||||
|
||||
let result = f(&mut repo_state, &canonical_path, &common_dir);
|
||||
|
||||
if emit_git_event {
|
||||
drop(repo_state);
|
||||
state.emit_event([(canonical_path, None)]);
|
||||
}
|
||||
|
||||
@@ -1655,14 +1765,12 @@ impl FakeFs {
|
||||
pub fn paths(&self, include_dot_git: bool) -> Vec<PathBuf> {
|
||||
let mut result = Vec::new();
|
||||
let mut queue = collections::VecDeque::new();
|
||||
queue.push_back((
|
||||
PathBuf::from(util::path!("/")),
|
||||
self.state.lock().root.clone(),
|
||||
));
|
||||
let state = &*self.state.lock();
|
||||
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
|
||||
while let Some((path, entry)) = queue.pop_front() {
|
||||
if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() {
|
||||
if let FakeFsEntry::Dir { entries, .. } = entry {
|
||||
for (name, entry) in entries {
|
||||
queue.push_back((path.join(name), entry.clone()));
|
||||
queue.push_back((path.join(name), entry));
|
||||
}
|
||||
}
|
||||
if include_dot_git
|
||||
@@ -1679,14 +1787,12 @@ impl FakeFs {
|
||||
pub fn directories(&self, include_dot_git: bool) -> Vec<PathBuf> {
|
||||
let mut result = Vec::new();
|
||||
let mut queue = collections::VecDeque::new();
|
||||
queue.push_back((
|
||||
PathBuf::from(util::path!("/")),
|
||||
self.state.lock().root.clone(),
|
||||
));
|
||||
let state = &*self.state.lock();
|
||||
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
|
||||
while let Some((path, entry)) = queue.pop_front() {
|
||||
if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() {
|
||||
if let FakeFsEntry::Dir { entries, .. } = entry {
|
||||
for (name, entry) in entries {
|
||||
queue.push_back((path.join(name), entry.clone()));
|
||||
queue.push_back((path.join(name), entry));
|
||||
}
|
||||
if include_dot_git
|
||||
|| !path
|
||||
@@ -1703,17 +1809,14 @@ impl FakeFs {
|
||||
pub fn files(&self) -> Vec<PathBuf> {
|
||||
let mut result = Vec::new();
|
||||
let mut queue = collections::VecDeque::new();
|
||||
queue.push_back((
|
||||
PathBuf::from(util::path!("/")),
|
||||
self.state.lock().root.clone(),
|
||||
));
|
||||
let state = &*self.state.lock();
|
||||
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
|
||||
while let Some((path, entry)) = queue.pop_front() {
|
||||
let e = entry.lock();
|
||||
match &*e {
|
||||
match entry {
|
||||
FakeFsEntry::File { .. } => result.push(path),
|
||||
FakeFsEntry::Dir { entries, .. } => {
|
||||
for (name, entry) in entries {
|
||||
queue.push_back((path.join(name), entry.clone()));
|
||||
queue.push_back((path.join(name), entry));
|
||||
}
|
||||
}
|
||||
FakeFsEntry::Symlink { .. } => {}
|
||||
@@ -1725,13 +1828,10 @@ impl FakeFs {
|
||||
pub fn files_with_contents(&self, prefix: &Path) -> Vec<(PathBuf, Vec<u8>)> {
|
||||
let mut result = Vec::new();
|
||||
let mut queue = collections::VecDeque::new();
|
||||
queue.push_back((
|
||||
PathBuf::from(util::path!("/")),
|
||||
self.state.lock().root.clone(),
|
||||
));
|
||||
let state = &*self.state.lock();
|
||||
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
|
||||
while let Some((path, entry)) = queue.pop_front() {
|
||||
let e = entry.lock();
|
||||
match &*e {
|
||||
match entry {
|
||||
FakeFsEntry::File { content, .. } => {
|
||||
if path.starts_with(prefix) {
|
||||
result.push((path, content.clone()));
|
||||
@@ -1739,7 +1839,7 @@ impl FakeFs {
|
||||
}
|
||||
FakeFsEntry::Dir { entries, .. } => {
|
||||
for (name, entry) in entries {
|
||||
queue.push_back((path.join(name), entry.clone()));
|
||||
queue.push_back((path.join(name), entry));
|
||||
}
|
||||
}
|
||||
FakeFsEntry::Symlink { .. } => {}
|
||||
@@ -1805,10 +1905,7 @@ impl FakeFsEntry {
|
||||
}
|
||||
}
|
||||
|
||||
fn dir_entries(
|
||||
&mut self,
|
||||
path: &Path,
|
||||
) -> Result<&mut BTreeMap<String, Arc<Mutex<FakeFsEntry>>>> {
|
||||
fn dir_entries(&mut self, path: &Path) -> Result<&mut BTreeMap<String, FakeFsEntry>> {
|
||||
if let Self::Dir { entries, .. } = self {
|
||||
Ok(entries)
|
||||
} else {
|
||||
@@ -1855,12 +1952,12 @@ struct FakeHandle {
|
||||
impl FileHandle for FakeHandle {
|
||||
fn current_path(&self, fs: &Arc<dyn Fs>) -> Result<PathBuf> {
|
||||
let fs = fs.as_fake();
|
||||
let state = fs.state.lock();
|
||||
let Some(target) = state.moves.get(&self.inode) else {
|
||||
let mut state = fs.state.lock();
|
||||
let Some(target) = state.moves.get(&self.inode).cloned() else {
|
||||
anyhow::bail!("fake fd not moved")
|
||||
};
|
||||
|
||||
if state.try_read_path(&target, false).is_some() {
|
||||
if state.try_entry(&target, false).is_some() {
|
||||
return Ok(target.clone());
|
||||
}
|
||||
anyhow::bail!("fake fd target not found")
|
||||
@@ -1888,13 +1985,13 @@ impl Fs for FakeFs {
|
||||
state.write_path(&cur_path, |entry| {
|
||||
entry.or_insert_with(|| {
|
||||
created_dirs.push((cur_path.clone(), Some(PathEventKind::Created)));
|
||||
Arc::new(Mutex::new(FakeFsEntry::Dir {
|
||||
FakeFsEntry::Dir {
|
||||
inode,
|
||||
mtime,
|
||||
len: 0,
|
||||
entries: Default::default(),
|
||||
git_repo_state: None,
|
||||
}))
|
||||
}
|
||||
});
|
||||
Ok(())
|
||||
})?
|
||||
@@ -1909,13 +2006,13 @@ impl Fs for FakeFs {
|
||||
let mut state = self.state.lock();
|
||||
let inode = state.get_and_increment_inode();
|
||||
let mtime = state.get_and_increment_mtime();
|
||||
let file = Arc::new(Mutex::new(FakeFsEntry::File {
|
||||
let file = FakeFsEntry::File {
|
||||
inode,
|
||||
mtime,
|
||||
len: 0,
|
||||
content: Vec::new(),
|
||||
git_dir_path: None,
|
||||
}));
|
||||
};
|
||||
let mut kind = Some(PathEventKind::Created);
|
||||
state.write_path(path, |entry| {
|
||||
match entry {
|
||||
@@ -1939,7 +2036,7 @@ impl Fs for FakeFs {
|
||||
|
||||
async fn create_symlink(&self, path: &Path, target: PathBuf) -> Result<()> {
|
||||
let mut state = self.state.lock();
|
||||
let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target }));
|
||||
let file = FakeFsEntry::Symlink { target };
|
||||
state
|
||||
.write_path(path.as_ref(), move |e| match e {
|
||||
btree_map::Entry::Vacant(e) => {
|
||||
@@ -2002,7 +2099,7 @@ impl Fs for FakeFs {
|
||||
}
|
||||
})?;
|
||||
|
||||
let inode = match *moved_entry.lock() {
|
||||
let inode = match moved_entry {
|
||||
FakeFsEntry::File { inode, .. } => inode,
|
||||
FakeFsEntry::Dir { inode, .. } => inode,
|
||||
_ => 0,
|
||||
@@ -2051,8 +2148,8 @@ impl Fs for FakeFs {
|
||||
let mut state = self.state.lock();
|
||||
let mtime = state.get_and_increment_mtime();
|
||||
let inode = state.get_and_increment_inode();
|
||||
let source_entry = state.read_path(&source)?;
|
||||
let content = source_entry.lock().file_content(&source)?.clone();
|
||||
let source_entry = state.entry(&source)?;
|
||||
let content = source_entry.file_content(&source)?.clone();
|
||||
let mut kind = Some(PathEventKind::Created);
|
||||
state.write_path(&target, |e| match e {
|
||||
btree_map::Entry::Occupied(e) => {
|
||||
@@ -2066,13 +2163,13 @@ impl Fs for FakeFs {
|
||||
}
|
||||
}
|
||||
btree_map::Entry::Vacant(e) => Ok(Some(
|
||||
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
|
||||
e.insert(FakeFsEntry::File {
|
||||
inode,
|
||||
mtime,
|
||||
len: content.len() as u64,
|
||||
content,
|
||||
git_dir_path: None,
|
||||
})))
|
||||
})
|
||||
.clone(),
|
||||
)),
|
||||
})?;
|
||||
@@ -2088,8 +2185,7 @@ impl Fs for FakeFs {
|
||||
let base_name = path.file_name().context("cannot remove the root")?;
|
||||
|
||||
let mut state = self.state.lock();
|
||||
let parent_entry = state.read_path(parent_path)?;
|
||||
let mut parent_entry = parent_entry.lock();
|
||||
let parent_entry = state.entry(parent_path)?;
|
||||
let entry = parent_entry
|
||||
.dir_entries(parent_path)?
|
||||
.entry(base_name.to_str().unwrap().into());
|
||||
@@ -2100,15 +2196,14 @@ impl Fs for FakeFs {
|
||||
anyhow::bail!("{path:?} does not exist");
|
||||
}
|
||||
}
|
||||
btree_map::Entry::Occupied(e) => {
|
||||
btree_map::Entry::Occupied(mut entry) => {
|
||||
{
|
||||
let mut entry = e.get().lock();
|
||||
let children = entry.dir_entries(&path)?;
|
||||
let children = entry.get_mut().dir_entries(&path)?;
|
||||
if !options.recursive && !children.is_empty() {
|
||||
anyhow::bail!("{path:?} is not empty");
|
||||
}
|
||||
}
|
||||
e.remove();
|
||||
entry.remove();
|
||||
}
|
||||
}
|
||||
state.emit_event([(path, Some(PathEventKind::Removed))]);
|
||||
@@ -2122,8 +2217,7 @@ impl Fs for FakeFs {
|
||||
let parent_path = path.parent().context("cannot remove the root")?;
|
||||
let base_name = path.file_name().unwrap();
|
||||
let mut state = self.state.lock();
|
||||
let parent_entry = state.read_path(parent_path)?;
|
||||
let mut parent_entry = parent_entry.lock();
|
||||
let parent_entry = state.entry(parent_path)?;
|
||||
let entry = parent_entry
|
||||
.dir_entries(parent_path)?
|
||||
.entry(base_name.to_str().unwrap().into());
|
||||
@@ -2133,9 +2227,9 @@ impl Fs for FakeFs {
|
||||
anyhow::bail!("{path:?} does not exist");
|
||||
}
|
||||
}
|
||||
btree_map::Entry::Occupied(e) => {
|
||||
e.get().lock().file_content(&path)?;
|
||||
e.remove();
|
||||
btree_map::Entry::Occupied(mut entry) => {
|
||||
entry.get_mut().file_content(&path)?;
|
||||
entry.remove();
|
||||
}
|
||||
}
|
||||
state.emit_event([(path, Some(PathEventKind::Removed))]);
|
||||
@@ -2149,12 +2243,10 @@ impl Fs for FakeFs {
|
||||
|
||||
async fn open_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>> {
|
||||
self.simulate_random_delay().await;
|
||||
let state = self.state.lock();
|
||||
let entry = state.read_path(&path)?;
|
||||
let entry = entry.lock();
|
||||
let inode = match *entry {
|
||||
FakeFsEntry::File { inode, .. } => inode,
|
||||
FakeFsEntry::Dir { inode, .. } => inode,
|
||||
let mut state = self.state.lock();
|
||||
let inode = match state.entry(&path)? {
|
||||
FakeFsEntry::File { inode, .. } => *inode,
|
||||
FakeFsEntry::Dir { inode, .. } => *inode,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(Arc::new(FakeHandle { inode }))
|
||||
@@ -2204,8 +2296,8 @@ impl Fs for FakeFs {
|
||||
let path = normalize_path(path);
|
||||
self.simulate_random_delay().await;
|
||||
let state = self.state.lock();
|
||||
let (_, canonical_path) = state
|
||||
.try_read_path(&path, true)
|
||||
let canonical_path = state
|
||||
.canonicalize(&path, true)
|
||||
.with_context(|| format!("path does not exist: {path:?}"))?;
|
||||
Ok(canonical_path)
|
||||
}
|
||||
@@ -2213,9 +2305,9 @@ impl Fs for FakeFs {
|
||||
async fn is_file(&self, path: &Path) -> bool {
|
||||
let path = normalize_path(path);
|
||||
self.simulate_random_delay().await;
|
||||
let state = self.state.lock();
|
||||
if let Some((entry, _)) = state.try_read_path(&path, true) {
|
||||
entry.lock().is_file()
|
||||
let mut state = self.state.lock();
|
||||
if let Some((entry, _)) = state.try_entry(&path, true) {
|
||||
entry.is_file()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
@@ -2232,17 +2324,16 @@ impl Fs for FakeFs {
|
||||
let path = normalize_path(path);
|
||||
let mut state = self.state.lock();
|
||||
state.metadata_call_count += 1;
|
||||
if let Some((mut entry, _)) = state.try_read_path(&path, false) {
|
||||
let is_symlink = entry.lock().is_symlink();
|
||||
if let Some((mut entry, _)) = state.try_entry(&path, false) {
|
||||
let is_symlink = entry.is_symlink();
|
||||
if is_symlink {
|
||||
if let Some(e) = state.try_read_path(&path, true).map(|e| e.0) {
|
||||
if let Some(e) = state.try_entry(&path, true).map(|e| e.0) {
|
||||
entry = e;
|
||||
} else {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
let entry = entry.lock();
|
||||
Ok(Some(match &*entry {
|
||||
FakeFsEntry::File {
|
||||
inode, mtime, len, ..
|
||||
@@ -2274,12 +2365,11 @@ impl Fs for FakeFs {
|
||||
async fn read_link(&self, path: &Path) -> Result<PathBuf> {
|
||||
self.simulate_random_delay().await;
|
||||
let path = normalize_path(path);
|
||||
let state = self.state.lock();
|
||||
let mut state = self.state.lock();
|
||||
let (entry, _) = state
|
||||
.try_read_path(&path, false)
|
||||
.try_entry(&path, false)
|
||||
.with_context(|| format!("path does not exist: {path:?}"))?;
|
||||
let entry = entry.lock();
|
||||
if let FakeFsEntry::Symlink { target } = &*entry {
|
||||
if let FakeFsEntry::Symlink { target } = entry {
|
||||
Ok(target.clone())
|
||||
} else {
|
||||
anyhow::bail!("not a symlink: {path:?}")
|
||||
@@ -2294,8 +2384,7 @@ impl Fs for FakeFs {
|
||||
let path = normalize_path(path);
|
||||
let mut state = self.state.lock();
|
||||
state.read_dir_call_count += 1;
|
||||
let entry = state.read_path(&path)?;
|
||||
let mut entry = entry.lock();
|
||||
let entry = state.entry(&path)?;
|
||||
let children = entry.dir_entries(&path)?;
|
||||
let paths = children
|
||||
.keys()
|
||||
@@ -2359,6 +2448,7 @@ impl Fs for FakeFs {
|
||||
dot_git_path: abs_dot_git.to_path_buf(),
|
||||
repository_dir_path: repository_dir_path.to_owned(),
|
||||
common_dir_path: common_dir_path.to_owned(),
|
||||
checkpoints: Arc::default(),
|
||||
}) as _
|
||||
},
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ workspace = true
|
||||
path = "src/git.rs"
|
||||
|
||||
[features]
|
||||
test-support = []
|
||||
test-support = ["rand"]
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
@@ -26,6 +26,7 @@ http_client.workspace = true
|
||||
log.workspace = true
|
||||
parking_lot.workspace = true
|
||||
regex.workspace = true
|
||||
rand = { workspace = true, optional = true }
|
||||
rope.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
@@ -47,3 +48,4 @@ text = { workspace = true, features = ["test-support"] }
|
||||
unindent.workspace = true
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
tempfile.workspace = true
|
||||
rand.workspace = true
|
||||
|
||||
@@ -73,6 +73,7 @@ async fn run_git_blame(
|
||||
.current_dir(working_directory)
|
||||
.arg("blame")
|
||||
.arg("--incremental")
|
||||
.arg("-w")
|
||||
.arg("--contents")
|
||||
.arg("-")
|
||||
.arg(path.as_os_str())
|
||||
|
||||
@@ -119,6 +119,13 @@ impl Oid {
|
||||
Ok(Self(oid))
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn random(rng: &mut impl rand::Rng) -> Self {
|
||||
let mut bytes = [0; 20];
|
||||
rng.fill(&mut bytes);
|
||||
Self::from_bytes(&bytes).unwrap()
|
||||
}
|
||||
|
||||
pub fn as_bytes(&self) -> &[u8] {
|
||||
self.0.as_bytes()
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use collections::HashMap;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::{AsyncWriteExt, FutureExt as _, select_biased};
|
||||
use git2::BranchType;
|
||||
use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, SharedString};
|
||||
use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, SharedString, Task};
|
||||
use parking_lot::Mutex;
|
||||
use rope::Rope;
|
||||
use schemars::JsonSchema;
|
||||
@@ -338,7 +338,7 @@ pub trait GitRepository: Send + Sync {
|
||||
|
||||
fn merge_message(&self) -> BoxFuture<'_, Option<String>>;
|
||||
|
||||
fn status(&self, path_prefixes: &[RepoPath]) -> BoxFuture<'_, Result<GitStatus>>;
|
||||
fn status(&self, path_prefixes: &[RepoPath]) -> Task<Result<GitStatus>>;
|
||||
|
||||
fn branches(&self) -> BoxFuture<'_, Result<Vec<Branch>>>;
|
||||
|
||||
@@ -953,25 +953,27 @@ impl GitRepository for RealGitRepository {
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn status(&self, path_prefixes: &[RepoPath]) -> BoxFuture<'_, Result<GitStatus>> {
|
||||
fn status(&self, path_prefixes: &[RepoPath]) -> Task<Result<GitStatus>> {
|
||||
let git_binary_path = self.git_binary_path.clone();
|
||||
let working_directory = self.working_directory();
|
||||
let path_prefixes = path_prefixes.to_owned();
|
||||
self.executor
|
||||
.spawn(async move {
|
||||
let output = new_std_command(&git_binary_path)
|
||||
.current_dir(working_directory?)
|
||||
.args(git_status_args(&path_prefixes))
|
||||
.output()?;
|
||||
if output.status.success() {
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
stdout.parse()
|
||||
} else {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("git status failed: {stderr}");
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
let working_directory = match self.working_directory() {
|
||||
Ok(working_directory) => working_directory,
|
||||
Err(e) => return Task::ready(Err(e)),
|
||||
};
|
||||
let args = git_status_args(&path_prefixes);
|
||||
log::debug!("Checking for git status in {path_prefixes:?}");
|
||||
self.executor.spawn(async move {
|
||||
let output = new_std_command(&git_binary_path)
|
||||
.current_dir(working_directory)
|
||||
.args(args)
|
||||
.output()?;
|
||||
if output.status.success() {
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
stdout.parse()
|
||||
} else {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("git status failed: {stderr}");
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn branches(&self) -> BoxFuture<'_, Result<Vec<Branch>>> {
|
||||
|
||||
@@ -2105,7 +2105,7 @@ impl GitPanel {
|
||||
Ok(_) => cx.update(|window, cx| {
|
||||
window.prompt(
|
||||
PromptLevel::Info,
|
||||
"Git Clone",
|
||||
&format!("Git Clone: {}", repo_name),
|
||||
None,
|
||||
&["Add repo to project", "Open repo in new project"],
|
||||
cx,
|
||||
|
||||
@@ -181,10 +181,6 @@ pub fn init(cx: &mut App) {
|
||||
workspace.toggle_modal(window, cx, |window, cx| {
|
||||
GitCloneModal::show(panel, window, cx)
|
||||
});
|
||||
|
||||
// panel.update(cx, |panel, cx| {
|
||||
// panel.git_clone(window, cx);
|
||||
// });
|
||||
});
|
||||
workspace.register_action(|workspace, _: &git::OpenModifiedFiles, window, cx| {
|
||||
open_modified_files(workspace, window, cx);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use gpui::{
|
||||
App, Application, Context, Menu, MenuItem, Window, WindowOptions, actions, div, prelude::*, rgb,
|
||||
App, Application, Context, Menu, MenuItem, SystemMenuType, Window, WindowOptions, actions, div,
|
||||
prelude::*, rgb,
|
||||
};
|
||||
|
||||
struct SetMenus;
|
||||
@@ -27,7 +28,11 @@ fn main() {
|
||||
// Add menu items
|
||||
cx.set_menus(vec![Menu {
|
||||
name: "set_menus".into(),
|
||||
items: vec![MenuItem::action("Quit", Quit)],
|
||||
items: vec![
|
||||
MenuItem::os_submenu("Services", SystemMenuType::Services),
|
||||
MenuItem::separator(),
|
||||
MenuItem::action("Quit", Quit),
|
||||
],
|
||||
}]);
|
||||
cx.open_window(WindowOptions::default(), |_, cx| cx.new(|_| SetMenus {}))
|
||||
.unwrap();
|
||||
|
||||
@@ -277,6 +277,8 @@ pub struct App {
|
||||
pub(crate) release_listeners: SubscriberSet<EntityId, ReleaseListener>,
|
||||
pub(crate) global_observers: SubscriberSet<TypeId, Handler>,
|
||||
pub(crate) quit_observers: SubscriberSet<(), QuitHandler>,
|
||||
pub(crate) restart_observers: SubscriberSet<(), Handler>,
|
||||
pub(crate) restart_path: Option<PathBuf>,
|
||||
pub(crate) window_closed_observers: SubscriberSet<(), WindowClosedHandler>,
|
||||
pub(crate) layout_id_buffer: Vec<LayoutId>, // We recycle this memory across layout requests.
|
||||
pub(crate) propagate_event: bool,
|
||||
@@ -349,6 +351,8 @@ impl App {
|
||||
keyboard_layout_observers: SubscriberSet::new(),
|
||||
global_observers: SubscriberSet::new(),
|
||||
quit_observers: SubscriberSet::new(),
|
||||
restart_observers: SubscriberSet::new(),
|
||||
restart_path: None,
|
||||
window_closed_observers: SubscriberSet::new(),
|
||||
layout_id_buffer: Default::default(),
|
||||
propagate_event: true,
|
||||
@@ -832,8 +836,16 @@ impl App {
|
||||
}
|
||||
|
||||
/// Restarts the application.
|
||||
pub fn restart(&self, binary_path: Option<PathBuf>) {
|
||||
self.platform.restart(binary_path)
|
||||
pub fn restart(&mut self) {
|
||||
self.restart_observers
|
||||
.clone()
|
||||
.retain(&(), |observer| observer(self));
|
||||
self.platform.restart(self.restart_path.take())
|
||||
}
|
||||
|
||||
/// Sets the path to use when restarting the application.
|
||||
pub fn set_restart_path(&mut self, path: PathBuf) {
|
||||
self.restart_path = Some(path);
|
||||
}
|
||||
|
||||
/// Returns the HTTP client for the application.
|
||||
@@ -1466,6 +1478,21 @@ impl App {
|
||||
subscription
|
||||
}
|
||||
|
||||
/// Register a callback to be invoked when the application is about to restart.
|
||||
///
|
||||
/// These callbacks are called before any `on_app_quit` callbacks.
|
||||
pub fn on_app_restart(&self, mut on_restart: impl 'static + FnMut(&mut App)) -> Subscription {
|
||||
let (subscription, activate) = self.restart_observers.insert(
|
||||
(),
|
||||
Box::new(move |cx| {
|
||||
on_restart(cx);
|
||||
true
|
||||
}),
|
||||
);
|
||||
activate();
|
||||
subscription
|
||||
}
|
||||
|
||||
/// Register a callback to be invoked when a window is closed
|
||||
/// The window is no longer accessible at the point this callback is invoked.
|
||||
pub fn on_window_closed(&self, mut on_closed: impl FnMut(&mut App) + 'static) -> Subscription {
|
||||
|
||||
@@ -164,6 +164,20 @@ impl<'a, T: 'static> Context<'a, T> {
|
||||
subscription
|
||||
}
|
||||
|
||||
/// Register a callback to be invoked when the application is about to restart.
|
||||
pub fn on_app_restart(
|
||||
&self,
|
||||
mut on_restart: impl FnMut(&mut T, &mut App) + 'static,
|
||||
) -> Subscription
|
||||
where
|
||||
T: 'static,
|
||||
{
|
||||
let handle = self.weak_entity();
|
||||
self.app.on_app_restart(move |cx| {
|
||||
handle.update(cx, |entity, cx| on_restart(entity, cx)).ok();
|
||||
})
|
||||
}
|
||||
|
||||
/// Arrange for the given function to be invoked whenever the application is quit.
|
||||
/// The future returned from this callback will be polled for up to [crate::SHUTDOWN_TIMEOUT] until the app fully quits.
|
||||
pub fn on_app_quit<Fut>(
|
||||
@@ -175,20 +189,15 @@ impl<'a, T: 'static> Context<'a, T> {
|
||||
T: 'static,
|
||||
{
|
||||
let handle = self.weak_entity();
|
||||
let (subscription, activate) = self.app.quit_observers.insert(
|
||||
(),
|
||||
Box::new(move |cx| {
|
||||
let future = handle.update(cx, |entity, cx| on_quit(entity, cx)).ok();
|
||||
async move {
|
||||
if let Some(future) = future {
|
||||
future.await;
|
||||
}
|
||||
self.app.on_app_quit(move |cx| {
|
||||
let future = handle.update(cx, |entity, cx| on_quit(entity, cx)).ok();
|
||||
async move {
|
||||
if let Some(future) = future {
|
||||
future.await;
|
||||
}
|
||||
.boxed_local()
|
||||
}),
|
||||
);
|
||||
activate();
|
||||
subscription
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
}
|
||||
|
||||
/// Tell GPUI that this entity has changed and observers of it should be notified.
|
||||
|
||||
@@ -20,6 +20,34 @@ impl Menu {
|
||||
}
|
||||
}
|
||||
|
||||
/// OS menus are menus that are recognized by the operating system
|
||||
/// This allows the operating system to provide specialized items for
|
||||
/// these menus
|
||||
pub struct OsMenu {
|
||||
/// The name of the menu
|
||||
pub name: SharedString,
|
||||
|
||||
/// The type of menu
|
||||
pub menu_type: SystemMenuType,
|
||||
}
|
||||
|
||||
impl OsMenu {
|
||||
/// Create an OwnedOsMenu from this OsMenu
|
||||
pub fn owned(self) -> OwnedOsMenu {
|
||||
OwnedOsMenu {
|
||||
name: self.name.to_string().into(),
|
||||
menu_type: self.menu_type,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The type of system menu
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
pub enum SystemMenuType {
|
||||
/// The 'Services' menu in the Application menu on macOS
|
||||
Services,
|
||||
}
|
||||
|
||||
/// The different kinds of items that can be in a menu
|
||||
pub enum MenuItem {
|
||||
/// A separator between items
|
||||
@@ -28,6 +56,9 @@ pub enum MenuItem {
|
||||
/// A submenu
|
||||
Submenu(Menu),
|
||||
|
||||
/// A menu, managed by the system (for example, the Services menu on macOS)
|
||||
SystemMenu(OsMenu),
|
||||
|
||||
/// An action that can be performed
|
||||
Action {
|
||||
/// The name of this menu item
|
||||
@@ -53,6 +84,14 @@ impl MenuItem {
|
||||
Self::Submenu(menu)
|
||||
}
|
||||
|
||||
/// Creates a new submenu that is populated by the OS
|
||||
pub fn os_submenu(name: impl Into<SharedString>, menu_type: SystemMenuType) -> Self {
|
||||
Self::SystemMenu(OsMenu {
|
||||
name: name.into(),
|
||||
menu_type,
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates a new menu item that invokes an action
|
||||
pub fn action(name: impl Into<SharedString>, action: impl Action) -> Self {
|
||||
Self::Action {
|
||||
@@ -89,10 +128,23 @@ impl MenuItem {
|
||||
action,
|
||||
os_action,
|
||||
},
|
||||
MenuItem::SystemMenu(os_menu) => OwnedMenuItem::SystemMenu(os_menu.owned()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// OS menus are menus that are recognized by the operating system
|
||||
/// This allows the operating system to provide specialized items for
|
||||
/// these menus
|
||||
#[derive(Clone)]
|
||||
pub struct OwnedOsMenu {
|
||||
/// The name of the menu
|
||||
pub name: SharedString,
|
||||
|
||||
/// The type of menu
|
||||
pub menu_type: SystemMenuType,
|
||||
}
|
||||
|
||||
/// A menu of the application, either a main menu or a submenu
|
||||
#[derive(Clone)]
|
||||
pub struct OwnedMenu {
|
||||
@@ -111,6 +163,9 @@ pub enum OwnedMenuItem {
|
||||
/// A submenu
|
||||
Submenu(OwnedMenu),
|
||||
|
||||
/// A menu, managed by the system (for example, the Services menu on macOS)
|
||||
SystemMenu(OwnedOsMenu),
|
||||
|
||||
/// An action that can be performed
|
||||
Action {
|
||||
/// The name of this menu item
|
||||
@@ -139,6 +194,7 @@ impl Clone for OwnedMenuItem {
|
||||
action: action.boxed_clone(),
|
||||
os_action: *os_action,
|
||||
},
|
||||
OwnedMenuItem::SystemMenu(os_menu) => OwnedMenuItem::SystemMenu(os_menu.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,11 +213,7 @@ impl CosmicTextSystemState {
|
||||
features: &FontFeatures,
|
||||
) -> Result<SmallVec<[FontId; 4]>> {
|
||||
// TODO: Determine the proper system UI font.
|
||||
let name = if name == ".SystemUIFont" {
|
||||
"Zed Plex Sans"
|
||||
} else {
|
||||
name
|
||||
};
|
||||
let name = crate::text_system::font_name_with_fallbacks(name, "IBM Plex Sans");
|
||||
|
||||
let families = self
|
||||
.font_system
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user