Compare commits
59 Commits
tool-input
...
acp-markdo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
34c890e23e | ||
|
|
4755d6fa9d | ||
|
|
135143d51b | ||
|
|
450604b4a1 | ||
|
|
348bc52a3f | ||
|
|
d16c595d57 | ||
|
|
975a7e6f7f | ||
|
|
7d2f7cb70e | ||
|
|
5f9afdf7ba | ||
|
|
7a3105b0c6 | ||
|
|
ab0b16939d | ||
|
|
28d992487d | ||
|
|
fde15a5a68 | ||
|
|
780db30e0b | ||
|
|
7c992adfe1 | ||
|
|
825aecfd28 | ||
|
|
f2f32fb3bd | ||
|
|
d9fd8d5eee | ||
|
|
8137b3318f | ||
|
|
3ceeefe460 | ||
|
|
6f768aefa2 | ||
|
|
28ac84ed01 | ||
|
|
4d803fa628 | ||
|
|
17b2dd9a93 | ||
|
|
7abf635e20 | ||
|
|
92adcb6e63 | ||
|
|
5ed001e0df | ||
|
|
f12fffd1ba | ||
|
|
991ba08711 | ||
|
|
c728731099 | ||
|
|
ddab1cbd71 | ||
|
|
f383a7626f | ||
|
|
ee1df65569 | ||
|
|
3be45822be | ||
|
|
3b6f30a6fd | ||
|
|
779a68f868 | ||
|
|
79c37284e0 | ||
|
|
0a053cf55d | ||
|
|
fc59d9cbf3 | ||
|
|
678a42e920 | ||
|
|
75bcaf743c | ||
|
|
47c875f6b5 | ||
|
|
81b4d7e35a | ||
|
|
33ee0c3093 | ||
|
|
d68f86052f | ||
|
|
a74ffd9ee4 | ||
|
|
8b9ad1cfae | ||
|
|
adbccb1ad0 | ||
|
|
f4e2d38c29 | ||
|
|
5f10be7791 | ||
|
|
d47a920c05 | ||
|
|
24b72be154 | ||
|
|
de779a45ce | ||
|
|
b094a636cf | ||
|
|
318709b60d | ||
|
|
f1bd531a32 | ||
|
|
549eb4d826 | ||
|
|
c1e53b7fa5 | ||
|
|
ec376e0b61 |
62
Cargo.lock
generated
62
Cargo.lock
generated
@@ -2,6 +2,36 @@
|
||||
# It is not intended for manual editing.
|
||||
version = 4
|
||||
|
||||
[[package]]
|
||||
name = "acp"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"agentic-coding-protocol",
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"collections",
|
||||
"editor",
|
||||
"env_logger 0.11.8",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"language",
|
||||
"log",
|
||||
"markdown",
|
||||
"parking_lot",
|
||||
"project",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"theme",
|
||||
"ui",
|
||||
"util",
|
||||
"uuid",
|
||||
"workspace-hack",
|
||||
"zed_actions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "activity_indicator"
|
||||
version = "0.1.0"
|
||||
@@ -130,6 +160,7 @@ dependencies = [
|
||||
name = "agent_ui"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"acp",
|
||||
"agent",
|
||||
"agent_settings",
|
||||
"anyhow",
|
||||
@@ -212,6 +243,21 @@ dependencies = [
|
||||
"zed_llm_client",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agentic-coding-protocol"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"chrono",
|
||||
"futures 0.3.31",
|
||||
"log",
|
||||
"parking_lot",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.7.8"
|
||||
@@ -2076,7 +2122,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "blade-graphics"
|
||||
version = "0.6.0"
|
||||
source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad"
|
||||
source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5"
|
||||
dependencies = [
|
||||
"ash",
|
||||
"ash-window",
|
||||
@@ -2109,7 +2155,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "blade-macros"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad"
|
||||
source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -2119,7 +2165,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "blade-util"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad"
|
||||
source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5"
|
||||
dependencies = [
|
||||
"blade-graphics",
|
||||
"bytemuck",
|
||||
@@ -4830,7 +4876,6 @@ dependencies = [
|
||||
"tree-sitter-python",
|
||||
"tree-sitter-rust",
|
||||
"tree-sitter-typescript",
|
||||
"tree-sitter-yaml",
|
||||
"ui",
|
||||
"unicode-script",
|
||||
"unicode-segmentation",
|
||||
@@ -12259,7 +12304,6 @@ dependencies = [
|
||||
"language",
|
||||
"log",
|
||||
"lsp",
|
||||
"markdown",
|
||||
"node_runtime",
|
||||
"parking_lot",
|
||||
"pathdiff",
|
||||
@@ -14058,6 +14102,7 @@ version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fe8c9d1c68d67dd9f97ecbc6f932b60eb289c5dbddd8aa1405484a8fd2fcd984"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"dyn-clone",
|
||||
"indexmap",
|
||||
"ref-cast",
|
||||
@@ -14569,7 +14614,6 @@ dependencies = [
|
||||
name = "settings_ui"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"collections",
|
||||
"command_palette",
|
||||
"command_palette_hooks",
|
||||
@@ -14580,7 +14624,6 @@ dependencies = [
|
||||
"fs",
|
||||
"fuzzy",
|
||||
"gpui",
|
||||
"language",
|
||||
"log",
|
||||
"menu",
|
||||
"paths",
|
||||
@@ -14590,8 +14633,6 @@ dependencies = [
|
||||
"serde",
|
||||
"settings",
|
||||
"theme",
|
||||
"tree-sitter-json",
|
||||
"tree-sitter-rust",
|
||||
"ui",
|
||||
"util",
|
||||
"workspace",
|
||||
@@ -17348,7 +17389,6 @@ dependencies = [
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"rust-embed",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_json_lenient",
|
||||
@@ -19946,7 +19986,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed"
|
||||
version = "0.195.0"
|
||||
version = "0.194.0"
|
||||
dependencies = [
|
||||
"activity_indicator",
|
||||
"agent",
|
||||
|
||||
13
Cargo.toml
13
Cargo.toml
@@ -2,6 +2,7 @@
|
||||
resolver = "2"
|
||||
members = [
|
||||
"crates/activity_indicator",
|
||||
"crates/acp",
|
||||
"crates/agent_ui",
|
||||
"crates/agent",
|
||||
"crates/agent_settings",
|
||||
@@ -215,8 +216,9 @@ edition = "2024"
|
||||
# Workspace member crates
|
||||
#
|
||||
|
||||
activity_indicator = { path = "crates/activity_indicator" }
|
||||
acp = { path = "crates/acp" }
|
||||
agent = { path = "crates/agent" }
|
||||
activity_indicator = { path = "crates/activity_indicator" }
|
||||
agent_ui = { path = "crates/agent_ui" }
|
||||
agent_settings = { path = "crates/agent_settings" }
|
||||
ai = { path = "crates/ai" }
|
||||
@@ -398,6 +400,7 @@ zlog_settings = { path = "crates/zlog_settings" }
|
||||
# External crates
|
||||
#
|
||||
|
||||
agentic-coding-protocol = { path = "../agentic-coding-protocol" }
|
||||
aho-corasick = "1.1"
|
||||
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
||||
any_vec = "0.14"
|
||||
@@ -425,9 +428,9 @@ aws-smithy-runtime-api = { version = "1.7.4", features = ["http-1x", "client"] }
|
||||
aws-smithy-types = { version = "1.3.0", features = ["http-body-1-x"] }
|
||||
base64 = "0.22"
|
||||
bitflags = "2.6.0"
|
||||
blade-graphics = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" }
|
||||
blade-macros = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" }
|
||||
blade-util = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" }
|
||||
blade-graphics = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" }
|
||||
blade-macros = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" }
|
||||
blade-util = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" }
|
||||
blake3 = "1.5.3"
|
||||
bytes = "1.0"
|
||||
cargo_metadata = "0.19"
|
||||
@@ -625,7 +628,7 @@ wasmtime = { version = "29", default-features = false, features = [
|
||||
wasmtime-wasi = "29"
|
||||
which = "6.0.0"
|
||||
workspace-hack = "0.1.0"
|
||||
zed_llm_client = "= 0.8.5"
|
||||
zed_llm_client = "0.8.5"
|
||||
zstd = "0.11"
|
||||
|
||||
[workspace.dependencies.async-stripe]
|
||||
|
||||
@@ -34,7 +34,7 @@
|
||||
"ctrl-q": "zed::Quit",
|
||||
"f4": "debugger::Start",
|
||||
"shift-f5": "debugger::Stop",
|
||||
"ctrl-shift-f5": "debugger::RerunSession",
|
||||
"ctrl-shift-f5": "debugger::Restart",
|
||||
"f6": "debugger::Pause",
|
||||
"f7": "debugger::StepOver",
|
||||
"ctrl-f11": "debugger::StepInto",
|
||||
@@ -557,13 +557,6 @@
|
||||
"ctrl-b": "workspace::ToggleLeftDock",
|
||||
"ctrl-j": "workspace::ToggleBottomDock",
|
||||
"ctrl-alt-y": "workspace::CloseAllDocks",
|
||||
"ctrl-alt-0": "workspace::ResetActiveDockSize",
|
||||
// For 0px parameter, uses UI font size value.
|
||||
"ctrl-alt--": ["workspace::DecreaseActiveDockSize", { "px": 0 }],
|
||||
"ctrl-alt-=": ["workspace::IncreaseActiveDockSize", { "px": 0 }],
|
||||
"ctrl-alt-)": "workspace::ResetOpenDocksSize",
|
||||
"ctrl-alt-_": ["workspace::DecreaseOpenDocksSize", { "px": 0 }],
|
||||
"ctrl-alt-+": ["workspace::IncreaseOpenDocksSize", { "px": 0 }],
|
||||
"shift-find": "pane::DeploySearch",
|
||||
"ctrl-shift-f": "pane::DeploySearch",
|
||||
"ctrl-shift-h": ["pane::DeploySearch", { "replace_enabled": true }],
|
||||
@@ -605,9 +598,7 @@
|
||||
// "foo-bar": ["task::Spawn", { "task_name": "MyTask", "reveal_target": "dock" }]
|
||||
// or by tag:
|
||||
// "foo-bar": ["task::Spawn", { "task_tag": "MyTag" }],
|
||||
"f5": "debugger::Rerun",
|
||||
"ctrl-f4": "workspace::CloseActiveDock",
|
||||
"ctrl-w": "workspace::CloseActiveDock"
|
||||
"f5": "debugger::RerunLastSession"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -710,13 +701,6 @@
|
||||
"pagedown": "editor::ContextMenuLast"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && showing_signature_help && !showing_completions",
|
||||
"bindings": {
|
||||
"up": "editor::SignatureHelpPrevious",
|
||||
"down": "editor::SignatureHelpNext"
|
||||
}
|
||||
},
|
||||
// Custom bindings
|
||||
{
|
||||
"bindings": {
|
||||
@@ -1084,13 +1068,6 @@
|
||||
"ctrl-shift-tab": "pane::ActivatePreviousItem"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "MarkdownPreview",
|
||||
"bindings": {
|
||||
"pageup": "markdown::MovePageUp",
|
||||
"pagedown": "markdown::MovePageDown"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "KeymapEditor",
|
||||
"use_key_equivalents": true,
|
||||
|
||||
@@ -5,10 +5,10 @@
|
||||
"bindings": {
|
||||
"f4": "debugger::Start",
|
||||
"shift-f5": "debugger::Stop",
|
||||
"shift-cmd-f5": "debugger::RerunSession",
|
||||
"shift-cmd-f5": "debugger::Restart",
|
||||
"f6": "debugger::Pause",
|
||||
"f7": "debugger::StepOver",
|
||||
"ctrl-f11": "debugger::StepInto",
|
||||
"f11": "debugger::StepInto",
|
||||
"shift-f11": "debugger::StepOut",
|
||||
"home": "menu::SelectFirst",
|
||||
"shift-pageup": "menu::SelectFirst",
|
||||
@@ -624,13 +624,6 @@
|
||||
"cmd-r": "workspace::ToggleRightDock",
|
||||
"cmd-j": "workspace::ToggleBottomDock",
|
||||
"alt-cmd-y": "workspace::CloseAllDocks",
|
||||
// For 0px parameter, uses UI font size value.
|
||||
"ctrl-alt-0": "workspace::ResetActiveDockSize",
|
||||
"ctrl-alt--": ["workspace::DecreaseActiveDockSize", { "px": 0 }],
|
||||
"ctrl-alt-=": ["workspace::IncreaseActiveDockSize", { "px": 0 }],
|
||||
"ctrl-alt-)": "workspace::ResetOpenDocksSize",
|
||||
"ctrl-alt-_": ["workspace::DecreaseOpenDocksSize", { "px": 0 }],
|
||||
"ctrl-alt-+": ["workspace::IncreaseOpenDocksSize", { "px": 0 }],
|
||||
"cmd-shift-f": "pane::DeploySearch",
|
||||
"cmd-shift-h": ["pane::DeploySearch", { "replace_enabled": true }],
|
||||
"cmd-shift-t": "pane::ReopenClosedItem",
|
||||
@@ -659,8 +652,7 @@
|
||||
"cmd-k shift-up": "workspace::SwapPaneUp",
|
||||
"cmd-k shift-down": "workspace::SwapPaneDown",
|
||||
"cmd-shift-x": "zed::Extensions",
|
||||
"f5": "debugger::Rerun",
|
||||
"cmd-w": "workspace::CloseActiveDock"
|
||||
"f5": "debugger::RerunLastSession"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -774,13 +766,6 @@
|
||||
"pagedown": "editor::ContextMenuLast"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && showing_signature_help && !showing_completions",
|
||||
"bindings": {
|
||||
"up": "editor::SignatureHelpPrevious",
|
||||
"down": "editor::SignatureHelpNext"
|
||||
}
|
||||
},
|
||||
// Custom bindings
|
||||
{
|
||||
"use_key_equivalents": true,
|
||||
@@ -1183,13 +1168,6 @@
|
||||
"ctrl-shift-tab": "pane::ActivatePreviousItem"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "MarkdownPreview",
|
||||
"bindings": {
|
||||
"pageup": "markdown::MovePageUp",
|
||||
"pagedown": "markdown::MovePageDown"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "KeymapEditor",
|
||||
"use_key_equivalents": true,
|
||||
|
||||
@@ -98,13 +98,6 @@
|
||||
"ctrl-n": "editor::ContextMenuNext"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && showing_signature_help && !showing_completions",
|
||||
"bindings": {
|
||||
"ctrl-p": "editor::SignatureHelpPrevious",
|
||||
"ctrl-n": "editor::SignatureHelpNext"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Workspace",
|
||||
"bindings": {
|
||||
|
||||
@@ -98,13 +98,6 @@
|
||||
"ctrl-n": "editor::ContextMenuNext"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && showing_signature_help && !showing_completions",
|
||||
"bindings": {
|
||||
"ctrl-p": "editor::SignatureHelpPrevious",
|
||||
"ctrl-n": "editor::SignatureHelpNext"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Workspace",
|
||||
"bindings": {
|
||||
|
||||
@@ -477,13 +477,6 @@
|
||||
"ctrl-n": "editor::ShowWordCompletions"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "vim_mode == insert && showing_signature_help && !showing_completions",
|
||||
"bindings": {
|
||||
"ctrl-p": "editor::SignatureHelpPrevious",
|
||||
"ctrl-n": "editor::SignatureHelpNext"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "vim_mode == replace",
|
||||
"bindings": {
|
||||
|
||||
@@ -746,6 +746,8 @@
|
||||
"default_width": 380
|
||||
},
|
||||
"agent": {
|
||||
// Version of this setting.
|
||||
"version": "2",
|
||||
// Whether the agent is enabled.
|
||||
"enabled": true,
|
||||
/// What completion mode to start new threads in, if available. Can be 'normal' or 'burn'.
|
||||
@@ -1290,8 +1292,6 @@
|
||||
// Whether or not selecting text in the terminal will automatically
|
||||
// copy to the system clipboard.
|
||||
"copy_on_select": false,
|
||||
// Whether to keep the text selection after copying it to the clipboard
|
||||
"keep_selection_on_copy": false,
|
||||
// Whether to show the terminal button in the status bar
|
||||
"button": true,
|
||||
// Any key-value pairs added to this list will be added to the terminal's
|
||||
@@ -1656,6 +1656,7 @@
|
||||
// Different settings for specific language models.
|
||||
"language_models": {
|
||||
"anthropic": {
|
||||
"version": "1",
|
||||
"api_url": "https://api.anthropic.com"
|
||||
},
|
||||
"google": {
|
||||
@@ -1665,6 +1666,7 @@
|
||||
"api_url": "http://localhost:11434"
|
||||
},
|
||||
"openai": {
|
||||
"version": "1",
|
||||
"api_url": "https://api.openai.com/v1"
|
||||
},
|
||||
"open_router": {
|
||||
@@ -1782,8 +1784,7 @@
|
||||
// `socks5h`. `http` will be used when no scheme is specified.
|
||||
//
|
||||
// By default no proxy will be used, or Zed will try get proxy settings from
|
||||
// environment variables. If certain hosts should not be proxied,
|
||||
// set the `no_proxy` environment variable and provide a comma-separated list.
|
||||
// environment variables.
|
||||
//
|
||||
// Examples:
|
||||
// - "proxy": "socks5h://localhost:10808"
|
||||
|
||||
48
crates/acp/Cargo.toml
Normal file
48
crates/acp/Cargo.toml
Normal file
@@ -0,0 +1,48 @@
|
||||
[package]
|
||||
name = "acp"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/acp.rs"
|
||||
doctest = false
|
||||
|
||||
[features]
|
||||
test-support = ["gpui/test-support", "project/test-support"]
|
||||
|
||||
[dependencies]
|
||||
agentic-coding-protocol = { path = "../../../agentic-coding-protocol" }
|
||||
anyhow.workspace = true
|
||||
async-trait.workspace = true
|
||||
base64.workspace = true
|
||||
chrono.workspace = true
|
||||
collections.workspace = true
|
||||
editor.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
language.workspace = true
|
||||
log.workspace = true
|
||||
markdown.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
theme.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_actions.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
env_logger.workspace = true
|
||||
gpui = { workspace = true, "features" = ["test-support"] }
|
||||
project = { workspace = true, "features" = ["test-support"] }
|
||||
serde_json.workspace = true
|
||||
util.workspace = true
|
||||
settings.workspace = true
|
||||
1
crates/acp/LICENSE-GPL
Symbolic link
1
crates/acp/LICENSE-GPL
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-GPL
|
||||
677
crates/acp/src/acp.rs
Normal file
677
crates/acp/src/acp.rs
Normal file
@@ -0,0 +1,677 @@
|
||||
mod server;
|
||||
mod thread_view;
|
||||
|
||||
use agentic_coding_protocol::{self as acp, Role};
|
||||
use anyhow::{Context as _, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
|
||||
use language::LanguageRegistry;
|
||||
use markdown::Markdown;
|
||||
use project::Project;
|
||||
use std::{mem, ops::Range, path::PathBuf, sync::Arc};
|
||||
use ui::{App, IconName};
|
||||
use util::{ResultExt, debug_panic};
|
||||
|
||||
pub use server::AcpServer;
|
||||
pub use thread_view::AcpThreadView;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct ThreadId(SharedString);
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub struct FileVersion(u64);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AgentThreadSummary {
|
||||
pub id: ThreadId,
|
||||
pub title: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct FileContent {
|
||||
pub path: PathBuf,
|
||||
pub version: FileVersion,
|
||||
pub content: SharedString,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct Message {
|
||||
pub role: acp::Role,
|
||||
pub chunks: Vec<MessageChunk>,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
fn into_acp(self, cx: &App) -> acp::Message {
|
||||
acp::Message {
|
||||
role: self.role,
|
||||
chunks: self
|
||||
.chunks
|
||||
.into_iter()
|
||||
.map(|chunk| chunk.into_acp(cx))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub enum MessageChunk {
|
||||
Text {
|
||||
chunk: Entity<Markdown>,
|
||||
},
|
||||
File {
|
||||
content: FileContent,
|
||||
},
|
||||
Directory {
|
||||
path: PathBuf,
|
||||
contents: Vec<FileContent>,
|
||||
},
|
||||
Symbol {
|
||||
path: PathBuf,
|
||||
range: Range<u64>,
|
||||
version: FileVersion,
|
||||
name: SharedString,
|
||||
content: SharedString,
|
||||
},
|
||||
Fetch {
|
||||
url: SharedString,
|
||||
content: SharedString,
|
||||
},
|
||||
}
|
||||
|
||||
impl MessageChunk {
|
||||
pub fn from_acp(
|
||||
chunk: acp::MessageChunk,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
match chunk {
|
||||
acp::MessageChunk::Text { chunk } => MessageChunk::Text {
|
||||
chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_acp(self, cx: &App) -> acp::MessageChunk {
|
||||
match self {
|
||||
MessageChunk::Text { chunk } => acp::MessageChunk::Text {
|
||||
chunk: chunk.read(cx).source().to_string(),
|
||||
},
|
||||
MessageChunk::File { .. } => todo!(),
|
||||
MessageChunk::Directory { .. } => todo!(),
|
||||
MessageChunk::Symbol { .. } => todo!(),
|
||||
MessageChunk::Fetch { .. } => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
|
||||
MessageChunk::Text {
|
||||
chunk: cx.new(|cx| {
|
||||
Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum AgentThreadEntryContent {
|
||||
Message(Message),
|
||||
ToolCall(ToolCall),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ToolCall {
|
||||
id: ToolCallId,
|
||||
label: Entity<Markdown>,
|
||||
icon: IconName,
|
||||
status: ToolCallStatus,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ToolCallStatus {
|
||||
WaitingForConfirmation {
|
||||
confirmation: acp::ToolCallConfirmation,
|
||||
respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
|
||||
},
|
||||
// todo! Running?
|
||||
Allowed {
|
||||
// todo! should this be variants in crate::ToolCallStatus instead?
|
||||
status: acp::ToolCallStatus,
|
||||
content: Option<Entity<Markdown>>,
|
||||
},
|
||||
Rejected,
|
||||
}
|
||||
|
||||
/// A `ThreadEntryId` that is known to be a ToolCall
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
pub struct ToolCallId(ThreadEntryId);
|
||||
|
||||
impl ToolCallId {
|
||||
pub fn as_u64(&self) -> u64 {
|
||||
self.0.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
pub struct ThreadEntryId(pub u64);
|
||||
|
||||
impl ThreadEntryId {
|
||||
pub fn post_inc(&mut self) -> Self {
|
||||
let id = *self;
|
||||
self.0 += 1;
|
||||
id
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ThreadEntry {
|
||||
pub id: ThreadEntryId,
|
||||
pub content: AgentThreadEntryContent,
|
||||
}
|
||||
|
||||
pub struct AcpThread {
|
||||
id: ThreadId,
|
||||
next_entry_id: ThreadEntryId,
|
||||
entries: Vec<ThreadEntry>,
|
||||
server: Arc<AcpServer>,
|
||||
title: SharedString,
|
||||
project: Entity<Project>,
|
||||
}
|
||||
|
||||
enum AcpThreadEvent {
|
||||
NewEntry,
|
||||
EntryUpdated(usize),
|
||||
}
|
||||
|
||||
impl EventEmitter<AcpThreadEvent> for AcpThread {}
|
||||
|
||||
impl AcpThread {
|
||||
pub fn new(
|
||||
server: Arc<AcpServer>,
|
||||
thread_id: ThreadId,
|
||||
entries: Vec<AgentThreadEntryContent>,
|
||||
project: Entity<Project>,
|
||||
_: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let mut next_entry_id = ThreadEntryId(0);
|
||||
Self {
|
||||
title: "A new agent2 thread".into(),
|
||||
entries: entries
|
||||
.into_iter()
|
||||
.map(|entry| ThreadEntry {
|
||||
id: next_entry_id.post_inc(),
|
||||
content: entry,
|
||||
})
|
||||
.collect(),
|
||||
server,
|
||||
id: thread_id,
|
||||
next_entry_id,
|
||||
project,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn title(&self) -> SharedString {
|
||||
self.title.clone()
|
||||
}
|
||||
|
||||
pub fn entries(&self) -> &[ThreadEntry] {
|
||||
&self.entries
|
||||
}
|
||||
|
||||
pub fn push_entry(
|
||||
&mut self,
|
||||
entry: AgentThreadEntryContent,
|
||||
cx: &mut Context<Self>,
|
||||
) -> ThreadEntryId {
|
||||
let id = self.next_entry_id.post_inc();
|
||||
self.entries.push(ThreadEntry { id, content: entry });
|
||||
cx.emit(AcpThreadEvent::NewEntry);
|
||||
id
|
||||
}
|
||||
|
||||
pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
|
||||
let entries_len = self.entries.len();
|
||||
if let Some(last_entry) = self.entries.last_mut()
|
||||
&& let AgentThreadEntryContent::Message(Message {
|
||||
ref mut chunks,
|
||||
role: Role::Assistant,
|
||||
}) = last_entry.content
|
||||
{
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
|
||||
|
||||
if let (
|
||||
Some(MessageChunk::Text { chunk: old_chunk }),
|
||||
acp::MessageChunk::Text { chunk: new_chunk },
|
||||
) = (chunks.last_mut(), &chunk)
|
||||
{
|
||||
old_chunk.update(cx, |old_chunk, cx| {
|
||||
old_chunk.append(&new_chunk, cx);
|
||||
});
|
||||
} else {
|
||||
chunks.push(MessageChunk::from_acp(
|
||||
chunk,
|
||||
self.project.read(cx).languages().clone(),
|
||||
cx,
|
||||
));
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
let chunk = MessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
|
||||
|
||||
self.push_entry(
|
||||
AgentThreadEntryContent::Message(Message {
|
||||
role: Role::Assistant,
|
||||
chunks: vec![chunk],
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn request_tool_call(
|
||||
&mut self,
|
||||
label: String,
|
||||
icon: acp::Icon,
|
||||
confirmation: acp::ToolCallConfirmation,
|
||||
cx: &mut Context<Self>,
|
||||
) -> ToolCallRequest {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
let status = ToolCallStatus::WaitingForConfirmation {
|
||||
confirmation,
|
||||
respond_tx: tx,
|
||||
};
|
||||
|
||||
let id = self.insert_tool_call(label, status, icon, cx);
|
||||
ToolCallRequest { id, outcome: rx }
|
||||
}
|
||||
|
||||
pub fn push_tool_call(
|
||||
&mut self,
|
||||
label: String,
|
||||
icon: acp::Icon,
|
||||
cx: &mut Context<Self>,
|
||||
) -> ToolCallId {
|
||||
let status = ToolCallStatus::Allowed {
|
||||
status: acp::ToolCallStatus::Running,
|
||||
content: None,
|
||||
};
|
||||
|
||||
self.insert_tool_call(label, status, icon, cx)
|
||||
}
|
||||
|
||||
fn insert_tool_call(
|
||||
&mut self,
|
||||
label: String,
|
||||
status: ToolCallStatus,
|
||||
icon: acp::Icon,
|
||||
cx: &mut Context<Self>,
|
||||
) -> ToolCallId {
|
||||
let language_registry = self.project.read(cx).languages().clone();
|
||||
|
||||
let entry_id = self.push_entry(
|
||||
AgentThreadEntryContent::ToolCall(ToolCall {
|
||||
// todo! clean up id creation
|
||||
id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
|
||||
label: cx.new(|cx| {
|
||||
Markdown::new(label.into(), Some(language_registry.clone()), None, cx)
|
||||
}),
|
||||
icon: acp_icon_to_ui_icon(icon),
|
||||
status,
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
|
||||
ToolCallId(entry_id)
|
||||
}
|
||||
|
||||
pub fn authorize_tool_call(
|
||||
&mut self,
|
||||
id: ToolCallId,
|
||||
outcome: acp::ToolCallConfirmationOutcome,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(entry) = self.entry_mut(id.0) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
|
||||
debug_panic!("expected ToolCall");
|
||||
return;
|
||||
};
|
||||
|
||||
let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
|
||||
ToolCallStatus::Rejected
|
||||
} else {
|
||||
ToolCallStatus::Allowed {
|
||||
status: acp::ToolCallStatus::Running,
|
||||
content: None,
|
||||
}
|
||||
};
|
||||
|
||||
let curr_status = mem::replace(&mut call.status, new_status);
|
||||
|
||||
if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
|
||||
respond_tx.send(outcome).log_err();
|
||||
} else {
|
||||
debug_panic!("tried to authorize an already authorized tool call");
|
||||
}
|
||||
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
|
||||
}
|
||||
|
||||
pub fn update_tool_call(
|
||||
&mut self,
|
||||
id: ToolCallId,
|
||||
new_status: acp::ToolCallStatus,
|
||||
new_content: Option<acp::ToolCallContent>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<()> {
|
||||
let language_registry = self.project.read(cx).languages().clone();
|
||||
let entry = self.entry_mut(id.0).context("Entry not found")?;
|
||||
|
||||
match &mut entry.content {
|
||||
AgentThreadEntryContent::ToolCall(call) => match &mut call.status {
|
||||
ToolCallStatus::Allowed { content, status } => {
|
||||
*content = new_content.map(|new_content| {
|
||||
let acp::ToolCallContent::Markdown { markdown } = new_content;
|
||||
|
||||
cx.new(|cx| {
|
||||
Markdown::new(markdown.into(), Some(language_registry), None, cx)
|
||||
})
|
||||
});
|
||||
|
||||
*status = new_status;
|
||||
}
|
||||
ToolCallStatus::WaitingForConfirmation { .. } => {
|
||||
anyhow::bail!("Tool call hasn't been authorized yet")
|
||||
}
|
||||
ToolCallStatus::Rejected => {
|
||||
anyhow::bail!("Tool call was rejected and therefore can't be updated")
|
||||
}
|
||||
},
|
||||
_ => anyhow::bail!("Entry is not a tool call"),
|
||||
}
|
||||
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
|
||||
let entry = self.entries.get_mut(id.0 as usize);
|
||||
debug_assert!(
|
||||
entry.is_some(),
|
||||
"We shouldn't give out ids to entries that don't exist"
|
||||
);
|
||||
entry
|
||||
}
|
||||
|
||||
/// Returns true if the last turn is awaiting tool authorization
|
||||
pub fn waiting_for_tool_confirmation(&self) -> bool {
|
||||
for entry in self.entries.iter().rev() {
|
||||
match &entry.content {
|
||||
AgentThreadEntryContent::ToolCall(call) => match call.status {
|
||||
ToolCallStatus::WaitingForConfirmation { .. } => return true,
|
||||
ToolCallStatus::Allowed { .. } | ToolCallStatus::Rejected => continue,
|
||||
},
|
||||
AgentThreadEntryContent::Message(_) => {
|
||||
// Reached the beginning of the turn
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let agent = self.server.clone();
|
||||
let id = self.id.clone();
|
||||
let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
|
||||
let message = Message {
|
||||
role: Role::User,
|
||||
chunks: vec![chunk],
|
||||
};
|
||||
self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
|
||||
let acp_message = message.into_acp(cx);
|
||||
cx.spawn(async move |_, cx| {
|
||||
agent.send_message(id, acp_message, cx).await?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
|
||||
match icon {
|
||||
acp::Icon::FileSearch => IconName::FileSearch,
|
||||
acp::Icon::Folder => IconName::Folder,
|
||||
acp::Icon::Globe => IconName::Globe,
|
||||
acp::Icon::Hammer => IconName::Hammer,
|
||||
acp::Icon::LightBulb => IconName::LightBulb,
|
||||
acp::Icon::Pencil => IconName::Pencil,
|
||||
acp::Icon::Regex => IconName::Regex,
|
||||
acp::Icon::Terminal => IconName::Terminal,
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ToolCallRequest {
|
||||
pub id: ToolCallId,
|
||||
pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use futures::{FutureExt as _, channel::mpsc, select};
|
||||
use gpui::{AsyncApp, TestAppContext};
|
||||
use project::FakeFs;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use smol::stream::StreamExt as _;
|
||||
use std::{env, path::Path, process::Stdio, time::Duration};
|
||||
use util::path;
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
env_logger::try_init().ok();
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
language::init(cx);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_gemini_basic(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, [], cx).await;
|
||||
let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
|
||||
let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
|
||||
thread
|
||||
.update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.entries.len(), 2);
|
||||
assert!(matches!(
|
||||
thread.entries[0].content,
|
||||
AgentThreadEntryContent::Message(Message {
|
||||
role: Role::User,
|
||||
..
|
||||
})
|
||||
));
|
||||
assert!(matches!(
|
||||
thread.entries[1].content,
|
||||
AgentThreadEntryContent::Message(Message {
|
||||
role: Role::Assistant,
|
||||
..
|
||||
})
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_gemini_tool_call(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/private/tmp"),
|
||||
json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||
let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
|
||||
let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(
|
||||
"Read the '/private/tmp/foo' file and tell me what you see.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
assert!(matches!(
|
||||
&thread.entries()[1].content,
|
||||
AgentThreadEntryContent::ToolCall(ToolCall {
|
||||
status: ToolCallStatus::Allowed { .. },
|
||||
..
|
||||
})
|
||||
));
|
||||
|
||||
assert!(matches!(
|
||||
thread.entries[2].content,
|
||||
AgentThreadEntryContent::Message(Message {
|
||||
role: Role::Assistant,
|
||||
..
|
||||
})
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||
let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
|
||||
let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
|
||||
let full_turn = thread.update(cx, |thread, cx| {
|
||||
thread.send(r#"Run `echo "Hello, world!"`"#, cx)
|
||||
});
|
||||
|
||||
run_until_tool_call(&thread, cx).await;
|
||||
|
||||
let tool_call_id = thread.read_with(cx, |thread, _cx| {
|
||||
let AgentThreadEntryContent::ToolCall(ToolCall {
|
||||
id,
|
||||
status:
|
||||
ToolCallStatus::WaitingForConfirmation {
|
||||
confirmation: acp::ToolCallConfirmation::Execute { root_command, .. },
|
||||
..
|
||||
},
|
||||
..
|
||||
}) = &thread.entries()[1].content
|
||||
else {
|
||||
panic!();
|
||||
};
|
||||
|
||||
assert_eq!(root_command, "echo");
|
||||
|
||||
*id
|
||||
});
|
||||
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
|
||||
|
||||
assert!(matches!(
|
||||
&thread.entries()[1].content,
|
||||
AgentThreadEntryContent::ToolCall(ToolCall {
|
||||
status: ToolCallStatus::Allowed { .. },
|
||||
..
|
||||
})
|
||||
));
|
||||
});
|
||||
|
||||
full_turn.await.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
let AgentThreadEntryContent::ToolCall(ToolCall {
|
||||
status: ToolCallStatus::Allowed { content, .. },
|
||||
..
|
||||
}) = &thread.entries()[1].content
|
||||
else {
|
||||
panic!();
|
||||
};
|
||||
|
||||
content.as_ref().unwrap().read_with(cx, |md, _cx| {
|
||||
assert!(
|
||||
md.source().contains("Hello, world!"),
|
||||
r#"Expected '{}' to contain "Hello, world!""#,
|
||||
md.source()
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
|
||||
let (mut tx, mut rx) = mpsc::channel::<()>(1);
|
||||
|
||||
let subscription = cx.update(|cx| {
|
||||
cx.subscribe(thread, move |thread, _, cx| {
|
||||
if thread
|
||||
.read(cx)
|
||||
.entries
|
||||
.iter()
|
||||
.any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
|
||||
{
|
||||
tx.try_send(()).unwrap();
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
select! {
|
||||
_ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
|
||||
panic!("Timeout waiting for tool call")
|
||||
}
|
||||
_ = rx.next().fuse() => {
|
||||
drop(subscription);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
|
||||
let cli_path =
|
||||
Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
|
||||
let mut command = util::command::new_smol_command("node");
|
||||
command
|
||||
.arg(cli_path)
|
||||
.arg("--acp")
|
||||
.current_dir("/private/tmp")
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::inherit())
|
||||
.kill_on_drop(true);
|
||||
|
||||
if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
|
||||
command.env("GEMINI_API_KEY", gemini_key);
|
||||
}
|
||||
|
||||
let child = command.spawn().unwrap();
|
||||
|
||||
Ok(AcpServer::stdio(child, project, &mut cx))
|
||||
}
|
||||
}
|
||||
322
crates/acp/src/server.rs
Normal file
322
crates/acp/src/server.rs
Normal file
@@ -0,0 +1,322 @@
|
||||
use crate::{AcpThread, ThreadEntryId, ThreadId, ToolCallId, ToolCallRequest};
|
||||
use agentic_coding_protocol as acp;
|
||||
use anyhow::{Context as _, Result};
|
||||
use async_trait::async_trait;
|
||||
use collections::HashMap;
|
||||
use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
|
||||
use parking_lot::Mutex;
|
||||
use project::Project;
|
||||
use smol::process::Child;
|
||||
use std::{io::Write as _, path::Path, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
pub struct AcpServer {
|
||||
connection: Arc<acp::AgentConnection>,
|
||||
threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
|
||||
project: Entity<Project>,
|
||||
_handler_task: Task<()>,
|
||||
_io_task: Task<()>,
|
||||
}
|
||||
|
||||
struct AcpClientDelegate {
|
||||
project: Entity<Project>,
|
||||
threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
|
||||
cx: AsyncApp,
|
||||
// sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
|
||||
}
|
||||
|
||||
impl AcpClientDelegate {
|
||||
fn new(
|
||||
project: Entity<Project>,
|
||||
threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
|
||||
cx: AsyncApp,
|
||||
) -> Self {
|
||||
Self {
|
||||
project,
|
||||
threads,
|
||||
cx: cx,
|
||||
}
|
||||
}
|
||||
|
||||
fn update_thread<R>(
|
||||
&self,
|
||||
thread_id: &ThreadId,
|
||||
cx: &mut App,
|
||||
callback: impl FnOnce(&mut AcpThread, &mut Context<AcpThread>) -> R,
|
||||
) -> Option<R> {
|
||||
let thread = self.threads.lock().get(&thread_id)?.clone();
|
||||
let Some(thread) = thread.upgrade() else {
|
||||
self.threads.lock().remove(&thread_id);
|
||||
return None;
|
||||
};
|
||||
Some(thread.update(cx, callback))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait(?Send)]
|
||||
impl acp::Client for AcpClientDelegate {
|
||||
async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> {
|
||||
let cx = &mut self.cx.clone();
|
||||
self.project.update(cx, |project, cx| {
|
||||
let path = project
|
||||
.project_path_for_absolute_path(Path::new(¶ms.path), cx)
|
||||
.context("Failed to get project path")?;
|
||||
|
||||
match project.entry_for_path(&path, cx) {
|
||||
// todo! refresh entry?
|
||||
None => Ok(acp::StatResponse {
|
||||
exists: false,
|
||||
is_directory: false,
|
||||
}),
|
||||
Some(entry) => Ok(acp::StatResponse {
|
||||
exists: entry.is_created(),
|
||||
is_directory: entry.is_dir(),
|
||||
}),
|
||||
}
|
||||
})?
|
||||
}
|
||||
|
||||
async fn stream_message_chunk(
|
||||
&self,
|
||||
params: acp::StreamMessageChunkParams,
|
||||
) -> Result<acp::StreamMessageChunkResponse> {
|
||||
let cx = &mut self.cx.clone();
|
||||
|
||||
cx.update(|cx| {
|
||||
self.update_thread(¶ms.thread_id.into(), cx, |thread, cx| {
|
||||
thread.push_assistant_chunk(params.chunk, cx)
|
||||
});
|
||||
})?;
|
||||
|
||||
Ok(acp::StreamMessageChunkResponse)
|
||||
}
|
||||
|
||||
async fn read_text_file(
|
||||
&self,
|
||||
request: acp::ReadTextFileParams,
|
||||
) -> Result<acp::ReadTextFileResponse> {
|
||||
let cx = &mut self.cx.clone();
|
||||
let buffer = self
|
||||
.project
|
||||
.update(cx, |project, cx| {
|
||||
let path = project
|
||||
.project_path_for_absolute_path(Path::new(&request.path), cx)
|
||||
.context("Failed to get project path")?;
|
||||
anyhow::Ok(project.open_buffer(path, cx))
|
||||
})??
|
||||
.await?;
|
||||
|
||||
buffer.update(cx, |buffer, _cx| {
|
||||
let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
|
||||
let end = match request.line_limit {
|
||||
None => buffer.max_point(),
|
||||
Some(limit) => start + language::Point::new(limit + 1, 0),
|
||||
};
|
||||
|
||||
let content: String = buffer.text_for_range(start..end).collect();
|
||||
|
||||
acp::ReadTextFileResponse {
|
||||
content,
|
||||
version: acp::FileVersion(0),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn read_binary_file(
|
||||
&self,
|
||||
request: acp::ReadBinaryFileParams,
|
||||
) -> Result<acp::ReadBinaryFileResponse> {
|
||||
let cx = &mut self.cx.clone();
|
||||
let file = self
|
||||
.project
|
||||
.update(cx, |project, cx| {
|
||||
let (worktree, path) = project
|
||||
.find_worktree(Path::new(&request.path), cx)
|
||||
.context("Failed to get project path")?;
|
||||
|
||||
let task = worktree.update(cx, |worktree, cx| worktree.load_binary_file(&path, cx));
|
||||
anyhow::Ok(task)
|
||||
})??
|
||||
.await?;
|
||||
|
||||
// todo! test
|
||||
let content = cx
|
||||
.background_spawn(async move {
|
||||
let start = request.byte_offset.unwrap_or(0) as usize;
|
||||
let end = request
|
||||
.byte_limit
|
||||
.map(|limit| (start + limit as usize).min(file.content.len()))
|
||||
.unwrap_or(file.content.len());
|
||||
|
||||
let range_content = &file.content[start..end];
|
||||
|
||||
let mut base64_content = Vec::new();
|
||||
let mut base64_encoder = base64::write::EncoderWriter::new(
|
||||
std::io::Cursor::new(&mut base64_content),
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
);
|
||||
base64_encoder.write_all(range_content)?;
|
||||
drop(base64_encoder);
|
||||
|
||||
// SAFETY: The base64 encoder should not produce non-UTF8.
|
||||
unsafe { anyhow::Ok(String::from_utf8_unchecked(base64_content)) }
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(acp::ReadBinaryFileResponse {
|
||||
content,
|
||||
// todo!
|
||||
version: acp::FileVersion(0),
|
||||
})
|
||||
}
|
||||
|
||||
async fn glob_search(
|
||||
&self,
|
||||
_request: acp::GlobSearchParams,
|
||||
) -> Result<acp::GlobSearchResponse> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn request_tool_call_confirmation(
|
||||
&self,
|
||||
request: acp::RequestToolCallConfirmationParams,
|
||||
) -> Result<acp::RequestToolCallConfirmationResponse> {
|
||||
let cx = &mut self.cx.clone();
|
||||
let ToolCallRequest { id, outcome } = cx
|
||||
.update(|cx| {
|
||||
self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
|
||||
thread.request_tool_call(request.label, request.icon, request.confirmation, cx)
|
||||
})
|
||||
})?
|
||||
.context("Failed to update thread")?;
|
||||
|
||||
Ok(acp::RequestToolCallConfirmationResponse {
|
||||
id: id.into(),
|
||||
outcome: outcome.await?,
|
||||
})
|
||||
}
|
||||
|
||||
async fn push_tool_call(
|
||||
&self,
|
||||
request: acp::PushToolCallParams,
|
||||
) -> Result<acp::PushToolCallResponse> {
|
||||
let cx = &mut self.cx.clone();
|
||||
let entry_id = cx
|
||||
.update(|cx| {
|
||||
self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
|
||||
thread.push_tool_call(request.label, request.icon, cx)
|
||||
})
|
||||
})?
|
||||
.context("Failed to update thread")?;
|
||||
|
||||
Ok(acp::PushToolCallResponse {
|
||||
id: entry_id.into(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn update_tool_call(
|
||||
&self,
|
||||
request: acp::UpdateToolCallParams,
|
||||
) -> Result<acp::UpdateToolCallResponse> {
|
||||
let cx = &mut self.cx.clone();
|
||||
|
||||
cx.update(|cx| {
|
||||
self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
|
||||
thread.update_tool_call(
|
||||
request.tool_call_id.into(),
|
||||
request.status,
|
||||
request.content,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})?
|
||||
.context("Failed to update thread")??;
|
||||
|
||||
Ok(acp::UpdateToolCallResponse)
|
||||
}
|
||||
}
|
||||
|
||||
impl AcpServer {
|
||||
pub fn stdio(mut process: Child, project: Entity<Project>, cx: &mut AsyncApp) -> Arc<Self> {
|
||||
let stdin = process.stdin.take().expect("process didn't have stdin");
|
||||
let stdout = process.stdout.take().expect("process didn't have stdout");
|
||||
|
||||
let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
|
||||
let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
|
||||
AcpClientDelegate::new(project.clone(), threads.clone(), cx.clone()),
|
||||
stdin,
|
||||
stdout,
|
||||
);
|
||||
|
||||
let io_task = cx.background_spawn(async move {
|
||||
io_fut.await.log_err();
|
||||
process.status().await.log_err();
|
||||
});
|
||||
|
||||
Arc::new(Self {
|
||||
project,
|
||||
connection: Arc::new(connection),
|
||||
threads,
|
||||
_handler_task: cx.foreground_executor().spawn(handler_fut),
|
||||
_io_task: io_task,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl AcpServer {
|
||||
pub async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<AcpThread>> {
|
||||
let response = self.connection.request(acp::CreateThreadParams).await?;
|
||||
let thread_id: ThreadId = response.thread_id.into();
|
||||
let server = self.clone();
|
||||
let thread = cx.new(|_| AcpThread {
|
||||
// todo!
|
||||
title: "ACP Thread".into(),
|
||||
id: thread_id.clone(),
|
||||
next_entry_id: ThreadEntryId(0),
|
||||
entries: Vec::default(),
|
||||
project: self.project.clone(),
|
||||
server,
|
||||
})?;
|
||||
self.threads.lock().insert(thread_id, thread.downgrade());
|
||||
Ok(thread)
|
||||
}
|
||||
|
||||
pub async fn send_message(
|
||||
&self,
|
||||
thread_id: ThreadId,
|
||||
message: acp::Message,
|
||||
_cx: &mut AsyncApp,
|
||||
) -> Result<()> {
|
||||
self.connection
|
||||
.request(acp::SendMessageParams {
|
||||
thread_id: thread_id.clone().into(),
|
||||
message,
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<acp::ThreadId> for ThreadId {
|
||||
fn from(thread_id: acp::ThreadId) -> Self {
|
||||
Self(thread_id.0.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ThreadId> for acp::ThreadId {
|
||||
fn from(thread_id: ThreadId) -> Self {
|
||||
acp::ThreadId(thread_id.0.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<acp::ToolCallId> for ToolCallId {
|
||||
fn from(tool_call_id: acp::ToolCallId) -> Self {
|
||||
Self(ThreadEntryId(tool_call_id.0))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ToolCallId> for acp::ToolCallId {
|
||||
fn from(tool_call_id: ToolCallId) -> Self {
|
||||
acp::ToolCallId(tool_call_id.as_u64())
|
||||
}
|
||||
}
|
||||
935
crates/acp/src/thread_view.rs
Normal file
935
crates/acp/src/thread_view.rs
Normal file
@@ -0,0 +1,935 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use agentic_coding_protocol::{self as acp, ToolCallConfirmation};
|
||||
use anyhow::Result;
|
||||
use editor::{Editor, MultiBuffer};
|
||||
use gpui::{
|
||||
Animation, AnimationExt, App, EdgesRefinement, Empty, Entity, Focusable, ListState,
|
||||
SharedString, StyleRefinement, Subscription, TextStyleRefinement, Transformation,
|
||||
UnderlineStyle, Window, div, list, percentage, prelude::*,
|
||||
};
|
||||
use gpui::{FocusHandle, Task};
|
||||
use language::Buffer;
|
||||
use markdown::{HeadingLevelStyles, MarkdownElement, MarkdownStyle};
|
||||
use project::Project;
|
||||
use settings::Settings as _;
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
use ui::{Button, Tooltip};
|
||||
use util::ResultExt;
|
||||
use zed_actions::agent::Chat;
|
||||
|
||||
use crate::{
|
||||
AcpServer, AcpThread, AcpThreadEvent, AgentThreadEntryContent, MessageChunk, Role, ThreadEntry,
|
||||
ToolCall, ToolCallId, ToolCallStatus,
|
||||
};
|
||||
|
||||
pub struct AcpThreadView {
|
||||
thread_state: ThreadState,
|
||||
// todo! use full message editor from agent2
|
||||
message_editor: Entity<Editor>,
|
||||
list_state: ListState,
|
||||
send_task: Option<Task<Result<()>>>,
|
||||
root: Arc<Path>,
|
||||
}
|
||||
|
||||
enum ThreadState {
|
||||
Loading {
|
||||
_task: Task<()>,
|
||||
},
|
||||
Ready {
|
||||
thread: Entity<AcpThread>,
|
||||
_subscription: Subscription,
|
||||
},
|
||||
LoadError(SharedString),
|
||||
}
|
||||
|
||||
impl AcpThreadView {
|
||||
pub fn new(project: Entity<Project>, window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
// todo!(): This should probably be contextual, like the terminal
|
||||
let Some(root_dir) = project
|
||||
.read(cx)
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.map(|worktree| worktree.read(cx).abs_path())
|
||||
else {
|
||||
todo!();
|
||||
};
|
||||
|
||||
let cli_path =
|
||||
Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
|
||||
|
||||
let child = util::command::new_smol_command("node")
|
||||
.arg(cli_path)
|
||||
.arg("--acp")
|
||||
.current_dir(&root_dir)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.kill_on_drop(true)
|
||||
.spawn()
|
||||
.unwrap();
|
||||
|
||||
let message_editor = cx.new(|cx| {
|
||||
let buffer = cx.new(|cx| Buffer::local("", cx));
|
||||
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
|
||||
let mut editor = Editor::new(
|
||||
editor::EditorMode::AutoHeight {
|
||||
min_lines: 4,
|
||||
max_lines: None,
|
||||
},
|
||||
buffer,
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
editor.set_placeholder_text("Send a message", cx);
|
||||
editor.set_soft_wrap();
|
||||
editor
|
||||
});
|
||||
|
||||
let project = project.clone();
|
||||
let load_task = cx.spawn_in(window, async move |this, cx| {
|
||||
let agent = AcpServer::stdio(child, project, cx);
|
||||
let result = agent.create_thread(cx).await;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
match result {
|
||||
Ok(thread) => {
|
||||
let subscription = cx.subscribe(&thread, |this, _, event, cx| {
|
||||
let count = this.list_state.item_count();
|
||||
match event {
|
||||
AcpThreadEvent::NewEntry => {
|
||||
this.list_state.splice(count..count, 1);
|
||||
}
|
||||
AcpThreadEvent::EntryUpdated(index) => {
|
||||
this.list_state.splice(*index..*index + 1, 1);
|
||||
}
|
||||
}
|
||||
cx.notify();
|
||||
});
|
||||
this.list_state
|
||||
.splice(0..0, thread.read(cx).entries().len());
|
||||
|
||||
this.thread_state = ThreadState::Ready {
|
||||
thread,
|
||||
_subscription: subscription,
|
||||
};
|
||||
}
|
||||
Err(e) => this.thread_state = ThreadState::LoadError(e.to_string().into()),
|
||||
};
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
});
|
||||
|
||||
let list_state = ListState::new(
|
||||
0,
|
||||
gpui::ListAlignment::Bottom,
|
||||
px(2048.0),
|
||||
cx.processor({
|
||||
move |this: &mut Self, item: usize, window, cx| {
|
||||
let Some(entry) = this
|
||||
.thread()
|
||||
.and_then(|thread| thread.read(cx).entries.get(item))
|
||||
else {
|
||||
return Empty.into_any();
|
||||
};
|
||||
this.render_entry(entry, window, cx)
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
Self {
|
||||
thread_state: ThreadState::Loading { _task: load_task },
|
||||
message_editor,
|
||||
send_task: None,
|
||||
list_state: list_state,
|
||||
root: root_dir,
|
||||
}
|
||||
}
|
||||
|
||||
fn thread(&self) -> Option<&Entity<AcpThread>> {
|
||||
match &self.thread_state {
|
||||
ThreadState::Ready { thread, .. } => Some(thread),
|
||||
ThreadState::Loading { .. } | ThreadState::LoadError(..) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn title(&self, cx: &App) -> SharedString {
|
||||
match &self.thread_state {
|
||||
ThreadState::Ready { thread, .. } => thread.read(cx).title(),
|
||||
ThreadState::Loading { .. } => "Loading...".into(),
|
||||
ThreadState::LoadError(_) => "Failed to load".into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cancel(&mut self) {
|
||||
self.send_task.take();
|
||||
}
|
||||
|
||||
fn chat(&mut self, _: &Chat, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let text = self.message_editor.read(cx).text(cx);
|
||||
if text.is_empty() {
|
||||
return;
|
||||
}
|
||||
let Some(thread) = self.thread() else { return };
|
||||
|
||||
let task = thread.update(cx, |thread, cx| thread.send(&text, cx));
|
||||
|
||||
self.send_task = Some(cx.spawn(async move |this, cx| {
|
||||
task.await?;
|
||||
|
||||
this.update(cx, |this, _cx| {
|
||||
this.send_task.take();
|
||||
})
|
||||
}));
|
||||
|
||||
self.message_editor.update(cx, |editor, cx| {
|
||||
editor.clear(window, cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn authorize_tool_call(
|
||||
&mut self,
|
||||
id: ToolCallId,
|
||||
outcome: acp::ToolCallConfirmationOutcome,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(thread) = self.thread() else {
|
||||
return;
|
||||
};
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.authorize_tool_call(id, outcome, cx);
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn render_entry(
|
||||
&self,
|
||||
entry: &ThreadEntry,
|
||||
window: &mut Window,
|
||||
cx: &Context<Self>,
|
||||
) -> AnyElement {
|
||||
match &entry.content {
|
||||
AgentThreadEntryContent::Message(message) => {
|
||||
let style = if message.role == Role::User {
|
||||
user_message_markdown_style(window, cx)
|
||||
} else {
|
||||
default_markdown_style(window, cx)
|
||||
};
|
||||
let message_body = div()
|
||||
.children(message.chunks.iter().map(|chunk| match chunk {
|
||||
MessageChunk::Text { chunk } => {
|
||||
// todo!() open link
|
||||
MarkdownElement::new(chunk.clone(), style.clone())
|
||||
}
|
||||
_ => todo!(),
|
||||
}))
|
||||
.into_any();
|
||||
|
||||
match message.role {
|
||||
Role::User => div()
|
||||
.p_2()
|
||||
.pt_5()
|
||||
.child(
|
||||
div()
|
||||
.text_xs()
|
||||
.p_3()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_lg()
|
||||
.shadow_md()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.child(message_body),
|
||||
)
|
||||
.into_any(),
|
||||
Role::Assistant => div()
|
||||
.text_ui(cx)
|
||||
.p_5()
|
||||
.pt_2()
|
||||
.child(message_body)
|
||||
.into_any(),
|
||||
}
|
||||
}
|
||||
AgentThreadEntryContent::ToolCall(tool_call) => div()
|
||||
.px_2()
|
||||
.py_4()
|
||||
.child(self.render_tool_call(tool_call, window, cx))
|
||||
.into_any(),
|
||||
}
|
||||
}
|
||||
|
||||
fn render_tool_call(&self, tool_call: &ToolCall, window: &Window, cx: &Context<Self>) -> Div {
|
||||
let status_icon = match &tool_call.status {
|
||||
ToolCallStatus::WaitingForConfirmation { .. } => Empty.into_element().into_any(),
|
||||
ToolCallStatus::Allowed {
|
||||
status: acp::ToolCallStatus::Running,
|
||||
..
|
||||
} => Icon::new(IconName::ArrowCircle)
|
||||
.color(Color::Success)
|
||||
.size(IconSize::Small)
|
||||
.with_animation(
|
||||
"running",
|
||||
Animation::new(Duration::from_secs(2)).repeat(),
|
||||
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
|
||||
)
|
||||
.into_any_element(),
|
||||
ToolCallStatus::Allowed {
|
||||
status: acp::ToolCallStatus::Finished,
|
||||
..
|
||||
} => Icon::new(IconName::Check)
|
||||
.color(Color::Success)
|
||||
.size(IconSize::Small)
|
||||
.into_any_element(),
|
||||
ToolCallStatus::Rejected
|
||||
| ToolCallStatus::Allowed {
|
||||
status: acp::ToolCallStatus::Error,
|
||||
..
|
||||
} => Icon::new(IconName::X)
|
||||
.color(Color::Error)
|
||||
.size(IconSize::Small)
|
||||
.into_any_element(),
|
||||
};
|
||||
|
||||
let content = match &tool_call.status {
|
||||
ToolCallStatus::WaitingForConfirmation { confirmation, .. } => {
|
||||
Some(self.render_tool_call_confirmation(tool_call.id, confirmation, cx))
|
||||
}
|
||||
ToolCallStatus::Allowed { content, .. } => content.clone().map(|content| {
|
||||
div()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.border_t_1()
|
||||
.px_2()
|
||||
.py_1p5()
|
||||
.child(MarkdownElement::new(
|
||||
content,
|
||||
default_markdown_style(window, cx),
|
||||
))
|
||||
.into_any_element()
|
||||
}),
|
||||
ToolCallStatus::Rejected => None,
|
||||
};
|
||||
|
||||
v_flex()
|
||||
.text_xs()
|
||||
.rounded_md()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.child(
|
||||
h_flex()
|
||||
.px_2()
|
||||
.py_1p5()
|
||||
.w_full()
|
||||
.gap_1p5()
|
||||
.child(
|
||||
Icon::new(tool_call.icon.into())
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
// todo! danilo please help
|
||||
.child(MarkdownElement::new(
|
||||
tool_call.label.clone(),
|
||||
default_markdown_style(window, cx),
|
||||
))
|
||||
.child(div().w_full())
|
||||
.child(status_icon),
|
||||
)
|
||||
.children(content)
|
||||
}
|
||||
|
||||
fn render_tool_call_confirmation(
|
||||
&self,
|
||||
tool_call_id: ToolCallId,
|
||||
confirmation: &ToolCallConfirmation,
|
||||
cx: &Context<Self>,
|
||||
) -> AnyElement {
|
||||
match confirmation {
|
||||
ToolCallConfirmation::Edit {
|
||||
file_name,
|
||||
file_diff,
|
||||
description,
|
||||
} => v_flex()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.border_t_1()
|
||||
.px_2()
|
||||
.py_1p5()
|
||||
// todo! nicer rendering
|
||||
.child(file_name.clone())
|
||||
.child(file_diff.clone())
|
||||
.children(description.clone())
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.gap_1()
|
||||
.child(
|
||||
Button::new(
|
||||
("always_allow", tool_call_id.as_u64()),
|
||||
"Always Allow Edits",
|
||||
)
|
||||
.icon(IconName::CheckDouble)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Success)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("allow", tool_call_id.as_u64()), "Allow")
|
||||
.icon(IconName::Check)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Success)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Allow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("reject", tool_call_id.as_u64()), "Reject")
|
||||
.icon(IconName::X)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Error)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Reject,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
),
|
||||
)
|
||||
.into_any(),
|
||||
ToolCallConfirmation::Execute {
|
||||
command,
|
||||
root_command,
|
||||
description,
|
||||
} => v_flex()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.border_t_1()
|
||||
.px_2()
|
||||
.py_1p5()
|
||||
// todo! nicer rendering
|
||||
.child(command.clone())
|
||||
.children(description.clone())
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.gap_1()
|
||||
.child(
|
||||
Button::new(
|
||||
("always_allow", tool_call_id.as_u64()),
|
||||
format!("Always Allow {root_command}"),
|
||||
)
|
||||
.icon(IconName::CheckDouble)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Success)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("allow", tool_call_id.as_u64()), "Allow")
|
||||
.icon(IconName::Check)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Success)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Allow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("reject", tool_call_id.as_u64()), "Reject")
|
||||
.icon(IconName::X)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Error)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Reject,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
),
|
||||
)
|
||||
.into_any(),
|
||||
ToolCallConfirmation::Mcp {
|
||||
server_name,
|
||||
tool_name: _,
|
||||
tool_display_name,
|
||||
description,
|
||||
} => v_flex()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.border_t_1()
|
||||
.px_2()
|
||||
.py_1p5()
|
||||
// todo! nicer rendering
|
||||
.child(format!("{server_name} - {tool_display_name}"))
|
||||
.children(description.clone())
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.gap_1()
|
||||
.child(
|
||||
Button::new(
|
||||
("always_allow_server", tool_call_id.as_u64()),
|
||||
format!("Always Allow {server_name}"),
|
||||
)
|
||||
.icon(IconName::CheckDouble)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Success)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(
|
||||
("always_allow_tool", tool_call_id.as_u64()),
|
||||
format!("Always Allow {tool_display_name}"),
|
||||
)
|
||||
.icon(IconName::CheckDouble)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Success)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::AlwaysAllowTool,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("allow", tool_call_id.as_u64()), "Allow")
|
||||
.icon(IconName::Check)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Success)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Allow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("reject", tool_call_id.as_u64()), "Reject")
|
||||
.icon(IconName::X)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Error)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Reject,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
),
|
||||
)
|
||||
.into_any(),
|
||||
ToolCallConfirmation::Fetch { description, urls } => v_flex()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.border_t_1()
|
||||
.px_2()
|
||||
.py_1p5()
|
||||
// todo! nicer rendering
|
||||
.children(urls.clone())
|
||||
.children(description.clone())
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.gap_1()
|
||||
.child(
|
||||
Button::new(("always_allow", tool_call_id.as_u64()), "Always Allow")
|
||||
.icon(IconName::CheckDouble)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Success)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("allow", tool_call_id.as_u64()), "Allow")
|
||||
.icon(IconName::Check)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Success)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Allow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("reject", tool_call_id.as_u64()), "Reject")
|
||||
.icon(IconName::X)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Error)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Reject,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
),
|
||||
)
|
||||
.into_any(),
|
||||
ToolCallConfirmation::Other { description } => v_flex()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.border_t_1()
|
||||
.px_2()
|
||||
.py_1p5()
|
||||
// todo! nicer rendering
|
||||
.child(description.clone())
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.gap_1()
|
||||
.child(
|
||||
Button::new(("always_allow", tool_call_id.as_u64()), "Always Allow")
|
||||
.icon(IconName::CheckDouble)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Success)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("allow", tool_call_id.as_u64()), "Allow")
|
||||
.icon(IconName::Check)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Success)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Allow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("reject", tool_call_id.as_u64()), "Reject")
|
||||
.icon(IconName::X)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Error)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Reject,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
),
|
||||
)
|
||||
.into_any(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Focusable for AcpThreadView {
|
||||
fn focus_handle(&self, cx: &App) -> FocusHandle {
|
||||
self.message_editor.focus_handle(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AcpThreadView {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let text = self.message_editor.read(cx).text(cx);
|
||||
let is_editor_empty = text.is_empty();
|
||||
let focus_handle = self.message_editor.focus_handle(cx);
|
||||
|
||||
v_flex()
|
||||
.key_context("MessageEditor")
|
||||
.on_action(cx.listener(Self::chat))
|
||||
.h_full()
|
||||
.child(match &self.thread_state {
|
||||
ThreadState::Loading { .. } => v_flex()
|
||||
.p_2()
|
||||
.flex_1()
|
||||
.justify_end()
|
||||
.child(Label::new("Connecting to Gemini...")),
|
||||
ThreadState::LoadError(e) => div()
|
||||
.p_2()
|
||||
.flex_1()
|
||||
.justify_end()
|
||||
.child(Label::new(format!("Failed to load {e}")).into_any_element()),
|
||||
ThreadState::Ready { thread, .. } => v_flex()
|
||||
.flex_1()
|
||||
.gap_2()
|
||||
.pb_2()
|
||||
.child(
|
||||
list(self.list_state.clone())
|
||||
.with_sizing_behavior(gpui::ListSizingBehavior::Auto)
|
||||
.flex_grow(),
|
||||
)
|
||||
.child(div().px_3().children(if self.send_task.is_none() {
|
||||
None
|
||||
} else {
|
||||
Label::new(if thread.read(cx).waiting_for_tool_confirmation() {
|
||||
"Waiting for tool confirmation"
|
||||
} else {
|
||||
"Generating..."
|
||||
})
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small)
|
||||
.into()
|
||||
})),
|
||||
})
|
||||
.child(
|
||||
v_flex()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.p_2()
|
||||
.gap_2()
|
||||
.child(self.message_editor.clone())
|
||||
.child(h_flex().justify_end().child(if self.send_task.is_some() {
|
||||
IconButton::new("stop-generation", IconName::StopFilled)
|
||||
.icon_color(Color::Error)
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Error))
|
||||
.tooltip(move |window, cx| {
|
||||
Tooltip::for_action(
|
||||
"Stop Generation",
|
||||
&editor::actions::Cancel,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.disabled(is_editor_empty)
|
||||
.on_click(cx.listener(|this, _event, _, _| this.cancel()))
|
||||
} else {
|
||||
IconButton::new("send-message", IconName::Send)
|
||||
.icon_color(Color::Accent)
|
||||
.style(ButtonStyle::Filled)
|
||||
.disabled(is_editor_empty)
|
||||
.on_click({
|
||||
let focus_handle = focus_handle.clone();
|
||||
move |_event, window, cx| {
|
||||
focus_handle.dispatch_action(&Chat, window, cx);
|
||||
}
|
||||
})
|
||||
.when(!is_editor_empty, |button| {
|
||||
button.tooltip(move |window, cx| {
|
||||
Tooltip::for_action("Send", &Chat, window, cx)
|
||||
})
|
||||
})
|
||||
.when(is_editor_empty, |button| {
|
||||
button.tooltip(Tooltip::text("Type a message to submit"))
|
||||
})
|
||||
})),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn user_message_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
|
||||
let mut style = default_markdown_style(window, cx);
|
||||
let mut text_style = window.text_style();
|
||||
let theme_settings = ThemeSettings::get_global(cx);
|
||||
|
||||
let buffer_font = theme_settings.buffer_font.family.clone();
|
||||
let buffer_font_size = TextSize::Small.rems(cx);
|
||||
|
||||
text_style.refine(&TextStyleRefinement {
|
||||
font_family: Some(buffer_font),
|
||||
font_size: Some(buffer_font_size.into()),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
style.base_text_style = text_style;
|
||||
style
|
||||
}
|
||||
|
||||
fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
|
||||
let theme_settings = ThemeSettings::get_global(cx);
|
||||
let colors = cx.theme().colors();
|
||||
let ui_font_size = TextSize::Default.rems(cx);
|
||||
let buffer_font_size = TextSize::Small.rems(cx);
|
||||
let mut text_style = window.text_style();
|
||||
let line_height = buffer_font_size * 1.75;
|
||||
|
||||
text_style.refine(&TextStyleRefinement {
|
||||
font_family: Some(theme_settings.ui_font.family.clone()),
|
||||
font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
|
||||
font_features: Some(theme_settings.ui_font.features.clone()),
|
||||
font_size: Some(ui_font_size.into()),
|
||||
line_height: Some(line_height.into()),
|
||||
color: Some(cx.theme().colors().text),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
MarkdownStyle {
|
||||
base_text_style: text_style.clone(),
|
||||
syntax: cx.theme().syntax().clone(),
|
||||
selection_background_color: cx.theme().colors().element_selection_background,
|
||||
code_block_overflow_x_scroll: true,
|
||||
table_overflow_x_scroll: true,
|
||||
heading_level_styles: Some(HeadingLevelStyles {
|
||||
h1: Some(TextStyleRefinement {
|
||||
font_size: Some(rems(1.15).into()),
|
||||
..Default::default()
|
||||
}),
|
||||
h2: Some(TextStyleRefinement {
|
||||
font_size: Some(rems(1.1).into()),
|
||||
..Default::default()
|
||||
}),
|
||||
h3: Some(TextStyleRefinement {
|
||||
font_size: Some(rems(1.05).into()),
|
||||
..Default::default()
|
||||
}),
|
||||
h4: Some(TextStyleRefinement {
|
||||
font_size: Some(rems(1.).into()),
|
||||
..Default::default()
|
||||
}),
|
||||
h5: Some(TextStyleRefinement {
|
||||
font_size: Some(rems(0.95).into()),
|
||||
..Default::default()
|
||||
}),
|
||||
h6: Some(TextStyleRefinement {
|
||||
font_size: Some(rems(0.875).into()),
|
||||
..Default::default()
|
||||
}),
|
||||
}),
|
||||
code_block: StyleRefinement {
|
||||
padding: EdgesRefinement {
|
||||
top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
|
||||
left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
|
||||
right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
|
||||
bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
|
||||
},
|
||||
background: Some(colors.editor_background.into()),
|
||||
text: Some(TextStyleRefinement {
|
||||
font_family: Some(theme_settings.buffer_font.family.clone()),
|
||||
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
|
||||
font_features: Some(theme_settings.buffer_font.features.clone()),
|
||||
font_size: Some(buffer_font_size.into()),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
inline_code: TextStyleRefinement {
|
||||
font_family: Some(theme_settings.buffer_font.family.clone()),
|
||||
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
|
||||
font_features: Some(theme_settings.buffer_font.features.clone()),
|
||||
font_size: Some(buffer_font_size.into()),
|
||||
background_color: Some(colors.editor_foreground.opacity(0.08)),
|
||||
..Default::default()
|
||||
},
|
||||
link: TextStyleRefinement {
|
||||
background_color: Some(colors.editor_foreground.opacity(0.025)),
|
||||
underline: Some(UnderlineStyle {
|
||||
color: Some(colors.text_accent.opacity(0.5)),
|
||||
thickness: px(1.),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
link_callback: Some(Rc::new(move |_url, _cx| {
|
||||
// todo!()
|
||||
// if MentionLink::is_valid(url) {
|
||||
// let colors = cx.theme().colors();
|
||||
// Some(TextStyleRefinement {
|
||||
// background_color: Some(colors.element_background),
|
||||
// ..Default::default()
|
||||
// })
|
||||
// } else {
|
||||
None
|
||||
// }
|
||||
})),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
@@ -31,13 +31,7 @@ use workspace::{StatusItemView, Workspace, item::ItemHandle};
|
||||
|
||||
const GIT_OPERATION_DELAY: Duration = Duration::from_millis(0);
|
||||
|
||||
actions!(
|
||||
activity_indicator,
|
||||
[
|
||||
/// Displays error messages from language servers in the status bar.
|
||||
ShowErrorMessage
|
||||
]
|
||||
);
|
||||
actions!(activity_indicator, [ShowErrorMessage]);
|
||||
|
||||
pub enum Event {
|
||||
ShowStatus {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings};
|
||||
use assistant_tool::{AnyTool, ToolSource, ToolWorkingSet, UniqueToolName};
|
||||
use assistant_tool::{Tool, ToolSource, ToolWorkingSet};
|
||||
use collections::IndexMap;
|
||||
use convert_case::{Case, Casing};
|
||||
use fs::Fs;
|
||||
@@ -72,7 +72,7 @@ impl AgentProfile {
|
||||
&self.id
|
||||
}
|
||||
|
||||
pub fn enabled_tools(&self, cx: &App) -> Vec<(UniqueToolName, AnyTool)> {
|
||||
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
|
||||
let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
|
||||
return Vec::new();
|
||||
};
|
||||
@@ -81,7 +81,7 @@ impl AgentProfile {
|
||||
.read(cx)
|
||||
.tools(cx)
|
||||
.into_iter()
|
||||
.filter(|(_, tool)| Self::is_enabled(settings, tool.source(), tool.name()))
|
||||
.filter(|tool| Self::is_enabled(settings, tool.source(), tool.name()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -108,7 +108,7 @@ impl AgentProfile {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use agent_settings::ContextServerPreset;
|
||||
use assistant_tool::{Tool, ToolRegistry};
|
||||
use assistant_tool::ToolRegistry;
|
||||
use collections::IndexMap;
|
||||
use gpui::SharedString;
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
@@ -137,7 +137,7 @@ mod tests {
|
||||
let mut enabled_tools = cx
|
||||
.read(|cx| profile.enabled_tools(cx))
|
||||
.into_iter()
|
||||
.map(|(_, tool)| tool.name())
|
||||
.map(|tool| tool.name())
|
||||
.collect::<Vec<_>>();
|
||||
enabled_tools.sort();
|
||||
|
||||
@@ -174,7 +174,7 @@ mod tests {
|
||||
let mut enabled_tools = cx
|
||||
.read(|cx| profile.enabled_tools(cx))
|
||||
.into_iter()
|
||||
.map(|(_, tool)| tool.name())
|
||||
.map(|tool| tool.name())
|
||||
.collect::<Vec<_>>();
|
||||
enabled_tools.sort();
|
||||
|
||||
@@ -207,7 +207,7 @@ mod tests {
|
||||
let mut enabled_tools = cx
|
||||
.read(|cx| profile.enabled_tools(cx))
|
||||
.into_iter()
|
||||
.map(|(_, tool)| tool.name())
|
||||
.map(|tool| tool.name())
|
||||
.collect::<Vec<_>>();
|
||||
enabled_tools.sort();
|
||||
|
||||
@@ -267,16 +267,10 @@ mod tests {
|
||||
}
|
||||
|
||||
fn default_tool_set(cx: &mut TestAppContext) -> Entity<ToolWorkingSet> {
|
||||
cx.new(|cx| {
|
||||
cx.new(|_| {
|
||||
let mut tool_set = ToolWorkingSet::default();
|
||||
tool_set.insert(
|
||||
Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")).into(),
|
||||
cx,
|
||||
);
|
||||
tool_set.insert(
|
||||
Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")).into(),
|
||||
cx,
|
||||
);
|
||||
tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")));
|
||||
tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")));
|
||||
tool_set
|
||||
})
|
||||
}
|
||||
@@ -296,8 +290,6 @@ mod tests {
|
||||
}
|
||||
|
||||
impl Tool for FakeTool {
|
||||
type Input = ();
|
||||
|
||||
fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
@@ -316,17 +308,17 @@ mod tests {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _input: &Self::Input, _cx: &App) -> bool {
|
||||
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &Self::Input) -> String {
|
||||
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
_input: Self::Input,
|
||||
_input: serde_json::Value,
|
||||
_request: Arc<language_model::LanguageModelRequest>,
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<assistant_tool::ActionLog>,
|
||||
|
||||
@@ -29,8 +29,6 @@ impl ContextServerTool {
|
||||
}
|
||||
|
||||
impl Tool for ContextServerTool {
|
||||
type Input = serde_json::Value;
|
||||
|
||||
fn name(&self) -> String {
|
||||
self.tool.name.clone()
|
||||
}
|
||||
@@ -49,7 +47,7 @@ impl Tool for ContextServerTool {
|
||||
}
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
@@ -71,13 +69,13 @@ impl Tool for ContextServerTool {
|
||||
})
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &Self::Input) -> String {
|
||||
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||
format!("Run MCP tool `{}`", self.tool.name)
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
|
||||
@@ -10,10 +10,10 @@ use crate::{
|
||||
};
|
||||
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, AnyTool, AnyToolCard, ToolWorkingSet};
|
||||
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use client::{ModelRequestUsage, RequestUsage};
|
||||
use collections::HashMap;
|
||||
use collections::{HashMap, HashSet};
|
||||
use feature_flags::{self, FeatureFlagAppExt};
|
||||
use futures::{FutureExt, StreamExt as _, future::Shared};
|
||||
use git::repository::DiffType;
|
||||
@@ -960,14 +960,13 @@ impl Thread {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
) -> Vec<LanguageModelRequestTool> {
|
||||
if model.supports_tools() {
|
||||
self.profile
|
||||
.enabled_tools(cx)
|
||||
resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
|
||||
.into_iter()
|
||||
.filter_map(|(name, tool)| {
|
||||
// Skip tools that cannot be supported
|
||||
let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
|
||||
Some(LanguageModelRequestTool {
|
||||
name: name.into(),
|
||||
name,
|
||||
description: tool.description(),
|
||||
input_schema,
|
||||
})
|
||||
@@ -2387,7 +2386,7 @@ impl Thread {
|
||||
|
||||
let tool_list = available_tools
|
||||
.iter()
|
||||
.map(|(name, tool)| format!("- {}: {}", name, tool.description()))
|
||||
.map(|tool| format!("- {}: {}", tool.name(), tool.description()))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
@@ -2452,7 +2451,7 @@ impl Thread {
|
||||
ui_text: impl Into<SharedString>,
|
||||
input: serde_json::Value,
|
||||
request: Arc<LanguageModelRequest>,
|
||||
tool: AnyTool,
|
||||
tool: Arc<dyn Tool>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Thread>,
|
||||
@@ -2468,7 +2467,7 @@ impl Thread {
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
request: Arc<LanguageModelRequest>,
|
||||
input: serde_json::Value,
|
||||
tool: AnyTool,
|
||||
tool: Arc<dyn Tool>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Thread>,
|
||||
@@ -2607,7 +2606,7 @@ impl Thread {
|
||||
.profile
|
||||
.enabled_tools(cx)
|
||||
.iter()
|
||||
.map(|(name, _)| name.clone().into())
|
||||
.map(|tool| tool.name())
|
||||
.collect();
|
||||
|
||||
self.message_feedback.insert(message_id, feedback);
|
||||
@@ -3145,6 +3144,85 @@ struct PendingCompletion {
|
||||
_task: Task<()>,
|
||||
}
|
||||
|
||||
/// Resolves tool name conflicts by ensuring all tool names are unique.
|
||||
///
|
||||
/// When multiple tools have the same name, this function applies the following rules:
|
||||
/// 1. Native tools always keep their original name
|
||||
/// 2. Context server tools get prefixed with their server ID and an underscore
|
||||
/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
|
||||
/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
|
||||
///
|
||||
/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
|
||||
fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
|
||||
fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
|
||||
let mut tool_name = tool.name();
|
||||
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
|
||||
tool_name
|
||||
}
|
||||
|
||||
const MAX_TOOL_NAME_LENGTH: usize = 64;
|
||||
|
||||
let mut duplicated_tool_names = HashSet::default();
|
||||
let mut seen_tool_names = HashSet::default();
|
||||
for tool in tools {
|
||||
let tool_name = resolve_tool_name(tool);
|
||||
if seen_tool_names.contains(&tool_name) {
|
||||
debug_assert!(
|
||||
tool.source() != assistant_tool::ToolSource::Native,
|
||||
"There are two built-in tools with the same name: {}",
|
||||
tool_name
|
||||
);
|
||||
duplicated_tool_names.insert(tool_name);
|
||||
} else {
|
||||
seen_tool_names.insert(tool_name);
|
||||
}
|
||||
}
|
||||
|
||||
if duplicated_tool_names.is_empty() {
|
||||
return tools
|
||||
.into_iter()
|
||||
.map(|tool| (resolve_tool_name(tool), tool.clone()))
|
||||
.collect();
|
||||
}
|
||||
|
||||
tools
|
||||
.into_iter()
|
||||
.filter_map(|tool| {
|
||||
let mut tool_name = resolve_tool_name(tool);
|
||||
if !duplicated_tool_names.contains(&tool_name) {
|
||||
return Some((tool_name, tool.clone()));
|
||||
}
|
||||
match tool.source() {
|
||||
assistant_tool::ToolSource::Native => {
|
||||
// Built-in tools always keep their original name
|
||||
Some((tool_name, tool.clone()))
|
||||
}
|
||||
assistant_tool::ToolSource::ContextServer { id } => {
|
||||
// Context server tools are prefixed with the context server ID, and truncated if necessary
|
||||
tool_name.insert(0, '_');
|
||||
if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
|
||||
let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
|
||||
let mut id = id.to_string();
|
||||
id.truncate(len);
|
||||
tool_name.insert_str(0, &id);
|
||||
} else {
|
||||
tool_name.insert_str(0, &id);
|
||||
}
|
||||
|
||||
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
|
||||
|
||||
if seen_tool_names.contains(&tool_name) {
|
||||
log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
|
||||
None
|
||||
} else {
|
||||
Some((tool_name, tool.clone()))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -3160,6 +3238,7 @@ mod tests {
|
||||
use futures::future::BoxFuture;
|
||||
use futures::stream::BoxStream;
|
||||
use gpui::TestAppContext;
|
||||
use icons::IconName;
|
||||
use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
|
||||
use language_model::{
|
||||
LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
|
||||
@@ -3804,6 +3883,148 @@ fn main() {{
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_resolve_tool_name_conflicts() {
|
||||
use assistant_tool::{Tool, ToolSource};
|
||||
|
||||
assert_resolve_tool_name_conflicts(
|
||||
vec![
|
||||
TestTool::new("tool1", ToolSource::Native),
|
||||
TestTool::new("tool2", ToolSource::Native),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
|
||||
],
|
||||
vec!["tool1", "tool2", "tool3"],
|
||||
);
|
||||
|
||||
assert_resolve_tool_name_conflicts(
|
||||
vec![
|
||||
TestTool::new("tool1", ToolSource::Native),
|
||||
TestTool::new("tool2", ToolSource::Native),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
|
||||
],
|
||||
vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
|
||||
);
|
||||
|
||||
assert_resolve_tool_name_conflicts(
|
||||
vec![
|
||||
TestTool::new("tool1", ToolSource::Native),
|
||||
TestTool::new("tool2", ToolSource::Native),
|
||||
TestTool::new("tool3", ToolSource::Native),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
|
||||
],
|
||||
vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
|
||||
);
|
||||
|
||||
// Test that tool with very long name is always truncated
|
||||
assert_resolve_tool_name_conflicts(
|
||||
vec![TestTool::new(
|
||||
"tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
|
||||
ToolSource::Native,
|
||||
)],
|
||||
vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
|
||||
);
|
||||
|
||||
// Test deduplication of tools with very long names, in this case the mcp server name should be truncated
|
||||
assert_resolve_tool_name_conflicts(
|
||||
vec![
|
||||
TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
|
||||
TestTool::new(
|
||||
"tool-with-very-very-very-long-name",
|
||||
ToolSource::ContextServer {
|
||||
id: "mcp-with-very-very-very-long-name".into(),
|
||||
},
|
||||
),
|
||||
],
|
||||
vec![
|
||||
"tool-with-very-very-very-long-name",
|
||||
"mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
|
||||
],
|
||||
);
|
||||
|
||||
fn assert_resolve_tool_name_conflicts(
|
||||
tools: Vec<TestTool>,
|
||||
expected: Vec<impl Into<String>>,
|
||||
) {
|
||||
let tools: Vec<Arc<dyn Tool>> = tools
|
||||
.into_iter()
|
||||
.map(|t| Arc::new(t) as Arc<dyn Tool>)
|
||||
.collect();
|
||||
let tools = resolve_tool_name_conflicts(&tools);
|
||||
assert_eq!(tools.len(), expected.len());
|
||||
for (i, expected_name) in expected.into_iter().enumerate() {
|
||||
let expected_name = expected_name.into();
|
||||
let actual_name = &tools[i].0;
|
||||
assert_eq!(
|
||||
actual_name, &expected_name,
|
||||
"Expected '{}' got '{}' at index {}",
|
||||
expected_name, actual_name, i
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
struct TestTool {
|
||||
name: String,
|
||||
source: ToolSource,
|
||||
}
|
||||
|
||||
impl TestTool {
|
||||
fn new(name: impl Into<String>, source: ToolSource) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
source,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Tool for TestTool {
|
||||
fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::Ai
|
||||
}
|
||||
|
||||
fn may_perform_edits(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn source(&self) -> ToolSource {
|
||||
self.source.clone()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"Test tool".to_string()
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||
"Test tool".to_string()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
_input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
_cx: &mut App,
|
||||
) -> assistant_tool::ToolResult {
|
||||
assistant_tool::ToolResult {
|
||||
output: Task::ready(Err(anyhow::anyhow!("No content"))),
|
||||
card: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to create a model that returns errors
|
||||
enum TestError {
|
||||
Overloaded,
|
||||
|
||||
@@ -537,8 +537,8 @@ impl ThreadStore {
|
||||
}
|
||||
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
|
||||
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
|
||||
tool_working_set.update(cx, |tool_working_set, cx| {
|
||||
tool_working_set.remove(&tool_ids, cx);
|
||||
tool_working_set.update(cx, |tool_working_set, _| {
|
||||
tool_working_set.remove(&tool_ids);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -569,18 +569,19 @@ impl ThreadStore {
|
||||
.log_err()
|
||||
{
|
||||
let tool_ids = tool_working_set
|
||||
.update(cx, |tool_working_set, cx| {
|
||||
tool_working_set.extend(
|
||||
response.tools.into_iter().map(|tool| {
|
||||
Arc::new(ContextServerTool::new(
|
||||
.update(cx, |tool_working_set, _| {
|
||||
response
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| {
|
||||
log::info!("registering context server tool: {:?}", tool.name);
|
||||
tool_working_set.insert(Arc::new(ContextServerTool::new(
|
||||
context_server_store.clone(),
|
||||
server.id(),
|
||||
tool,
|
||||
))
|
||||
.into()
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
)))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.log_err();
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::{
|
||||
};
|
||||
use anyhow::Result;
|
||||
use assistant_tool::{
|
||||
AnyTool, AnyToolCard, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
|
||||
AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
|
||||
};
|
||||
use collections::HashMap;
|
||||
use futures::{FutureExt as _, future::Shared};
|
||||
@@ -378,7 +378,7 @@ impl ToolUseState {
|
||||
ui_text: impl Into<Arc<str>>,
|
||||
input: serde_json::Value,
|
||||
request: Arc<LanguageModelRequest>,
|
||||
tool: AnyTool,
|
||||
tool: Arc<dyn Tool>,
|
||||
) {
|
||||
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
|
||||
let ui_text = ui_text.into();
|
||||
@@ -533,7 +533,7 @@ pub struct Confirmation {
|
||||
pub input: serde_json::Value,
|
||||
pub ui_text: Arc<str>,
|
||||
pub request: Arc<LanguageModelRequest>,
|
||||
pub tool: AnyTool,
|
||||
pub tool: Arc<dyn Tool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
||||
@@ -13,12 +13,10 @@ path = "src/agent_ui.rs"
|
||||
doctest = false
|
||||
|
||||
[features]
|
||||
test-support = [
|
||||
"gpui/test-support",
|
||||
"language/test-support",
|
||||
]
|
||||
test-support = ["gpui/test-support", "language/test-support"]
|
||||
|
||||
[dependencies]
|
||||
acp.workspace = true
|
||||
agent.workspace = true
|
||||
agent_settings.workspace = true
|
||||
anyhow.workspace = true
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use crate::context_picker::{ContextPicker, MentionLink};
|
||||
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
|
||||
use crate::message_editor::{extract_message_creases, insert_message_creases};
|
||||
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
|
||||
use crate::ui::{
|
||||
AddedContext, AgentNotification, AgentNotificationEvent, AnimatedLabel, ContextPill,
|
||||
};
|
||||
use crate::{AgentPanel, ModelUsageContext};
|
||||
use agent::{
|
||||
ContextStore, LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, TextThreadStore,
|
||||
@@ -1024,7 +1026,6 @@ impl ActiveThread {
|
||||
}
|
||||
}
|
||||
ThreadEvent::MessageAdded(message_id) => {
|
||||
self.clear_last_error();
|
||||
if let Some(rendered_message) = self.thread.update(cx, |thread, cx| {
|
||||
thread.message(*message_id).map(|message| {
|
||||
RenderedMessage::from_segments(
|
||||
@@ -1041,7 +1042,6 @@ impl ActiveThread {
|
||||
cx.notify();
|
||||
}
|
||||
ThreadEvent::MessageEdited(message_id) => {
|
||||
self.clear_last_error();
|
||||
if let Some(index) = self.messages.iter().position(|id| id == message_id) {
|
||||
if let Some(rendered_message) = self.thread.update(cx, |thread, cx| {
|
||||
thread.message(*message_id).map(|message| {
|
||||
@@ -1818,7 +1818,7 @@ impl ActiveThread {
|
||||
.my_3()
|
||||
.mx_5()
|
||||
.when(is_generating_stale || message.is_hidden, |this| {
|
||||
this.child(LoadingLabel::new("").size(LabelSize::Small))
|
||||
this.child(AnimatedLabel::new("").size(LabelSize::Small))
|
||||
})
|
||||
});
|
||||
|
||||
@@ -2584,7 +2584,7 @@ impl ActiveThread {
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(LoadingLabel::new("Thinking").size(LabelSize::Small)),
|
||||
.child(AnimatedLabel::new("Thinking").size(LabelSize::Small)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
@@ -3153,7 +3153,7 @@ impl ActiveThread {
|
||||
.border_color(self.tool_card_border_color(cx))
|
||||
.rounded_b_lg()
|
||||
.child(
|
||||
LoadingLabel::new("Waiting for Confirmation").size(LabelSize::Small)
|
||||
AnimatedLabel::new("Waiting for Confirmation").size(LabelSize::Small)
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
|
||||
@@ -26,8 +26,8 @@ use project::{
|
||||
};
|
||||
use settings::{Settings, update_settings_file};
|
||||
use ui::{
|
||||
ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu,
|
||||
Scrollbar, ScrollbarState, Switch, SwitchColor, Tooltip, prelude::*,
|
||||
ContextMenu, Disclosure, ElevationIndex, Indicator, PopoverMenu, Scrollbar, ScrollbarState,
|
||||
Switch, SwitchColor, Tooltip, prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use workspace::Workspace;
|
||||
@@ -172,29 +172,19 @@ impl AgentConfiguration {
|
||||
.unwrap_or(false);
|
||||
|
||||
v_flex()
|
||||
.when(is_expanded, |this| this.mb_2())
|
||||
.child(
|
||||
div()
|
||||
.opacity(0.6)
|
||||
.px_2()
|
||||
.child(Divider::horizontal().color(DividerColor::Border)),
|
||||
)
|
||||
.py_2()
|
||||
.gap_1p5()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border.opacity(0.6))
|
||||
.child(
|
||||
h_flex()
|
||||
.map(|this| {
|
||||
if is_expanded {
|
||||
this.mt_2().mb_1()
|
||||
} else {
|
||||
this.my_2()
|
||||
}
|
||||
})
|
||||
.w_full()
|
||||
.gap_1()
|
||||
.justify_between()
|
||||
.child(
|
||||
h_flex()
|
||||
.id(provider_id_string.clone())
|
||||
.cursor_pointer()
|
||||
.px_2()
|
||||
.py_0p5()
|
||||
.w_full()
|
||||
.justify_between()
|
||||
@@ -257,16 +247,12 @@ impl AgentConfiguration {
|
||||
)
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.px_2()
|
||||
.when(is_expanded, |parent| match configuration_view {
|
||||
Some(configuration_view) => parent.child(configuration_view),
|
||||
None => parent.child(Label::new(format!(
|
||||
"No configuration view for {provider_name}",
|
||||
))),
|
||||
}),
|
||||
)
|
||||
.when(is_expanded, |parent| match configuration_view {
|
||||
Some(configuration_view) => parent.child(configuration_view),
|
||||
None => parent.child(Label::new(format!(
|
||||
"No configuration view for {provider_name}",
|
||||
))),
|
||||
})
|
||||
}
|
||||
|
||||
fn render_provider_configuration_section(
|
||||
@@ -276,11 +262,12 @@ impl AgentConfiguration {
|
||||
let providers = LanguageModelRegistry::read_global(cx).providers();
|
||||
|
||||
v_flex()
|
||||
.p(DynamicSpacing::Base16.rems(cx))
|
||||
.pr(DynamicSpacing::Base20.rems(cx))
|
||||
.border_b_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.child(
|
||||
v_flex()
|
||||
.p(DynamicSpacing::Base16.rems(cx))
|
||||
.pr(DynamicSpacing::Base20.rems(cx))
|
||||
.pb_0()
|
||||
.mb_2p5()
|
||||
.gap_0p5()
|
||||
.child(Headline::new("LLM Providers"))
|
||||
@@ -289,15 +276,10 @@ impl AgentConfiguration {
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.pl(DynamicSpacing::Base08.rems(cx))
|
||||
.pr(DynamicSpacing::Base20.rems(cx))
|
||||
.children(
|
||||
providers.into_iter().map(|provider| {
|
||||
self.render_provider_configuration_block(&provider, cx)
|
||||
}),
|
||||
),
|
||||
.children(
|
||||
providers
|
||||
.into_iter()
|
||||
.map(|provider| self.render_provider_configuration_block(&provider, cx)),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -379,14 +379,6 @@ impl ConfigureContextServerModal {
|
||||
};
|
||||
|
||||
self.state = State::Waiting;
|
||||
|
||||
let existing_server = self.context_server_store.read(cx).get_running_server(&id);
|
||||
if existing_server.is_some() {
|
||||
self.context_server_store.update(cx, |store, cx| {
|
||||
store.stop_server(&id, cx).log_err();
|
||||
});
|
||||
}
|
||||
|
||||
let wait_for_context_server_task =
|
||||
wait_for_context_server(&self.context_server_store, id.clone(), cx);
|
||||
cx.spawn({
|
||||
@@ -407,21 +399,13 @@ impl ConfigureContextServerModal {
|
||||
})
|
||||
.detach();
|
||||
|
||||
let settings_changed =
|
||||
ProjectSettings::get_global(cx).context_servers.get(&id.0) != Some(&settings);
|
||||
|
||||
if settings_changed {
|
||||
// When we write the settings to the file, the context server will be restarted.
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
let fs = workspace.app_state().fs.clone();
|
||||
update_settings_file::<ProjectSettings>(fs.clone(), cx, |project_settings, _| {
|
||||
project_settings.context_servers.insert(id.0, settings);
|
||||
});
|
||||
// When we write the settings to the file, the context server will be restarted.
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
let fs = workspace.app_state().fs.clone();
|
||||
update_settings_file::<ProjectSettings>(fs.clone(), cx, |project_settings, _| {
|
||||
project_settings.context_servers.insert(id.0, settings);
|
||||
});
|
||||
} else if let Some(existing_server) = existing_server {
|
||||
self.context_server_store
|
||||
.update(cx, |store, cx| store.start_server(existing_server, cx));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn cancel(&mut self, _: &menu::Cancel, cx: &mut Context<Self>) {
|
||||
|
||||
@@ -7,6 +7,7 @@ use std::time::Duration;
|
||||
use db::kvp::{Dismissable, KEY_VALUE_STORE};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::NewGeminiThread;
|
||||
use crate::language_model_selector::ToggleModelSelector;
|
||||
use crate::{
|
||||
AddContextServer, AgentDiffPane, ContinueThread, ContinueWithBurnMode,
|
||||
@@ -109,6 +110,12 @@ pub fn init(cx: &mut App) {
|
||||
panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx));
|
||||
}
|
||||
})
|
||||
.register_action(|workspace, _: &NewGeminiThread, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| panel.new_gemini_thread(window, cx));
|
||||
}
|
||||
})
|
||||
.register_action(|workspace, action: &OpenRulesLibrary, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
@@ -125,6 +132,7 @@ pub fn init(cx: &mut App) {
|
||||
let thread = thread.read(cx).thread().clone();
|
||||
AgentDiffPane::deploy_in_workspace(thread, workspace, window, cx);
|
||||
}
|
||||
ActiveView::AcpThread { .. } => todo!(),
|
||||
ActiveView::TextThread { .. }
|
||||
| ActiveView::History
|
||||
| ActiveView::Configuration => {}
|
||||
@@ -188,6 +196,9 @@ enum ActiveView {
|
||||
message_editor: Entity<MessageEditor>,
|
||||
_subscriptions: Vec<gpui::Subscription>,
|
||||
},
|
||||
AcpThread {
|
||||
thread_view: Entity<acp::AcpThreadView>,
|
||||
},
|
||||
TextThread {
|
||||
context_editor: Entity<TextThreadEditor>,
|
||||
title_editor: Entity<Editor>,
|
||||
@@ -207,7 +218,9 @@ enum WhichFontSize {
|
||||
impl ActiveView {
|
||||
pub fn which_font_size_used(&self) -> WhichFontSize {
|
||||
match self {
|
||||
ActiveView::Thread { .. } | ActiveView::History => WhichFontSize::AgentFont,
|
||||
ActiveView::Thread { .. } | ActiveView::AcpThread { .. } | ActiveView::History => {
|
||||
WhichFontSize::AgentFont
|
||||
}
|
||||
ActiveView::TextThread { .. } => WhichFontSize::BufferFont,
|
||||
ActiveView::Configuration => WhichFontSize::None,
|
||||
}
|
||||
@@ -238,6 +251,9 @@ impl ActiveView {
|
||||
thread.scroll_to_bottom(cx);
|
||||
});
|
||||
}
|
||||
ActiveView::AcpThread { .. } => {
|
||||
// todo!
|
||||
}
|
||||
ActiveView::TextThread { .. }
|
||||
| ActiveView::History
|
||||
| ActiveView::Configuration => {}
|
||||
@@ -653,6 +669,9 @@ impl AgentPanel {
|
||||
.clone()
|
||||
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx));
|
||||
}
|
||||
ActiveView::AcpThread { .. } => {
|
||||
// todo!
|
||||
}
|
||||
ActiveView::TextThread { .. }
|
||||
| ActiveView::History
|
||||
| ActiveView::Configuration => {}
|
||||
@@ -733,6 +752,9 @@ impl AgentPanel {
|
||||
ActiveView::Thread { thread, .. } => {
|
||||
thread.update(cx, |thread, cx| thread.cancel_last_completion(window, cx));
|
||||
}
|
||||
ActiveView::AcpThread { thread_view, .. } => {
|
||||
thread_view.update(cx, |thread_element, _cx| thread_element.cancel());
|
||||
}
|
||||
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {}
|
||||
}
|
||||
}
|
||||
@@ -740,6 +762,10 @@ impl AgentPanel {
|
||||
fn active_message_editor(&self) -> Option<&Entity<MessageEditor>> {
|
||||
match &self.active_view {
|
||||
ActiveView::Thread { message_editor, .. } => Some(message_editor),
|
||||
ActiveView::AcpThread { .. } => {
|
||||
// todo!
|
||||
None
|
||||
}
|
||||
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => None,
|
||||
}
|
||||
}
|
||||
@@ -862,6 +888,19 @@ impl AgentPanel {
|
||||
context_editor.focus_handle(cx).focus(window);
|
||||
}
|
||||
|
||||
fn new_gemini_thread(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let project = self.project.clone();
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let thread_view =
|
||||
cx.new_window_entity(|window, cx| acp::AcpThreadView::new(project, window, cx))?;
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.set_active_view(ActiveView::AcpThread { thread_view }, window, cx);
|
||||
})
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn deploy_rules_library(
|
||||
&mut self,
|
||||
action: &OpenRulesLibrary,
|
||||
@@ -994,6 +1033,7 @@ impl AgentPanel {
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let message_editor = cx.new(|cx| {
|
||||
MessageEditor::new(
|
||||
self.fs.clone(),
|
||||
@@ -1018,6 +1058,7 @@ impl AgentPanel {
|
||||
pub fn go_back(&mut self, _: &workspace::GoBack, window: &mut Window, cx: &mut Context<Self>) {
|
||||
match self.active_view {
|
||||
ActiveView::Configuration | ActiveView::History => {
|
||||
// todo! check go back works correctly
|
||||
if let Some(previous_view) = self.previous_view.take() {
|
||||
self.active_view = previous_view;
|
||||
|
||||
@@ -1025,6 +1066,9 @@ impl AgentPanel {
|
||||
ActiveView::Thread { message_editor, .. } => {
|
||||
message_editor.focus_handle(cx).focus(window);
|
||||
}
|
||||
ActiveView::AcpThread { .. } => {
|
||||
todo!()
|
||||
}
|
||||
ActiveView::TextThread { context_editor, .. } => {
|
||||
context_editor.focus_handle(cx).focus(window);
|
||||
}
|
||||
@@ -1144,6 +1188,7 @@ impl AgentPanel {
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
ActiveView::AcpThread { .. } => todo!(),
|
||||
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {}
|
||||
}
|
||||
}
|
||||
@@ -1197,6 +1242,9 @@ impl AgentPanel {
|
||||
)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
ActiveView::AcpThread { .. } => {
|
||||
todo!()
|
||||
}
|
||||
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {}
|
||||
}
|
||||
}
|
||||
@@ -1231,6 +1279,10 @@ impl AgentPanel {
|
||||
pub(crate) fn active_thread(&self, cx: &App) -> Option<Entity<Thread>> {
|
||||
match &self.active_view {
|
||||
ActiveView::Thread { thread, .. } => Some(thread.read(cx).thread().clone()),
|
||||
ActiveView::AcpThread { .. } => {
|
||||
// todo!
|
||||
None
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -1336,6 +1388,9 @@ impl AgentPanel {
|
||||
});
|
||||
}
|
||||
}
|
||||
ActiveView::AcpThread { .. } => {
|
||||
// todo!
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
@@ -1351,6 +1406,9 @@ impl AgentPanel {
|
||||
}
|
||||
})
|
||||
}
|
||||
ActiveView::AcpThread { .. } => {
|
||||
// todo! push history entry
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
@@ -1437,6 +1495,7 @@ impl Focusable for AgentPanel {
|
||||
fn focus_handle(&self, cx: &App) -> FocusHandle {
|
||||
match &self.active_view {
|
||||
ActiveView::Thread { message_editor, .. } => message_editor.focus_handle(cx),
|
||||
ActiveView::AcpThread { thread_view, .. } => thread_view.focus_handle(cx),
|
||||
ActiveView::History => self.history.focus_handle(cx),
|
||||
ActiveView::TextThread { context_editor, .. } => context_editor.focus_handle(cx),
|
||||
ActiveView::Configuration => {
|
||||
@@ -1593,6 +1652,9 @@ impl AgentPanel {
|
||||
.into_any_element(),
|
||||
}
|
||||
}
|
||||
ActiveView::AcpThread { thread_view } => Label::new(thread_view.read(cx).title(cx))
|
||||
.truncate()
|
||||
.into_any_element(),
|
||||
ActiveView::TextThread {
|
||||
title_editor,
|
||||
context_editor,
|
||||
@@ -1727,6 +1789,10 @@ impl AgentPanel {
|
||||
|
||||
let active_thread = match &self.active_view {
|
||||
ActiveView::Thread { thread, .. } => Some(thread.read(cx).thread().clone()),
|
||||
ActiveView::AcpThread { .. } => {
|
||||
// todo!
|
||||
None
|
||||
}
|
||||
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => None,
|
||||
};
|
||||
|
||||
@@ -1755,6 +1821,7 @@ impl AgentPanel {
|
||||
menu = menu
|
||||
.action("New Thread", NewThread::default().boxed_clone())
|
||||
.action("New Text Thread", NewTextThread.boxed_clone())
|
||||
.action("New Gemini Thread", NewGeminiThread.boxed_clone())
|
||||
.when_some(active_thread, |this, active_thread| {
|
||||
let thread = active_thread.read(cx);
|
||||
if !thread.is_empty() {
|
||||
@@ -1893,6 +1960,10 @@ impl AgentPanel {
|
||||
message_editor,
|
||||
..
|
||||
} => (thread.read(cx), message_editor.read(cx)),
|
||||
ActiveView::AcpThread { .. } => {
|
||||
// todo!
|
||||
return None;
|
||||
}
|
||||
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {
|
||||
return None;
|
||||
}
|
||||
@@ -2031,6 +2102,10 @@ impl AgentPanel {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
ActiveView::AcpThread { .. } => {
|
||||
// todo!
|
||||
return false;
|
||||
}
|
||||
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {
|
||||
return false;
|
||||
}
|
||||
@@ -2615,6 +2690,10 @@ impl AgentPanel {
|
||||
) -> Option<AnyElement> {
|
||||
let active_thread = match &self.active_view {
|
||||
ActiveView::Thread { thread, .. } => thread,
|
||||
ActiveView::AcpThread { .. } => {
|
||||
// todo!
|
||||
return None;
|
||||
}
|
||||
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {
|
||||
return None;
|
||||
}
|
||||
@@ -2961,6 +3040,9 @@ impl AgentPanel {
|
||||
.detach();
|
||||
});
|
||||
}
|
||||
ActiveView::AcpThread { .. } => {
|
||||
unimplemented!()
|
||||
}
|
||||
ActiveView::TextThread { context_editor, .. } => {
|
||||
context_editor.update(cx, |context_editor, cx| {
|
||||
TextThreadEditor::insert_dragged_files(
|
||||
@@ -3034,6 +3116,9 @@ impl Render for AgentPanel {
|
||||
});
|
||||
this.continue_conversation(window, cx);
|
||||
}
|
||||
ActiveView::AcpThread { .. } => {
|
||||
todo!()
|
||||
}
|
||||
ActiveView::TextThread { .. }
|
||||
| ActiveView::History
|
||||
| ActiveView::Configuration => {}
|
||||
@@ -3075,6 +3160,12 @@ impl Render for AgentPanel {
|
||||
})
|
||||
.child(h_flex().child(message_editor.clone()))
|
||||
.child(self.render_drag_target(cx)),
|
||||
ActiveView::AcpThread { thread_view, .. } => parent
|
||||
.relative()
|
||||
.child(thread_view.clone())
|
||||
// todo!
|
||||
// .child(h_flex().child(self.message_editor.clone()))
|
||||
.child(self.render_drag_target(cx)),
|
||||
ActiveView::History => parent.child(self.history.clone()),
|
||||
ActiveView::TextThread {
|
||||
context_editor,
|
||||
|
||||
@@ -54,76 +54,42 @@ pub use ui::preview::{all_agent_previews, get_agent_preview};
|
||||
actions!(
|
||||
agent,
|
||||
[
|
||||
/// Creates a new text-based conversation thread.
|
||||
NewTextThread,
|
||||
/// Toggles the context picker interface for adding files, symbols, or other context.
|
||||
NewGeminiThread,
|
||||
ToggleContextPicker,
|
||||
/// Toggles the navigation menu for switching between threads and views.
|
||||
ToggleNavigationMenu,
|
||||
/// Toggles the options menu for agent settings and preferences.
|
||||
ToggleOptionsMenu,
|
||||
/// Deletes the recently opened thread from history.
|
||||
DeleteRecentlyOpenThread,
|
||||
/// Toggles the profile selector for switching between agent profiles.
|
||||
ToggleProfileSelector,
|
||||
/// Removes all added context from the current conversation.
|
||||
RemoveAllContext,
|
||||
/// Expands the message editor to full size.
|
||||
ExpandMessageEditor,
|
||||
/// Opens the conversation history view.
|
||||
OpenHistory,
|
||||
/// Adds a context server to the configuration.
|
||||
AddContextServer,
|
||||
/// Removes the currently selected thread.
|
||||
RemoveSelectedThread,
|
||||
/// Starts a chat conversation with the agent.
|
||||
Chat,
|
||||
/// Starts a chat conversation with follow-up enabled.
|
||||
ChatWithFollow,
|
||||
/// Cycles to the next inline assist suggestion.
|
||||
CycleNextInlineAssist,
|
||||
/// Cycles to the previous inline assist suggestion.
|
||||
CyclePreviousInlineAssist,
|
||||
/// Moves focus up in the interface.
|
||||
FocusUp,
|
||||
/// Moves focus down in the interface.
|
||||
FocusDown,
|
||||
/// Moves focus left in the interface.
|
||||
FocusLeft,
|
||||
/// Moves focus right in the interface.
|
||||
FocusRight,
|
||||
/// Removes the currently focused context item.
|
||||
RemoveFocusedContext,
|
||||
/// Accepts the suggested context item.
|
||||
AcceptSuggestedContext,
|
||||
/// Opens the active thread as a markdown file.
|
||||
OpenActiveThreadAsMarkdown,
|
||||
/// Opens the agent diff view to review changes.
|
||||
OpenAgentDiff,
|
||||
/// Keeps the current suggestion or change.
|
||||
Keep,
|
||||
/// Rejects the current suggestion or change.
|
||||
Reject,
|
||||
/// Rejects all suggestions or changes.
|
||||
RejectAll,
|
||||
/// Keeps all suggestions or changes.
|
||||
KeepAll,
|
||||
/// Follows the agent's suggestions.
|
||||
Follow,
|
||||
/// Resets the trial upsell notification.
|
||||
ResetTrialUpsell,
|
||||
/// Resets the trial end upsell notification.
|
||||
ResetTrialEndUpsell,
|
||||
/// Continues the current thread.
|
||||
ContinueThread,
|
||||
/// Continues the thread with burn mode enabled.
|
||||
ContinueWithBurnMode,
|
||||
/// Toggles burn mode for faster responses.
|
||||
ToggleBurnMode,
|
||||
]
|
||||
);
|
||||
|
||||
/// Creates a new conversation thread, optionally based on an existing thread.
|
||||
#[derive(Default, Clone, PartialEq, Deserialize, JsonSchema, Action)]
|
||||
#[action(namespace = agent)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
@@ -132,7 +98,6 @@ pub struct NewThread {
|
||||
from_thread_id: Option<ThreadId>,
|
||||
}
|
||||
|
||||
/// Opens the profile management interface for configuring agent tools and settings.
|
||||
#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)]
|
||||
#[action(namespace = agent)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
|
||||
@@ -18,7 +18,6 @@ use ui::{ListItem, ListItemSpacing, prelude::*};
|
||||
actions!(
|
||||
agent,
|
||||
[
|
||||
/// Toggles the language model selector dropdown.
|
||||
#[action(deprecated_aliases = ["assistant::ToggleModelSelector", "assistant2::ToggleModelSelector"])]
|
||||
ToggleModelSelector
|
||||
]
|
||||
|
||||
@@ -47,13 +47,14 @@ use ui::{
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use workspace::{CollaboratorId, Workspace};
|
||||
use zed_actions::agent::Chat;
|
||||
use zed_llm_client::CompletionIntent;
|
||||
|
||||
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention};
|
||||
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
|
||||
use crate::profile_selector::ProfileSelector;
|
||||
use crate::{
|
||||
ActiveThread, AgentDiffPane, Chat, ChatWithFollow, ExpandMessageEditor, Follow, KeepAll,
|
||||
ActiveThread, AgentDiffPane, ChatWithFollow, ExpandMessageEditor, Follow, KeepAll,
|
||||
ModelUsageContext, NewThread, OpenAgentDiff, RejectAll, RemoveAllContext, ToggleBurnMode,
|
||||
ToggleContextPicker, ToggleProfileSelector, register_agent_preview,
|
||||
};
|
||||
|
||||
@@ -85,24 +85,16 @@ use assistant_context::{
|
||||
actions!(
|
||||
assistant,
|
||||
[
|
||||
/// Sends the current message to the assistant.
|
||||
Assist,
|
||||
/// Confirms and executes the entered slash command.
|
||||
ConfirmCommand,
|
||||
/// Copies code from the assistant's response to the clipboard.
|
||||
CopyCode,
|
||||
/// Cycles between user and assistant message roles.
|
||||
CycleMessageRole,
|
||||
/// Inserts the selected text into the active editor.
|
||||
InsertIntoEditor,
|
||||
/// Quotes the current selection in the assistant conversation.
|
||||
QuoteSelection,
|
||||
/// Splits the conversation at the current cursor position.
|
||||
Split,
|
||||
]
|
||||
);
|
||||
|
||||
/// Inserts files that were dragged and dropped into the assistant conversation.
|
||||
#[derive(PartialEq, Clone, Action)]
|
||||
#[action(namespace = assistant, no_json, no_register)]
|
||||
pub enum InsertDraggedFiles {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use agent::{Thread, ThreadEvent};
|
||||
use assistant_tool::{AnyTool, ToolSource};
|
||||
use assistant_tool::{Tool, ToolSource};
|
||||
use collections::HashMap;
|
||||
use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window};
|
||||
use language_model::{LanguageModel, LanguageModelToolSchemaFormat};
|
||||
@@ -7,7 +7,7 @@ use std::sync::Arc;
|
||||
use ui::prelude::*;
|
||||
|
||||
pub struct IncompatibleToolsState {
|
||||
cache: HashMap<LanguageModelToolSchemaFormat, Vec<AnyTool>>,
|
||||
cache: HashMap<LanguageModelToolSchemaFormat, Vec<Arc<dyn Tool>>>,
|
||||
thread: Entity<Thread>,
|
||||
_thread_subscription: Subscription,
|
||||
}
|
||||
@@ -29,7 +29,11 @@ impl IncompatibleToolsState {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn incompatible_tools(&mut self, model: &Arc<dyn LanguageModel>, cx: &App) -> &[AnyTool] {
|
||||
pub fn incompatible_tools(
|
||||
&mut self,
|
||||
model: &Arc<dyn LanguageModel>,
|
||||
cx: &App,
|
||||
) -> &[Arc<dyn Tool>] {
|
||||
self.cache
|
||||
.entry(model.tool_input_format())
|
||||
.or_insert_with(|| {
|
||||
@@ -38,15 +42,15 @@ impl IncompatibleToolsState {
|
||||
.profile()
|
||||
.enabled_tools(cx)
|
||||
.iter()
|
||||
.filter(|(_, tool)| tool.input_schema(model.tool_input_format()).is_err())
|
||||
.map(|(_, tool)| tool.clone())
|
||||
.filter(|tool| tool.input_schema(model.tool_input_format()).is_err())
|
||||
.cloned()
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct IncompatibleToolsTooltip {
|
||||
pub incompatible_tools: Vec<AnyTool>,
|
||||
pub incompatible_tools: Vec<Arc<dyn Tool>>,
|
||||
}
|
||||
|
||||
impl Render for IncompatibleToolsTooltip {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
mod agent_notification;
|
||||
mod animated_label;
|
||||
mod burn_mode_tooltip;
|
||||
mod context_pill;
|
||||
mod onboarding_modal;
|
||||
@@ -6,6 +7,7 @@ pub mod preview;
|
||||
mod upsell;
|
||||
|
||||
pub use agent_notification::*;
|
||||
pub use animated_label::*;
|
||||
pub use burn_mode_tooltip::*;
|
||||
pub use context_pill::*;
|
||||
pub use onboarding_modal::*;
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
use crate::prelude::*;
|
||||
use gpui::{Animation, AnimationExt, FontWeight, pulsating_between};
|
||||
use std::time::Duration;
|
||||
use ui::prelude::*;
|
||||
|
||||
#[derive(IntoElement)]
|
||||
pub struct LoadingLabel {
|
||||
pub struct AnimatedLabel {
|
||||
base: Label,
|
||||
text: SharedString,
|
||||
}
|
||||
|
||||
impl LoadingLabel {
|
||||
impl AnimatedLabel {
|
||||
pub fn new(text: impl Into<SharedString>) -> Self {
|
||||
let text = text.into();
|
||||
LoadingLabel {
|
||||
AnimatedLabel {
|
||||
base: Label::new(text.clone()),
|
||||
text,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelCommon for LoadingLabel {
|
||||
impl LabelCommon for AnimatedLabel {
|
||||
fn size(mut self, size: LabelSize) -> Self {
|
||||
self.base = self.base.size(size);
|
||||
self
|
||||
@@ -80,14 +80,14 @@ impl LabelCommon for LoadingLabel {
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for LoadingLabel {
|
||||
impl RenderOnce for AnimatedLabel {
|
||||
fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
|
||||
let text = self.text.clone();
|
||||
|
||||
self.base
|
||||
.color(Color::Muted)
|
||||
.with_animations(
|
||||
"loading_label",
|
||||
"animated-label",
|
||||
vec![
|
||||
Animation::new(Duration::from_secs(1)),
|
||||
Animation::new(Duration::from_secs(1)).repeat(),
|
||||
@@ -22,7 +22,6 @@ gpui.workspace = true
|
||||
icons.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
log.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
|
||||
@@ -4,19 +4,25 @@ mod tool_registry;
|
||||
mod tool_schema;
|
||||
mod tool_working_set;
|
||||
|
||||
use std::{fmt, fmt::Debug, fmt::Formatter, ops::Deref, sync::Arc};
|
||||
use std::fmt;
|
||||
use std::fmt::Debug;
|
||||
use std::fmt::Formatter;
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use gpui::{
|
||||
AnyElement, AnyWindowHandle, App, Context, Entity, IntoElement, SharedString, Task, WeakEntity,
|
||||
Window,
|
||||
};
|
||||
use gpui::AnyElement;
|
||||
use gpui::AnyWindowHandle;
|
||||
use gpui::Context;
|
||||
use gpui::IntoElement;
|
||||
use gpui::Window;
|
||||
use gpui::{App, Entity, SharedString, Task, WeakEntity};
|
||||
use icons::IconName;
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelImage, LanguageModelRequest, LanguageModelToolSchemaFormat,
|
||||
};
|
||||
use language_model::LanguageModel;
|
||||
use language_model::LanguageModelImage;
|
||||
use language_model::LanguageModelRequest;
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
use project::Project;
|
||||
use serde::de::DeserializeOwned;
|
||||
use workspace::Workspace;
|
||||
|
||||
pub use crate::action_log::*;
|
||||
@@ -193,10 +199,7 @@ pub enum ToolSource {
|
||||
}
|
||||
|
||||
/// A tool that can be used by a language model.
|
||||
pub trait Tool: Send + Sync + 'static {
|
||||
/// The input type that is accepted by the tool.
|
||||
type Input: DeserializeOwned;
|
||||
|
||||
pub trait Tool: 'static + Send + Sync {
|
||||
/// Returns the name of the tool.
|
||||
fn name(&self) -> String;
|
||||
|
||||
@@ -213,7 +216,7 @@ pub trait Tool: Send + Sync + 'static {
|
||||
|
||||
/// Returns true if the tool needs the users's confirmation
|
||||
/// before having permission to run.
|
||||
fn needs_confirmation(&self, input: &Self::Input, cx: &App) -> bool;
|
||||
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
|
||||
|
||||
/// Returns true if the tool may perform edits.
|
||||
fn may_perform_edits(&self) -> bool;
|
||||
@@ -224,18 +227,18 @@ pub trait Tool: Send + Sync + 'static {
|
||||
}
|
||||
|
||||
/// Returns markdown to be displayed in the UI for this tool.
|
||||
fn ui_text(&self, input: &Self::Input) -> String;
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String;
|
||||
|
||||
/// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
|
||||
/// (so information may be missing).
|
||||
fn still_streaming_ui_text(&self, input: &Self::Input) -> String {
|
||||
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
|
||||
self.ui_text(input)
|
||||
}
|
||||
|
||||
/// Runs the tool with the provided input.
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
input: serde_json::Value,
|
||||
request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
@@ -255,199 +258,7 @@ pub trait Tool: Send + Sync + 'static {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AnyTool {
|
||||
inner: Arc<dyn ErasedTool>,
|
||||
}
|
||||
|
||||
/// Copy of `Tool` where the Input type is erased.
|
||||
trait ErasedTool: Send + Sync {
|
||||
fn name(&self) -> String;
|
||||
fn description(&self) -> String;
|
||||
fn icon(&self) -> IconName;
|
||||
fn source(&self) -> ToolSource;
|
||||
fn may_perform_edits(&self) -> bool;
|
||||
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String;
|
||||
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String;
|
||||
fn run(
|
||||
&self,
|
||||
input: serde_json::Value,
|
||||
request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult;
|
||||
fn deserialize_card(
|
||||
&self,
|
||||
output: serde_json::Value,
|
||||
project: Entity<Project>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Option<AnyToolCard>;
|
||||
}
|
||||
|
||||
struct ErasedToolWrapper<T: Tool> {
|
||||
tool: Arc<T>,
|
||||
}
|
||||
|
||||
impl<T: Tool> ErasedTool for ErasedToolWrapper<T> {
|
||||
fn name(&self) -> String {
|
||||
self.tool.name()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
self.tool.description()
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
self.tool.icon()
|
||||
}
|
||||
|
||||
fn source(&self) -> ToolSource {
|
||||
self.tool.source()
|
||||
}
|
||||
|
||||
fn may_perform_edits(&self) -> bool {
|
||||
self.tool.may_perform_edits()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool {
|
||||
match serde_json::from_value::<T::Input>(input.clone()) {
|
||||
Ok(parsed_input) => self.tool.needs_confirmation(&parsed_input, cx),
|
||||
Err(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
self.tool.input_schema(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<T::Input>(input.clone()) {
|
||||
Ok(parsed_input) => self.tool.ui_text(&parsed_input),
|
||||
Err(_) => "Invalid input".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<T::Input>(input.clone()) {
|
||||
Ok(parsed_input) => self.tool.still_streaming_ui_text(&parsed_input),
|
||||
Err(_) => "Invalid input".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
&self,
|
||||
input: serde_json::Value,
|
||||
request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
match serde_json::from_value::<T::Input>(input) {
|
||||
Ok(parsed_input) => self.tool.clone().run(
|
||||
parsed_input,
|
||||
request,
|
||||
project,
|
||||
action_log,
|
||||
model,
|
||||
window,
|
||||
cx,
|
||||
),
|
||||
Err(err) => ToolResult::from(Task::ready(Err(err.into()))),
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_card(
|
||||
&self,
|
||||
output: serde_json::Value,
|
||||
project: Entity<Project>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Option<AnyToolCard> {
|
||||
self.tool
|
||||
.clone()
|
||||
.deserialize_card(output, project, window, cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Tool> From<Arc<T>> for AnyTool {
|
||||
fn from(tool: Arc<T>) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(ErasedToolWrapper { tool }),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AnyTool {
|
||||
pub fn name(&self) -> String {
|
||||
self.inner.name()
|
||||
}
|
||||
|
||||
pub fn description(&self) -> String {
|
||||
self.inner.description()
|
||||
}
|
||||
|
||||
pub fn icon(&self) -> IconName {
|
||||
self.inner.icon()
|
||||
}
|
||||
|
||||
pub fn source(&self) -> ToolSource {
|
||||
self.inner.source()
|
||||
}
|
||||
|
||||
pub fn may_perform_edits(&self) -> bool {
|
||||
self.inner.may_perform_edits()
|
||||
}
|
||||
|
||||
pub fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool {
|
||||
self.inner.needs_confirmation(input, cx)
|
||||
}
|
||||
|
||||
pub fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
self.inner.input_schema(format)
|
||||
}
|
||||
|
||||
pub fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
self.inner.ui_text(input)
|
||||
}
|
||||
|
||||
pub fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
|
||||
self.inner.still_streaming_ui_text(input)
|
||||
}
|
||||
|
||||
pub fn run(
|
||||
&self,
|
||||
input: serde_json::Value,
|
||||
request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
self.inner
|
||||
.run(input, request, project, action_log, model, window, cx)
|
||||
}
|
||||
|
||||
pub fn deserialize_card(
|
||||
&self,
|
||||
output: serde_json::Value,
|
||||
project: Entity<Project>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Option<AnyToolCard> {
|
||||
self.inner.deserialize_card(output, project, window, cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for AnyTool {
|
||||
impl Debug for dyn Tool {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Tool").field("name", &self.name()).finish()
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use gpui::Global;
|
||||
use gpui::{App, ReadGlobal};
|
||||
use parking_lot::RwLock;
|
||||
|
||||
use crate::{AnyTool, Tool};
|
||||
use crate::Tool;
|
||||
|
||||
#[derive(Default, Deref, DerefMut)]
|
||||
struct GlobalToolRegistry(Arc<ToolRegistry>);
|
||||
@@ -15,7 +15,7 @@ impl Global for GlobalToolRegistry {}
|
||||
|
||||
#[derive(Default)]
|
||||
struct ToolRegistryState {
|
||||
tools: HashMap<Arc<str>, AnyTool>,
|
||||
tools: HashMap<Arc<str>, Arc<dyn Tool>>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -48,7 +48,7 @@ impl ToolRegistry {
|
||||
pub fn register_tool(&self, tool: impl Tool) {
|
||||
let mut state = self.state.write();
|
||||
let tool_name: Arc<str> = tool.name().into();
|
||||
state.tools.insert(tool_name, Arc::new(tool).into());
|
||||
state.tools.insert(tool_name, Arc::new(tool));
|
||||
}
|
||||
|
||||
/// Unregisters the provided [`Tool`].
|
||||
@@ -63,12 +63,12 @@ impl ToolRegistry {
|
||||
}
|
||||
|
||||
/// Returns the list of tools in the registry.
|
||||
pub fn tools(&self) -> Vec<AnyTool> {
|
||||
pub fn tools(&self) -> Vec<Arc<dyn Tool>> {
|
||||
self.state.read().tools.values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Returns the [`Tool`] with the given name.
|
||||
pub fn tool(&self, name: &str) -> Option<AnyTool> {
|
||||
pub fn tool(&self, name: &str) -> Option<Arc<dyn Tool>> {
|
||||
self.state.read().tools.get(name).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,77 +1,39 @@
|
||||
use std::borrow::Borrow;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{AnyTool, ToolRegistry, ToolSource};
|
||||
use collections::{HashMap, HashSet, IndexMap};
|
||||
use gpui::{App, SharedString};
|
||||
use util::debug_panic;
|
||||
use collections::{HashMap, IndexMap};
|
||||
use gpui::App;
|
||||
|
||||
use crate::{Tool, ToolRegistry, ToolSource};
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
|
||||
pub struct ToolId(usize);
|
||||
|
||||
/// A unique identifier for a tool within a working set.
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Default)]
|
||||
pub struct UniqueToolName(SharedString);
|
||||
|
||||
impl Borrow<str> for UniqueToolName {
|
||||
fn borrow(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for UniqueToolName {
|
||||
fn from(value: String) -> Self {
|
||||
UniqueToolName(SharedString::new(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<String> for UniqueToolName {
|
||||
fn into(self) -> String {
|
||||
self.0.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for UniqueToolName {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for UniqueToolName {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
/// A working set of tools for use in one instance of the Assistant Panel.
|
||||
#[derive(Default)]
|
||||
pub struct ToolWorkingSet {
|
||||
context_server_tools_by_id: HashMap<ToolId, AnyTool>,
|
||||
context_server_tools_by_name: HashMap<UniqueToolName, AnyTool>,
|
||||
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
|
||||
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
|
||||
next_tool_id: ToolId,
|
||||
}
|
||||
|
||||
impl ToolWorkingSet {
|
||||
pub fn tool(&self, name: &str, cx: &App) -> Option<AnyTool> {
|
||||
pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
|
||||
self.context_server_tools_by_name
|
||||
.get(name)
|
||||
.cloned()
|
||||
.or_else(|| ToolRegistry::global(cx).tool(name))
|
||||
}
|
||||
|
||||
pub fn tools(&self, cx: &App) -> Vec<(UniqueToolName, AnyTool)> {
|
||||
let mut tools = ToolRegistry::global(cx)
|
||||
.tools()
|
||||
.into_iter()
|
||||
.map(|tool| (UniqueToolName(tool.name().into()), tool))
|
||||
.collect::<Vec<_>>();
|
||||
tools.extend(self.context_server_tools_by_name.clone());
|
||||
pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
|
||||
let mut tools = ToolRegistry::global(cx).tools();
|
||||
tools.extend(self.context_server_tools_by_id.values().cloned());
|
||||
tools
|
||||
}
|
||||
|
||||
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<AnyTool>> {
|
||||
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
|
||||
let mut tools_by_source = IndexMap::default();
|
||||
|
||||
for (_, tool) in self.tools(cx) {
|
||||
for tool in self.tools(cx) {
|
||||
tools_by_source
|
||||
.entry(tool.source())
|
||||
.or_insert_with(Vec::new)
|
||||
@@ -87,330 +49,27 @@ impl ToolWorkingSet {
|
||||
tools_by_source
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, tool: AnyTool, cx: &App) -> ToolId {
|
||||
let tool_id = self.register_tool(tool);
|
||||
self.tools_changed(cx);
|
||||
tool_id
|
||||
}
|
||||
|
||||
pub fn extend(&mut self, tools: impl Iterator<Item = AnyTool>, cx: &App) -> Vec<ToolId> {
|
||||
let ids = tools.map(|tool| self.register_tool(tool)).collect();
|
||||
self.tools_changed(cx);
|
||||
ids
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, tool_ids_to_remove: &[ToolId], cx: &App) {
|
||||
self.context_server_tools_by_id
|
||||
.retain(|id, _| !tool_ids_to_remove.contains(id));
|
||||
self.tools_changed(cx);
|
||||
}
|
||||
|
||||
fn register_tool(&mut self, tool: AnyTool) -> ToolId {
|
||||
pub fn insert(&mut self, tool: Arc<dyn Tool>) -> ToolId {
|
||||
let tool_id = self.next_tool_id;
|
||||
self.next_tool_id.0 += 1;
|
||||
self.context_server_tools_by_id
|
||||
.insert(tool_id, tool.clone());
|
||||
self.tools_changed();
|
||||
tool_id
|
||||
}
|
||||
|
||||
fn tools_changed(&mut self, cx: &App) {
|
||||
self.context_server_tools_by_name = resolve_context_server_tool_name_conflicts(
|
||||
&self
|
||||
.context_server_tools_by_id
|
||||
pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) {
|
||||
self.context_server_tools_by_id
|
||||
.retain(|id, _| !tool_ids_to_remove.contains(id));
|
||||
self.tools_changed();
|
||||
}
|
||||
|
||||
fn tools_changed(&mut self) {
|
||||
self.context_server_tools_by_name.clear();
|
||||
self.context_server_tools_by_name.extend(
|
||||
self.context_server_tools_by_id
|
||||
.values()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>(),
|
||||
&ToolRegistry::global(cx).tools(),
|
||||
.map(|tool| (tool.name(), tool.clone())),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_context_server_tool_name_conflicts(
|
||||
context_server_tools: &[AnyTool],
|
||||
native_tools: &[AnyTool],
|
||||
) -> HashMap<UniqueToolName, AnyTool> {
|
||||
fn resolve_tool_name(tool: &AnyTool) -> String {
|
||||
let mut tool_name = tool.name();
|
||||
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
|
||||
tool_name
|
||||
}
|
||||
|
||||
const MAX_TOOL_NAME_LENGTH: usize = 64;
|
||||
|
||||
let mut duplicated_tool_names = HashSet::default();
|
||||
let mut seen_tool_names = HashSet::default();
|
||||
seen_tool_names.extend(native_tools.iter().map(|tool| tool.name()));
|
||||
for tool in context_server_tools {
|
||||
let tool_name = resolve_tool_name(tool);
|
||||
if seen_tool_names.contains(&tool_name) {
|
||||
debug_assert!(
|
||||
tool.source() != ToolSource::Native,
|
||||
"Expected MCP tool but got a native tool: {}",
|
||||
tool_name
|
||||
);
|
||||
duplicated_tool_names.insert(tool_name);
|
||||
} else {
|
||||
seen_tool_names.insert(tool_name);
|
||||
}
|
||||
}
|
||||
|
||||
if duplicated_tool_names.is_empty() {
|
||||
return context_server_tools
|
||||
.into_iter()
|
||||
.map(|tool| (resolve_tool_name(tool).into(), tool.clone()))
|
||||
.collect();
|
||||
}
|
||||
|
||||
context_server_tools
|
||||
.into_iter()
|
||||
.filter_map(|tool| {
|
||||
let mut tool_name = resolve_tool_name(tool);
|
||||
if !duplicated_tool_names.contains(&tool_name) {
|
||||
return Some((tool_name.into(), tool.clone()));
|
||||
}
|
||||
match tool.source() {
|
||||
ToolSource::Native => {
|
||||
debug_panic!("Expected MCP tool but got a native tool: {}", tool_name);
|
||||
// Built-in tools always keep their original name
|
||||
Some((tool_name.into(), tool.clone()))
|
||||
}
|
||||
ToolSource::ContextServer { id } => {
|
||||
// Context server tools are prefixed with the context server ID, and truncated if necessary
|
||||
tool_name.insert(0, '_');
|
||||
if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
|
||||
let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
|
||||
let mut id = id.to_string();
|
||||
id.truncate(len);
|
||||
tool_name.insert_str(0, &id);
|
||||
} else {
|
||||
tool_name.insert_str(0, &id);
|
||||
}
|
||||
|
||||
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
|
||||
|
||||
if seen_tool_names.contains(&tool_name) {
|
||||
log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
|
||||
None
|
||||
} else {
|
||||
Some((tool_name.into(), tool.clone()))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use gpui::{AnyWindowHandle, Entity, Task, TestAppContext};
|
||||
use language_model::{LanguageModel, LanguageModelRequest};
|
||||
use project::Project;
|
||||
|
||||
use crate::{ActionLog, Tool, ToolResult};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_unique_tool_names(cx: &mut TestAppContext) {
|
||||
fn assert_tool(
|
||||
tool_working_set: &ToolWorkingSet,
|
||||
unique_name: &str,
|
||||
expected_name: &str,
|
||||
expected_source: ToolSource,
|
||||
cx: &App,
|
||||
) {
|
||||
let tool = tool_working_set.tool(unique_name, cx).unwrap();
|
||||
assert_eq!(tool.name(), expected_name);
|
||||
assert_eq!(tool.source(), expected_source);
|
||||
}
|
||||
|
||||
let tool_registry = cx.update(ToolRegistry::default_global);
|
||||
tool_registry.register_tool(TestTool::new("tool1", ToolSource::Native));
|
||||
tool_registry.register_tool(TestTool::new("tool2", ToolSource::Native));
|
||||
|
||||
let mut tool_working_set = ToolWorkingSet::default();
|
||||
cx.update(|cx| {
|
||||
tool_working_set.extend(
|
||||
vec![
|
||||
Arc::new(TestTool::new(
|
||||
"tool2",
|
||||
ToolSource::ContextServer { id: "mcp-1".into() },
|
||||
))
|
||||
.into(),
|
||||
Arc::new(TestTool::new(
|
||||
"tool2",
|
||||
ToolSource::ContextServer { id: "mcp-2".into() },
|
||||
))
|
||||
.into(),
|
||||
]
|
||||
.into_iter(),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
assert_tool(&tool_working_set, "tool1", "tool1", ToolSource::Native, cx);
|
||||
assert_tool(&tool_working_set, "tool2", "tool2", ToolSource::Native, cx);
|
||||
assert_tool(
|
||||
&tool_working_set,
|
||||
"mcp-1_tool2",
|
||||
"tool2",
|
||||
ToolSource::ContextServer { id: "mcp-1".into() },
|
||||
cx,
|
||||
);
|
||||
assert_tool(
|
||||
&tool_working_set,
|
||||
"mcp-2_tool2",
|
||||
"tool2",
|
||||
ToolSource::ContextServer { id: "mcp-2".into() },
|
||||
cx,
|
||||
);
|
||||
})
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_resolve_context_server_tool_name_conflicts() {
|
||||
assert_resolve_context_server_tool_name_conflicts(
|
||||
vec![
|
||||
TestTool::new("tool1", ToolSource::Native),
|
||||
TestTool::new("tool2", ToolSource::Native),
|
||||
],
|
||||
vec![TestTool::new(
|
||||
"tool3",
|
||||
ToolSource::ContextServer { id: "mcp-1".into() },
|
||||
)],
|
||||
vec!["tool3"],
|
||||
);
|
||||
|
||||
assert_resolve_context_server_tool_name_conflicts(
|
||||
vec![
|
||||
TestTool::new("tool1", ToolSource::Native),
|
||||
TestTool::new("tool2", ToolSource::Native),
|
||||
],
|
||||
vec![
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
|
||||
],
|
||||
vec!["mcp-1_tool3", "mcp-2_tool3"],
|
||||
);
|
||||
|
||||
assert_resolve_context_server_tool_name_conflicts(
|
||||
vec![
|
||||
TestTool::new("tool1", ToolSource::Native),
|
||||
TestTool::new("tool2", ToolSource::Native),
|
||||
TestTool::new("tool3", ToolSource::Native),
|
||||
],
|
||||
vec![
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
|
||||
],
|
||||
vec!["mcp-1_tool3", "mcp-2_tool3"],
|
||||
);
|
||||
|
||||
// Test deduplication of tools with very long names, in this case the mcp server name should be truncated
|
||||
assert_resolve_context_server_tool_name_conflicts(
|
||||
vec![TestTool::new(
|
||||
"tool-with-very-very-very-long-name",
|
||||
ToolSource::Native,
|
||||
)],
|
||||
vec![TestTool::new(
|
||||
"tool-with-very-very-very-long-name",
|
||||
ToolSource::ContextServer {
|
||||
id: "mcp-with-very-very-very-long-name".into(),
|
||||
},
|
||||
)],
|
||||
vec!["mcp-with-very-very-very-long-_tool-with-very-very-very-long-name"],
|
||||
);
|
||||
|
||||
fn assert_resolve_context_server_tool_name_conflicts(
|
||||
builtin_tools: Vec<TestTool>,
|
||||
context_server_tools: Vec<TestTool>,
|
||||
expected: Vec<&'static str>,
|
||||
) {
|
||||
let context_server_tools: Vec<AnyTool> = context_server_tools
|
||||
.into_iter()
|
||||
.map(|t| Arc::new(t).into())
|
||||
.collect();
|
||||
let builtin_tools: Vec<AnyTool> = builtin_tools
|
||||
.into_iter()
|
||||
.map(|t| Arc::new(t).into())
|
||||
.collect();
|
||||
let tools =
|
||||
resolve_context_server_tool_name_conflicts(&context_server_tools, &builtin_tools);
|
||||
assert_eq!(tools.len(), expected.len());
|
||||
for (i, (name, _)) in tools.into_iter().enumerate() {
|
||||
assert_eq!(
|
||||
name.0.as_ref(),
|
||||
expected[i],
|
||||
"Expected '{}' got '{}' at index {}",
|
||||
expected[i],
|
||||
name,
|
||||
i
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TestTool {
|
||||
name: String,
|
||||
source: ToolSource,
|
||||
}
|
||||
|
||||
impl TestTool {
|
||||
fn new(name: impl Into<String>, source: ToolSource) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
source,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Tool for TestTool {
|
||||
type Input = ();
|
||||
|
||||
fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn icon(&self) -> icons::IconName {
|
||||
icons::IconName::Ai
|
||||
}
|
||||
|
||||
fn may_perform_edits(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _input: &Self::Input, _cx: &App) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn source(&self) -> ToolSource {
|
||||
self.source.clone()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"Test tool".to_string()
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &Self::Input) -> String {
|
||||
"Test tool".to_string()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
_input: Self::Input,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
_cx: &mut App,
|
||||
) -> ToolResult {
|
||||
ToolResult {
|
||||
output: Task::ready(Err(anyhow::anyhow!("No content"))),
|
||||
card: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,13 +40,11 @@ pub struct CopyPathToolInput {
|
||||
pub struct CopyPathTool;
|
||||
|
||||
impl Tool for CopyPathTool {
|
||||
type Input = CopyPathToolInput;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"copy_path".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -66,15 +64,20 @@ impl Tool for CopyPathTool {
|
||||
json_schema_for::<CopyPathToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &Self::Input) -> String {
|
||||
let src = MarkdownInlineCode(&input.source_path);
|
||||
let dest = MarkdownInlineCode(&input.destination_path);
|
||||
format!("Copy {src} to {dest}")
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<CopyPathToolInput>(input.clone()) {
|
||||
Ok(input) => {
|
||||
let src = MarkdownInlineCode(&input.source_path);
|
||||
let dest = MarkdownInlineCode(&input.destination_path);
|
||||
format!("Copy {src} to {dest}")
|
||||
}
|
||||
Err(_) => "Copy path".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
@@ -82,6 +85,10 @@ impl Tool for CopyPathTool {
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<CopyPathToolInput>(input) {
|
||||
Ok(input) => input,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
let copy_task = project.update(cx, |project, cx| {
|
||||
match project
|
||||
.find_project_path(&input.source_path, cx)
|
||||
|
||||
@@ -29,8 +29,6 @@ pub struct CreateDirectoryToolInput {
|
||||
pub struct CreateDirectoryTool;
|
||||
|
||||
impl Tool for CreateDirectoryTool {
|
||||
type Input = CreateDirectoryToolInput;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"create_directory".into()
|
||||
}
|
||||
@@ -39,7 +37,7 @@ impl Tool for CreateDirectoryTool {
|
||||
include_str!("./create_directory_tool/description.md").into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -55,13 +53,18 @@ impl Tool for CreateDirectoryTool {
|
||||
json_schema_for::<CreateDirectoryToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &Self::Input) -> String {
|
||||
format!("Create directory {}", MarkdownInlineCode(&input.path))
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<CreateDirectoryToolInput>(input.clone()) {
|
||||
Ok(input) => {
|
||||
format!("Create directory {}", MarkdownInlineCode(&input.path))
|
||||
}
|
||||
Err(_) => "Create directory".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
@@ -69,6 +72,10 @@ impl Tool for CreateDirectoryTool {
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<CreateDirectoryToolInput>(input) {
|
||||
Ok(input) => input,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
let project_path = match project.read(cx).find_project_path(&input.path, cx) {
|
||||
Some(project_path) => project_path,
|
||||
None => {
|
||||
|
||||
@@ -29,13 +29,11 @@ pub struct DeletePathToolInput {
|
||||
pub struct DeletePathTool;
|
||||
|
||||
impl Tool for DeletePathTool {
|
||||
type Input = DeletePathToolInput;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"delete_path".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -55,13 +53,16 @@ impl Tool for DeletePathTool {
|
||||
json_schema_for::<DeletePathToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &Self::Input) -> String {
|
||||
format!("Delete “`{}`”", input.path)
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<DeletePathToolInput>(input.clone()) {
|
||||
Ok(input) => format!("Delete “`{}`”", input.path),
|
||||
Err(_) => "Delete path".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
@@ -69,7 +70,10 @@ impl Tool for DeletePathTool {
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let path_str = input.path;
|
||||
let path_str = match serde_json::from_value::<DeletePathToolInput>(input) {
|
||||
Ok(input) => input.path,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
let Some(project_path) = project.read(cx).find_project_path(&path_str, cx) else {
|
||||
return Task::ready(Err(anyhow!(
|
||||
"Couldn't delete {path_str} because that path isn't in this project."
|
||||
|
||||
@@ -42,13 +42,11 @@ where
|
||||
pub struct DiagnosticsTool;
|
||||
|
||||
impl Tool for DiagnosticsTool {
|
||||
type Input = DiagnosticsToolInput;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"diagnostics".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -68,9 +66,15 @@ impl Tool for DiagnosticsTool {
|
||||
json_schema_for::<DiagnosticsToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &Self::Input) -> String {
|
||||
if let Some(path) = input.path.as_ref().filter(|p| !p.is_empty()) {
|
||||
format!("Check diagnostics for {}", MarkdownInlineCode(path))
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
if let Some(path) = serde_json::from_value::<DiagnosticsToolInput>(input.clone())
|
||||
.ok()
|
||||
.and_then(|input| match input.path {
|
||||
Some(path) if !path.is_empty() => Some(path),
|
||||
_ => None,
|
||||
})
|
||||
{
|
||||
format!("Check diagnostics for {}", MarkdownInlineCode(&path))
|
||||
} else {
|
||||
"Check project diagnostics".to_string()
|
||||
}
|
||||
@@ -78,7 +82,7 @@ impl Tool for DiagnosticsTool {
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
@@ -86,7 +90,10 @@ impl Tool for DiagnosticsTool {
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
match input.path {
|
||||
match serde_json::from_value::<DiagnosticsToolInput>(input)
|
||||
.ok()
|
||||
.and_then(|input| input.path)
|
||||
{
|
||||
Some(path) if !path.is_empty() => {
|
||||
let Some(project_path) = project.read(cx).find_project_path(&path, cx) else {
|
||||
return Task::ready(Err(anyhow!("Could not find path {path} in project",)))
|
||||
|
||||
@@ -9132,7 +9132,7 @@ impl Editor {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.manipulate_lines(window, cx, |lines| lines.sort())
|
||||
self.manipulate_immutable_lines(window, cx, |lines| lines.sort())
|
||||
}
|
||||
|
||||
pub fn sort_lines_case_insensitive(
|
||||
@@ -9141,7 +9141,7 @@ impl Editor {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.manipulate_lines(window, cx, |lines| {
|
||||
self.manipulate_immutable_lines(window, cx, |lines| {
|
||||
lines.sort_by_key(|line| line.to_lowercase())
|
||||
})
|
||||
}
|
||||
@@ -9152,7 +9152,7 @@ impl Editor {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.manipulate_lines(window, cx, |lines| {
|
||||
self.manipulate_immutable_lines(window, cx, |lines| {
|
||||
let mut seen = HashSet::default();
|
||||
lines.retain(|line| seen.insert(line.to_lowercase()));
|
||||
})
|
||||
@@ -9164,7 +9164,7 @@ impl Editor {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.manipulate_lines(window, cx, |lines| {
|
||||
self.manipulate_immutable_lines(window, cx, |lines| {
|
||||
let mut seen = HashSet::default();
|
||||
lines.retain(|line| seen.insert(*line));
|
||||
})
|
||||
@@ -9606,20 +9606,20 @@ impl Editor {
|
||||
}
|
||||
|
||||
pub fn reverse_lines(&mut self, _: &ReverseLines, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.manipulate_lines(window, cx, |lines| lines.reverse())
|
||||
self.manipulate_immutable_lines(window, cx, |lines| lines.reverse())
|
||||
}
|
||||
|
||||
pub fn shuffle_lines(&mut self, _: &ShuffleLines, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.manipulate_lines(window, cx, |lines| lines.shuffle(&mut thread_rng()))
|
||||
self.manipulate_immutable_lines(window, cx, |lines| lines.shuffle(&mut thread_rng()))
|
||||
}
|
||||
|
||||
fn manipulate_lines<Fn>(
|
||||
fn manipulate_lines<M>(
|
||||
&mut self,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
mut callback: Fn,
|
||||
mut manipulate: M,
|
||||
) where
|
||||
Fn: FnMut(&mut Vec<&str>),
|
||||
M: FnMut(&str) -> LineManipulationResult,
|
||||
{
|
||||
self.hide_mouse_cursor(&HideMouseCursorOrigin::TypingAction);
|
||||
|
||||
@@ -9652,18 +9652,14 @@ impl Editor {
|
||||
.text_for_range(start_point..end_point)
|
||||
.collect::<String>();
|
||||
|
||||
let mut lines = text.split('\n').collect_vec();
|
||||
let LineManipulationResult { new_text, line_count_before, line_count_after} = manipulate(&text);
|
||||
|
||||
let lines_before = lines.len();
|
||||
callback(&mut lines);
|
||||
let lines_after = lines.len();
|
||||
|
||||
edits.push((start_point..end_point, lines.join("\n")));
|
||||
edits.push((start_point..end_point, new_text));
|
||||
|
||||
// Selections must change based on added and removed line count
|
||||
let start_row =
|
||||
MultiBufferRow(start_point.row + added_lines as u32 - removed_lines as u32);
|
||||
let end_row = MultiBufferRow(start_row.0 + lines_after.saturating_sub(1) as u32);
|
||||
let end_row = MultiBufferRow(start_row.0 + line_count_after.saturating_sub(1) as u32);
|
||||
new_selections.push(Selection {
|
||||
id: selection.id,
|
||||
start: start_row,
|
||||
@@ -9672,10 +9668,10 @@ impl Editor {
|
||||
reversed: selection.reversed,
|
||||
});
|
||||
|
||||
if lines_after > lines_before {
|
||||
added_lines += lines_after - lines_before;
|
||||
} else if lines_before > lines_after {
|
||||
removed_lines += lines_before - lines_after;
|
||||
if line_count_after > line_count_before {
|
||||
added_lines += line_count_after - line_count_before;
|
||||
} else if line_count_before > line_count_after {
|
||||
removed_lines += line_count_before - line_count_after;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9720,6 +9716,171 @@ impl Editor {
|
||||
})
|
||||
}
|
||||
|
||||
fn manipulate_immutable_lines<Fn>(
|
||||
&mut self,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
mut callback: Fn,
|
||||
) where
|
||||
Fn: FnMut(&mut Vec<&str>),
|
||||
{
|
||||
self.manipulate_lines(window, cx, |text| {
|
||||
let mut lines: Vec<&str> = text.split('\n').collect();
|
||||
let line_count_before = lines.len();
|
||||
|
||||
callback(&mut lines);
|
||||
|
||||
LineManipulationResult {
|
||||
new_text: lines.join("\n"),
|
||||
line_count_before,
|
||||
line_count_after: lines.len(),
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn manipulate_mutable_lines<Fn>(
|
||||
&mut self,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
mut callback: Fn,
|
||||
) where
|
||||
Fn: FnMut(&mut Vec<Cow<'_, str>>),
|
||||
{
|
||||
self.manipulate_lines(window, cx, |text| {
|
||||
let mut lines: Vec<Cow<str>> = text.split('\n').map(Cow::from).collect();
|
||||
let line_count_before = lines.len();
|
||||
|
||||
callback(&mut lines);
|
||||
|
||||
LineManipulationResult {
|
||||
new_text: lines.join("\n"),
|
||||
line_count_before,
|
||||
line_count_after: lines.len(),
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn convert_indentation_to_spaces(
|
||||
&mut self,
|
||||
_: &ConvertIndentationToSpaces,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let settings = self.buffer.read(cx).language_settings(cx);
|
||||
let tab_size = settings.tab_size.get() as usize;
|
||||
|
||||
self.manipulate_mutable_lines(window, cx, |lines| {
|
||||
// Allocates a reasonably sized scratch buffer once for the whole loop
|
||||
let mut reindented_line = String::with_capacity(MAX_LINE_LEN);
|
||||
// Avoids recomputing spaces that could be inserted many times
|
||||
let space_cache: Vec<Vec<char>> = (1..=tab_size)
|
||||
.map(|n| IndentSize::spaces(n as u32).chars().collect())
|
||||
.collect();
|
||||
|
||||
for line in lines.iter_mut().filter(|line| !line.is_empty()) {
|
||||
let mut chars = line.as_ref().chars();
|
||||
let mut col = 0;
|
||||
let mut changed = false;
|
||||
|
||||
while let Some(ch) = chars.next() {
|
||||
match ch {
|
||||
' ' => {
|
||||
reindented_line.push(' ');
|
||||
col += 1;
|
||||
}
|
||||
'\t' => {
|
||||
// \t are converted to spaces depending on the current column
|
||||
let spaces_len = tab_size - (col % tab_size);
|
||||
reindented_line.extend(&space_cache[spaces_len - 1]);
|
||||
col += spaces_len;
|
||||
changed = true;
|
||||
}
|
||||
_ => {
|
||||
// If we dont append before break, the character is consumed
|
||||
reindented_line.push(ch);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
reindented_line.clear();
|
||||
continue;
|
||||
}
|
||||
// Append the rest of the line and replace old reference with new one
|
||||
reindented_line.extend(chars);
|
||||
*line = Cow::Owned(reindented_line.clone());
|
||||
reindented_line.clear();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn convert_indentation_to_tabs(
|
||||
&mut self,
|
||||
_: &ConvertIndentationToTabs,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let settings = self.buffer.read(cx).language_settings(cx);
|
||||
let tab_size = settings.tab_size.get() as usize;
|
||||
|
||||
self.manipulate_mutable_lines(window, cx, |lines| {
|
||||
// Allocates a reasonably sized buffer once for the whole loop
|
||||
let mut reindented_line = String::with_capacity(MAX_LINE_LEN);
|
||||
// Avoids recomputing spaces that could be inserted many times
|
||||
let space_cache: Vec<Vec<char>> = (1..=tab_size)
|
||||
.map(|n| IndentSize::spaces(n as u32).chars().collect())
|
||||
.collect();
|
||||
|
||||
for line in lines.iter_mut().filter(|line| !line.is_empty()) {
|
||||
let mut chars = line.chars();
|
||||
let mut spaces_count = 0;
|
||||
let mut first_non_indent_char = None;
|
||||
let mut changed = false;
|
||||
|
||||
while let Some(ch) = chars.next() {
|
||||
match ch {
|
||||
' ' => {
|
||||
// Keep track of spaces. Append \t when we reach tab_size
|
||||
spaces_count += 1;
|
||||
changed = true;
|
||||
if spaces_count == tab_size {
|
||||
reindented_line.push('\t');
|
||||
spaces_count = 0;
|
||||
}
|
||||
}
|
||||
'\t' => {
|
||||
reindented_line.push('\t');
|
||||
spaces_count = 0;
|
||||
}
|
||||
_ => {
|
||||
// Dont append it yet, we might have remaining spaces
|
||||
first_non_indent_char = Some(ch);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
reindented_line.clear();
|
||||
continue;
|
||||
}
|
||||
// Remaining spaces that didn't make a full tab stop
|
||||
if spaces_count > 0 {
|
||||
reindented_line.extend(&space_cache[spaces_count - 1]);
|
||||
}
|
||||
// If we consume an extra character that was not indentation, add it back
|
||||
if let Some(extra_char) = first_non_indent_char {
|
||||
reindented_line.push(extra_char);
|
||||
}
|
||||
// Append the rest of the line and replace old reference with new one
|
||||
reindented_line.extend(chars);
|
||||
*line = Cow::Owned(reindented_line.clone());
|
||||
reindented_line.clear();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn convert_to_upper_case(
|
||||
&mut self,
|
||||
_: &ConvertToUpperCase,
|
||||
@@ -21157,6 +21318,13 @@ pub struct LineHighlight {
|
||||
pub type_id: Option<TypeId>,
|
||||
}
|
||||
|
||||
struct LineManipulationResult {
|
||||
pub new_text: String,
|
||||
pub line_count_before: usize,
|
||||
pub line_count_after: usize,
|
||||
}
|
||||
|
||||
|
||||
fn render_diff_hunk_controls(
|
||||
row: u32,
|
||||
status: &DiffHunkStatus,
|
||||
|
||||
@@ -121,13 +121,11 @@ struct PartialInput {
|
||||
const DEFAULT_UI_TEXT: &str = "Editing file";
|
||||
|
||||
impl Tool for EditFileTool {
|
||||
type Input = EditFileToolInput;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"edit_file".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -147,20 +145,24 @@ impl Tool for EditFileTool {
|
||||
json_schema_for::<EditFileToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &Self::Input) -> String {
|
||||
input.display_description.clone()
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<EditFileToolInput>(input.clone()) {
|
||||
Ok(input) => input.display_description,
|
||||
Err(_) => "Editing file".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn still_streaming_ui_text(&self, input: &Self::Input) -> String {
|
||||
let description = input.display_description.trim();
|
||||
if !description.is_empty() {
|
||||
return description.to_string();
|
||||
}
|
||||
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
|
||||
if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
|
||||
let description = input.display_description.trim();
|
||||
if !description.is_empty() {
|
||||
return description.to_string();
|
||||
}
|
||||
|
||||
let path = input.path.to_string_lossy();
|
||||
let path = path.trim();
|
||||
if !path.is_empty() {
|
||||
return path.to_string();
|
||||
let path = input.path.trim();
|
||||
if !path.is_empty() {
|
||||
return path.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
DEFAULT_UI_TEXT.to_string()
|
||||
@@ -168,7 +170,7 @@ impl Tool for EditFileTool {
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
input: serde_json::Value,
|
||||
request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
@@ -176,6 +178,11 @@ impl Tool for EditFileTool {
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<EditFileToolInput>(input) {
|
||||
Ok(input) => input,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
|
||||
let project_path = match resolve_path(&input, project.clone(), cx) {
|
||||
Ok(path) => path,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
@@ -1162,11 +1169,12 @@ mod tests {
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = EditFileToolInput {
|
||||
let input = serde_json::to_value(EditFileToolInput {
|
||||
display_description: "Some edit".into(),
|
||||
path: "root/nonexistent_file.txt".into(),
|
||||
mode: EditFileMode::Edit,
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
Arc::new(EditFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -1280,22 +1288,24 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_with_path() {
|
||||
let input = EditFileToolInput {
|
||||
path: "src/main.rs".into(),
|
||||
display_description: "".into(),
|
||||
mode: EditFileMode::Edit,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "src/main.rs",
|
||||
"display_description": "",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
});
|
||||
|
||||
assert_eq!(EditFileTool.still_streaming_ui_text(&input), "src/main.rs");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_with_description() {
|
||||
let input = EditFileToolInput {
|
||||
path: "".into(),
|
||||
display_description: "Fix error handling".into(),
|
||||
mode: EditFileMode::Edit,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "",
|
||||
"display_description": "Fix error handling",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
EditFileTool.still_streaming_ui_text(&input),
|
||||
@@ -1305,11 +1315,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_with_path_and_description() {
|
||||
let input = EditFileToolInput {
|
||||
path: "src/main.rs".into(),
|
||||
display_description: "Fix error handling".into(),
|
||||
mode: EditFileMode::Edit,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "src/main.rs",
|
||||
"display_description": "Fix error handling",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
EditFileTool.still_streaming_ui_text(&input),
|
||||
@@ -1319,11 +1330,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_no_path_or_description() {
|
||||
let input = EditFileToolInput {
|
||||
path: "".into(),
|
||||
display_description: "".into(),
|
||||
mode: EditFileMode::Edit,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "",
|
||||
"display_description": "",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
EditFileTool.still_streaming_ui_text(&input),
|
||||
@@ -1333,11 +1345,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_with_null() {
|
||||
let input = EditFileToolInput {
|
||||
path: "".into(),
|
||||
display_description: "".into(),
|
||||
mode: EditFileMode::Edit,
|
||||
};
|
||||
let input = serde_json::Value::Null;
|
||||
|
||||
assert_eq!(
|
||||
EditFileTool.still_streaming_ui_text(&input),
|
||||
@@ -1449,11 +1457,12 @@ mod tests {
|
||||
// Have the model stream unformatted content
|
||||
let edit_result = {
|
||||
let edit_task = cx.update(|cx| {
|
||||
let input = EditFileToolInput {
|
||||
let input = serde_json::to_value(EditFileToolInput {
|
||||
display_description: "Create main function".into(),
|
||||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
Arc::new(EditFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -1512,11 +1521,12 @@ mod tests {
|
||||
// Stream unformatted edits again
|
||||
let edit_result = {
|
||||
let edit_task = cx.update(|cx| {
|
||||
let input = EditFileToolInput {
|
||||
let input = serde_json::to_value(EditFileToolInput {
|
||||
display_description: "Update main function".into(),
|
||||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
Arc::new(EditFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -1590,11 +1600,12 @@ mod tests {
|
||||
// Have the model stream content that contains trailing whitespace
|
||||
let edit_result = {
|
||||
let edit_task = cx.update(|cx| {
|
||||
let input = EditFileToolInput {
|
||||
let input = serde_json::to_value(EditFileToolInput {
|
||||
display_description: "Create main function".into(),
|
||||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
Arc::new(EditFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -1646,11 +1657,12 @@ mod tests {
|
||||
// Stream edits again with trailing whitespace
|
||||
let edit_result = {
|
||||
let edit_task = cx.update(|cx| {
|
||||
let input = EditFileToolInput {
|
||||
let input = serde_json::to_value(EditFileToolInput {
|
||||
display_description: "Update main function".into(),
|
||||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
Arc::new(EditFileTool)
|
||||
.run(
|
||||
input,
|
||||
|
||||
@@ -3,10 +3,10 @@ use std::sync::Arc;
|
||||
use std::{borrow::Cow, cell::RefCell};
|
||||
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Context as _, Result, bail};
|
||||
use anyhow::{Context as _, Result, anyhow, bail};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use futures::AsyncReadExt as _;
|
||||
use gpui::{AnyWindowHandle, App, AppContext as _, Entity};
|
||||
use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task};
|
||||
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
|
||||
use http_client::{AsyncBody, HttpClientWithUrl};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
@@ -113,13 +113,11 @@ impl FetchTool {
|
||||
}
|
||||
|
||||
impl Tool for FetchTool {
|
||||
type Input = FetchToolInput;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"fetch".to_string()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -139,13 +137,16 @@ impl Tool for FetchTool {
|
||||
json_schema_for::<FetchToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &Self::Input) -> String {
|
||||
format!("Fetch {}", MarkdownEscaped(&input.url))
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<FetchToolInput>(input.clone()) {
|
||||
Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)),
|
||||
Err(_) => "Fetch URL".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
@@ -153,6 +154,11 @@ impl Tool for FetchTool {
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<FetchToolInput>(input) {
|
||||
Ok(input) => input,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
|
||||
let text = cx.background_spawn({
|
||||
let http_client = self.http_client.clone();
|
||||
async move { Self::build_message(http_client, &input.url).await }
|
||||
|
||||
@@ -51,13 +51,11 @@ const RESULTS_PER_PAGE: usize = 50;
|
||||
pub struct FindPathTool;
|
||||
|
||||
impl Tool for FindPathTool {
|
||||
type Input = FindPathToolInput;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"find_path".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -77,13 +75,16 @@ impl Tool for FindPathTool {
|
||||
json_schema_for::<FindPathToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &Self::Input) -> String {
|
||||
format!("Find paths matching \"`{}`\"", input.glob)
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<FindPathToolInput>(input.clone()) {
|
||||
Ok(input) => format!("Find paths matching “`{}`”", input.glob),
|
||||
Err(_) => "Search paths".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
@@ -91,7 +92,10 @@ impl Tool for FindPathTool {
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let (offset, glob) = (input.offset, input.glob);
|
||||
let (offset, glob) = match serde_json::from_value::<FindPathToolInput>(input) {
|
||||
Ok(input) => (input.offset, input.glob),
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
|
||||
|
||||
@@ -53,13 +53,11 @@ const RESULTS_PER_PAGE: u32 = 20;
|
||||
pub struct GrepTool;
|
||||
|
||||
impl Tool for GrepTool {
|
||||
type Input = GrepToolInput;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"grep".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -79,25 +77,30 @@ impl Tool for GrepTool {
|
||||
json_schema_for::<GrepToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &Self::Input) -> String {
|
||||
let page = input.page();
|
||||
let regex_str = MarkdownInlineCode(&input.regex);
|
||||
let case_info = if input.case_sensitive {
|
||||
" (case-sensitive)"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<GrepToolInput>(input.clone()) {
|
||||
Ok(input) => {
|
||||
let page = input.page();
|
||||
let regex_str = MarkdownInlineCode(&input.regex);
|
||||
let case_info = if input.case_sensitive {
|
||||
" (case-sensitive)"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
if page > 1 {
|
||||
format!("Get page {page} of search results for regex {regex_str}{case_info}")
|
||||
} else {
|
||||
format!("Search files for regex {regex_str}{case_info}")
|
||||
if page > 1 {
|
||||
format!("Get page {page} of search results for regex {regex_str}{case_info}")
|
||||
} else {
|
||||
format!("Search files for regex {regex_str}{case_info}")
|
||||
}
|
||||
}
|
||||
Err(_) => "Search with regex".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
@@ -108,6 +111,13 @@ impl Tool for GrepTool {
|
||||
const CONTEXT_LINES: u32 = 2;
|
||||
const MAX_ANCESTOR_LINES: u32 = 10;
|
||||
|
||||
let input = match serde_json::from_value::<GrepToolInput>(input) {
|
||||
Ok(input) => input,
|
||||
Err(error) => {
|
||||
return Task::ready(Err(anyhow!("Failed to parse input: {error}"))).into();
|
||||
}
|
||||
};
|
||||
|
||||
let include_matcher = match PathMatcher::new(
|
||||
input
|
||||
.include_pattern
|
||||
@@ -338,12 +348,13 @@ mod tests {
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
|
||||
// Test with include pattern for Rust files inside the root of the project
|
||||
let input = GrepToolInput {
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "println".to_string(),
|
||||
include_pattern: Some("root/**/*.rs".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
assert!(result.contains("main.rs"), "Should find matches in main.rs");
|
||||
@@ -357,12 +368,13 @@ mod tests {
|
||||
);
|
||||
|
||||
// Test with include pattern for src directory only
|
||||
let input = GrepToolInput {
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "fn".to_string(),
|
||||
include_pattern: Some("root/**/src/**".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
assert!(
|
||||
@@ -379,12 +391,13 @@ mod tests {
|
||||
);
|
||||
|
||||
// Test with empty include pattern (should default to all files)
|
||||
let input = GrepToolInput {
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "fn".to_string(),
|
||||
include_pattern: None,
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
assert!(result.contains("main.rs"), "Should find matches in main.rs");
|
||||
@@ -415,12 +428,13 @@ mod tests {
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
|
||||
// Test case-insensitive search (default)
|
||||
let input = GrepToolInput {
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "uppercase".to_string(),
|
||||
include_pattern: Some("**/*.txt".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
assert!(
|
||||
@@ -429,12 +443,13 @@ mod tests {
|
||||
);
|
||||
|
||||
// Test case-sensitive search
|
||||
let input = GrepToolInput {
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "uppercase".to_string(),
|
||||
include_pattern: Some("**/*.txt".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: true,
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
assert!(
|
||||
@@ -443,12 +458,13 @@ mod tests {
|
||||
);
|
||||
|
||||
// Test case-sensitive search
|
||||
let input = GrepToolInput {
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "LOWERCASE".to_string(),
|
||||
include_pattern: Some("**/*.txt".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: true,
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
|
||||
@@ -458,12 +474,13 @@ mod tests {
|
||||
);
|
||||
|
||||
// Test case-sensitive search for lowercase pattern
|
||||
let input = GrepToolInput {
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "lowercase".to_string(),
|
||||
include_pattern: Some("**/*.txt".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: true,
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
assert!(
|
||||
@@ -559,12 +576,13 @@ mod tests {
|
||||
let project = setup_syntax_test(cx).await;
|
||||
|
||||
// Test: Line at the top level of the file
|
||||
let input = GrepToolInput {
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "This is at the top level".to_string(),
|
||||
include_pattern: Some("**/*.rs".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
let expected = r#"
|
||||
@@ -588,12 +606,13 @@ mod tests {
|
||||
let project = setup_syntax_test(cx).await;
|
||||
|
||||
// Test: Line inside a function body
|
||||
let input = GrepToolInput {
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "Function in nested module".to_string(),
|
||||
include_pattern: Some("**/*.rs".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
let expected = r#"
|
||||
@@ -619,12 +638,13 @@ mod tests {
|
||||
let project = setup_syntax_test(cx).await;
|
||||
|
||||
// Test: Line with a function argument
|
||||
let input = GrepToolInput {
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "second_arg".to_string(),
|
||||
include_pattern: Some("**/*.rs".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
let expected = r#"
|
||||
@@ -654,12 +674,13 @@ mod tests {
|
||||
let project = setup_syntax_test(cx).await;
|
||||
|
||||
// Test: Line inside an if block
|
||||
let input = GrepToolInput {
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "Inside if block".to_string(),
|
||||
include_pattern: Some("**/*.rs".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
let expected = r#"
|
||||
@@ -684,12 +705,13 @@ mod tests {
|
||||
let project = setup_syntax_test(cx).await;
|
||||
|
||||
// Test: Line in the middle of a long function - should show message about remaining lines
|
||||
let input = GrepToolInput {
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "Line 5".to_string(),
|
||||
include_pattern: Some("**/*.rs".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
let expected = r#"
|
||||
@@ -724,12 +746,13 @@ mod tests {
|
||||
let project = setup_syntax_test(cx).await;
|
||||
|
||||
// Test: Line in the long function
|
||||
let input = GrepToolInput {
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "Line 12".to_string(),
|
||||
include_pattern: Some("**/*.rs".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
let expected = r#"
|
||||
@@ -751,7 +774,7 @@ mod tests {
|
||||
}
|
||||
|
||||
async fn run_grep_tool(
|
||||
input: GrepToolInput,
|
||||
input: serde_json::Value,
|
||||
project: Entity<Project>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> String {
|
||||
@@ -853,12 +876,9 @@ mod tests {
|
||||
// Searching for files outside the project worktree should return no results
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = GrepToolInput {
|
||||
regex: "outside_function".to_string(),
|
||||
include_pattern: None,
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
let input = json!({
|
||||
"regex": "outside_function"
|
||||
});
|
||||
Arc::new(GrepTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -882,12 +902,9 @@ mod tests {
|
||||
// Searching within the project should succeed
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = GrepToolInput {
|
||||
regex: "main".to_string(),
|
||||
include_pattern: None,
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
let input = json!({
|
||||
"regex": "main"
|
||||
});
|
||||
Arc::new(GrepTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -911,12 +928,9 @@ mod tests {
|
||||
// Searching files that match file_scan_exclusions should return no results
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = GrepToolInput {
|
||||
regex: "special_configuration".to_string(),
|
||||
include_pattern: None,
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
let input = json!({
|
||||
"regex": "special_configuration"
|
||||
});
|
||||
Arc::new(GrepTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -939,12 +953,9 @@ mod tests {
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = GrepToolInput {
|
||||
regex: "custom_metadata".to_string(),
|
||||
include_pattern: None,
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
let input = json!({
|
||||
"regex": "custom_metadata"
|
||||
});
|
||||
Arc::new(GrepTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -968,12 +979,9 @@ mod tests {
|
||||
// Searching private files should return no results
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = GrepToolInput {
|
||||
regex: "SECRET_KEY".to_string(),
|
||||
include_pattern: None,
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
let input = json!({
|
||||
"regex": "SECRET_KEY"
|
||||
});
|
||||
Arc::new(GrepTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -996,12 +1004,9 @@ mod tests {
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = GrepToolInput {
|
||||
regex: "private_key_content".to_string(),
|
||||
include_pattern: None,
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
let input = json!({
|
||||
"regex": "private_key_content"
|
||||
});
|
||||
Arc::new(GrepTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -1024,12 +1029,9 @@ mod tests {
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = GrepToolInput {
|
||||
regex: "sensitive_data".to_string(),
|
||||
include_pattern: None,
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
let input = json!({
|
||||
"regex": "sensitive_data"
|
||||
});
|
||||
Arc::new(GrepTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -1053,12 +1055,9 @@ mod tests {
|
||||
// Searching a normal file should still work, even with private_files configured
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = GrepToolInput {
|
||||
regex: "normal_file_content".to_string(),
|
||||
include_pattern: None,
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
let input = json!({
|
||||
"regex": "normal_file_content"
|
||||
});
|
||||
Arc::new(GrepTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -1082,12 +1081,10 @@ mod tests {
|
||||
// Path traversal attempts with .. in include_pattern should not escape project
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = GrepToolInput {
|
||||
regex: "outside_function".to_string(),
|
||||
include_pattern: Some("../outside_project/**/*.rs".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
let input = json!({
|
||||
"regex": "outside_function",
|
||||
"include_pattern": "../outside_project/**/*.rs"
|
||||
});
|
||||
Arc::new(GrepTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -1188,12 +1185,10 @@ mod tests {
|
||||
// Search for "secret" - should exclude files based on worktree-specific settings
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = GrepToolInput {
|
||||
regex: "secret".to_string(),
|
||||
include_pattern: None,
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
let input = json!({
|
||||
"regex": "secret",
|
||||
"case_sensitive": false
|
||||
});
|
||||
Arc::new(GrepTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -1255,12 +1250,10 @@ mod tests {
|
||||
// Test with `include_pattern` specific to one worktree
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = GrepToolInput {
|
||||
regex: "secret".to_string(),
|
||||
include_pattern: Some("worktree1/**/*.rs".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
};
|
||||
let input = json!({
|
||||
"regex": "secret",
|
||||
"include_pattern": "worktree1/**/*.rs"
|
||||
});
|
||||
Arc::new(GrepTool)
|
||||
.run(
|
||||
input,
|
||||
|
||||
@@ -41,13 +41,11 @@ pub struct ListDirectoryToolInput {
|
||||
pub struct ListDirectoryTool;
|
||||
|
||||
impl Tool for ListDirectoryTool {
|
||||
type Input = ListDirectoryToolInput;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"list_directory".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -67,14 +65,19 @@ impl Tool for ListDirectoryTool {
|
||||
json_schema_for::<ListDirectoryToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &Self::Input) -> String {
|
||||
let path = MarkdownInlineCode(&input.path);
|
||||
format!("List the {path} directory's contents")
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<ListDirectoryToolInput>(input.clone()) {
|
||||
Ok(input) => {
|
||||
let path = MarkdownInlineCode(&input.path);
|
||||
format!("List the {path} directory's contents")
|
||||
}
|
||||
Err(_) => "List directory".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
@@ -82,6 +85,11 @@ impl Tool for ListDirectoryTool {
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<ListDirectoryToolInput>(input) {
|
||||
Ok(input) => input,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
|
||||
// Sometimes models will return these even though we tell it to give a path and not a glob.
|
||||
// When this happens, just list the root worktree directories.
|
||||
if matches!(input.path.as_str(), "." | "" | "./" | "*") {
|
||||
@@ -277,9 +285,9 @@ mod tests {
|
||||
let tool = Arc::new(ListDirectoryTool);
|
||||
|
||||
// Test listing root directory
|
||||
let input = ListDirectoryToolInput {
|
||||
path: "project".to_string(),
|
||||
};
|
||||
let input = json!({
|
||||
"path": "project"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
@@ -312,9 +320,9 @@ mod tests {
|
||||
);
|
||||
|
||||
// Test listing src directory
|
||||
let input = ListDirectoryToolInput {
|
||||
path: "project/src".to_string(),
|
||||
};
|
||||
let input = json!({
|
||||
"path": "project/src"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
@@ -347,9 +355,9 @@ mod tests {
|
||||
);
|
||||
|
||||
// Test listing directory with only files
|
||||
let input = ListDirectoryToolInput {
|
||||
path: "project/tests".to_string(),
|
||||
};
|
||||
let input = json!({
|
||||
"path": "project/tests"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
@@ -391,9 +399,9 @@ mod tests {
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let tool = Arc::new(ListDirectoryTool);
|
||||
|
||||
let input = ListDirectoryToolInput {
|
||||
path: "project/empty_dir".to_string(),
|
||||
};
|
||||
let input = json!({
|
||||
"path": "project/empty_dir"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx))
|
||||
@@ -424,9 +432,9 @@ mod tests {
|
||||
let tool = Arc::new(ListDirectoryTool);
|
||||
|
||||
// Test non-existent path
|
||||
let input = ListDirectoryToolInput {
|
||||
path: "project/nonexistent".to_string(),
|
||||
};
|
||||
let input = json!({
|
||||
"path": "project/nonexistent"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
@@ -447,9 +455,9 @@ mod tests {
|
||||
assert!(result.unwrap_err().to_string().contains("Path not found"));
|
||||
|
||||
// Test trying to list a file instead of directory
|
||||
let input = ListDirectoryToolInput {
|
||||
path: "project/file.txt".to_string(),
|
||||
};
|
||||
let input = json!({
|
||||
"path": "project/file.txt"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx))
|
||||
@@ -519,9 +527,9 @@ mod tests {
|
||||
let tool = Arc::new(ListDirectoryTool);
|
||||
|
||||
// Listing root directory should exclude private and excluded files
|
||||
let input = ListDirectoryToolInput {
|
||||
path: "project".to_string(),
|
||||
};
|
||||
let input = json!({
|
||||
"path": "project"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
@@ -560,9 +568,9 @@ mod tests {
|
||||
);
|
||||
|
||||
// Trying to list an excluded directory should fail
|
||||
let input = ListDirectoryToolInput {
|
||||
path: "project/.secretdir".to_string(),
|
||||
};
|
||||
let input = json!({
|
||||
"path": "project/.secretdir"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
@@ -592,9 +600,9 @@ mod tests {
|
||||
);
|
||||
|
||||
// Listing a directory should exclude private files within it
|
||||
let input = ListDirectoryToolInput {
|
||||
path: "project/visible_dir".to_string(),
|
||||
};
|
||||
let input = json!({
|
||||
"path": "project/visible_dir"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
@@ -712,9 +720,9 @@ mod tests {
|
||||
let tool = Arc::new(ListDirectoryTool);
|
||||
|
||||
// Test listing worktree1/src - should exclude secret.rs and config.toml based on local settings
|
||||
let input = ListDirectoryToolInput {
|
||||
path: "worktree1/src".to_string(),
|
||||
};
|
||||
let input = json!({
|
||||
"path": "worktree1/src"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
@@ -744,9 +752,9 @@ mod tests {
|
||||
);
|
||||
|
||||
// Test listing worktree1/tests - should exclude fixture.sql based on local settings
|
||||
let input = ListDirectoryToolInput {
|
||||
path: "worktree1/tests".to_string(),
|
||||
};
|
||||
let input = json!({
|
||||
"path": "worktree1/tests"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
@@ -772,9 +780,9 @@ mod tests {
|
||||
);
|
||||
|
||||
// Test listing worktree2/lib - should exclude private.js and data.json based on local settings
|
||||
let input = ListDirectoryToolInput {
|
||||
path: "worktree2/lib".to_string(),
|
||||
};
|
||||
let input = json!({
|
||||
"path": "worktree2/lib"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
@@ -804,9 +812,9 @@ mod tests {
|
||||
);
|
||||
|
||||
// Test listing worktree2/docs - should exclude internal.md based on local settings
|
||||
let input = ListDirectoryToolInput {
|
||||
path: "worktree2/docs".to_string(),
|
||||
};
|
||||
let input = json!({
|
||||
"path": "worktree2/docs"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
@@ -832,9 +840,9 @@ mod tests {
|
||||
);
|
||||
|
||||
// Test trying to list an excluded directory directly
|
||||
let input = ListDirectoryToolInput {
|
||||
path: "worktree1/src/secret.rs".to_string(),
|
||||
};
|
||||
let input = json!({
|
||||
"path": "worktree1/src/secret.rs"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
|
||||
@@ -38,13 +38,11 @@ pub struct MovePathToolInput {
|
||||
pub struct MovePathTool;
|
||||
|
||||
impl Tool for MovePathTool {
|
||||
type Input = MovePathToolInput;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"move_path".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -64,29 +62,34 @@ impl Tool for MovePathTool {
|
||||
json_schema_for::<MovePathToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &Self::Input) -> String {
|
||||
let src = MarkdownInlineCode(&input.source_path);
|
||||
let dest = MarkdownInlineCode(&input.destination_path);
|
||||
let src_path = Path::new(&input.source_path);
|
||||
let dest_path = Path::new(&input.destination_path);
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<MovePathToolInput>(input.clone()) {
|
||||
Ok(input) => {
|
||||
let src = MarkdownInlineCode(&input.source_path);
|
||||
let dest = MarkdownInlineCode(&input.destination_path);
|
||||
let src_path = Path::new(&input.source_path);
|
||||
let dest_path = Path::new(&input.destination_path);
|
||||
|
||||
match dest_path
|
||||
.file_name()
|
||||
.and_then(|os_str| os_str.to_os_string().into_string().ok())
|
||||
{
|
||||
Some(filename) if src_path.parent() == dest_path.parent() => {
|
||||
let filename = MarkdownInlineCode(&filename);
|
||||
format!("Rename {src} to {filename}")
|
||||
}
|
||||
_ => {
|
||||
format!("Move {src} to {dest}")
|
||||
match dest_path
|
||||
.file_name()
|
||||
.and_then(|os_str| os_str.to_os_string().into_string().ok())
|
||||
{
|
||||
Some(filename) if src_path.parent() == dest_path.parent() => {
|
||||
let filename = MarkdownInlineCode(&filename);
|
||||
format!("Rename {src} to {filename}")
|
||||
}
|
||||
_ => {
|
||||
format!("Move {src} to {dest}")
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => "Move path".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
@@ -94,6 +97,10 @@ impl Tool for MovePathTool {
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<MovePathToolInput>(input) {
|
||||
Ok(input) => input,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
let rename_task = project.update(cx, |project, cx| {
|
||||
match project
|
||||
.find_project_path(&input.source_path, cx)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::Result;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use chrono::{Local, Utc};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
@@ -29,13 +29,11 @@ pub struct NowToolInput {
|
||||
pub struct NowTool;
|
||||
|
||||
impl Tool for NowTool {
|
||||
type Input = NowToolInput;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"now".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -55,13 +53,13 @@ impl Tool for NowTool {
|
||||
json_schema_for::<NowToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &Self::Input) -> String {
|
||||
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||
"Get current time".to_string()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
@@ -69,6 +67,11 @@ impl Tool for NowTool {
|
||||
_window: Option<AnyWindowHandle>,
|
||||
_cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input: NowToolInput = match serde_json::from_value(input) {
|
||||
Ok(input) => input,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
|
||||
let now = match input.timezone {
|
||||
Timezone::Utc => Utc::now().to_rfc3339(),
|
||||
Timezone::Local => Local::now().to_rfc3339(),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Context as _, Result};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -19,13 +19,11 @@ pub struct OpenToolInput {
|
||||
pub struct OpenTool;
|
||||
|
||||
impl Tool for OpenTool {
|
||||
type Input = OpenToolInput;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"open".to_string()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
true
|
||||
}
|
||||
fn may_perform_edits(&self) -> bool {
|
||||
@@ -43,13 +41,16 @@ impl Tool for OpenTool {
|
||||
json_schema_for::<OpenToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &Self::Input) -> String {
|
||||
format!("Open `{}`", MarkdownEscaped(&input.path_or_url))
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<OpenToolInput>(input.clone()) {
|
||||
Ok(input) => format!("Open `{}`", MarkdownEscaped(&input.path_or_url)),
|
||||
Err(_) => "Open file or URL".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
@@ -57,6 +58,11 @@ impl Tool for OpenTool {
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input: OpenToolInput = match serde_json::from_value(input) {
|
||||
Ok(input) => input,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
|
||||
// If path_or_url turns out to be a path in the project, make it absolute.
|
||||
let abs_path = to_absolute_path(&input.path_or_url, project, cx);
|
||||
|
||||
|
||||
@@ -51,13 +51,11 @@ pub struct ReadFileToolInput {
|
||||
pub struct ReadFileTool;
|
||||
|
||||
impl Tool for ReadFileTool {
|
||||
type Input = ReadFileToolInput;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"read_file".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -77,18 +75,23 @@ impl Tool for ReadFileTool {
|
||||
json_schema_for::<ReadFileToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &Self::Input) -> String {
|
||||
let path = MarkdownInlineCode(&input.path);
|
||||
match (input.start_line, input.end_line) {
|
||||
(Some(start), None) => format!("Read file {path} (from line {start})"),
|
||||
(Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
|
||||
_ => format!("Read file {path}"),
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<ReadFileToolInput>(input.clone()) {
|
||||
Ok(input) => {
|
||||
let path = MarkdownInlineCode(&input.path);
|
||||
match (input.start_line, input.end_line) {
|
||||
(Some(start), None) => format!("Read file {path} (from line {start})"),
|
||||
(Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
|
||||
_ => format!("Read file {path}"),
|
||||
}
|
||||
}
|
||||
Err(_) => "Read file".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
@@ -96,6 +99,11 @@ impl Tool for ReadFileTool {
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<ReadFileToolInput>(input) {
|
||||
Ok(input) => input,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
|
||||
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
|
||||
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
|
||||
};
|
||||
@@ -300,12 +308,9 @@ mod test {
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "root/nonexistent_file.txt".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
|
||||
let input = json!({
|
||||
"path": "root/nonexistent_file.txt"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -342,11 +347,9 @@ mod test {
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "root/small_file.txt".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "root/small_file.txt"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -386,11 +389,9 @@ mod test {
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "root/large_file.rs".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "root/large_file.rs"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -420,11 +421,10 @@ mod test {
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "root/large_file.rs".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "root/large_file.rs",
|
||||
"offset": 1
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -477,11 +477,11 @@ mod test {
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "root/multiline.txt".to_string(),
|
||||
start_line: Some(2),
|
||||
end_line: Some(4),
|
||||
};
|
||||
let input = json!({
|
||||
"path": "root/multiline.txt",
|
||||
"start_line": 2,
|
||||
"end_line": 4
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -520,11 +520,11 @@ mod test {
|
||||
// start_line of 0 should be treated as 1
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "root/multiline.txt".to_string(),
|
||||
start_line: Some(0),
|
||||
end_line: Some(2),
|
||||
};
|
||||
let input = json!({
|
||||
"path": "root/multiline.txt",
|
||||
"start_line": 0,
|
||||
"end_line": 2
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -543,11 +543,11 @@ mod test {
|
||||
// end_line of 0 should result in at least 1 line
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "root/multiline.txt".to_string(),
|
||||
start_line: Some(1),
|
||||
end_line: Some(0),
|
||||
};
|
||||
let input = json!({
|
||||
"path": "root/multiline.txt",
|
||||
"start_line": 1,
|
||||
"end_line": 0
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -566,11 +566,11 @@ mod test {
|
||||
// when start_line > end_line, should still return at least 1 line
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "root/multiline.txt".to_string(),
|
||||
start_line: Some(3),
|
||||
end_line: Some(2),
|
||||
};
|
||||
let input = json!({
|
||||
"path": "root/multiline.txt",
|
||||
"start_line": 3,
|
||||
"end_line": 2
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -694,11 +694,9 @@ mod test {
|
||||
// Reading a file outside the project worktree should fail
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "/outside_project/sensitive_file.txt".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "/outside_project/sensitive_file.txt"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -720,11 +718,9 @@ mod test {
|
||||
// Reading a file within the project should succeed
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "project_root/allowed_file.txt".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "project_root/allowed_file.txt"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -746,11 +742,9 @@ mod test {
|
||||
// Reading files that match file_scan_exclusions should fail
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "project_root/.secretdir/config".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "project_root/.secretdir/config"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -771,11 +765,9 @@ mod test {
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "project_root/.mymetadata".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "project_root/.mymetadata"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -797,11 +789,9 @@ mod test {
|
||||
// Reading private files should fail
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "project_root/secrets/.mysecrets".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "project_root/.mysecrets"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -822,11 +812,9 @@ mod test {
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "project_root/subdir/special.privatekey".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "project_root/subdir/special.privatekey"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -847,11 +835,9 @@ mod test {
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "project_root/subdir/data.mysensitive".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "project_root/subdir/data.mysensitive"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -873,11 +859,9 @@ mod test {
|
||||
// Reading a normal file should still work, even with private_files configured
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "project_root/subdir/normal_file.txt".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "project_root/subdir/normal_file.txt"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -900,11 +884,9 @@ mod test {
|
||||
// Path traversal attempts with .. should fail
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "project_root/../outside_project/sensitive_file.txt".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "project_root/../outside_project/sensitive_file.txt"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
@@ -999,11 +981,9 @@ mod test {
|
||||
let tool = Arc::new(ReadFileTool);
|
||||
|
||||
// Test reading allowed files in worktree1
|
||||
let input = ReadFileToolInput {
|
||||
path: "worktree1/src/main.rs".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "worktree1/src/main.rs"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
@@ -1027,11 +1007,9 @@ mod test {
|
||||
);
|
||||
|
||||
// Test reading private file in worktree1 should fail
|
||||
let input = ReadFileToolInput {
|
||||
path: "worktree1/src/secret.rs".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "worktree1/src/secret.rs"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
@@ -1058,11 +1036,9 @@ mod test {
|
||||
);
|
||||
|
||||
// Test reading excluded file in worktree1 should fail
|
||||
let input = ReadFileToolInput {
|
||||
path: "worktree1/tests/fixture.sql".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "worktree1/tests/fixture.sql"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
@@ -1089,11 +1065,9 @@ mod test {
|
||||
);
|
||||
|
||||
// Test reading allowed files in worktree2
|
||||
let input = ReadFileToolInput {
|
||||
path: "worktree2/lib/public.js".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "worktree2/lib/public.js"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
@@ -1117,11 +1091,9 @@ mod test {
|
||||
);
|
||||
|
||||
// Test reading private file in worktree2 should fail
|
||||
let input = ReadFileToolInput {
|
||||
path: "worktree2/lib/private.js".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "worktree2/lib/private.js"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
@@ -1148,11 +1120,9 @@ mod test {
|
||||
);
|
||||
|
||||
// Test reading excluded file in worktree2 should fail
|
||||
let input = ReadFileToolInput {
|
||||
path: "worktree2/docs/internal.md".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "worktree2/docs/internal.md"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
@@ -1180,11 +1150,9 @@ mod test {
|
||||
|
||||
// Test that files allowed in one worktree but not in another are handled correctly
|
||||
// (e.g., config.toml is private in worktree1 but doesn't exist in worktree2)
|
||||
let input = ReadFileToolInput {
|
||||
path: "worktree1/src/config.toml".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
let input = json!({
|
||||
"path": "worktree1/src/config.toml"
|
||||
});
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
|
||||
@@ -25,7 +25,9 @@ fn schema_to_json(
|
||||
fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
|
||||
let mut generator = match format {
|
||||
LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(),
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
|
||||
// TODO: Gemini docs mention using a subset of OpenAPI 3, so this may benefit from using
|
||||
// `SchemaSettings::openapi3()`.
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::draft07()
|
||||
.with(|settings| {
|
||||
settings.meta_schema = None;
|
||||
settings.inline_subschemas = true;
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::{
|
||||
schema::json_schema_for,
|
||||
ui::{COLLAPSED_LINES, ToolOutputPreview},
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
|
||||
use futures::{FutureExt as _, future::Shared};
|
||||
use gpui::{
|
||||
@@ -72,13 +72,11 @@ impl TerminalTool {
|
||||
}
|
||||
|
||||
impl Tool for TerminalTool {
|
||||
type Input = TerminalToolInput;
|
||||
|
||||
fn name(&self) -> String {
|
||||
Self::NAME.to_string()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
@@ -98,24 +96,30 @@ impl Tool for TerminalTool {
|
||||
json_schema_for::<TerminalToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &Self::Input) -> String {
|
||||
let mut lines = input.command.lines();
|
||||
let first_line = lines.next().unwrap_or_default();
|
||||
let remaining_line_count = lines.count();
|
||||
match remaining_line_count {
|
||||
0 => MarkdownInlineCode(&first_line).to_string(),
|
||||
1 => MarkdownInlineCode(&format!(
|
||||
"{} - {} more line",
|
||||
first_line, remaining_line_count
|
||||
))
|
||||
.to_string(),
|
||||
n => MarkdownInlineCode(&format!("{} - {} more lines", first_line, n)).to_string(),
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<TerminalToolInput>(input.clone()) {
|
||||
Ok(input) => {
|
||||
let mut lines = input.command.lines();
|
||||
let first_line = lines.next().unwrap_or_default();
|
||||
let remaining_line_count = lines.count();
|
||||
match remaining_line_count {
|
||||
0 => MarkdownInlineCode(&first_line).to_string(),
|
||||
1 => MarkdownInlineCode(&format!(
|
||||
"{} - {} more line",
|
||||
first_line, remaining_line_count
|
||||
))
|
||||
.to_string(),
|
||||
n => MarkdownInlineCode(&format!("{} - {} more lines", first_line, n))
|
||||
.to_string(),
|
||||
}
|
||||
}
|
||||
Err(_) => "Run terminal command".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
@@ -123,6 +127,11 @@ impl Tool for TerminalTool {
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input: TerminalToolInput = match serde_json::from_value(input) {
|
||||
Ok(input) => input,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
|
||||
let working_dir = match working_dir(&input, &project, cx) {
|
||||
Ok(dir) => dir,
|
||||
Err(err) => return Task::ready(Err(err)).into(),
|
||||
@@ -747,7 +756,7 @@ mod tests {
|
||||
let result = cx.update(|cx| {
|
||||
TerminalTool::run(
|
||||
Arc::new(TerminalTool::new(cx)),
|
||||
input,
|
||||
serde_json::to_value(input).unwrap(),
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
@@ -782,7 +791,7 @@ mod tests {
|
||||
let check = |input, expected, cx: &mut App| {
|
||||
let headless_result = TerminalTool::run(
|
||||
Arc::new(TerminalTool::new(cx)),
|
||||
input,
|
||||
serde_json::to_value(input).unwrap(),
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::Result;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
@@ -20,13 +20,11 @@ pub struct ThinkingToolInput {
|
||||
pub struct ThinkingTool;
|
||||
|
||||
impl Tool for ThinkingTool {
|
||||
type Input = ThinkingToolInput;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"thinking".to_string()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -46,13 +44,13 @@ impl Tool for ThinkingTool {
|
||||
json_schema_for::<ThinkingToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &Self::Input) -> String {
|
||||
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||
"Thinking".to_string()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
_input: Self::Input,
|
||||
input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
@@ -61,6 +59,10 @@ impl Tool for ThinkingTool {
|
||||
_cx: &mut App,
|
||||
) -> ToolResult {
|
||||
// This tool just "thinks out loud" and doesn't perform any actions.
|
||||
Task::ready(Ok("Finished thinking.".to_string().into())).into()
|
||||
Task::ready(match serde_json::from_value::<ThinkingToolInput>(input) {
|
||||
Ok(_input) => Ok("Finished thinking.".to_string().into()),
|
||||
Err(err) => Err(anyhow!(err)),
|
||||
})
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,13 +28,11 @@ pub struct WebSearchToolInput {
|
||||
pub struct WebSearchTool;
|
||||
|
||||
impl Tool for WebSearchTool {
|
||||
type Input = WebSearchToolInput;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"web_search".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -54,13 +52,13 @@ impl Tool for WebSearchTool {
|
||||
json_schema_for::<WebSearchToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &Self::Input) -> String {
|
||||
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||
"Searching the Web".to_string()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
@@ -68,6 +66,10 @@ impl Tool for WebSearchTool {
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<WebSearchToolInput>(input) {
|
||||
Ok(input) => input,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
|
||||
return Task::ready(Err(anyhow!("Web search is not available."))).into();
|
||||
};
|
||||
|
||||
@@ -28,17 +28,7 @@ use workspace::Workspace;
|
||||
const SHOULD_SHOW_UPDATE_NOTIFICATION_KEY: &str = "auto-updater-should-show-updated-notification";
|
||||
const POLL_INTERVAL: Duration = Duration::from_secs(60 * 60);
|
||||
|
||||
actions!(
|
||||
auto_update,
|
||||
[
|
||||
/// Checks for available updates.
|
||||
Check,
|
||||
/// Dismisses the update error message.
|
||||
DismissErrorMessage,
|
||||
/// Opens the release notes for the current version in a browser.
|
||||
ViewReleaseNotes,
|
||||
]
|
||||
);
|
||||
actions!(auto_update, [Check, DismissErrorMessage, ViewReleaseNotes,]);
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct UpdateRequestBody {
|
||||
|
||||
@@ -12,13 +12,7 @@ use workspace::Workspace;
|
||||
use workspace::notifications::simple_message_notification::MessageNotification;
|
||||
use workspace::notifications::{NotificationId, show_app_notification};
|
||||
|
||||
actions!(
|
||||
auto_update,
|
||||
[
|
||||
/// Opens the release notes for the current version in a new tab.
|
||||
ViewReleaseNotesLocally
|
||||
]
|
||||
);
|
||||
actions!(auto_update, [ViewReleaseNotesLocally]);
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
notify_if_app_was_updated(cx);
|
||||
|
||||
@@ -29,7 +29,7 @@ client.workspace = true
|
||||
collections.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui = { workspace = true, features = ["screen-capture"] }
|
||||
gpui.workspace = true
|
||||
language.workspace = true
|
||||
log.workspace = true
|
||||
postage.workspace = true
|
||||
|
||||
@@ -81,17 +81,7 @@ pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
|
||||
pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10);
|
||||
pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
|
||||
|
||||
actions!(
|
||||
client,
|
||||
[
|
||||
/// Signs in to Zed account.
|
||||
SignIn,
|
||||
/// Signs out of Zed account.
|
||||
SignOut,
|
||||
/// Reconnects to the collaboration server.
|
||||
Reconnect
|
||||
]
|
||||
);
|
||||
actions!(client, [SignIn, SignOut, Reconnect]);
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ClientSettingsContent {
|
||||
|
||||
@@ -35,7 +35,6 @@ dashmap.workspace = true
|
||||
derive_more.workspace = true
|
||||
envy = "0.4.2"
|
||||
futures.workspace = true
|
||||
gpui = { workspace = true, features = ["screen-capture"] }
|
||||
hex.workspace = true
|
||||
http_client.workspace = true
|
||||
jsonwebtoken.workspace = true
|
||||
|
||||
@@ -107,7 +107,7 @@ CREATE INDEX "index_worktree_entries_on_project_id" ON "worktree_entries" ("proj
|
||||
CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id");
|
||||
|
||||
CREATE TABLE "project_repositories" (
|
||||
"project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE,
|
||||
"project_id" INTEGER NOT NULL,
|
||||
"abs_path" VARCHAR,
|
||||
"id" INTEGER NOT NULL,
|
||||
"entry_ids" VARCHAR,
|
||||
@@ -124,7 +124,7 @@ CREATE TABLE "project_repositories" (
|
||||
CREATE INDEX "index_project_repositories_on_project_id" ON "project_repositories" ("project_id");
|
||||
|
||||
CREATE TABLE "project_repository_statuses" (
|
||||
"project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE,
|
||||
"project_id" INTEGER NOT NULL,
|
||||
"repository_id" INTEGER NOT NULL,
|
||||
"repo_path" VARCHAR NOT NULL,
|
||||
"status" INT8 NOT NULL,
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
DELETE FROM project_repositories
|
||||
WHERE project_id NOT IN (SELECT id FROM projects);
|
||||
|
||||
ALTER TABLE project_repositories
|
||||
ADD CONSTRAINT fk_project_repositories_project_id
|
||||
FOREIGN KEY (project_id)
|
||||
REFERENCES projects (id)
|
||||
ON DELETE CASCADE
|
||||
NOT VALID;
|
||||
|
||||
ALTER TABLE project_repositories
|
||||
VALIDATE CONSTRAINT fk_project_repositories_project_id;
|
||||
|
||||
DELETE FROM project_repository_statuses
|
||||
WHERE project_id NOT IN (SELECT id FROM projects);
|
||||
|
||||
ALTER TABLE project_repository_statuses
|
||||
ADD CONSTRAINT fk_project_repository_statuses_project_id
|
||||
FOREIGN KEY (project_id)
|
||||
REFERENCES projects (id)
|
||||
ON DELETE CASCADE
|
||||
NOT VALID;
|
||||
|
||||
ALTER TABLE project_repository_statuses
|
||||
VALIDATE CONSTRAINT fk_project_repository_statuses_project_id;
|
||||
@@ -1404,9 +1404,6 @@ async fn sync_model_request_usage_with_stripe(
|
||||
llm_db: &Arc<LlmDatabase>,
|
||||
stripe_billing: &Arc<StripeBilling>,
|
||||
) -> anyhow::Result<()> {
|
||||
log::info!("Stripe usage sync: Starting");
|
||||
let started_at = Utc::now();
|
||||
|
||||
let staff_users = app.db.get_staff_users().await?;
|
||||
let staff_user_ids = staff_users
|
||||
.iter()
|
||||
@@ -1451,10 +1448,6 @@ async fn sync_model_request_usage_with_stripe(
|
||||
.find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
|
||||
.await?;
|
||||
|
||||
let usage_meter_count = usage_meters.len();
|
||||
|
||||
log::info!("Stripe usage sync: Syncing {usage_meter_count} usage meters");
|
||||
|
||||
for (usage_meter, usage) in usage_meters {
|
||||
maybe!(async {
|
||||
let Some((billing_customer, billing_subscription)) =
|
||||
@@ -1511,10 +1504,5 @@ async fn sync_model_request_usage_with_stripe(
|
||||
.log_err();
|
||||
}
|
||||
|
||||
log::info!(
|
||||
"Stripe usage sync: Synced {usage_meter_count} usage meters in {:?}",
|
||||
Utc::now() - started_at
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -4,19 +4,20 @@ mod tables;
|
||||
#[cfg(test)]
|
||||
pub mod tests;
|
||||
|
||||
use crate::{Error, Result};
|
||||
use crate::{Error, Result, executor::Executor};
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use collections::{BTreeMap, BTreeSet, HashMap, HashSet};
|
||||
use dashmap::DashMap;
|
||||
use futures::StreamExt;
|
||||
use project_repository_statuses::StatusKind;
|
||||
use rand::{Rng, SeedableRng, prelude::StdRng};
|
||||
use rpc::ExtensionProvides;
|
||||
use rpc::{
|
||||
ConnectionId, ExtensionMetadata,
|
||||
proto::{self},
|
||||
};
|
||||
use sea_orm::{
|
||||
ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
|
||||
ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr,
|
||||
FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
|
||||
TransactionTrait,
|
||||
entity::prelude::*,
|
||||
@@ -32,6 +33,7 @@ use std::{
|
||||
ops::{Deref, DerefMut},
|
||||
rc::Rc,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use time::PrimitiveDateTime;
|
||||
use tokio::sync::{Mutex, OwnedMutexGuard};
|
||||
@@ -56,7 +58,6 @@ pub use tables::*;
|
||||
|
||||
#[cfg(test)]
|
||||
pub struct DatabaseTestOptions {
|
||||
pub executor: gpui::BackgroundExecutor,
|
||||
pub runtime: tokio::runtime::Runtime,
|
||||
pub query_failure_probability: parking_lot::Mutex<f64>,
|
||||
}
|
||||
@@ -68,6 +69,8 @@ pub struct Database {
|
||||
pool: DatabaseConnection,
|
||||
rooms: DashMap<RoomId, Arc<Mutex<()>>>,
|
||||
projects: DashMap<ProjectId, Arc<Mutex<()>>>,
|
||||
rng: Mutex<StdRng>,
|
||||
executor: Executor,
|
||||
notification_kinds_by_id: HashMap<NotificationKindId, &'static str>,
|
||||
notification_kinds_by_name: HashMap<String, NotificationKindId>,
|
||||
#[cfg(test)]
|
||||
@@ -78,15 +81,17 @@ pub struct Database {
|
||||
// separate files in the `queries` folder.
|
||||
impl Database {
|
||||
/// Connects to the database with the given options
|
||||
pub async fn new(options: ConnectOptions) -> Result<Self> {
|
||||
pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
|
||||
sqlx::any::install_default_drivers();
|
||||
Ok(Self {
|
||||
options: options.clone(),
|
||||
pool: sea_orm::Database::connect(options).await?,
|
||||
rooms: DashMap::with_capacity(16384),
|
||||
projects: DashMap::with_capacity(16384),
|
||||
rng: Mutex::new(StdRng::seed_from_u64(0)),
|
||||
notification_kinds_by_id: HashMap::default(),
|
||||
notification_kinds_by_name: HashMap::default(),
|
||||
executor,
|
||||
#[cfg(test)]
|
||||
test_options: None,
|
||||
})
|
||||
@@ -102,13 +107,48 @@ impl Database {
|
||||
self.projects.clear();
|
||||
}
|
||||
|
||||
/// Transaction runs things in a transaction. If you want to call other methods
|
||||
/// and pass the transaction around you need to reborrow the transaction at each
|
||||
/// call site with: `&*tx`.
|
||||
pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
|
||||
where
|
||||
F: Send + Fn(TransactionHandle) -> Fut,
|
||||
Fut: Send + Future<Output = Result<T>>,
|
||||
{
|
||||
let body = async {
|
||||
let (tx, result) = self.with_transaction(&f).await?;
|
||||
let mut i = 0;
|
||||
loop {
|
||||
let (tx, result) = self.with_transaction(&f).await?;
|
||||
match result {
|
||||
Ok(result) => match tx.commit().await.map_err(Into::into) {
|
||||
Ok(()) => return Ok(result),
|
||||
Err(error) => {
|
||||
if !self.retry_on_serialization_error(&error, i).await {
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(error) => {
|
||||
tx.rollback().await?;
|
||||
if !self.retry_on_serialization_error(&error, i).await {
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
};
|
||||
|
||||
self.run(body).await
|
||||
}
|
||||
|
||||
pub async fn weak_transaction<F, Fut, T>(&self, f: F) -> Result<T>
|
||||
where
|
||||
F: Send + Fn(TransactionHandle) -> Fut,
|
||||
Fut: Send + Future<Output = Result<T>>,
|
||||
{
|
||||
let body = async {
|
||||
let (tx, result) = self.with_weak_transaction(&f).await?;
|
||||
match result {
|
||||
Ok(result) => match tx.commit().await.map_err(Into::into) {
|
||||
Ok(()) => Ok(result),
|
||||
@@ -134,28 +174,44 @@ impl Database {
|
||||
Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
|
||||
{
|
||||
let body = async {
|
||||
let (tx, result) = self.with_transaction(&f).await?;
|
||||
match result {
|
||||
Ok(Some((room_id, data))) => {
|
||||
let lock = self.rooms.entry(room_id).or_default().clone();
|
||||
let _guard = lock.lock_owned().await;
|
||||
match tx.commit().await.map_err(Into::into) {
|
||||
Ok(()) => Ok(Some(TransactionGuard {
|
||||
data,
|
||||
_guard,
|
||||
_not_send: PhantomData,
|
||||
})),
|
||||
Err(error) => Err(error),
|
||||
let mut i = 0;
|
||||
loop {
|
||||
let (tx, result) = self.with_transaction(&f).await?;
|
||||
match result {
|
||||
Ok(Some((room_id, data))) => {
|
||||
let lock = self.rooms.entry(room_id).or_default().clone();
|
||||
let _guard = lock.lock_owned().await;
|
||||
match tx.commit().await.map_err(Into::into) {
|
||||
Ok(()) => {
|
||||
return Ok(Some(TransactionGuard {
|
||||
data,
|
||||
_guard,
|
||||
_not_send: PhantomData,
|
||||
}));
|
||||
}
|
||||
Err(error) => {
|
||||
if !self.retry_on_serialization_error(&error, i).await {
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None) => match tx.commit().await.map_err(Into::into) {
|
||||
Ok(()) => return Ok(None),
|
||||
Err(error) => {
|
||||
if !self.retry_on_serialization_error(&error, i).await {
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(error) => {
|
||||
tx.rollback().await?;
|
||||
if !self.retry_on_serialization_error(&error, i).await {
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None) => match tx.commit().await.map_err(Into::into) {
|
||||
Ok(()) => Ok(None),
|
||||
Err(error) => Err(error),
|
||||
},
|
||||
Err(error) => {
|
||||
tx.rollback().await?;
|
||||
Err(error)
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -173,26 +229,38 @@ impl Database {
|
||||
{
|
||||
let room_id = Database::room_id_for_project(self, project_id).await?;
|
||||
let body = async {
|
||||
let lock = if let Some(room_id) = room_id {
|
||||
self.rooms.entry(room_id).or_default().clone()
|
||||
} else {
|
||||
self.projects.entry(project_id).or_default().clone()
|
||||
};
|
||||
let _guard = lock.lock_owned().await;
|
||||
let (tx, result) = self.with_transaction(&f).await?;
|
||||
match result {
|
||||
Ok(data) => match tx.commit().await.map_err(Into::into) {
|
||||
Ok(()) => Ok(TransactionGuard {
|
||||
data,
|
||||
_guard,
|
||||
_not_send: PhantomData,
|
||||
}),
|
||||
Err(error) => Err(error),
|
||||
},
|
||||
Err(error) => {
|
||||
tx.rollback().await?;
|
||||
Err(error)
|
||||
let mut i = 0;
|
||||
loop {
|
||||
let lock = if let Some(room_id) = room_id {
|
||||
self.rooms.entry(room_id).or_default().clone()
|
||||
} else {
|
||||
self.projects.entry(project_id).or_default().clone()
|
||||
};
|
||||
let _guard = lock.lock_owned().await;
|
||||
let (tx, result) = self.with_transaction(&f).await?;
|
||||
match result {
|
||||
Ok(data) => match tx.commit().await.map_err(Into::into) {
|
||||
Ok(()) => {
|
||||
return Ok(TransactionGuard {
|
||||
data,
|
||||
_guard,
|
||||
_not_send: PhantomData,
|
||||
});
|
||||
}
|
||||
Err(error) => {
|
||||
if !self.retry_on_serialization_error(&error, i).await {
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(error) => {
|
||||
tx.rollback().await?;
|
||||
if !self.retry_on_serialization_error(&error, i).await {
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -212,22 +280,34 @@ impl Database {
|
||||
Fut: Send + Future<Output = Result<T>>,
|
||||
{
|
||||
let body = async {
|
||||
let lock = self.rooms.entry(room_id).or_default().clone();
|
||||
let _guard = lock.lock_owned().await;
|
||||
let (tx, result) = self.with_transaction(&f).await?;
|
||||
match result {
|
||||
Ok(data) => match tx.commit().await.map_err(Into::into) {
|
||||
Ok(()) => Ok(TransactionGuard {
|
||||
data,
|
||||
_guard,
|
||||
_not_send: PhantomData,
|
||||
}),
|
||||
Err(error) => Err(error),
|
||||
},
|
||||
Err(error) => {
|
||||
tx.rollback().await?;
|
||||
Err(error)
|
||||
let mut i = 0;
|
||||
loop {
|
||||
let lock = self.rooms.entry(room_id).or_default().clone();
|
||||
let _guard = lock.lock_owned().await;
|
||||
let (tx, result) = self.with_transaction(&f).await?;
|
||||
match result {
|
||||
Ok(data) => match tx.commit().await.map_err(Into::into) {
|
||||
Ok(()) => {
|
||||
return Ok(TransactionGuard {
|
||||
data,
|
||||
_guard,
|
||||
_not_send: PhantomData,
|
||||
});
|
||||
}
|
||||
Err(error) => {
|
||||
if !self.retry_on_serialization_error(&error, i).await {
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(error) => {
|
||||
tx.rollback().await?;
|
||||
if !self.retry_on_serialization_error(&error, i).await {
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -235,6 +315,28 @@ impl Database {
|
||||
}
|
||||
|
||||
async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
|
||||
where
|
||||
F: Send + Fn(TransactionHandle) -> Fut,
|
||||
Fut: Send + Future<Output = Result<T>>,
|
||||
{
|
||||
let tx = self
|
||||
.pool
|
||||
.begin_with_config(Some(IsolationLevel::Serializable), None)
|
||||
.await?;
|
||||
|
||||
let mut tx = Arc::new(Some(tx));
|
||||
let result = f(TransactionHandle(tx.clone())).await;
|
||||
let tx = Arc::get_mut(&mut tx)
|
||||
.and_then(|tx| tx.take())
|
||||
.context("couldn't complete transaction because it's still in use")?;
|
||||
|
||||
Ok((tx, result))
|
||||
}
|
||||
|
||||
async fn with_weak_transaction<F, Fut, T>(
|
||||
&self,
|
||||
f: &F,
|
||||
) -> Result<(DatabaseTransaction, Result<T>)>
|
||||
where
|
||||
F: Send + Fn(TransactionHandle) -> Fut,
|
||||
Fut: Send + Future<Output = Result<T>>,
|
||||
@@ -259,13 +361,13 @@ impl Database {
|
||||
{
|
||||
#[cfg(test)]
|
||||
{
|
||||
use rand::prelude::*;
|
||||
|
||||
let test_options = self.test_options.as_ref().unwrap();
|
||||
test_options.executor.simulate_random_delay().await;
|
||||
let fail_probability = *test_options.query_failure_probability.lock();
|
||||
if test_options.executor.rng().gen_bool(fail_probability) {
|
||||
return Err(anyhow!("simulated query failure"))?;
|
||||
if let Executor::Deterministic(executor) = &self.executor {
|
||||
executor.simulate_random_delay().await;
|
||||
let fail_probability = *test_options.query_failure_probability.lock();
|
||||
if executor.rng().gen_bool(fail_probability) {
|
||||
return Err(anyhow!("simulated query failure"))?;
|
||||
}
|
||||
}
|
||||
|
||||
test_options.runtime.block_on(future)
|
||||
@@ -276,6 +378,46 @@ impl Database {
|
||||
future.await
|
||||
}
|
||||
}
|
||||
|
||||
async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: usize) -> bool {
|
||||
// If the error is due to a failure to serialize concurrent transactions, then retry
|
||||
// this transaction after a delay. With each subsequent retry, double the delay duration.
|
||||
// Also vary the delay randomly in order to ensure different database connections retry
|
||||
// at different times.
|
||||
const SLEEPS: [f32; 10] = [10., 20., 40., 80., 160., 320., 640., 1280., 2560., 5120.];
|
||||
if is_serialization_error(error) && prev_attempt_count < SLEEPS.len() {
|
||||
let base_delay = SLEEPS[prev_attempt_count];
|
||||
let randomized_delay = base_delay * self.rng.lock().await.gen_range(0.5..=2.0);
|
||||
log::warn!(
|
||||
"retrying transaction after serialization error. delay: {} ms.",
|
||||
randomized_delay
|
||||
);
|
||||
self.executor
|
||||
.sleep(Duration::from_millis(randomized_delay as u64))
|
||||
.await;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_serialization_error(error: &Error) -> bool {
|
||||
const SERIALIZATION_FAILURE_CODE: &str = "40001";
|
||||
match error {
|
||||
Error::Database(
|
||||
DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
|
||||
| DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
|
||||
) if error
|
||||
.as_database_error()
|
||||
.and_then(|error| error.code())
|
||||
.as_deref()
|
||||
== Some(SERIALIZATION_FAILURE_CODE) =>
|
||||
{
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// A handle to a [`DatabaseTransaction`].
|
||||
|
||||
@@ -20,7 +20,7 @@ impl Database {
|
||||
&self,
|
||||
params: &CreateBillingCustomerParams,
|
||||
) -> Result<billing_customer::Model> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
let customer = billing_customer::Entity::insert(billing_customer::ActiveModel {
|
||||
user_id: ActiveValue::set(params.user_id),
|
||||
stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()),
|
||||
@@ -40,7 +40,7 @@ impl Database {
|
||||
id: BillingCustomerId,
|
||||
params: &UpdateBillingCustomerParams,
|
||||
) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
billing_customer::Entity::update(billing_customer::ActiveModel {
|
||||
id: ActiveValue::set(id),
|
||||
user_id: params.user_id.clone(),
|
||||
@@ -61,7 +61,7 @@ impl Database {
|
||||
&self,
|
||||
id: BillingCustomerId,
|
||||
) -> Result<Option<billing_customer::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
Ok(billing_customer::Entity::find()
|
||||
.filter(billing_customer::Column::Id.eq(id))
|
||||
.one(&*tx)
|
||||
@@ -75,7 +75,7 @@ impl Database {
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<billing_customer::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
Ok(billing_customer::Entity::find()
|
||||
.filter(billing_customer::Column::UserId.eq(user_id))
|
||||
.one(&*tx)
|
||||
@@ -89,7 +89,7 @@ impl Database {
|
||||
&self,
|
||||
stripe_customer_id: &str,
|
||||
) -> Result<Option<billing_customer::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
Ok(billing_customer::Entity::find()
|
||||
.filter(billing_customer::Column::StripeCustomerId.eq(stripe_customer_id))
|
||||
.one(&*tx)
|
||||
|
||||
@@ -22,7 +22,7 @@ impl Database {
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<billing_preference::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
Ok(billing_preference::Entity::find()
|
||||
.filter(billing_preference::Column::UserId.eq(user_id))
|
||||
.one(&*tx)
|
||||
@@ -37,7 +37,7 @@ impl Database {
|
||||
user_id: UserId,
|
||||
params: &CreateBillingPreferencesParams,
|
||||
) -> Result<billing_preference::Model> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
let preferences = billing_preference::Entity::insert(billing_preference::ActiveModel {
|
||||
user_id: ActiveValue::set(user_id),
|
||||
max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
|
||||
@@ -65,7 +65,7 @@ impl Database {
|
||||
user_id: UserId,
|
||||
params: &UpdateBillingPreferencesParams,
|
||||
) -> Result<billing_preference::Model> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
let preferences = billing_preference::Entity::update_many()
|
||||
.set(billing_preference::ActiveModel {
|
||||
max_monthly_llm_usage_spending_in_cents: params
|
||||
|
||||
@@ -35,7 +35,7 @@ impl Database {
|
||||
&self,
|
||||
params: &CreateBillingSubscriptionParams,
|
||||
) -> Result<billing_subscription::Model> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
|
||||
billing_customer_id: ActiveValue::set(params.billing_customer_id),
|
||||
kind: ActiveValue::set(params.kind),
|
||||
@@ -64,7 +64,7 @@ impl Database {
|
||||
id: BillingSubscriptionId,
|
||||
params: &UpdateBillingSubscriptionParams,
|
||||
) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
billing_subscription::Entity::update(billing_subscription::ActiveModel {
|
||||
id: ActiveValue::set(id),
|
||||
billing_customer_id: params.billing_customer_id.clone(),
|
||||
@@ -90,7 +90,7 @@ impl Database {
|
||||
&self,
|
||||
id: BillingSubscriptionId,
|
||||
) -> Result<Option<billing_subscription::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
Ok(billing_subscription::Entity::find_by_id(id)
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
@@ -103,7 +103,7 @@ impl Database {
|
||||
&self,
|
||||
stripe_subscription_id: &str,
|
||||
) -> Result<Option<billing_subscription::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
Ok(billing_subscription::Entity::find()
|
||||
.filter(
|
||||
billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id),
|
||||
@@ -118,7 +118,7 @@ impl Database {
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<billing_subscription::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
Ok(billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.filter(billing_customer::Column::UserId.eq(user_id))
|
||||
@@ -152,7 +152,7 @@ impl Database {
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> Result<Vec<billing_subscription::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
let subscriptions = billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.filter(billing_customer::Column::UserId.eq(user_id))
|
||||
@@ -169,7 +169,7 @@ impl Database {
|
||||
&self,
|
||||
user_ids: HashSet<UserId>,
|
||||
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
|
||||
self.transaction(|tx| {
|
||||
self.weak_transaction(|tx| {
|
||||
let user_ids = user_ids.clone();
|
||||
async move {
|
||||
let mut rows = billing_subscription::Entity::find()
|
||||
@@ -201,7 +201,7 @@ impl Database {
|
||||
&self,
|
||||
user_ids: HashSet<UserId>,
|
||||
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
|
||||
self.transaction(|tx| {
|
||||
self.weak_transaction(|tx| {
|
||||
let user_ids = user_ids.clone();
|
||||
async move {
|
||||
let mut rows = billing_subscription::Entity::find()
|
||||
@@ -236,7 +236,7 @@ impl Database {
|
||||
|
||||
/// Returns the count of the active billing subscriptions for the user with the specified ID.
|
||||
pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
let count = billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.filter(
|
||||
|
||||
@@ -501,8 +501,10 @@ impl Database {
|
||||
|
||||
/// Returns all channels for the user with the given ID.
|
||||
pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
|
||||
self.transaction(|tx| async move { self.get_user_channels(user_id, None, true, &tx).await })
|
||||
.await
|
||||
self.weak_transaction(
|
||||
|tx| async move { self.get_user_channels(user_id, None, true, &tx).await },
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns all channels for the user with the given ID that are descendants
|
||||
|
||||
@@ -15,7 +15,7 @@ impl Database {
|
||||
user_b_busy: bool,
|
||||
}
|
||||
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
let user_a_participant = Alias::new("user_a_participant");
|
||||
let user_b_participant = Alias::new("user_b_participant");
|
||||
let mut db_contacts = contact::Entity::find()
|
||||
@@ -91,7 +91,7 @@ impl Database {
|
||||
|
||||
/// Returns whether the given user is a busy (on a call).
|
||||
pub async fn is_user_busy(&self, user_id: UserId) -> Result<bool> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
let participant = room_participant::Entity::find()
|
||||
.filter(room_participant::Column::UserId.eq(user_id))
|
||||
.one(&*tx)
|
||||
|
||||
@@ -9,7 +9,7 @@ pub enum ContributorSelector {
|
||||
impl Database {
|
||||
/// Retrieves the GitHub logins of all users who have signed the CLA.
|
||||
pub async fn get_contributors(&self) -> Result<Vec<String>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
|
||||
enum QueryGithubLogin {
|
||||
GithubLogin,
|
||||
@@ -32,7 +32,7 @@ impl Database {
|
||||
&self,
|
||||
selector: &ContributorSelector,
|
||||
) -> Result<Option<DateTime>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
let condition = match selector {
|
||||
ContributorSelector::GitHubUserId { github_user_id } => {
|
||||
user::Column::GithubUserId.eq(*github_user_id)
|
||||
@@ -69,7 +69,7 @@ impl Database {
|
||||
github_user_created_at: DateTimeUtc,
|
||||
initial_channel_id: Option<ChannelId>,
|
||||
) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
let user = self
|
||||
.update_or_create_user_by_github_account_tx(
|
||||
github_login,
|
||||
|
||||
@@ -8,7 +8,7 @@ impl Database {
|
||||
model: &str,
|
||||
digests: &[Vec<u8>],
|
||||
) -> Result<HashMap<Vec<u8>, Vec<f32>>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
let embeddings = {
|
||||
let mut db_embeddings = embedding::Entity::find()
|
||||
.filter(
|
||||
@@ -52,7 +52,7 @@ impl Database {
|
||||
model: &str,
|
||||
embeddings: &HashMap<Vec<u8>, Vec<f32>>,
|
||||
) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
embedding::Entity::insert_many(embeddings.iter().map(|(digest, dimensions)| {
|
||||
let now_offset_datetime = OffsetDateTime::now_utc();
|
||||
let retrieved_at =
|
||||
@@ -78,7 +78,7 @@ impl Database {
|
||||
}
|
||||
|
||||
pub async fn purge_old_embeddings(&self) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
embedding::Entity::delete_many()
|
||||
.filter(
|
||||
embedding::Column::RetrievedAt
|
||||
|
||||
@@ -15,7 +15,7 @@ impl Database {
|
||||
max_schema_version: i32,
|
||||
limit: usize,
|
||||
) -> Result<Vec<ExtensionMetadata>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
let mut condition = Condition::all()
|
||||
.add(
|
||||
extension::Column::LatestVersion
|
||||
@@ -43,7 +43,7 @@ impl Database {
|
||||
ids: &[&str],
|
||||
constraints: Option<&ExtensionVersionConstraints>,
|
||||
) -> Result<Vec<ExtensionMetadata>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
let extensions = extension::Entity::find()
|
||||
.filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
|
||||
.all(&*tx)
|
||||
@@ -123,7 +123,7 @@ impl Database {
|
||||
&self,
|
||||
extension_id: &str,
|
||||
) -> Result<Vec<ExtensionMetadata>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
let condition = extension::Column::ExternalId
|
||||
.eq(extension_id)
|
||||
.into_condition();
|
||||
@@ -162,7 +162,7 @@ impl Database {
|
||||
extension_id: &str,
|
||||
constraints: Option<&ExtensionVersionConstraints>,
|
||||
) -> Result<Option<ExtensionMetadata>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
let extension = extension::Entity::find()
|
||||
.filter(extension::Column::ExternalId.eq(extension_id))
|
||||
.one(&*tx)
|
||||
@@ -187,7 +187,7 @@ impl Database {
|
||||
extension_id: &str,
|
||||
version: &str,
|
||||
) -> Result<Option<ExtensionMetadata>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
let extension = extension::Entity::find()
|
||||
.filter(extension::Column::ExternalId.eq(extension_id))
|
||||
.filter(extension_version::Column::Version.eq(version))
|
||||
@@ -204,7 +204,7 @@ impl Database {
|
||||
}
|
||||
|
||||
pub async fn get_known_extension_versions(&self) -> Result<HashMap<String, Vec<String>>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
let mut extension_external_ids_by_id = HashMap::default();
|
||||
|
||||
let mut rows = extension::Entity::find().stream(&*tx).await?;
|
||||
@@ -242,7 +242,7 @@ impl Database {
|
||||
&self,
|
||||
versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
|
||||
) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
for (external_id, versions) in versions_by_extension_id {
|
||||
if versions.is_empty() {
|
||||
continue;
|
||||
@@ -349,7 +349,7 @@ impl Database {
|
||||
}
|
||||
|
||||
pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
|
||||
enum QueryId {
|
||||
Id,
|
||||
|
||||
@@ -13,7 +13,7 @@ impl Database {
|
||||
&self,
|
||||
params: &CreateProcessedStripeEventParams,
|
||||
) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
processed_stripe_event::Entity::insert(processed_stripe_event::ActiveModel {
|
||||
stripe_event_id: ActiveValue::set(params.stripe_event_id.clone()),
|
||||
stripe_event_type: ActiveValue::set(params.stripe_event_type.clone()),
|
||||
@@ -35,7 +35,7 @@ impl Database {
|
||||
&self,
|
||||
event_id: &str,
|
||||
) -> Result<Option<processed_stripe_event::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
Ok(processed_stripe_event::Entity::find_by_id(event_id)
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
@@ -48,7 +48,7 @@ impl Database {
|
||||
&self,
|
||||
event_ids: &[&str],
|
||||
) -> Result<Vec<processed_stripe_event::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
Ok(processed_stripe_event::Entity::find()
|
||||
.filter(
|
||||
processed_stripe_event::Column::StripeEventId.is_in(event_ids.iter().copied()),
|
||||
|
||||
@@ -112,7 +112,7 @@ impl Database {
|
||||
}
|
||||
|
||||
pub async fn delete_project(&self, project_id: ProjectId) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
project::Entity::delete_by_id(project_id).exec(&*tx).await?;
|
||||
Ok(())
|
||||
})
|
||||
|
||||
@@ -80,7 +80,7 @@ impl Database {
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<proto::IncomingCall>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
let pending_participant = room_participant::Entity::find()
|
||||
.filter(
|
||||
room_participant::Column::UserId
|
||||
|
||||
@@ -142,50 +142,6 @@ impl Database {
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
let delete_query = Query::delete()
|
||||
.from_table(project_repository_statuses::Entity)
|
||||
.and_where(
|
||||
Expr::tuple([Expr::col((
|
||||
project_repository_statuses::Entity,
|
||||
project_repository_statuses::Column::ProjectId,
|
||||
))
|
||||
.into()])
|
||||
.in_subquery(
|
||||
Query::select()
|
||||
.columns([(
|
||||
project_repository_statuses::Entity,
|
||||
project_repository_statuses::Column::ProjectId,
|
||||
)])
|
||||
.from(project_repository_statuses::Entity)
|
||||
.inner_join(
|
||||
project::Entity,
|
||||
Expr::col((project::Entity, project::Column::Id)).equals((
|
||||
project_repository_statuses::Entity,
|
||||
project_repository_statuses::Column::ProjectId,
|
||||
)),
|
||||
)
|
||||
.and_where(project::Column::HostConnectionServerId.ne(server_id))
|
||||
.limit(10000)
|
||||
.to_owned(),
|
||||
),
|
||||
)
|
||||
.to_owned();
|
||||
|
||||
let statement = Statement::from_sql_and_values(
|
||||
tx.get_database_backend(),
|
||||
delete_query
|
||||
.to_string(sea_orm::sea_query::PostgresQueryBuilder)
|
||||
.as_str(),
|
||||
vec![],
|
||||
);
|
||||
|
||||
let result = tx.execute(statement).await?;
|
||||
if result.rows_affected() == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
|
||||
@@ -382,7 +382,7 @@ impl Database {
|
||||
|
||||
/// Returns the active flags for the user.
|
||||
pub async fn get_user_flags(&self, user: UserId) -> Result<Vec<String>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.weak_transaction(|tx| async move {
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
|
||||
enum QueryAs {
|
||||
Flag,
|
||||
|
||||
@@ -17,15 +17,11 @@ use crate::migrations::run_database_migrations;
|
||||
use super::*;
|
||||
use gpui::BackgroundExecutor;
|
||||
use parking_lot::Mutex;
|
||||
use rand::prelude::*;
|
||||
use sea_orm::ConnectionTrait;
|
||||
use sqlx::migrate::MigrateDatabase;
|
||||
use std::{
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
|
||||
},
|
||||
time::Duration,
|
||||
use std::sync::{
|
||||
Arc,
|
||||
atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
|
||||
};
|
||||
|
||||
pub struct TestDb {
|
||||
@@ -45,7 +41,9 @@ impl TestDb {
|
||||
let mut db = runtime.block_on(async {
|
||||
let mut options = ConnectOptions::new(url);
|
||||
options.max_connections(5);
|
||||
let mut db = Database::new(options).await.unwrap();
|
||||
let mut db = Database::new(options, Executor::Deterministic(executor.clone()))
|
||||
.await
|
||||
.unwrap();
|
||||
let sql = include_str!(concat!(
|
||||
env!("CARGO_MANIFEST_DIR"),
|
||||
"/migrations.sqlite/20221109000000_test_schema.sql"
|
||||
@@ -62,7 +60,6 @@ impl TestDb {
|
||||
});
|
||||
|
||||
db.test_options = Some(DatabaseTestOptions {
|
||||
executor,
|
||||
runtime,
|
||||
query_failure_probability: parking_lot::Mutex::new(0.0),
|
||||
});
|
||||
@@ -96,7 +93,9 @@ impl TestDb {
|
||||
options
|
||||
.max_connections(5)
|
||||
.idle_timeout(Duration::from_secs(0));
|
||||
let mut db = Database::new(options).await.unwrap();
|
||||
let mut db = Database::new(options, Executor::Deterministic(executor.clone()))
|
||||
.await
|
||||
.unwrap();
|
||||
let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
|
||||
run_database_migrations(db.options(), migrations_path)
|
||||
.await
|
||||
@@ -106,7 +105,6 @@ impl TestDb {
|
||||
});
|
||||
|
||||
db.test_options = Some(DatabaseTestOptions {
|
||||
executor,
|
||||
runtime,
|
||||
query_failure_probability: parking_lot::Mutex::new(0.0),
|
||||
});
|
||||
|
||||
@@ -49,7 +49,7 @@ async fn test_purge_old_embeddings(cx: &mut gpui::TestAppContext) {
|
||||
db.save_embeddings(model, &embeddings).await.unwrap();
|
||||
|
||||
// Reach into the DB and change the retrieved at to be > 60 days
|
||||
db.transaction(|tx| {
|
||||
db.weak_transaction(|tx| {
|
||||
let digest = digest.clone();
|
||||
async move {
|
||||
let sixty_days_ago = OffsetDateTime::now_utc().sub(Duration::days(61));
|
||||
|
||||
@@ -285,7 +285,7 @@ impl AppState {
|
||||
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
|
||||
let mut db_options = db::ConnectOptions::new(config.database_url.clone());
|
||||
db_options.max_connections(config.database_max_connections);
|
||||
let mut db = Database::new(db_options).await?;
|
||||
let mut db = Database::new(db_options, Executor::Production).await?;
|
||||
db.initialize_notification_kinds().await?;
|
||||
|
||||
let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config
|
||||
|
||||
@@ -59,7 +59,7 @@ async fn main() -> Result<()> {
|
||||
let config = envy::from_env::<Config>().expect("error loading config");
|
||||
let db_options = db::ConnectOptions::new(config.database_url.clone());
|
||||
|
||||
let mut db = Database::new(db_options).await?;
|
||||
let mut db = Database::new(db_options, Executor::Production).await?;
|
||||
db.initialize_notification_kinds().await?;
|
||||
|
||||
collab::seed::seed(&config, &db, false).await?;
|
||||
@@ -253,7 +253,7 @@ async fn main() -> Result<()> {
|
||||
|
||||
async fn setup_app_database(config: &Config) -> Result<()> {
|
||||
let db_options = db::ConnectOptions::new(config.database_url.clone());
|
||||
let mut db = Database::new(db_options).await?;
|
||||
let mut db = Database::new(db_options, Executor::Production).await?;
|
||||
|
||||
let migrations_path = config.migrations_path.as_deref().unwrap_or_else(|| {
|
||||
#[cfg(feature = "sqlite")]
|
||||
|
||||
@@ -22,9 +22,7 @@ use gpui::{
|
||||
use language::{
|
||||
Diagnostic, DiagnosticEntry, DiagnosticSourceKind, FakeLspAdapter, Language, LanguageConfig,
|
||||
LanguageMatcher, LineEnding, OffsetRangeExt, Point, Rope,
|
||||
language_settings::{
|
||||
AllLanguageSettings, Formatter, FormatterList, PrettierSettings, SelectedFormatter,
|
||||
},
|
||||
language_settings::{AllLanguageSettings, Formatter, PrettierSettings, SelectedFormatter},
|
||||
tree_sitter_rust, tree_sitter_typescript,
|
||||
};
|
||||
use lsp::{LanguageServerId, OneOf};
|
||||
@@ -4591,14 +4589,13 @@ async fn test_formatting_buffer(
|
||||
cx_a.update(|cx| {
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.update_user_settings::<AllLanguageSettings>(cx, |file| {
|
||||
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList::Single(
|
||||
Formatter::External {
|
||||
file.defaults.formatter =
|
||||
Some(SelectedFormatter::List(vec![Formatter::External {
|
||||
command: "awk".into(),
|
||||
arguments: Some(
|
||||
vec!["{sub(/two/,\"{buffer_path}\")}1".to_string()].into(),
|
||||
),
|
||||
},
|
||||
)));
|
||||
}]));
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -4698,9 +4695,10 @@ async fn test_prettier_formatting_buffer(
|
||||
cx_b.update(|cx| {
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.update_user_settings::<AllLanguageSettings>(cx, |file| {
|
||||
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList::Single(
|
||||
Formatter::LanguageServer { name: None },
|
||||
)));
|
||||
file.defaults.formatter =
|
||||
Some(SelectedFormatter::List(vec![Formatter::LanguageServer {
|
||||
name: None,
|
||||
}]));
|
||||
file.defaults.prettier = Some(PrettierSettings {
|
||||
allowed: true,
|
||||
..PrettierSettings::default()
|
||||
@@ -4821,7 +4819,7 @@ async fn test_definition(
|
||||
);
|
||||
|
||||
let definitions_1 = project_b
|
||||
.update(cx_b, |p, cx| p.definitions(&buffer_b, 23, cx))
|
||||
.update(cx_b, |p, cx| p.definition(&buffer_b, 23, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
cx_b.read(|cx| {
|
||||
@@ -4852,7 +4850,7 @@ async fn test_definition(
|
||||
);
|
||||
|
||||
let definitions_2 = project_b
|
||||
.update(cx_b, |p, cx| p.definitions(&buffer_b, 33, cx))
|
||||
.update(cx_b, |p, cx| p.definition(&buffer_b, 33, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
cx_b.read(|cx| {
|
||||
@@ -4889,7 +4887,7 @@ async fn test_definition(
|
||||
);
|
||||
|
||||
let type_definitions = project_b
|
||||
.update(cx_b, |p, cx| p.type_definitions(&buffer_b, 7, cx))
|
||||
.update(cx_b, |p, cx| p.type_definition(&buffer_b, 7, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
cx_b.read(|cx| {
|
||||
@@ -5057,7 +5055,7 @@ async fn test_references(
|
||||
lsp_response_tx
|
||||
.unbounded_send(Err(anyhow!("can't find references")))
|
||||
.unwrap();
|
||||
assert_eq!(references.await.unwrap(), []);
|
||||
references.await.unwrap_err();
|
||||
|
||||
// User is informed that the request is no longer pending.
|
||||
executor.run_until_parked();
|
||||
@@ -5641,7 +5639,7 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it(
|
||||
let definitions;
|
||||
let buffer_b2;
|
||||
if rng.r#gen() {
|
||||
definitions = project_b.update(cx_b, |p, cx| p.definitions(&buffer_b1, 23, cx));
|
||||
definitions = project_b.update(cx_b, |p, cx| p.definition(&buffer_b1, 23, cx));
|
||||
(buffer_b2, _) = project_b
|
||||
.update(cx_b, |p, cx| {
|
||||
p.open_buffer_with_lsp((worktree_id, "b.rs"), cx)
|
||||
@@ -5655,7 +5653,7 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it(
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
definitions = project_b.update(cx_b, |p, cx| p.definitions(&buffer_b1, 23, cx));
|
||||
definitions = project_b.update(cx_b, |p, cx| p.definition(&buffer_b1, 23, cx));
|
||||
}
|
||||
|
||||
let definitions = definitions.await.unwrap();
|
||||
|
||||
@@ -838,7 +838,7 @@ impl RandomizedTest for ProjectCollaborationTest {
|
||||
.map(|_| Ok(()))
|
||||
.boxed(),
|
||||
LspRequestKind::Definition => project
|
||||
.definitions(&buffer, offset, cx)
|
||||
.definition(&buffer, offset, cx)
|
||||
.map_ok(|_| ())
|
||||
.boxed(),
|
||||
LspRequestKind::Highlights => project
|
||||
|
||||
@@ -14,8 +14,7 @@ use http_client::BlockedHttpClient;
|
||||
use language::{
|
||||
FakeLspAdapter, Language, LanguageConfig, LanguageMatcher, LanguageRegistry,
|
||||
language_settings::{
|
||||
AllLanguageSettings, Formatter, FormatterList, PrettierSettings, SelectedFormatter,
|
||||
language_settings,
|
||||
AllLanguageSettings, Formatter, PrettierSettings, SelectedFormatter, language_settings,
|
||||
},
|
||||
tree_sitter_typescript,
|
||||
};
|
||||
@@ -505,9 +504,10 @@ async fn test_ssh_collaboration_formatting_with_prettier(
|
||||
cx_b.update(|cx| {
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.update_user_settings::<AllLanguageSettings>(cx, |file| {
|
||||
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList::Single(
|
||||
Formatter::LanguageServer { name: None },
|
||||
)));
|
||||
file.defaults.formatter =
|
||||
Some(SelectedFormatter::List(vec![Formatter::LanguageServer {
|
||||
name: None,
|
||||
}]));
|
||||
file.defaults.prettier = Some(PrettierSettings {
|
||||
allowed: true,
|
||||
..PrettierSettings::default()
|
||||
|
||||
@@ -30,13 +30,7 @@ use workspace::{
|
||||
};
|
||||
use workspace::{item::Dedup, notifications::NotificationId};
|
||||
|
||||
actions!(
|
||||
collab,
|
||||
[
|
||||
/// Copies a link to the current position in the channel buffer.
|
||||
CopyLink
|
||||
]
|
||||
);
|
||||
actions!(collab, [CopyLink]);
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
workspace::FollowableViewRegistry::register::<ChannelView>(cx)
|
||||
|
||||
@@ -71,13 +71,7 @@ struct SerializedChatPanel {
|
||||
width: Option<Pixels>,
|
||||
}
|
||||
|
||||
actions!(
|
||||
chat_panel,
|
||||
[
|
||||
/// Toggles focus on the chat panel.
|
||||
ToggleFocus
|
||||
]
|
||||
);
|
||||
actions!(chat_panel, [ToggleFocus]);
|
||||
|
||||
impl ChatPanel {
|
||||
pub fn new(
|
||||
|
||||
@@ -44,25 +44,15 @@ use workspace::{
|
||||
actions!(
|
||||
collab_panel,
|
||||
[
|
||||
/// Toggles focus on the collaboration panel.
|
||||
ToggleFocus,
|
||||
/// Removes the selected channel or contact.
|
||||
Remove,
|
||||
/// Opens the context menu for the selected item.
|
||||
Secondary,
|
||||
/// Collapses the selected channel in the tree view.
|
||||
CollapseSelectedChannel,
|
||||
/// Expands the selected channel in the tree view.
|
||||
ExpandSelectedChannel,
|
||||
/// Starts moving a channel to a new location.
|
||||
StartMoveChannel,
|
||||
/// Moves the selected item to the current location.
|
||||
MoveSelected,
|
||||
/// Inserts a space character in the filter input.
|
||||
InsertSpace,
|
||||
/// Moves the selected channel up in the list.
|
||||
MoveChannelUp,
|
||||
/// Moves the selected channel down in the list.
|
||||
MoveChannelDown,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -17,13 +17,9 @@ use workspace::{ModalView, notifications::DetachAndPromptErr};
|
||||
actions!(
|
||||
channel_modal,
|
||||
[
|
||||
/// Selects the next control in the channel modal.
|
||||
SelectNextControl,
|
||||
/// Toggles between invite members and manage members mode.
|
||||
ToggleMode,
|
||||
/// Toggles admin status for the selected member.
|
||||
ToggleMemberAdmin,
|
||||
/// Removes the selected member from the channel.
|
||||
RemoveMember
|
||||
]
|
||||
);
|
||||
|
||||
@@ -74,13 +74,7 @@ pub struct NotificationPresenter {
|
||||
pub can_navigate: bool,
|
||||
}
|
||||
|
||||
actions!(
|
||||
notification_panel,
|
||||
[
|
||||
/// Toggles focus on the notification panel.
|
||||
ToggleFocus
|
||||
]
|
||||
);
|
||||
actions!(notification_panel, [ToggleFocus]);
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
cx.observe_new(|workspace: &mut Workspace, _, _| {
|
||||
|
||||
@@ -61,7 +61,7 @@ impl RenderOnce for ComponentExample {
|
||||
12.0,
|
||||
12.0,
|
||||
))
|
||||
.shadow_xs()
|
||||
.shadow_sm()
|
||||
.child(self.element),
|
||||
)
|
||||
.into_any_element()
|
||||
|
||||
@@ -46,17 +46,11 @@ pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in, reinstall_an
|
||||
actions!(
|
||||
copilot,
|
||||
[
|
||||
/// Requests a code completion suggestion from Copilot.
|
||||
Suggest,
|
||||
/// Cycles to the next Copilot suggestion.
|
||||
NextSuggestion,
|
||||
/// Cycles to the previous Copilot suggestion.
|
||||
PreviousSuggestion,
|
||||
/// Reinstalls the Copilot language server.
|
||||
Reinstall,
|
||||
/// Signs in to GitHub Copilot.
|
||||
SignIn,
|
||||
/// Signs out of GitHub Copilot.
|
||||
SignOut
|
||||
]
|
||||
);
|
||||
|
||||
@@ -79,9 +79,9 @@ impl JsDebugAdapter {
|
||||
let command = configuration.get("command")?.as_str()?.to_owned();
|
||||
let mut args = shlex::split(&command)?.into_iter();
|
||||
let program = args.next()?;
|
||||
configuration.insert("runtimeExecutable".to_owned(), program.into());
|
||||
configuration.insert("program".to_owned(), program.into());
|
||||
configuration.insert(
|
||||
"runtimeArgs".to_owned(),
|
||||
"args".to_owned(),
|
||||
args.map(Value::from).collect::<Vec<_>>().into(),
|
||||
);
|
||||
configuration.insert("console".to_owned(), "externalTerminal".into());
|
||||
@@ -522,11 +522,7 @@ impl DebugAdapter for JsDebugAdapter {
|
||||
}
|
||||
|
||||
fn label_for_child_session(&self, args: &StartDebuggingRequestArguments) -> Option<String> {
|
||||
let label = args
|
||||
.configuration
|
||||
.get("name")?
|
||||
.as_str()
|
||||
.filter(|name| !name.is_empty())?;
|
||||
let label = args.configuration.get("name")?.as_str()?;
|
||||
Some(label.to_owned())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -660,15 +660,6 @@ impl DebugAdapter for PythonDebugAdapter {
|
||||
self.get_installed_binary(delegate, &config, None, user_args, toolchain, false)
|
||||
.await
|
||||
}
|
||||
|
||||
fn label_for_child_session(&self, args: &StartDebuggingRequestArguments) -> Option<String> {
|
||||
let label = args
|
||||
.configuration
|
||||
.get("name")?
|
||||
.as_str()
|
||||
.filter(|label| !label.is_empty())?;
|
||||
Some(label.to_owned())
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_latest_adapter_version_from_github(
|
||||
|
||||
@@ -918,13 +918,7 @@ impl Render for DapLogView {
|
||||
}
|
||||
}
|
||||
|
||||
actions!(
|
||||
dev,
|
||||
[
|
||||
/// Opens the debug adapter protocol logs viewer.
|
||||
OpenDebugAdapterLogs
|
||||
]
|
||||
);
|
||||
actions!(dev, [OpenDebugAdapterLogs]);
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
let log_store = cx.new(|cx| LogStore::new(cx));
|
||||
|
||||
@@ -5,7 +5,7 @@ use crate::session::running::breakpoint_list::BreakpointList;
|
||||
use crate::{
|
||||
ClearAllBreakpoints, Continue, CopyDebugAdapterArguments, Detach, FocusBreakpointList,
|
||||
FocusConsole, FocusFrames, FocusLoadedSources, FocusModules, FocusTerminal, FocusVariables,
|
||||
NewProcessModal, NewProcessMode, Pause, RerunSession, StepInto, StepOut, StepOver, Stop,
|
||||
NewProcessModal, NewProcessMode, Pause, Restart, StepInto, StepOut, StepOver, Stop,
|
||||
ToggleExpandItem, ToggleSessionPicker, ToggleThreadPicker, persistence, spawn_task_or_modal,
|
||||
};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
@@ -25,7 +25,7 @@ use gpui::{
|
||||
use itertools::Itertools as _;
|
||||
use language::Buffer;
|
||||
use project::debugger::session::{Session, SessionStateEvent};
|
||||
use project::{DebugScenarioContext, Fs, ProjectPath, WorktreeId};
|
||||
use project::{Fs, ProjectPath, WorktreeId};
|
||||
use project::{Project, debugger::session::ThreadStatus};
|
||||
use rpc::proto::{self};
|
||||
use settings::Settings;
|
||||
@@ -197,7 +197,6 @@ impl DebugPanel {
|
||||
.and_then(|buffer| buffer.read(cx).file())
|
||||
.map(|f| f.worktree_id(cx))
|
||||
});
|
||||
|
||||
let Some(worktree) = worktree
|
||||
.and_then(|id| self.project.read(cx).worktree_for_id(id, cx))
|
||||
.or_else(|| self.project.read(cx).visible_worktrees(cx).next())
|
||||
@@ -205,7 +204,6 @@ impl DebugPanel {
|
||||
log::debug!("Could not find a worktree to spawn the debug session in");
|
||||
return;
|
||||
};
|
||||
|
||||
self.debug_scenario_scheduled_last = true;
|
||||
if let Some(inventory) = self
|
||||
.project
|
||||
@@ -216,15 +214,7 @@ impl DebugPanel {
|
||||
.cloned()
|
||||
{
|
||||
inventory.update(cx, |inventory, _| {
|
||||
inventory.scenario_scheduled(
|
||||
scenario.clone(),
|
||||
// todo(debugger): Task context is cloned three times
|
||||
// once in Session,inventory, and in resolve scenario
|
||||
// we should wrap it in an RC instead to save some memory
|
||||
task_context.clone(),
|
||||
worktree_id,
|
||||
active_buffer.as_ref().map(|buffer| buffer.downgrade()),
|
||||
);
|
||||
inventory.scenario_scheduled(scenario.clone());
|
||||
})
|
||||
}
|
||||
let task = cx.spawn_in(window, {
|
||||
@@ -235,16 +225,6 @@ impl DebugPanel {
|
||||
let definition = debug_session
|
||||
.update_in(cx, |debug_session, window, cx| {
|
||||
debug_session.running_state().update(cx, |running, cx| {
|
||||
if scenario.build.is_some() {
|
||||
running.scenario = Some(scenario.clone());
|
||||
running.scenario_context = Some(DebugScenarioContext {
|
||||
active_buffer: active_buffer
|
||||
.as_ref()
|
||||
.map(|entity| entity.downgrade()),
|
||||
task_context: task_context.clone(),
|
||||
worktree_id: worktree_id,
|
||||
});
|
||||
};
|
||||
running.resolve_scenario(
|
||||
scenario,
|
||||
task_context,
|
||||
@@ -293,8 +273,7 @@ impl DebugPanel {
|
||||
return;
|
||||
};
|
||||
let workspace = self.workspace.clone();
|
||||
let Some((scenario, context)) = task_inventory.read(cx).last_scheduled_scenario().cloned()
|
||||
else {
|
||||
let Some(scenario) = task_inventory.read(cx).last_scheduled_scenario().cloned() else {
|
||||
window.defer(cx, move |window, cx| {
|
||||
workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
@@ -305,22 +284,28 @@ impl DebugPanel {
|
||||
return;
|
||||
};
|
||||
|
||||
let DebugScenarioContext {
|
||||
task_context,
|
||||
worktree_id,
|
||||
active_buffer,
|
||||
} = context;
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let task_contexts = workspace
|
||||
.update_in(cx, |workspace, window, cx| {
|
||||
tasks_ui::task_contexts(workspace, window, cx)
|
||||
})?
|
||||
.await;
|
||||
|
||||
let active_buffer = active_buffer.and_then(|buffer| buffer.upgrade());
|
||||
let task_context = task_contexts.active_context().cloned().unwrap_or_default();
|
||||
let worktree_id = task_contexts.worktree();
|
||||
|
||||
self.start_session(
|
||||
scenario,
|
||||
task_context,
|
||||
active_buffer,
|
||||
worktree_id,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.start_session(
|
||||
scenario.clone(),
|
||||
task_context,
|
||||
None,
|
||||
worktree_id,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
})
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
pub(crate) async fn register_session(
|
||||
@@ -773,16 +758,16 @@ impl DebugPanel {
|
||||
.icon_size(IconSize::XSmall)
|
||||
.on_click(window.listener_for(
|
||||
&running_state,
|
||||
|this, _, window, cx| {
|
||||
this.rerun_session(window, cx);
|
||||
|this, _, _window, cx| {
|
||||
this.restart_session(cx);
|
||||
},
|
||||
))
|
||||
.tooltip({
|
||||
let focus_handle = focus_handle.clone();
|
||||
move |window, cx| {
|
||||
Tooltip::for_action_in(
|
||||
"Rerun Session",
|
||||
&RerunSession,
|
||||
"Restart",
|
||||
&Restart,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
@@ -1313,13 +1298,11 @@ impl Render for DebugPanel {
|
||||
}
|
||||
|
||||
v_flex()
|
||||
.when(!self.is_zoomed, |this| {
|
||||
this.when_else(
|
||||
self.position(window, cx) == DockPosition::Bottom,
|
||||
|this| this.max_h(self.size),
|
||||
|this| this.max_w(self.size),
|
||||
)
|
||||
})
|
||||
.when_else(
|
||||
self.position(window, cx) == DockPosition::Bottom,
|
||||
|this| this.max_h(self.size),
|
||||
|this| this.max_w(self.size),
|
||||
)
|
||||
.size_full()
|
||||
.key_context("DebugPanel")
|
||||
.child(h_flex().children(self.top_controls_strip(window, cx)))
|
||||
@@ -1617,13 +1600,12 @@ impl workspace::DebuggerProvider for DebuggerProvider {
|
||||
definition: DebugScenario,
|
||||
context: TaskContext,
|
||||
buffer: Option<Entity<Buffer>>,
|
||||
worktree_id: Option<WorktreeId>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) {
|
||||
self.0.update(cx, |_, cx| {
|
||||
cx.defer_in(window, move |this, window, cx| {
|
||||
this.start_session(definition, context, buffer, worktree_id, window, cx);
|
||||
cx.defer_in(window, |this, window, cx| {
|
||||
this.start_session(definition, context, buffer, None, window, cx);
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -32,67 +32,34 @@ pub mod tests;
|
||||
actions!(
|
||||
debugger,
|
||||
[
|
||||
/// Starts a new debugging session.
|
||||
Start,
|
||||
/// Continues execution until the next breakpoint.
|
||||
Continue,
|
||||
/// Detaches the debugger from the running process.
|
||||
Detach,
|
||||
/// Pauses the currently running program.
|
||||
Pause,
|
||||
/// Restarts the current debugging session.
|
||||
Restart,
|
||||
/// Reruns the current debugging session with the same configuration.
|
||||
RerunSession,
|
||||
/// Steps into the next function call.
|
||||
StepInto,
|
||||
/// Steps over the current line.
|
||||
StepOver,
|
||||
/// Steps out of the current function.
|
||||
StepOut,
|
||||
/// Steps back to the previous statement.
|
||||
StepBack,
|
||||
/// Stops the debugging session.
|
||||
Stop,
|
||||
/// Toggles whether to ignore all breakpoints.
|
||||
ToggleIgnoreBreakpoints,
|
||||
/// Clears all breakpoints in the project.
|
||||
ClearAllBreakpoints,
|
||||
/// Focuses on the debugger console panel.
|
||||
FocusConsole,
|
||||
/// Focuses on the variables panel.
|
||||
FocusVariables,
|
||||
/// Focuses on the breakpoint list panel.
|
||||
FocusBreakpointList,
|
||||
/// Focuses on the call stack frames panel.
|
||||
FocusFrames,
|
||||
/// Focuses on the loaded modules panel.
|
||||
FocusModules,
|
||||
/// Focuses on the loaded sources panel.
|
||||
FocusLoadedSources,
|
||||
/// Focuses on the terminal panel.
|
||||
FocusTerminal,
|
||||
/// Shows the stack trace for the current thread.
|
||||
ShowStackTrace,
|
||||
/// Toggles the thread picker dropdown.
|
||||
ToggleThreadPicker,
|
||||
/// Toggles the session picker dropdown.
|
||||
ToggleSessionPicker,
|
||||
/// Reruns the last debugging session.
|
||||
#[action(deprecated_aliases = ["debugger::RerunLastSession"])]
|
||||
Rerun,
|
||||
/// Toggles expansion of the selected item in the debugger UI.
|
||||
RerunLastSession,
|
||||
ToggleExpandItem,
|
||||
]
|
||||
);
|
||||
|
||||
actions!(
|
||||
dev,
|
||||
[
|
||||
/// Copies debug adapter launch arguments to clipboard.
|
||||
CopyDebugAdapterArguments
|
||||
]
|
||||
);
|
||||
actions!(dev, [CopyDebugAdapterArguments]);
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
DebuggerSettings::register(cx);
|
||||
@@ -107,15 +74,17 @@ pub fn init(cx: &mut App) {
|
||||
.register_action(|workspace: &mut Workspace, _: &Start, window, cx| {
|
||||
NewProcessModal::show(workspace, window, NewProcessMode::Debug, None, cx);
|
||||
})
|
||||
.register_action(|workspace: &mut Workspace, _: &Rerun, window, cx| {
|
||||
let Some(debug_panel) = workspace.panel::<DebugPanel>(cx) else {
|
||||
return;
|
||||
};
|
||||
.register_action(
|
||||
|workspace: &mut Workspace, _: &RerunLastSession, window, cx| {
|
||||
let Some(debug_panel) = workspace.panel::<DebugPanel>(cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
debug_panel.update(cx, |debug_panel, cx| {
|
||||
debug_panel.rerun_last_session(workspace, window, cx);
|
||||
})
|
||||
})
|
||||
debug_panel.update(cx, |debug_panel, cx| {
|
||||
debug_panel.rerun_last_session(workspace, window, cx);
|
||||
})
|
||||
},
|
||||
)
|
||||
.register_action(
|
||||
|workspace: &mut Workspace, _: &ShutdownDebugAdapters, _window, cx| {
|
||||
workspace.project().update(cx, |project, cx| {
|
||||
@@ -241,14 +210,6 @@ pub fn init(cx: &mut App) {
|
||||
.ok();
|
||||
}
|
||||
})
|
||||
.on_action({
|
||||
let active_item = active_item.clone();
|
||||
move |_: &RerunSession, window, cx| {
|
||||
active_item
|
||||
.update(cx, |item, cx| item.rerun_session(window, cx))
|
||||
.ok();
|
||||
}
|
||||
})
|
||||
.on_action({
|
||||
let active_item = active_item.clone();
|
||||
move |_: &Stop, _, cx| {
|
||||
|
||||
@@ -4,7 +4,6 @@ use collections::HashMap;
|
||||
use gpui::{Animation, AnimationExt as _, Entity, Transformation, percentage};
|
||||
use project::debugger::session::{ThreadId, ThreadStatus};
|
||||
use ui::{ContextMenu, DropdownMenu, DropdownStyle, Indicator, prelude::*};
|
||||
use util::truncate_and_trailoff;
|
||||
|
||||
use crate::{
|
||||
debugger_panel::DebugPanel,
|
||||
@@ -13,8 +12,6 @@ use crate::{
|
||||
|
||||
impl DebugPanel {
|
||||
fn dropdown_label(label: impl Into<SharedString>) -> Label {
|
||||
const MAX_LABEL_CHARS: usize = 50;
|
||||
let label = truncate_and_trailoff(&label.into(), MAX_LABEL_CHARS);
|
||||
Label::new(label).size(LabelSize::Small)
|
||||
}
|
||||
|
||||
@@ -173,8 +170,6 @@ impl DebugPanel {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<DropdownMenu> {
|
||||
const MAX_LABEL_CHARS: usize = 150;
|
||||
|
||||
let running_state = running_state.clone();
|
||||
let running_state_read = running_state.read(cx);
|
||||
let thread_id = running_state_read.thread_id();
|
||||
@@ -207,7 +202,6 @@ impl DebugPanel {
|
||||
.is_empty()
|
||||
.then(|| format!("Tid: {}", thread.id))
|
||||
.unwrap_or_else(|| thread.name);
|
||||
let entry_name = truncate_and_trailoff(&entry_name, MAX_LABEL_CHARS);
|
||||
|
||||
this = this.entry(entry_name, None, move |window, cx| {
|
||||
running_state.update(cx, |running_state, cx| {
|
||||
|
||||
@@ -23,9 +23,7 @@ use gpui::{
|
||||
};
|
||||
use itertools::Itertools as _;
|
||||
use picker::{Picker, PickerDelegate, highlighted_match_with_paths::HighlightedMatch};
|
||||
use project::{
|
||||
DebugScenarioContext, ProjectPath, TaskContexts, TaskSourceKind, task_store::TaskStore,
|
||||
};
|
||||
use project::{ProjectPath, TaskContexts, TaskSourceKind, task_store::TaskStore};
|
||||
use settings::{Settings, initial_local_debug_tasks_content};
|
||||
use task::{DebugScenario, RevealTarget, ZedDebugConfig};
|
||||
use theme::ThemeSettings;
|
||||
@@ -94,7 +92,6 @@ impl NewProcessModal {
|
||||
|
||||
cx.spawn_in(window, async move |workspace, cx| {
|
||||
let task_contexts = workspace.update_in(cx, |workspace, window, cx| {
|
||||
// todo(debugger): get the buffer here (if the active item is an editor) and store it so we can pass it to start_session later
|
||||
tasks_ui::task_contexts(workspace, window, cx)
|
||||
})?;
|
||||
workspace.update_in(cx, |workspace, window, cx| {
|
||||
@@ -1113,11 +1110,7 @@ pub(super) struct TaskMode {
|
||||
|
||||
pub(super) struct DebugDelegate {
|
||||
task_store: Entity<TaskStore>,
|
||||
candidates: Vec<(
|
||||
Option<TaskSourceKind>,
|
||||
DebugScenario,
|
||||
Option<DebugScenarioContext>,
|
||||
)>,
|
||||
candidates: Vec<(Option<TaskSourceKind>, DebugScenario)>,
|
||||
selected_index: usize,
|
||||
matches: Vec<StringMatch>,
|
||||
prompt: String,
|
||||
@@ -1215,11 +1208,7 @@ impl DebugDelegate {
|
||||
|
||||
this.delegate.candidates = recent
|
||||
.into_iter()
|
||||
.map(|(scenario, context)| {
|
||||
let (kind, scenario) =
|
||||
Self::get_scenario_kind(&languages, &dap_registry, scenario);
|
||||
(kind, scenario, Some(context))
|
||||
})
|
||||
.map(|scenario| Self::get_scenario_kind(&languages, &dap_registry, scenario))
|
||||
.chain(
|
||||
scenarios
|
||||
.into_iter()
|
||||
@@ -1234,7 +1223,7 @@ impl DebugDelegate {
|
||||
.map(|(kind, scenario)| {
|
||||
let (language, scenario) =
|
||||
Self::get_scenario_kind(&languages, &dap_registry, scenario);
|
||||
(language.or(Some(kind)), scenario, None)
|
||||
(language.or(Some(kind)), scenario)
|
||||
}),
|
||||
)
|
||||
.collect();
|
||||
@@ -1280,7 +1269,7 @@ impl PickerDelegate for DebugDelegate {
|
||||
let candidates: Vec<_> = candidates
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(index, (_, candidate, _))| {
|
||||
.map(|(index, (_, candidate))| {
|
||||
StringMatchCandidate::new(index, candidate.label.as_ref())
|
||||
})
|
||||
.collect();
|
||||
@@ -1445,40 +1434,25 @@ impl PickerDelegate for DebugDelegate {
|
||||
.get(self.selected_index())
|
||||
.and_then(|match_candidate| self.candidates.get(match_candidate.candidate_id).cloned());
|
||||
|
||||
let Some((_, debug_scenario, context)) = debug_scenario else {
|
||||
let Some((_, debug_scenario)) = debug_scenario else {
|
||||
return;
|
||||
};
|
||||
|
||||
let context = context.unwrap_or_else(|| {
|
||||
self.task_contexts
|
||||
.as_ref()
|
||||
.and_then(|task_contexts| {
|
||||
Some(DebugScenarioContext {
|
||||
task_context: task_contexts.active_context().cloned()?,
|
||||
active_buffer: None,
|
||||
worktree_id: task_contexts.worktree(),
|
||||
})
|
||||
})
|
||||
.unwrap_or_default()
|
||||
});
|
||||
let DebugScenarioContext {
|
||||
task_context,
|
||||
active_buffer,
|
||||
worktree_id,
|
||||
} = context;
|
||||
let active_buffer = active_buffer.and_then(|buffer| buffer.upgrade());
|
||||
let (task_context, worktree_id) = self
|
||||
.task_contexts
|
||||
.as_ref()
|
||||
.and_then(|task_contexts| {
|
||||
Some((
|
||||
task_contexts.active_context().cloned()?,
|
||||
task_contexts.worktree(),
|
||||
))
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
send_telemetry(&debug_scenario, TelemetrySpawnLocation::ScenarioList, cx);
|
||||
self.debug_panel
|
||||
.update(cx, |panel, cx| {
|
||||
panel.start_session(
|
||||
debug_scenario,
|
||||
task_context,
|
||||
active_buffer,
|
||||
worktree_id,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
panel.start_session(debug_scenario, task_context, None, worktree_id, window, cx);
|
||||
})
|
||||
.ok();
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ use rpc::proto;
|
||||
use running::RunningState;
|
||||
use std::{cell::OnceCell, sync::OnceLock};
|
||||
use ui::{Indicator, Tooltip, prelude::*};
|
||||
use util::truncate_and_trailoff;
|
||||
use workspace::{
|
||||
CollaboratorId, FollowableItem, ViewId, Workspace,
|
||||
item::{self, Item},
|
||||
@@ -127,10 +126,7 @@ impl DebugSession {
|
||||
}
|
||||
|
||||
pub(crate) fn label_element(&self, depth: usize, cx: &App) -> AnyElement {
|
||||
const MAX_LABEL_CHARS: usize = 150;
|
||||
|
||||
let label = self.label(cx);
|
||||
let label = truncate_and_trailoff(&label, MAX_LABEL_CHARS);
|
||||
|
||||
let is_terminated = self
|
||||
.running_state
|
||||
|
||||
@@ -33,7 +33,7 @@ use language::Buffer;
|
||||
use loaded_source_list::LoadedSourceList;
|
||||
use module_list::ModuleList;
|
||||
use project::{
|
||||
DebugScenarioContext, Project, WorktreeId,
|
||||
Project, WorktreeId,
|
||||
debugger::session::{Session, SessionEvent, ThreadId, ThreadStatus},
|
||||
terminals::TerminalKind,
|
||||
};
|
||||
@@ -79,8 +79,6 @@ pub struct RunningState {
|
||||
pane_close_subscriptions: HashMap<EntityId, Subscription>,
|
||||
dock_axis: Axis,
|
||||
_schedule_serialize: Option<Task<()>>,
|
||||
pub(crate) scenario: Option<DebugScenario>,
|
||||
pub(crate) scenario_context: Option<DebugScenarioContext>,
|
||||
}
|
||||
|
||||
impl RunningState {
|
||||
@@ -833,8 +831,6 @@ impl RunningState {
|
||||
debug_terminal,
|
||||
dock_axis,
|
||||
_schedule_serialize: None,
|
||||
scenario: None,
|
||||
scenario_context: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1043,7 +1039,7 @@ impl RunningState {
|
||||
let scenario = dap_registry
|
||||
.adapter(&adapter)
|
||||
.with_context(|| anyhow!("{}: is not a valid adapter name", &adapter))?.config_from_zed_format(zed_config)
|
||||
.await?;
|
||||
.await?;
|
||||
config = scenario.config;
|
||||
util::merge_non_null_json_value_into(extra_config, &mut config);
|
||||
|
||||
@@ -1529,34 +1525,6 @@ impl RunningState {
|
||||
});
|
||||
}
|
||||
|
||||
pub fn rerun_session(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if let Some((scenario, context)) = self.scenario.take().zip(self.scenario_context.take())
|
||||
&& scenario.build.is_some()
|
||||
{
|
||||
let DebugScenarioContext {
|
||||
task_context,
|
||||
active_buffer,
|
||||
worktree_id,
|
||||
} = context;
|
||||
let active_buffer = active_buffer.and_then(|buffer| buffer.upgrade());
|
||||
|
||||
self.workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
workspace.start_debug_session(
|
||||
scenario,
|
||||
task_context,
|
||||
active_buffer,
|
||||
worktree_id,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.ok();
|
||||
} else {
|
||||
self.restart_session(cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn restart_session(&self, cx: &mut Context<Self>) {
|
||||
self.session().update(cx, |state, cx| {
|
||||
state.restart(None, cx);
|
||||
|
||||
@@ -33,12 +33,7 @@ use zed_actions::{ToggleEnableBreakpoint, UnsetBreakpoint};
|
||||
|
||||
actions!(
|
||||
debugger,
|
||||
[
|
||||
/// Navigates to the previous breakpoint property in the list.
|
||||
PreviousBreakpointProperty,
|
||||
/// Navigates to the next breakpoint property in the list.
|
||||
NextBreakpointProperty
|
||||
]
|
||||
[PreviousBreakpointProperty, NextBreakpointProperty]
|
||||
);
|
||||
#[derive(Clone, Copy, PartialEq)]
|
||||
pub(crate) enum SelectedBreakpointKind {
|
||||
|
||||
@@ -13,26 +13,17 @@ use gpui::{
|
||||
Render, Subscription, Task, TextStyle, WeakEntity, actions,
|
||||
};
|
||||
use language::{Buffer, CodeLabel, ToOffset};
|
||||
use menu::{Confirm, SelectNext, SelectPrevious};
|
||||
use menu::Confirm;
|
||||
use project::{
|
||||
Completion, CompletionResponse,
|
||||
debugger::session::{CompletionsQuery, OutputToken, Session},
|
||||
search_history::{SearchHistory, SearchHistoryCursor},
|
||||
debugger::session::{CompletionsQuery, OutputToken, Session, SessionEvent},
|
||||
};
|
||||
use settings::Settings;
|
||||
use std::fmt::Write;
|
||||
use std::{cell::RefCell, ops::Range, rc::Rc, usize};
|
||||
use theme::{Theme, ThemeSettings};
|
||||
use ui::{ContextMenu, Divider, PopoverMenu, SplitButton, Tooltip, prelude::*};
|
||||
use util::ResultExt;
|
||||
|
||||
actions!(
|
||||
console,
|
||||
[
|
||||
/// Adds an expression to the watch list.
|
||||
WatchExpression
|
||||
]
|
||||
);
|
||||
actions!(console, [WatchExpression]);
|
||||
|
||||
pub struct Console {
|
||||
console: Entity<Editor>,
|
||||
@@ -42,10 +33,8 @@ pub struct Console {
|
||||
variable_list: Entity<VariableList>,
|
||||
stack_frame_list: Entity<StackFrameList>,
|
||||
last_token: OutputToken,
|
||||
update_output_task: Option<Task<()>>,
|
||||
update_output_task: Task<()>,
|
||||
focus_handle: FocusHandle,
|
||||
history: SearchHistory,
|
||||
cursor: SearchHistoryCursor,
|
||||
}
|
||||
|
||||
impl Console {
|
||||
@@ -94,6 +83,11 @@ impl Console {
|
||||
|
||||
let _subscriptions = vec![
|
||||
cx.subscribe(&stack_frame_list, Self::handle_stack_frame_list_events),
|
||||
cx.subscribe_in(&session, window, |this, _, event, window, cx| {
|
||||
if let SessionEvent::ConsoleOutput = event {
|
||||
this.update_output(window, cx)
|
||||
}
|
||||
}),
|
||||
cx.on_focus(&focus_handle, window, |console, window, cx| {
|
||||
if console.is_running(cx) {
|
||||
console.query_bar.focus_handle(cx).focus(window);
|
||||
@@ -108,14 +102,9 @@ impl Console {
|
||||
variable_list,
|
||||
_subscriptions,
|
||||
stack_frame_list,
|
||||
update_output_task: None,
|
||||
update_output_task: Task::ready(()),
|
||||
last_token: OutputToken(0),
|
||||
focus_handle,
|
||||
history: SearchHistory::new(
|
||||
None,
|
||||
project::search_history::QueryInsertionBehavior::ReplacePreviousIfContains,
|
||||
),
|
||||
cursor: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,116 +133,202 @@ impl Console {
|
||||
self.session.read(cx).has_new_output(self.last_token)
|
||||
}
|
||||
|
||||
fn add_messages(
|
||||
pub fn add_messages<'a>(
|
||||
&mut self,
|
||||
events: Vec<OutputEvent>,
|
||||
events: impl Iterator<Item = &'a OutputEvent>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<()>> {
|
||||
self.console.update(cx, |_, cx| {
|
||||
cx.spawn_in(window, async move |console, cx| {
|
||||
let mut len = console.update(cx, |this, cx| this.buffer().read(cx).len(cx))?;
|
||||
let (output, spans, background_spans) = cx
|
||||
.background_spawn(async move {
|
||||
let mut all_spans = Vec::new();
|
||||
let mut all_background_spans = Vec::new();
|
||||
let mut to_insert = String::new();
|
||||
let mut scratch = String::new();
|
||||
) {
|
||||
self.console.update(cx, |console, cx| {
|
||||
console.set_read_only(false);
|
||||
|
||||
for event in &events {
|
||||
scratch.clear();
|
||||
let mut ansi_handler = ConsoleHandler::default();
|
||||
let mut ansi_processor =
|
||||
ansi::Processor::<ansi::StdSyncHandler>::default();
|
||||
for event in events {
|
||||
let to_insert = format!("{}\n", event.output.trim_end());
|
||||
|
||||
let trimmed_output = event.output.trim_end();
|
||||
let _ = writeln!(&mut scratch, "{trimmed_output}");
|
||||
ansi_processor.advance(&mut ansi_handler, scratch.as_bytes());
|
||||
let output = std::mem::take(&mut ansi_handler.output);
|
||||
to_insert.extend(output.chars());
|
||||
let mut spans = std::mem::take(&mut ansi_handler.spans);
|
||||
let mut background_spans =
|
||||
std::mem::take(&mut ansi_handler.background_spans);
|
||||
if ansi_handler.current_range_start < output.len() {
|
||||
spans.push((
|
||||
ansi_handler.current_range_start..output.len(),
|
||||
ansi_handler.current_color,
|
||||
));
|
||||
let mut ansi_handler = ConsoleHandler::default();
|
||||
let mut ansi_processor = ansi::Processor::<ansi::StdSyncHandler>::default();
|
||||
|
||||
let len = console.buffer().read(cx).len(cx);
|
||||
ansi_processor.advance(&mut ansi_handler, to_insert.as_bytes());
|
||||
let output = std::mem::take(&mut ansi_handler.output);
|
||||
let mut spans = std::mem::take(&mut ansi_handler.spans);
|
||||
let mut background_spans = std::mem::take(&mut ansi_handler.background_spans);
|
||||
if ansi_handler.current_range_start < output.len() {
|
||||
spans.push((
|
||||
ansi_handler.current_range_start..output.len(),
|
||||
ansi_handler.current_color,
|
||||
));
|
||||
}
|
||||
if ansi_handler.current_background_range_start < output.len() {
|
||||
background_spans.push((
|
||||
ansi_handler.current_background_range_start..output.len(),
|
||||
ansi_handler.current_background_color,
|
||||
));
|
||||
}
|
||||
console.move_to_end(&editor::actions::MoveToEnd, window, cx);
|
||||
console.insert(&output, window, cx);
|
||||
let buffer = console.buffer().read(cx).snapshot(cx);
|
||||
|
||||
struct ConsoleAnsiHighlight;
|
||||
|
||||
for (range, color) in spans {
|
||||
let Some(color) = color else { continue };
|
||||
let start_offset = len + range.start;
|
||||
let range = start_offset..len + range.end;
|
||||
let range = buffer.anchor_after(range.start)..buffer.anchor_before(range.end);
|
||||
let style = HighlightStyle {
|
||||
color: Some(terminal_view::terminal_element::convert_color(
|
||||
&color,
|
||||
cx.theme(),
|
||||
)),
|
||||
..Default::default()
|
||||
};
|
||||
console.highlight_text_key::<ConsoleAnsiHighlight>(
|
||||
start_offset,
|
||||
vec![range],
|
||||
style,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
||||
for (range, color) in background_spans {
|
||||
let Some(color) = color else { continue };
|
||||
let start_offset = len + range.start;
|
||||
let range = start_offset..len + range.end;
|
||||
let range = buffer.anchor_after(range.start)..buffer.anchor_before(range.end);
|
||||
|
||||
let color_fetcher: fn(&Theme) -> Hsla = match color {
|
||||
// Named and theme defined colors
|
||||
ansi::Color::Named(n) => match n {
|
||||
ansi::NamedColor::Black => |theme| theme.colors().terminal_ansi_black,
|
||||
ansi::NamedColor::Red => |theme| theme.colors().terminal_ansi_red,
|
||||
ansi::NamedColor::Green => |theme| theme.colors().terminal_ansi_green,
|
||||
ansi::NamedColor::Yellow => |theme| theme.colors().terminal_ansi_yellow,
|
||||
ansi::NamedColor::Blue => |theme| theme.colors().terminal_ansi_blue,
|
||||
ansi::NamedColor::Magenta => {
|
||||
|theme| theme.colors().terminal_ansi_magenta
|
||||
}
|
||||
if ansi_handler.current_background_range_start < output.len() {
|
||||
background_spans.push((
|
||||
ansi_handler.current_background_range_start..output.len(),
|
||||
ansi_handler.current_background_color,
|
||||
));
|
||||
ansi::NamedColor::Cyan => |theme| theme.colors().terminal_ansi_cyan,
|
||||
ansi::NamedColor::White => |theme| theme.colors().terminal_ansi_white,
|
||||
ansi::NamedColor::BrightBlack => {
|
||||
|theme| theme.colors().terminal_ansi_bright_black
|
||||
}
|
||||
|
||||
for (range, _) in spans.iter_mut() {
|
||||
let start_offset = len + range.start;
|
||||
*range = start_offset..len + range.end;
|
||||
ansi::NamedColor::BrightRed => {
|
||||
|theme| theme.colors().terminal_ansi_bright_red
|
||||
}
|
||||
|
||||
for (range, _) in background_spans.iter_mut() {
|
||||
let start_offset = len + range.start;
|
||||
*range = start_offset..len + range.end;
|
||||
ansi::NamedColor::BrightGreen => {
|
||||
|theme| theme.colors().terminal_ansi_bright_green
|
||||
}
|
||||
ansi::NamedColor::BrightYellow => {
|
||||
|theme| theme.colors().terminal_ansi_bright_yellow
|
||||
}
|
||||
ansi::NamedColor::BrightBlue => {
|
||||
|theme| theme.colors().terminal_ansi_bright_blue
|
||||
}
|
||||
ansi::NamedColor::BrightMagenta => {
|
||||
|theme| theme.colors().terminal_ansi_bright_magenta
|
||||
}
|
||||
ansi::NamedColor::BrightCyan => {
|
||||
|theme| theme.colors().terminal_ansi_bright_cyan
|
||||
}
|
||||
ansi::NamedColor::BrightWhite => {
|
||||
|theme| theme.colors().terminal_ansi_bright_white
|
||||
}
|
||||
ansi::NamedColor::Foreground => {
|
||||
|theme| theme.colors().terminal_foreground
|
||||
}
|
||||
ansi::NamedColor::Background => {
|
||||
|theme| theme.colors().terminal_background
|
||||
}
|
||||
ansi::NamedColor::Cursor => |theme| theme.players().local().cursor,
|
||||
ansi::NamedColor::DimBlack => {
|
||||
|theme| theme.colors().terminal_ansi_dim_black
|
||||
}
|
||||
ansi::NamedColor::DimRed => {
|
||||
|theme| theme.colors().terminal_ansi_dim_red
|
||||
}
|
||||
ansi::NamedColor::DimGreen => {
|
||||
|theme| theme.colors().terminal_ansi_dim_green
|
||||
}
|
||||
ansi::NamedColor::DimYellow => {
|
||||
|theme| theme.colors().terminal_ansi_dim_yellow
|
||||
}
|
||||
ansi::NamedColor::DimBlue => {
|
||||
|theme| theme.colors().terminal_ansi_dim_blue
|
||||
}
|
||||
ansi::NamedColor::DimMagenta => {
|
||||
|theme| theme.colors().terminal_ansi_dim_magenta
|
||||
}
|
||||
ansi::NamedColor::DimCyan => {
|
||||
|theme| theme.colors().terminal_ansi_dim_cyan
|
||||
}
|
||||
ansi::NamedColor::DimWhite => {
|
||||
|theme| theme.colors().terminal_ansi_dim_white
|
||||
}
|
||||
ansi::NamedColor::BrightForeground => {
|
||||
|theme| theme.colors().terminal_bright_foreground
|
||||
}
|
||||
ansi::NamedColor::DimForeground => {
|
||||
|theme| theme.colors().terminal_dim_foreground
|
||||
}
|
||||
},
|
||||
// 'True' colors
|
||||
ansi::Color::Spec(_) => |theme| theme.colors().editor_background,
|
||||
// 8 bit, indexed colors
|
||||
ansi::Color::Indexed(i) => {
|
||||
match i {
|
||||
// 0-15 are the same as the named colors above
|
||||
0 => |theme| theme.colors().terminal_ansi_black,
|
||||
1 => |theme| theme.colors().terminal_ansi_red,
|
||||
2 => |theme| theme.colors().terminal_ansi_green,
|
||||
3 => |theme| theme.colors().terminal_ansi_yellow,
|
||||
4 => |theme| theme.colors().terminal_ansi_blue,
|
||||
5 => |theme| theme.colors().terminal_ansi_magenta,
|
||||
6 => |theme| theme.colors().terminal_ansi_cyan,
|
||||
7 => |theme| theme.colors().terminal_ansi_white,
|
||||
8 => |theme| theme.colors().terminal_ansi_bright_black,
|
||||
9 => |theme| theme.colors().terminal_ansi_bright_red,
|
||||
10 => |theme| theme.colors().terminal_ansi_bright_green,
|
||||
11 => |theme| theme.colors().terminal_ansi_bright_yellow,
|
||||
12 => |theme| theme.colors().terminal_ansi_bright_blue,
|
||||
13 => |theme| theme.colors().terminal_ansi_bright_magenta,
|
||||
14 => |theme| theme.colors().terminal_ansi_bright_cyan,
|
||||
15 => |theme| theme.colors().terminal_ansi_bright_white,
|
||||
// 16-231 are a 6x6x6 RGB color cube, mapped to 0-255 using steps defined by XTerm.
|
||||
// See: https://github.com/xterm-x11/xterm-snapshots/blob/master/256colres.pl
|
||||
// 16..=231 => {
|
||||
// let (r, g, b) = rgb_for_index(index as u8);
|
||||
// rgba_color(
|
||||
// if r == 0 { 0 } else { r * 40 + 55 },
|
||||
// if g == 0 { 0 } else { g * 40 + 55 },
|
||||
// if b == 0 { 0 } else { b * 40 + 55 },
|
||||
// )
|
||||
// }
|
||||
// 232-255 are a 24-step grayscale ramp from (8, 8, 8) to (238, 238, 238).
|
||||
// 232..=255 => {
|
||||
// let i = index as u8 - 232; // Align index to 0..24
|
||||
// let value = i * 10 + 8;
|
||||
// rgba_color(value, value, value)
|
||||
// }
|
||||
// For compatibility with the alacritty::Colors interface
|
||||
// See: https://github.com/alacritty/alacritty/blob/master/alacritty_terminal/src/term/color.rs
|
||||
_ => |_| gpui::black(),
|
||||
}
|
||||
|
||||
len += output.len();
|
||||
|
||||
all_spans.extend(spans);
|
||||
all_background_spans.extend(background_spans);
|
||||
}
|
||||
(to_insert, all_spans, all_background_spans)
|
||||
})
|
||||
.await;
|
||||
console.update_in(cx, |console, window, cx| {
|
||||
console.set_read_only(false);
|
||||
console.move_to_end(&editor::actions::MoveToEnd, window, cx);
|
||||
console.insert(&output, window, cx);
|
||||
console.set_read_only(true);
|
||||
};
|
||||
|
||||
struct ConsoleAnsiHighlight;
|
||||
console.highlight_background_key::<ConsoleAnsiHighlight>(
|
||||
start_offset,
|
||||
&[range],
|
||||
color_fetcher,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let buffer = console.buffer().read(cx).snapshot(cx);
|
||||
|
||||
for (range, color) in spans {
|
||||
let Some(color) = color else { continue };
|
||||
let start_offset = range.start;
|
||||
let range =
|
||||
buffer.anchor_after(range.start)..buffer.anchor_before(range.end);
|
||||
let style = HighlightStyle {
|
||||
color: Some(terminal_view::terminal_element::convert_color(
|
||||
&color,
|
||||
cx.theme(),
|
||||
)),
|
||||
..Default::default()
|
||||
};
|
||||
console.highlight_text_key::<ConsoleAnsiHighlight>(
|
||||
start_offset,
|
||||
vec![range],
|
||||
style,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
||||
for (range, color) in background_spans {
|
||||
let Some(color) = color else { continue };
|
||||
let start_offset = range.start;
|
||||
let range =
|
||||
buffer.anchor_after(range.start)..buffer.anchor_before(range.end);
|
||||
console.highlight_background_key::<ConsoleAnsiHighlight>(
|
||||
start_offset,
|
||||
&[range],
|
||||
color_fetcher(color),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
console.set_read_only(true);
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
|
||||
pub fn watch_expression(
|
||||
@@ -270,8 +345,7 @@ impl Console {
|
||||
|
||||
expression
|
||||
});
|
||||
self.history.add(&mut self.cursor, expression.clone());
|
||||
self.cursor.reset();
|
||||
|
||||
self.session.update(cx, |session, cx| {
|
||||
session
|
||||
.evaluate(
|
||||
@@ -291,28 +365,7 @@ impl Console {
|
||||
});
|
||||
}
|
||||
|
||||
fn previous_query(&mut self, _: &SelectPrevious, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let prev = self.history.previous(&mut self.cursor);
|
||||
if let Some(prev) = prev {
|
||||
self.query_bar.update(cx, |editor, cx| {
|
||||
editor.set_text(prev, window, cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn next_query(&mut self, _: &SelectNext, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let next = self.history.next(&mut self.cursor);
|
||||
let query = next.unwrap_or_else(|| {
|
||||
self.cursor.reset();
|
||||
""
|
||||
});
|
||||
|
||||
self.query_bar.update(cx, |editor, cx| {
|
||||
editor.set_text(query, window, cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn evaluate(&mut self, _: &Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
pub fn evaluate(&mut self, _: &Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let expression = self.query_bar.update(cx, |editor, cx| {
|
||||
let expression = editor.text(cx);
|
||||
cx.defer_in(window, |editor, window, cx| {
|
||||
@@ -322,8 +375,6 @@ impl Console {
|
||||
expression
|
||||
});
|
||||
|
||||
self.history.add(&mut self.cursor, expression.clone());
|
||||
self.cursor.reset();
|
||||
self.session.update(cx, |session, cx| {
|
||||
session
|
||||
.evaluate(
|
||||
@@ -407,50 +458,31 @@ impl Console {
|
||||
EditorElement::new(&self.query_bar, Self::editor_style(&self.query_bar, cx))
|
||||
}
|
||||
|
||||
pub(crate) fn update_output(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self.update_output_task.is_some() {
|
||||
return;
|
||||
}
|
||||
fn update_output(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let session = self.session.clone();
|
||||
let token = self.last_token;
|
||||
self.update_output_task = Some(cx.spawn_in(window, async move |this, cx| {
|
||||
let Some((last_processed_token, task)) = session
|
||||
.update_in(cx, |session, window, cx| {
|
||||
let (output, last_processed_token) = session.output(token);
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
if last_processed_token == this.last_token {
|
||||
return None;
|
||||
}
|
||||
Some((
|
||||
last_processed_token,
|
||||
this.add_messages(output.cloned().collect(), window, cx),
|
||||
))
|
||||
})
|
||||
.ok()
|
||||
.flatten()
|
||||
})
|
||||
.ok()
|
||||
.flatten()
|
||||
else {
|
||||
_ = this.update(cx, |this, _| {
|
||||
this.update_output_task.take();
|
||||
self.update_output_task = cx.spawn_in(window, async move |this, cx| {
|
||||
_ = session.update_in(cx, move |session, window, cx| {
|
||||
let (output, last_processed_token) = session.output(token);
|
||||
|
||||
_ = this.update(cx, |this, cx| {
|
||||
if last_processed_token == this.last_token {
|
||||
return;
|
||||
}
|
||||
this.add_messages(output, window, cx);
|
||||
|
||||
this.last_token = last_processed_token;
|
||||
});
|
||||
return;
|
||||
};
|
||||
_ = task.await.log_err();
|
||||
_ = this.update(cx, |this, _| {
|
||||
this.last_token = last_processed_token;
|
||||
this.update_output_task.take();
|
||||
});
|
||||
}));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for Console {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let query_focus_handle = self.query_bar.focus_handle(cx);
|
||||
self.update_output(window, cx);
|
||||
|
||||
v_flex()
|
||||
.track_focus(&self.focus_handle)
|
||||
.key_context("DebugConsole")
|
||||
@@ -461,8 +493,6 @@ impl Render for Console {
|
||||
.when(self.is_running(cx), |this| {
|
||||
this.child(Divider::horizontal()).child(
|
||||
h_flex()
|
||||
.on_action(cx.listener(Self::previous_query))
|
||||
.on_action(cx.listener(Self::next_query))
|
||||
.gap_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.child(self.render_query_bar(cx))
|
||||
@@ -815,84 +845,3 @@ impl ansi::Handler for ConsoleHandler {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn color_fetcher(color: ansi::Color) -> fn(&Theme) -> Hsla {
|
||||
let color_fetcher: fn(&Theme) -> Hsla = match color {
|
||||
// Named and theme defined colors
|
||||
ansi::Color::Named(n) => match n {
|
||||
ansi::NamedColor::Black => |theme| theme.colors().terminal_ansi_black,
|
||||
ansi::NamedColor::Red => |theme| theme.colors().terminal_ansi_red,
|
||||
ansi::NamedColor::Green => |theme| theme.colors().terminal_ansi_green,
|
||||
ansi::NamedColor::Yellow => |theme| theme.colors().terminal_ansi_yellow,
|
||||
ansi::NamedColor::Blue => |theme| theme.colors().terminal_ansi_blue,
|
||||
ansi::NamedColor::Magenta => |theme| theme.colors().terminal_ansi_magenta,
|
||||
ansi::NamedColor::Cyan => |theme| theme.colors().terminal_ansi_cyan,
|
||||
ansi::NamedColor::White => |theme| theme.colors().terminal_ansi_white,
|
||||
ansi::NamedColor::BrightBlack => |theme| theme.colors().terminal_ansi_bright_black,
|
||||
ansi::NamedColor::BrightRed => |theme| theme.colors().terminal_ansi_bright_red,
|
||||
ansi::NamedColor::BrightGreen => |theme| theme.colors().terminal_ansi_bright_green,
|
||||
ansi::NamedColor::BrightYellow => |theme| theme.colors().terminal_ansi_bright_yellow,
|
||||
ansi::NamedColor::BrightBlue => |theme| theme.colors().terminal_ansi_bright_blue,
|
||||
ansi::NamedColor::BrightMagenta => |theme| theme.colors().terminal_ansi_bright_magenta,
|
||||
ansi::NamedColor::BrightCyan => |theme| theme.colors().terminal_ansi_bright_cyan,
|
||||
ansi::NamedColor::BrightWhite => |theme| theme.colors().terminal_ansi_bright_white,
|
||||
ansi::NamedColor::Foreground => |theme| theme.colors().terminal_foreground,
|
||||
ansi::NamedColor::Background => |theme| theme.colors().terminal_background,
|
||||
ansi::NamedColor::Cursor => |theme| theme.players().local().cursor,
|
||||
ansi::NamedColor::DimBlack => |theme| theme.colors().terminal_ansi_dim_black,
|
||||
ansi::NamedColor::DimRed => |theme| theme.colors().terminal_ansi_dim_red,
|
||||
ansi::NamedColor::DimGreen => |theme| theme.colors().terminal_ansi_dim_green,
|
||||
ansi::NamedColor::DimYellow => |theme| theme.colors().terminal_ansi_dim_yellow,
|
||||
ansi::NamedColor::DimBlue => |theme| theme.colors().terminal_ansi_dim_blue,
|
||||
ansi::NamedColor::DimMagenta => |theme| theme.colors().terminal_ansi_dim_magenta,
|
||||
ansi::NamedColor::DimCyan => |theme| theme.colors().terminal_ansi_dim_cyan,
|
||||
ansi::NamedColor::DimWhite => |theme| theme.colors().terminal_ansi_dim_white,
|
||||
ansi::NamedColor::BrightForeground => |theme| theme.colors().terminal_bright_foreground,
|
||||
ansi::NamedColor::DimForeground => |theme| theme.colors().terminal_dim_foreground,
|
||||
},
|
||||
// 'True' colors
|
||||
ansi::Color::Spec(_) => |theme| theme.colors().editor_background,
|
||||
// 8 bit, indexed colors
|
||||
ansi::Color::Indexed(i) => {
|
||||
match i {
|
||||
// 0-15 are the same as the named colors above
|
||||
0 => |theme| theme.colors().terminal_ansi_black,
|
||||
1 => |theme| theme.colors().terminal_ansi_red,
|
||||
2 => |theme| theme.colors().terminal_ansi_green,
|
||||
3 => |theme| theme.colors().terminal_ansi_yellow,
|
||||
4 => |theme| theme.colors().terminal_ansi_blue,
|
||||
5 => |theme| theme.colors().terminal_ansi_magenta,
|
||||
6 => |theme| theme.colors().terminal_ansi_cyan,
|
||||
7 => |theme| theme.colors().terminal_ansi_white,
|
||||
8 => |theme| theme.colors().terminal_ansi_bright_black,
|
||||
9 => |theme| theme.colors().terminal_ansi_bright_red,
|
||||
10 => |theme| theme.colors().terminal_ansi_bright_green,
|
||||
11 => |theme| theme.colors().terminal_ansi_bright_yellow,
|
||||
12 => |theme| theme.colors().terminal_ansi_bright_blue,
|
||||
13 => |theme| theme.colors().terminal_ansi_bright_magenta,
|
||||
14 => |theme| theme.colors().terminal_ansi_bright_cyan,
|
||||
15 => |theme| theme.colors().terminal_ansi_bright_white,
|
||||
// 16-231 are a 6x6x6 RGB color cube, mapped to 0-255 using steps defined by XTerm.
|
||||
// See: https://github.com/xterm-x11/xterm-snapshots/blob/master/256colres.pl
|
||||
// 16..=231 => {
|
||||
// let (r, g, b) = rgb_for_index(index as u8);
|
||||
// rgba_color(
|
||||
// if r == 0 { 0 } else { r * 40 + 55 },
|
||||
// if g == 0 { 0 } else { g * 40 + 55 },
|
||||
// if b == 0 { 0 } else { b * 40 + 55 },
|
||||
// )
|
||||
// }
|
||||
// 232-255 are a 24-step grayscale ramp from (8, 8, 8) to (238, 238, 238).
|
||||
// 232..=255 => {
|
||||
// let i = index as u8 - 232; // Align index to 0..24
|
||||
// let value = i * 10 + 8;
|
||||
// rgba_color(value, value, value)
|
||||
// }
|
||||
// For compatibility with the alacritty::Colors interface
|
||||
// See: https://github.com/alacritty/alacritty/blob/master/alacritty_terminal/src/term/color.rs
|
||||
_ => |_| gpui::black(),
|
||||
}
|
||||
}
|
||||
};
|
||||
color_fetcher
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user