Compare commits

..

59 Commits

Author SHA1 Message Date
Mikayla Maki
34c890e23e WIP
co-authored-by: Nathan <nathan@zed.dev>
2025-07-02 15:51:38 -07:00
Agus Zubiaga
4755d6fa9d Display tool icons
Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Nathan Sobo <nathan@zed.dev>
2025-07-02 13:48:57 -03:00
Agus Zubiaga
135143d51b Rename display_name to label
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-07-02 13:16:30 -03:00
Agus Zubiaga
450604b4a1 Add tool call with confirmation test 2025-07-02 12:13:20 -03:00
Agus Zubiaga
348bc52a3f Merge branch 'acp' of github.com:zed-industries/zed into acp 2025-07-02 11:33:22 -03:00
Agus Zubiaga
d16c595d57 Fix always allow, and update acp confirmation types 2025-07-02 11:31:51 -03:00
Antonio Scandurra
975a7e6f7f Fix clicking on tool confirmation buttons
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-07-02 14:54:24 +02:00
Antonio Scandurra
7d2f7cb70e Replace title with display_name for tool calls
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-07-02 14:40:16 +02:00
Ben Brandt
5f9afdf7ba Add buttons for more outcomes and handle tools that don't need
authorization

Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-07-02 12:56:03 +02:00
Ben Brandt
7a3105b0c6 Wire up push_tool_call
Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-07-02 12:03:35 +02:00
Ben Brandt
ab0b16939d Update tool call confirmation 2025-07-02 11:32:03 +02:00
Agus Zubiaga
28d992487d Better temporary title 2025-07-02 00:58:05 -03:00
Agus Zubiaga
fde15a5a68 Update tool calls via ACP 2025-07-02 00:47:28 -03:00
Agus Zubiaga
780db30e0b Handle waiting for tool confirmation in UI 2025-07-01 23:48:09 -03:00
Agus Zubiaga
7c992adfe1 Improve spacing even more 2025-07-01 23:35:29 -03:00
Agus Zubiaga
825aecfd28 Fix spacing and list scrolling 2025-07-01 23:27:12 -03:00
Agus Zubiaga
f2f32fb3bd Proper allow/reject UI 2025-07-01 23:13:56 -03:00
Agus Zubiaga
d9fd8d5eee Improve spacing 2025-07-01 21:50:14 -03:00
Agus Zubiaga
8137b3318f Remove ReadFile entry and test tool call 2025-07-01 21:37:31 -03:00
Agus Zubiaga
3ceeefe460 Tool authorization 2025-07-01 20:32:21 -03:00
Agus Zubiaga
6f768aefa2 Copy
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-07-01 17:15:57 -03:00
Agus Zubiaga
28ac84ed01 Jump to gemini thread view immediately
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-07-01 17:15:20 -03:00
Agus Zubiaga
4d803fa628 message markdown
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-07-01 16:57:22 -03:00
Agus Zubiaga
17b2dd9a93 Update list incrementally
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-07-01 16:13:16 -03:00
Mikayla Maki
7abf635e20 Use a list to render items
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-07-01 11:48:03 -07:00
Antonio Scandurra
92adcb6e63 WIP 2025-07-01 19:01:02 +02:00
Antonio Scandurra
5ed001e0df Merge remote-tracking branch 'origin/main' into agent2
# Conflicts:
#	Cargo.lock
2025-07-01 18:30:08 +02:00
Antonio Scandurra
f12fffd1ba WIP 2025-07-01 18:23:21 +02:00
Agus Zubiaga
991ba08711 Stop button 2025-06-26 14:37:22 -03:00
Agus Zubiaga
c728731099 Merge last chunk 2025-06-26 14:30:59 -03:00
Agus Zubiaga
ddab1cbd71 Fix notify and margin
Co-authored-by: Smit Barmase <heysmitbarmase@gmail.com>
2025-06-26 14:23:39 -03:00
Agus Zubiaga
f383a7626f Improve user message
Co-authored-by: Smit Barmase <heysmitbarmase@gmail.com>
2025-06-26 14:16:30 -03:00
Agus Zubiaga
ee1df65569 Start displaying messages in new thread element
Co-authored-by: Smit Barmase <heysmitbarmase@gmail.com>
2025-06-26 14:05:59 -03:00
Agus Zubiaga
3be45822be agent2 basic message editor
Co-authored-by: Smit Barmase <heysmitbarmase@gmail.com>
2025-06-26 13:37:23 -03:00
Agus Zubiaga
3b6f30a6fd Add ThreadElement and render it when active
Co-authored-by: Smit Barmase <heysmitbarmase@gmail.com>
2025-06-26 13:07:02 -03:00
Agus Zubiaga
779a68f868 Merge branch 'main' into agent2
Co-authored-by: Smit Barmase <heysmitbarmase@gmail.com>
2025-06-26 12:50:36 -03:00
Agus Zubiaga
79c37284e0 Move ActiveThread into ActiveView::Thread
Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-06-26 11:36:05 -03:00
Ben Brandt
0a053cf55d Merge branch 'main' into agent2 2025-06-26 14:36:39 +02:00
Ben Brandt
fc59d9cbf3 Clean up tests
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-06-26 14:22:13 +02:00
Ben Brandt
678a42e920 Fix missing variant 2025-06-26 14:00:21 +02:00
Ben Brandt
75bcaf743c Put user messages into thread 2025-06-26 13:59:41 +02:00
Ben Brandt
47c875f6b5 Pass GEMINI_API_KEY to agent process if available 2025-06-26 12:25:23 +02:00
Max Brunsfeld
81b4d7e35a Start on using agent2 from agent_ui 2025-06-25 20:23:41 -07:00
Max Brunsfeld
33ee0c3093 Return an Arc from AcpAgent::stdio 2025-06-25 20:23:18 -07:00
Max Brunsfeld
d68f86052f Merge branch 'main' into agent2 2025-06-25 15:57:59 -07:00
Max Brunsfeld
a74ffd9ee4 In test, start gemini in the right directory
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-06-25 14:59:07 -07:00
Conrad Irwin
8b9ad1cfae passing roundtrip test
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-06-25 15:18:42 -06:00
Max Brunsfeld
adbccb1ad0 Get agent2 compiling
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-06-25 10:30:52 -07:00
Agus Zubiaga
f4e2d38c29 --wip-- 2025-06-25 13:54:31 -03:00
Ben Brandt
5f10be7791 Start implementing send
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-06-25 14:40:33 +02:00
Ben Brandt
d47a920c05 Implement ACP threads
The `create_thread` and `get_threads` methods are now implemented for
the ACP agent. A test is added to verify the file reading flow.
2025-06-25 13:10:43 +02:00
Ben Brandt
24b72be154 Add debug/clone to structs for testing 2025-06-25 10:11:50 +02:00
Max Brunsfeld
de779a45ce Get one test passing w/ gemini cli 2025-06-24 20:07:41 -07:00
Agus Zubiaga
b094a636cf Checkpoint: Wiring up acp crate
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com> Co-authored-by:
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Max <max@zed.dev>
2025-06-24 18:27:25 -03:00
Agus Zubiaga
318709b60d Fix typo
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-06-24 16:51:43 -03:00
Agus Zubiaga
f1bd531a32 Handle pending requests
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-06-24 16:30:29 -03:00
Ben Brandt
549eb4d826 wip: request / response in send loop
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-06-24 14:50:48 +02:00
Ben Brandt
c1e53b7fa5 wip: test
Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-06-24 12:31:04 +02:00
Ben Brandt
ec376e0b61 Sketch out new Agent traits
Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-06-24 12:26:40 +02:00
302 changed files with 6514 additions and 8517 deletions

62
Cargo.lock generated
View File

@@ -2,6 +2,36 @@
# It is not intended for manual editing.
version = 4
[[package]]
name = "acp"
version = "0.1.0"
dependencies = [
"agentic-coding-protocol",
"anyhow",
"async-trait",
"base64 0.22.1",
"chrono",
"collections",
"editor",
"env_logger 0.11.8",
"futures 0.3.31",
"gpui",
"language",
"log",
"markdown",
"parking_lot",
"project",
"serde_json",
"settings",
"smol",
"theme",
"ui",
"util",
"uuid",
"workspace-hack",
"zed_actions",
]
[[package]]
name = "activity_indicator"
version = "0.1.0"
@@ -130,6 +160,7 @@ dependencies = [
name = "agent_ui"
version = "0.1.0"
dependencies = [
"acp",
"agent",
"agent_settings",
"anyhow",
@@ -212,6 +243,21 @@ dependencies = [
"zed_llm_client",
]
[[package]]
name = "agentic-coding-protocol"
version = "0.0.1"
dependencies = [
"anyhow",
"async-trait",
"chrono",
"futures 0.3.31",
"log",
"parking_lot",
"schemars",
"serde",
"serde_json",
]
[[package]]
name = "ahash"
version = "0.7.8"
@@ -2076,7 +2122,7 @@ dependencies = [
[[package]]
name = "blade-graphics"
version = "0.6.0"
source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad"
source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5"
dependencies = [
"ash",
"ash-window",
@@ -2109,7 +2155,7 @@ dependencies = [
[[package]]
name = "blade-macros"
version = "0.3.0"
source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad"
source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5"
dependencies = [
"proc-macro2",
"quote",
@@ -2119,7 +2165,7 @@ dependencies = [
[[package]]
name = "blade-util"
version = "0.2.0"
source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad"
source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5"
dependencies = [
"blade-graphics",
"bytemuck",
@@ -4830,7 +4876,6 @@ dependencies = [
"tree-sitter-python",
"tree-sitter-rust",
"tree-sitter-typescript",
"tree-sitter-yaml",
"ui",
"unicode-script",
"unicode-segmentation",
@@ -12259,7 +12304,6 @@ dependencies = [
"language",
"log",
"lsp",
"markdown",
"node_runtime",
"parking_lot",
"pathdiff",
@@ -14058,6 +14102,7 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe8c9d1c68d67dd9f97ecbc6f932b60eb289c5dbddd8aa1405484a8fd2fcd984"
dependencies = [
"chrono",
"dyn-clone",
"indexmap",
"ref-cast",
@@ -14569,7 +14614,6 @@ dependencies = [
name = "settings_ui"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"command_palette",
"command_palette_hooks",
@@ -14580,7 +14624,6 @@ dependencies = [
"fs",
"fuzzy",
"gpui",
"language",
"log",
"menu",
"paths",
@@ -14590,8 +14633,6 @@ dependencies = [
"serde",
"settings",
"theme",
"tree-sitter-json",
"tree-sitter-rust",
"ui",
"util",
"workspace",
@@ -17348,7 +17389,6 @@ dependencies = [
"rand 0.8.5",
"regex",
"rust-embed",
"schemars",
"serde",
"serde_json",
"serde_json_lenient",
@@ -19946,7 +19986,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.195.0"
version = "0.194.0"
dependencies = [
"activity_indicator",
"agent",

View File

@@ -2,6 +2,7 @@
resolver = "2"
members = [
"crates/activity_indicator",
"crates/acp",
"crates/agent_ui",
"crates/agent",
"crates/agent_settings",
@@ -215,8 +216,9 @@ edition = "2024"
# Workspace member crates
#
activity_indicator = { path = "crates/activity_indicator" }
acp = { path = "crates/acp" }
agent = { path = "crates/agent" }
activity_indicator = { path = "crates/activity_indicator" }
agent_ui = { path = "crates/agent_ui" }
agent_settings = { path = "crates/agent_settings" }
ai = { path = "crates/ai" }
@@ -398,6 +400,7 @@ zlog_settings = { path = "crates/zlog_settings" }
# External crates
#
agentic-coding-protocol = { path = "../agentic-coding-protocol" }
aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
any_vec = "0.14"
@@ -425,9 +428,9 @@ aws-smithy-runtime-api = { version = "1.7.4", features = ["http-1x", "client"] }
aws-smithy-types = { version = "1.3.0", features = ["http-body-1-x"] }
base64 = "0.22"
bitflags = "2.6.0"
blade-graphics = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" }
blade-macros = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" }
blade-util = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" }
blade-graphics = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" }
blade-macros = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" }
blade-util = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" }
blake3 = "1.5.3"
bytes = "1.0"
cargo_metadata = "0.19"
@@ -625,7 +628,7 @@ wasmtime = { version = "29", default-features = false, features = [
wasmtime-wasi = "29"
which = "6.0.0"
workspace-hack = "0.1.0"
zed_llm_client = "= 0.8.5"
zed_llm_client = "0.8.5"
zstd = "0.11"
[workspace.dependencies.async-stripe]

View File

@@ -34,7 +34,7 @@
"ctrl-q": "zed::Quit",
"f4": "debugger::Start",
"shift-f5": "debugger::Stop",
"ctrl-shift-f5": "debugger::RerunSession",
"ctrl-shift-f5": "debugger::Restart",
"f6": "debugger::Pause",
"f7": "debugger::StepOver",
"ctrl-f11": "debugger::StepInto",
@@ -557,13 +557,6 @@
"ctrl-b": "workspace::ToggleLeftDock",
"ctrl-j": "workspace::ToggleBottomDock",
"ctrl-alt-y": "workspace::CloseAllDocks",
"ctrl-alt-0": "workspace::ResetActiveDockSize",
// For 0px parameter, uses UI font size value.
"ctrl-alt--": ["workspace::DecreaseActiveDockSize", { "px": 0 }],
"ctrl-alt-=": ["workspace::IncreaseActiveDockSize", { "px": 0 }],
"ctrl-alt-)": "workspace::ResetOpenDocksSize",
"ctrl-alt-_": ["workspace::DecreaseOpenDocksSize", { "px": 0 }],
"ctrl-alt-+": ["workspace::IncreaseOpenDocksSize", { "px": 0 }],
"shift-find": "pane::DeploySearch",
"ctrl-shift-f": "pane::DeploySearch",
"ctrl-shift-h": ["pane::DeploySearch", { "replace_enabled": true }],
@@ -605,9 +598,7 @@
// "foo-bar": ["task::Spawn", { "task_name": "MyTask", "reveal_target": "dock" }]
// or by tag:
// "foo-bar": ["task::Spawn", { "task_tag": "MyTag" }],
"f5": "debugger::Rerun",
"ctrl-f4": "workspace::CloseActiveDock",
"ctrl-w": "workspace::CloseActiveDock"
"f5": "debugger::RerunLastSession"
}
},
{
@@ -710,13 +701,6 @@
"pagedown": "editor::ContextMenuLast"
}
},
{
"context": "Editor && showing_signature_help && !showing_completions",
"bindings": {
"up": "editor::SignatureHelpPrevious",
"down": "editor::SignatureHelpNext"
}
},
// Custom bindings
{
"bindings": {
@@ -1084,13 +1068,6 @@
"ctrl-shift-tab": "pane::ActivatePreviousItem"
}
},
{
"context": "MarkdownPreview",
"bindings": {
"pageup": "markdown::MovePageUp",
"pagedown": "markdown::MovePageDown"
}
},
{
"context": "KeymapEditor",
"use_key_equivalents": true,

View File

@@ -5,10 +5,10 @@
"bindings": {
"f4": "debugger::Start",
"shift-f5": "debugger::Stop",
"shift-cmd-f5": "debugger::RerunSession",
"shift-cmd-f5": "debugger::Restart",
"f6": "debugger::Pause",
"f7": "debugger::StepOver",
"ctrl-f11": "debugger::StepInto",
"f11": "debugger::StepInto",
"shift-f11": "debugger::StepOut",
"home": "menu::SelectFirst",
"shift-pageup": "menu::SelectFirst",
@@ -624,13 +624,6 @@
"cmd-r": "workspace::ToggleRightDock",
"cmd-j": "workspace::ToggleBottomDock",
"alt-cmd-y": "workspace::CloseAllDocks",
// For 0px parameter, uses UI font size value.
"ctrl-alt-0": "workspace::ResetActiveDockSize",
"ctrl-alt--": ["workspace::DecreaseActiveDockSize", { "px": 0 }],
"ctrl-alt-=": ["workspace::IncreaseActiveDockSize", { "px": 0 }],
"ctrl-alt-)": "workspace::ResetOpenDocksSize",
"ctrl-alt-_": ["workspace::DecreaseOpenDocksSize", { "px": 0 }],
"ctrl-alt-+": ["workspace::IncreaseOpenDocksSize", { "px": 0 }],
"cmd-shift-f": "pane::DeploySearch",
"cmd-shift-h": ["pane::DeploySearch", { "replace_enabled": true }],
"cmd-shift-t": "pane::ReopenClosedItem",
@@ -659,8 +652,7 @@
"cmd-k shift-up": "workspace::SwapPaneUp",
"cmd-k shift-down": "workspace::SwapPaneDown",
"cmd-shift-x": "zed::Extensions",
"f5": "debugger::Rerun",
"cmd-w": "workspace::CloseActiveDock"
"f5": "debugger::RerunLastSession"
}
},
{
@@ -774,13 +766,6 @@
"pagedown": "editor::ContextMenuLast"
}
},
{
"context": "Editor && showing_signature_help && !showing_completions",
"bindings": {
"up": "editor::SignatureHelpPrevious",
"down": "editor::SignatureHelpNext"
}
},
// Custom bindings
{
"use_key_equivalents": true,
@@ -1183,13 +1168,6 @@
"ctrl-shift-tab": "pane::ActivatePreviousItem"
}
},
{
"context": "MarkdownPreview",
"bindings": {
"pageup": "markdown::MovePageUp",
"pagedown": "markdown::MovePageDown"
}
},
{
"context": "KeymapEditor",
"use_key_equivalents": true,

View File

@@ -98,13 +98,6 @@
"ctrl-n": "editor::ContextMenuNext"
}
},
{
"context": "Editor && showing_signature_help && !showing_completions",
"bindings": {
"ctrl-p": "editor::SignatureHelpPrevious",
"ctrl-n": "editor::SignatureHelpNext"
}
},
{
"context": "Workspace",
"bindings": {

View File

@@ -98,13 +98,6 @@
"ctrl-n": "editor::ContextMenuNext"
}
},
{
"context": "Editor && showing_signature_help && !showing_completions",
"bindings": {
"ctrl-p": "editor::SignatureHelpPrevious",
"ctrl-n": "editor::SignatureHelpNext"
}
},
{
"context": "Workspace",
"bindings": {

View File

@@ -477,13 +477,6 @@
"ctrl-n": "editor::ShowWordCompletions"
}
},
{
"context": "vim_mode == insert && showing_signature_help && !showing_completions",
"bindings": {
"ctrl-p": "editor::SignatureHelpPrevious",
"ctrl-n": "editor::SignatureHelpNext"
}
},
{
"context": "vim_mode == replace",
"bindings": {

View File

@@ -746,6 +746,8 @@
"default_width": 380
},
"agent": {
// Version of this setting.
"version": "2",
// Whether the agent is enabled.
"enabled": true,
/// What completion mode to start new threads in, if available. Can be 'normal' or 'burn'.
@@ -1290,8 +1292,6 @@
// Whether or not selecting text in the terminal will automatically
// copy to the system clipboard.
"copy_on_select": false,
// Whether to keep the text selection after copying it to the clipboard
"keep_selection_on_copy": false,
// Whether to show the terminal button in the status bar
"button": true,
// Any key-value pairs added to this list will be added to the terminal's
@@ -1656,6 +1656,7 @@
// Different settings for specific language models.
"language_models": {
"anthropic": {
"version": "1",
"api_url": "https://api.anthropic.com"
},
"google": {
@@ -1665,6 +1666,7 @@
"api_url": "http://localhost:11434"
},
"openai": {
"version": "1",
"api_url": "https://api.openai.com/v1"
},
"open_router": {
@@ -1782,8 +1784,7 @@
// `socks5h`. `http` will be used when no scheme is specified.
//
// By default no proxy will be used, or Zed will try get proxy settings from
// environment variables. If certain hosts should not be proxied,
// set the `no_proxy` environment variable and provide a comma-separated list.
// environment variables.
//
// Examples:
// - "proxy": "socks5h://localhost:10808"

48
crates/acp/Cargo.toml Normal file
View File

@@ -0,0 +1,48 @@
[package]
name = "acp"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/acp.rs"
doctest = false
[features]
test-support = ["gpui/test-support", "project/test-support"]
[dependencies]
agentic-coding-protocol = { path = "../../../agentic-coding-protocol" }
anyhow.workspace = true
async-trait.workspace = true
base64.workspace = true
chrono.workspace = true
collections.workspace = true
editor.workspace = true
futures.workspace = true
gpui.workspace = true
language.workspace = true
log.workspace = true
markdown.workspace = true
parking_lot.workspace = true
project.workspace = true
settings.workspace = true
smol.workspace = true
theme.workspace = true
ui.workspace = true
util.workspace = true
uuid.workspace = true
workspace-hack.workspace = true
zed_actions.workspace = true
[dev-dependencies]
env_logger.workspace = true
gpui = { workspace = true, "features" = ["test-support"] }
project = { workspace = true, "features" = ["test-support"] }
serde_json.workspace = true
util.workspace = true
settings.workspace = true

1
crates/acp/LICENSE-GPL Symbolic link
View File

@@ -0,0 +1 @@
../../LICENSE-GPL

677
crates/acp/src/acp.rs Normal file
View File

@@ -0,0 +1,677 @@
mod server;
mod thread_view;
use agentic_coding_protocol::{self as acp, Role};
use anyhow::{Context as _, Result};
use chrono::{DateTime, Utc};
use futures::channel::oneshot;
use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
use language::LanguageRegistry;
use markdown::Markdown;
use project::Project;
use std::{mem, ops::Range, path::PathBuf, sync::Arc};
use ui::{App, IconName};
use util::{ResultExt, debug_panic};
pub use server::AcpServer;
pub use thread_view::AcpThreadView;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ThreadId(SharedString);
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct FileVersion(u64);
#[derive(Debug)]
pub struct AgentThreadSummary {
pub id: ThreadId,
pub title: String,
pub created_at: DateTime<Utc>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct FileContent {
pub path: PathBuf,
pub version: FileVersion,
pub content: SharedString,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Message {
pub role: acp::Role,
pub chunks: Vec<MessageChunk>,
}
impl Message {
fn into_acp(self, cx: &App) -> acp::Message {
acp::Message {
role: self.role,
chunks: self
.chunks
.into_iter()
.map(|chunk| chunk.into_acp(cx))
.collect(),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum MessageChunk {
Text {
chunk: Entity<Markdown>,
},
File {
content: FileContent,
},
Directory {
path: PathBuf,
contents: Vec<FileContent>,
},
Symbol {
path: PathBuf,
range: Range<u64>,
version: FileVersion,
name: SharedString,
content: SharedString,
},
Fetch {
url: SharedString,
content: SharedString,
},
}
impl MessageChunk {
pub fn from_acp(
chunk: acp::MessageChunk,
language_registry: Arc<LanguageRegistry>,
cx: &mut App,
) -> Self {
match chunk {
acp::MessageChunk::Text { chunk } => MessageChunk::Text {
chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
},
}
}
pub fn into_acp(self, cx: &App) -> acp::MessageChunk {
match self {
MessageChunk::Text { chunk } => acp::MessageChunk::Text {
chunk: chunk.read(cx).source().to_string(),
},
MessageChunk::File { .. } => todo!(),
MessageChunk::Directory { .. } => todo!(),
MessageChunk::Symbol { .. } => todo!(),
MessageChunk::Fetch { .. } => todo!(),
}
}
pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
MessageChunk::Text {
chunk: cx.new(|cx| {
Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
}),
}
}
}
#[derive(Debug)]
pub enum AgentThreadEntryContent {
Message(Message),
ToolCall(ToolCall),
}
#[derive(Debug)]
pub struct ToolCall {
id: ToolCallId,
label: Entity<Markdown>,
icon: IconName,
status: ToolCallStatus,
}
#[derive(Debug)]
pub enum ToolCallStatus {
WaitingForConfirmation {
confirmation: acp::ToolCallConfirmation,
respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
},
// todo! Running?
Allowed {
// todo! should this be variants in crate::ToolCallStatus instead?
status: acp::ToolCallStatus,
content: Option<Entity<Markdown>>,
},
Rejected,
}
/// A `ThreadEntryId` that is known to be a ToolCall
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct ToolCallId(ThreadEntryId);
impl ToolCallId {
pub fn as_u64(&self) -> u64 {
self.0.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct ThreadEntryId(pub u64);
impl ThreadEntryId {
pub fn post_inc(&mut self) -> Self {
let id = *self;
self.0 += 1;
id
}
}
#[derive(Debug)]
pub struct ThreadEntry {
pub id: ThreadEntryId,
pub content: AgentThreadEntryContent,
}
pub struct AcpThread {
id: ThreadId,
next_entry_id: ThreadEntryId,
entries: Vec<ThreadEntry>,
server: Arc<AcpServer>,
title: SharedString,
project: Entity<Project>,
}
enum AcpThreadEvent {
NewEntry,
EntryUpdated(usize),
}
impl EventEmitter<AcpThreadEvent> for AcpThread {}
impl AcpThread {
pub fn new(
server: Arc<AcpServer>,
thread_id: ThreadId,
entries: Vec<AgentThreadEntryContent>,
project: Entity<Project>,
_: &mut Context<Self>,
) -> Self {
let mut next_entry_id = ThreadEntryId(0);
Self {
title: "A new agent2 thread".into(),
entries: entries
.into_iter()
.map(|entry| ThreadEntry {
id: next_entry_id.post_inc(),
content: entry,
})
.collect(),
server,
id: thread_id,
next_entry_id,
project,
}
}
pub fn title(&self) -> SharedString {
self.title.clone()
}
pub fn entries(&self) -> &[ThreadEntry] {
&self.entries
}
pub fn push_entry(
&mut self,
entry: AgentThreadEntryContent,
cx: &mut Context<Self>,
) -> ThreadEntryId {
let id = self.next_entry_id.post_inc();
self.entries.push(ThreadEntry { id, content: entry });
cx.emit(AcpThreadEvent::NewEntry);
id
}
pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
let entries_len = self.entries.len();
if let Some(last_entry) = self.entries.last_mut()
&& let AgentThreadEntryContent::Message(Message {
ref mut chunks,
role: Role::Assistant,
}) = last_entry.content
{
cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
if let (
Some(MessageChunk::Text { chunk: old_chunk }),
acp::MessageChunk::Text { chunk: new_chunk },
) = (chunks.last_mut(), &chunk)
{
old_chunk.update(cx, |old_chunk, cx| {
old_chunk.append(&new_chunk, cx);
});
} else {
chunks.push(MessageChunk::from_acp(
chunk,
self.project.read(cx).languages().clone(),
cx,
));
}
return;
}
let chunk = MessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
self.push_entry(
AgentThreadEntryContent::Message(Message {
role: Role::Assistant,
chunks: vec![chunk],
}),
cx,
);
}
pub fn request_tool_call(
&mut self,
label: String,
icon: acp::Icon,
confirmation: acp::ToolCallConfirmation,
cx: &mut Context<Self>,
) -> ToolCallRequest {
let (tx, rx) = oneshot::channel();
let status = ToolCallStatus::WaitingForConfirmation {
confirmation,
respond_tx: tx,
};
let id = self.insert_tool_call(label, status, icon, cx);
ToolCallRequest { id, outcome: rx }
}
pub fn push_tool_call(
&mut self,
label: String,
icon: acp::Icon,
cx: &mut Context<Self>,
) -> ToolCallId {
let status = ToolCallStatus::Allowed {
status: acp::ToolCallStatus::Running,
content: None,
};
self.insert_tool_call(label, status, icon, cx)
}
fn insert_tool_call(
&mut self,
label: String,
status: ToolCallStatus,
icon: acp::Icon,
cx: &mut Context<Self>,
) -> ToolCallId {
let language_registry = self.project.read(cx).languages().clone();
let entry_id = self.push_entry(
AgentThreadEntryContent::ToolCall(ToolCall {
// todo! clean up id creation
id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
label: cx.new(|cx| {
Markdown::new(label.into(), Some(language_registry.clone()), None, cx)
}),
icon: acp_icon_to_ui_icon(icon),
status,
}),
cx,
);
ToolCallId(entry_id)
}
pub fn authorize_tool_call(
&mut self,
id: ToolCallId,
outcome: acp::ToolCallConfirmationOutcome,
cx: &mut Context<Self>,
) {
let Some(entry) = self.entry_mut(id.0) else {
return;
};
let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
debug_panic!("expected ToolCall");
return;
};
let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
ToolCallStatus::Rejected
} else {
ToolCallStatus::Allowed {
status: acp::ToolCallStatus::Running,
content: None,
}
};
let curr_status = mem::replace(&mut call.status, new_status);
if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
respond_tx.send(outcome).log_err();
} else {
debug_panic!("tried to authorize an already authorized tool call");
}
cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
}
pub fn update_tool_call(
&mut self,
id: ToolCallId,
new_status: acp::ToolCallStatus,
new_content: Option<acp::ToolCallContent>,
cx: &mut Context<Self>,
) -> Result<()> {
let language_registry = self.project.read(cx).languages().clone();
let entry = self.entry_mut(id.0).context("Entry not found")?;
match &mut entry.content {
AgentThreadEntryContent::ToolCall(call) => match &mut call.status {
ToolCallStatus::Allowed { content, status } => {
*content = new_content.map(|new_content| {
let acp::ToolCallContent::Markdown { markdown } = new_content;
cx.new(|cx| {
Markdown::new(markdown.into(), Some(language_registry), None, cx)
})
});
*status = new_status;
}
ToolCallStatus::WaitingForConfirmation { .. } => {
anyhow::bail!("Tool call hasn't been authorized yet")
}
ToolCallStatus::Rejected => {
anyhow::bail!("Tool call was rejected and therefore can't be updated")
}
},
_ => anyhow::bail!("Entry is not a tool call"),
}
cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
Ok(())
}
fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
let entry = self.entries.get_mut(id.0 as usize);
debug_assert!(
entry.is_some(),
"We shouldn't give out ids to entries that don't exist"
);
entry
}
/// Returns true if the last turn is awaiting tool authorization
pub fn waiting_for_tool_confirmation(&self) -> bool {
for entry in self.entries.iter().rev() {
match &entry.content {
AgentThreadEntryContent::ToolCall(call) => match call.status {
ToolCallStatus::WaitingForConfirmation { .. } => return true,
ToolCallStatus::Allowed { .. } | ToolCallStatus::Rejected => continue,
},
AgentThreadEntryContent::Message(_) => {
// Reached the beginning of the turn
return false;
}
}
}
false
}
pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
let agent = self.server.clone();
let id = self.id.clone();
let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
let message = Message {
role: Role::User,
chunks: vec![chunk],
};
self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
let acp_message = message.into_acp(cx);
cx.spawn(async move |_, cx| {
agent.send_message(id, acp_message, cx).await?;
Ok(())
})
}
}
fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
match icon {
acp::Icon::FileSearch => IconName::FileSearch,
acp::Icon::Folder => IconName::Folder,
acp::Icon::Globe => IconName::Globe,
acp::Icon::Hammer => IconName::Hammer,
acp::Icon::LightBulb => IconName::LightBulb,
acp::Icon::Pencil => IconName::Pencil,
acp::Icon::Regex => IconName::Regex,
acp::Icon::Terminal => IconName::Terminal,
}
}
pub struct ToolCallRequest {
pub id: ToolCallId,
pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
}
#[cfg(test)]
mod tests {
use super::*;
use futures::{FutureExt as _, channel::mpsc, select};
use gpui::{AsyncApp, TestAppContext};
use project::FakeFs;
use serde_json::json;
use settings::SettingsStore;
use smol::stream::StreamExt as _;
use std::{env, path::Path, process::Stdio, time::Duration};
use util::path;
fn init_test(cx: &mut TestAppContext) {
env_logger::try_init().ok();
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
language::init(cx);
});
}
#[gpui::test]
async fn test_gemini_basic(cx: &mut TestAppContext) {
init_test(cx);
cx.executor().allow_parking();
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, [], cx).await;
let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
thread
.update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
.await
.unwrap();
thread.read_with(cx, |thread, _| {
assert_eq!(thread.entries.len(), 2);
assert!(matches!(
thread.entries[0].content,
AgentThreadEntryContent::Message(Message {
role: Role::User,
..
})
));
assert!(matches!(
thread.entries[1].content,
AgentThreadEntryContent::Message(Message {
role: Role::Assistant,
..
})
));
});
}
#[gpui::test]
async fn test_gemini_tool_call(cx: &mut TestAppContext) {
init_test(cx);
cx.executor().allow_parking();
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/private/tmp"),
json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
)
.await;
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
thread
.update(cx, |thread, cx| {
thread.send(
"Read the '/private/tmp/foo' file and tell me what you see.",
cx,
)
})
.await
.unwrap();
thread.read_with(cx, |thread, _cx| {
assert!(matches!(
&thread.entries()[1].content,
AgentThreadEntryContent::ToolCall(ToolCall {
status: ToolCallStatus::Allowed { .. },
..
})
));
assert!(matches!(
thread.entries[2].content,
AgentThreadEntryContent::Message(Message {
role: Role::Assistant,
..
})
));
});
}
#[gpui::test]
async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
init_test(cx);
cx.executor().allow_parking();
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
let full_turn = thread.update(cx, |thread, cx| {
thread.send(r#"Run `echo "Hello, world!"`"#, cx)
});
run_until_tool_call(&thread, cx).await;
let tool_call_id = thread.read_with(cx, |thread, _cx| {
let AgentThreadEntryContent::ToolCall(ToolCall {
id,
status:
ToolCallStatus::WaitingForConfirmation {
confirmation: acp::ToolCallConfirmation::Execute { root_command, .. },
..
},
..
}) = &thread.entries()[1].content
else {
panic!();
};
assert_eq!(root_command, "echo");
*id
});
thread.update(cx, |thread, cx| {
thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
assert!(matches!(
&thread.entries()[1].content,
AgentThreadEntryContent::ToolCall(ToolCall {
status: ToolCallStatus::Allowed { .. },
..
})
));
});
full_turn.await.unwrap();
thread.read_with(cx, |thread, cx| {
let AgentThreadEntryContent::ToolCall(ToolCall {
status: ToolCallStatus::Allowed { content, .. },
..
}) = &thread.entries()[1].content
else {
panic!();
};
content.as_ref().unwrap().read_with(cx, |md, _cx| {
assert!(
md.source().contains("Hello, world!"),
r#"Expected '{}' to contain "Hello, world!""#,
md.source()
);
});
});
}
async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
let (mut tx, mut rx) = mpsc::channel::<()>(1);
let subscription = cx.update(|cx| {
cx.subscribe(thread, move |thread, _, cx| {
if thread
.read(cx)
.entries
.iter()
.any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
{
tx.try_send(()).unwrap();
}
})
});
select! {
_ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
panic!("Timeout waiting for tool call")
}
_ = rx.next().fuse() => {
drop(subscription);
}
}
}
pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
let cli_path =
Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
let mut command = util::command::new_smol_command("node");
command
.arg(cli_path)
.arg("--acp")
.current_dir("/private/tmp")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.kill_on_drop(true);
if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
command.env("GEMINI_API_KEY", gemini_key);
}
let child = command.spawn().unwrap();
Ok(AcpServer::stdio(child, project, &mut cx))
}
}

322
crates/acp/src/server.rs Normal file
View File

@@ -0,0 +1,322 @@
use crate::{AcpThread, ThreadEntryId, ThreadId, ToolCallId, ToolCallRequest};
use agentic_coding_protocol as acp;
use anyhow::{Context as _, Result};
use async_trait::async_trait;
use collections::HashMap;
use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
use parking_lot::Mutex;
use project::Project;
use smol::process::Child;
use std::{io::Write as _, path::Path, sync::Arc};
use util::ResultExt;
pub struct AcpServer {
connection: Arc<acp::AgentConnection>,
threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
project: Entity<Project>,
_handler_task: Task<()>,
_io_task: Task<()>,
}
struct AcpClientDelegate {
project: Entity<Project>,
threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
cx: AsyncApp,
// sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
}
impl AcpClientDelegate {
fn new(
project: Entity<Project>,
threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
cx: AsyncApp,
) -> Self {
Self {
project,
threads,
cx: cx,
}
}
fn update_thread<R>(
&self,
thread_id: &ThreadId,
cx: &mut App,
callback: impl FnOnce(&mut AcpThread, &mut Context<AcpThread>) -> R,
) -> Option<R> {
let thread = self.threads.lock().get(&thread_id)?.clone();
let Some(thread) = thread.upgrade() else {
self.threads.lock().remove(&thread_id);
return None;
};
Some(thread.update(cx, callback))
}
}
#[async_trait(?Send)]
impl acp::Client for AcpClientDelegate {
async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> {
let cx = &mut self.cx.clone();
self.project.update(cx, |project, cx| {
let path = project
.project_path_for_absolute_path(Path::new(&params.path), cx)
.context("Failed to get project path")?;
match project.entry_for_path(&path, cx) {
// todo! refresh entry?
None => Ok(acp::StatResponse {
exists: false,
is_directory: false,
}),
Some(entry) => Ok(acp::StatResponse {
exists: entry.is_created(),
is_directory: entry.is_dir(),
}),
}
})?
}
async fn stream_message_chunk(
&self,
params: acp::StreamMessageChunkParams,
) -> Result<acp::StreamMessageChunkResponse> {
let cx = &mut self.cx.clone();
cx.update(|cx| {
self.update_thread(&params.thread_id.into(), cx, |thread, cx| {
thread.push_assistant_chunk(params.chunk, cx)
});
})?;
Ok(acp::StreamMessageChunkResponse)
}
async fn read_text_file(
&self,
request: acp::ReadTextFileParams,
) -> Result<acp::ReadTextFileResponse> {
let cx = &mut self.cx.clone();
let buffer = self
.project
.update(cx, |project, cx| {
let path = project
.project_path_for_absolute_path(Path::new(&request.path), cx)
.context("Failed to get project path")?;
anyhow::Ok(project.open_buffer(path, cx))
})??
.await?;
buffer.update(cx, |buffer, _cx| {
let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
let end = match request.line_limit {
None => buffer.max_point(),
Some(limit) => start + language::Point::new(limit + 1, 0),
};
let content: String = buffer.text_for_range(start..end).collect();
acp::ReadTextFileResponse {
content,
version: acp::FileVersion(0),
}
})
}
async fn read_binary_file(
&self,
request: acp::ReadBinaryFileParams,
) -> Result<acp::ReadBinaryFileResponse> {
let cx = &mut self.cx.clone();
let file = self
.project
.update(cx, |project, cx| {
let (worktree, path) = project
.find_worktree(Path::new(&request.path), cx)
.context("Failed to get project path")?;
let task = worktree.update(cx, |worktree, cx| worktree.load_binary_file(&path, cx));
anyhow::Ok(task)
})??
.await?;
// todo! test
let content = cx
.background_spawn(async move {
let start = request.byte_offset.unwrap_or(0) as usize;
let end = request
.byte_limit
.map(|limit| (start + limit as usize).min(file.content.len()))
.unwrap_or(file.content.len());
let range_content = &file.content[start..end];
let mut base64_content = Vec::new();
let mut base64_encoder = base64::write::EncoderWriter::new(
std::io::Cursor::new(&mut base64_content),
&base64::engine::general_purpose::STANDARD,
);
base64_encoder.write_all(range_content)?;
drop(base64_encoder);
// SAFETY: The base64 encoder should not produce non-UTF8.
unsafe { anyhow::Ok(String::from_utf8_unchecked(base64_content)) }
})
.await?;
Ok(acp::ReadBinaryFileResponse {
content,
// todo!
version: acp::FileVersion(0),
})
}
async fn glob_search(
&self,
_request: acp::GlobSearchParams,
) -> Result<acp::GlobSearchResponse> {
todo!()
}
async fn request_tool_call_confirmation(
&self,
request: acp::RequestToolCallConfirmationParams,
) -> Result<acp::RequestToolCallConfirmationResponse> {
let cx = &mut self.cx.clone();
let ToolCallRequest { id, outcome } = cx
.update(|cx| {
self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
thread.request_tool_call(request.label, request.icon, request.confirmation, cx)
})
})?
.context("Failed to update thread")?;
Ok(acp::RequestToolCallConfirmationResponse {
id: id.into(),
outcome: outcome.await?,
})
}
async fn push_tool_call(
&self,
request: acp::PushToolCallParams,
) -> Result<acp::PushToolCallResponse> {
let cx = &mut self.cx.clone();
let entry_id = cx
.update(|cx| {
self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
thread.push_tool_call(request.label, request.icon, cx)
})
})?
.context("Failed to update thread")?;
Ok(acp::PushToolCallResponse {
id: entry_id.into(),
})
}
async fn update_tool_call(
&self,
request: acp::UpdateToolCallParams,
) -> Result<acp::UpdateToolCallResponse> {
let cx = &mut self.cx.clone();
cx.update(|cx| {
self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
thread.update_tool_call(
request.tool_call_id.into(),
request.status,
request.content,
cx,
)
})
})?
.context("Failed to update thread")??;
Ok(acp::UpdateToolCallResponse)
}
}
impl AcpServer {
pub fn stdio(mut process: Child, project: Entity<Project>, cx: &mut AsyncApp) -> Arc<Self> {
let stdin = process.stdin.take().expect("process didn't have stdin");
let stdout = process.stdout.take().expect("process didn't have stdout");
let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
AcpClientDelegate::new(project.clone(), threads.clone(), cx.clone()),
stdin,
stdout,
);
let io_task = cx.background_spawn(async move {
io_fut.await.log_err();
process.status().await.log_err();
});
Arc::new(Self {
project,
connection: Arc::new(connection),
threads,
_handler_task: cx.foreground_executor().spawn(handler_fut),
_io_task: io_task,
})
}
}
impl AcpServer {
pub async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<AcpThread>> {
let response = self.connection.request(acp::CreateThreadParams).await?;
let thread_id: ThreadId = response.thread_id.into();
let server = self.clone();
let thread = cx.new(|_| AcpThread {
// todo!
title: "ACP Thread".into(),
id: thread_id.clone(),
next_entry_id: ThreadEntryId(0),
entries: Vec::default(),
project: self.project.clone(),
server,
})?;
self.threads.lock().insert(thread_id, thread.downgrade());
Ok(thread)
}
pub async fn send_message(
&self,
thread_id: ThreadId,
message: acp::Message,
_cx: &mut AsyncApp,
) -> Result<()> {
self.connection
.request(acp::SendMessageParams {
thread_id: thread_id.clone().into(),
message,
})
.await?;
Ok(())
}
}
impl From<acp::ThreadId> for ThreadId {
fn from(thread_id: acp::ThreadId) -> Self {
Self(thread_id.0.into())
}
}
impl From<ThreadId> for acp::ThreadId {
fn from(thread_id: ThreadId) -> Self {
acp::ThreadId(thread_id.0.to_string())
}
}
impl From<acp::ToolCallId> for ToolCallId {
fn from(tool_call_id: acp::ToolCallId) -> Self {
Self(ThreadEntryId(tool_call_id.0))
}
}
impl From<ToolCallId> for acp::ToolCallId {
fn from(tool_call_id: ToolCallId) -> Self {
acp::ToolCallId(tool_call_id.as_u64())
}
}

View File

@@ -0,0 +1,935 @@
use std::path::{Path, PathBuf};
use std::rc::Rc;
use std::sync::Arc;
use std::time::Duration;
use agentic_coding_protocol::{self as acp, ToolCallConfirmation};
use anyhow::Result;
use editor::{Editor, MultiBuffer};
use gpui::{
Animation, AnimationExt, App, EdgesRefinement, Empty, Entity, Focusable, ListState,
SharedString, StyleRefinement, Subscription, TextStyleRefinement, Transformation,
UnderlineStyle, Window, div, list, percentage, prelude::*,
};
use gpui::{FocusHandle, Task};
use language::Buffer;
use markdown::{HeadingLevelStyles, MarkdownElement, MarkdownStyle};
use project::Project;
use settings::Settings as _;
use theme::ThemeSettings;
use ui::prelude::*;
use ui::{Button, Tooltip};
use util::ResultExt;
use zed_actions::agent::Chat;
use crate::{
AcpServer, AcpThread, AcpThreadEvent, AgentThreadEntryContent, MessageChunk, Role, ThreadEntry,
ToolCall, ToolCallId, ToolCallStatus,
};
pub struct AcpThreadView {
thread_state: ThreadState,
// todo! use full message editor from agent2
message_editor: Entity<Editor>,
list_state: ListState,
send_task: Option<Task<Result<()>>>,
root: Arc<Path>,
}
enum ThreadState {
Loading {
_task: Task<()>,
},
Ready {
thread: Entity<AcpThread>,
_subscription: Subscription,
},
LoadError(SharedString),
}
impl AcpThreadView {
pub fn new(project: Entity<Project>, window: &mut Window, cx: &mut Context<Self>) -> Self {
// todo!(): This should probably be contextual, like the terminal
let Some(root_dir) = project
.read(cx)
.visible_worktrees(cx)
.next()
.map(|worktree| worktree.read(cx).abs_path())
else {
todo!();
};
let cli_path =
Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
let child = util::command::new_smol_command("node")
.arg(cli_path)
.arg("--acp")
.current_dir(&root_dir)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit())
.kill_on_drop(true)
.spawn()
.unwrap();
let message_editor = cx.new(|cx| {
let buffer = cx.new(|cx| Buffer::local("", cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let mut editor = Editor::new(
editor::EditorMode::AutoHeight {
min_lines: 4,
max_lines: None,
},
buffer,
None,
window,
cx,
);
editor.set_placeholder_text("Send a message", cx);
editor.set_soft_wrap();
editor
});
let project = project.clone();
let load_task = cx.spawn_in(window, async move |this, cx| {
let agent = AcpServer::stdio(child, project, cx);
let result = agent.create_thread(cx).await;
this.update(cx, |this, cx| {
match result {
Ok(thread) => {
let subscription = cx.subscribe(&thread, |this, _, event, cx| {
let count = this.list_state.item_count();
match event {
AcpThreadEvent::NewEntry => {
this.list_state.splice(count..count, 1);
}
AcpThreadEvent::EntryUpdated(index) => {
this.list_state.splice(*index..*index + 1, 1);
}
}
cx.notify();
});
this.list_state
.splice(0..0, thread.read(cx).entries().len());
this.thread_state = ThreadState::Ready {
thread,
_subscription: subscription,
};
}
Err(e) => this.thread_state = ThreadState::LoadError(e.to_string().into()),
};
cx.notify();
})
.log_err();
});
let list_state = ListState::new(
0,
gpui::ListAlignment::Bottom,
px(2048.0),
cx.processor({
move |this: &mut Self, item: usize, window, cx| {
let Some(entry) = this
.thread()
.and_then(|thread| thread.read(cx).entries.get(item))
else {
return Empty.into_any();
};
this.render_entry(entry, window, cx)
}
}),
);
Self {
thread_state: ThreadState::Loading { _task: load_task },
message_editor,
send_task: None,
list_state: list_state,
root: root_dir,
}
}
fn thread(&self) -> Option<&Entity<AcpThread>> {
match &self.thread_state {
ThreadState::Ready { thread, .. } => Some(thread),
ThreadState::Loading { .. } | ThreadState::LoadError(..) => None,
}
}
pub fn title(&self, cx: &App) -> SharedString {
match &self.thread_state {
ThreadState::Ready { thread, .. } => thread.read(cx).title(),
ThreadState::Loading { .. } => "Loading...".into(),
ThreadState::LoadError(_) => "Failed to load".into(),
}
}
pub fn cancel(&mut self) {
self.send_task.take();
}
fn chat(&mut self, _: &Chat, window: &mut Window, cx: &mut Context<Self>) {
let text = self.message_editor.read(cx).text(cx);
if text.is_empty() {
return;
}
let Some(thread) = self.thread() else { return };
let task = thread.update(cx, |thread, cx| thread.send(&text, cx));
self.send_task = Some(cx.spawn(async move |this, cx| {
task.await?;
this.update(cx, |this, _cx| {
this.send_task.take();
})
}));
self.message_editor.update(cx, |editor, cx| {
editor.clear(window, cx);
});
}
fn authorize_tool_call(
&mut self,
id: ToolCallId,
outcome: acp::ToolCallConfirmationOutcome,
cx: &mut Context<Self>,
) {
let Some(thread) = self.thread() else {
return;
};
thread.update(cx, |thread, cx| {
thread.authorize_tool_call(id, outcome, cx);
});
cx.notify();
}
fn render_entry(
&self,
entry: &ThreadEntry,
window: &mut Window,
cx: &Context<Self>,
) -> AnyElement {
match &entry.content {
AgentThreadEntryContent::Message(message) => {
let style = if message.role == Role::User {
user_message_markdown_style(window, cx)
} else {
default_markdown_style(window, cx)
};
let message_body = div()
.children(message.chunks.iter().map(|chunk| match chunk {
MessageChunk::Text { chunk } => {
// todo!() open link
MarkdownElement::new(chunk.clone(), style.clone())
}
_ => todo!(),
}))
.into_any();
match message.role {
Role::User => div()
.p_2()
.pt_5()
.child(
div()
.text_xs()
.p_3()
.bg(cx.theme().colors().editor_background)
.rounded_lg()
.shadow_md()
.border_1()
.border_color(cx.theme().colors().border)
.child(message_body),
)
.into_any(),
Role::Assistant => div()
.text_ui(cx)
.p_5()
.pt_2()
.child(message_body)
.into_any(),
}
}
AgentThreadEntryContent::ToolCall(tool_call) => div()
.px_2()
.py_4()
.child(self.render_tool_call(tool_call, window, cx))
.into_any(),
}
}
fn render_tool_call(&self, tool_call: &ToolCall, window: &Window, cx: &Context<Self>) -> Div {
let status_icon = match &tool_call.status {
ToolCallStatus::WaitingForConfirmation { .. } => Empty.into_element().into_any(),
ToolCallStatus::Allowed {
status: acp::ToolCallStatus::Running,
..
} => Icon::new(IconName::ArrowCircle)
.color(Color::Success)
.size(IconSize::Small)
.with_animation(
"running",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
)
.into_any_element(),
ToolCallStatus::Allowed {
status: acp::ToolCallStatus::Finished,
..
} => Icon::new(IconName::Check)
.color(Color::Success)
.size(IconSize::Small)
.into_any_element(),
ToolCallStatus::Rejected
| ToolCallStatus::Allowed {
status: acp::ToolCallStatus::Error,
..
} => Icon::new(IconName::X)
.color(Color::Error)
.size(IconSize::Small)
.into_any_element(),
};
let content = match &tool_call.status {
ToolCallStatus::WaitingForConfirmation { confirmation, .. } => {
Some(self.render_tool_call_confirmation(tool_call.id, confirmation, cx))
}
ToolCallStatus::Allowed { content, .. } => content.clone().map(|content| {
div()
.border_color(cx.theme().colors().border)
.border_t_1()
.px_2()
.py_1p5()
.child(MarkdownElement::new(
content,
default_markdown_style(window, cx),
))
.into_any_element()
}),
ToolCallStatus::Rejected => None,
};
v_flex()
.text_xs()
.rounded_md()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().editor_background)
.child(
h_flex()
.px_2()
.py_1p5()
.w_full()
.gap_1p5()
.child(
Icon::new(tool_call.icon.into())
.size(IconSize::Small)
.color(Color::Muted),
)
// todo! danilo please help
.child(MarkdownElement::new(
tool_call.label.clone(),
default_markdown_style(window, cx),
))
.child(div().w_full())
.child(status_icon),
)
.children(content)
}
fn render_tool_call_confirmation(
&self,
tool_call_id: ToolCallId,
confirmation: &ToolCallConfirmation,
cx: &Context<Self>,
) -> AnyElement {
match confirmation {
ToolCallConfirmation::Edit {
file_name,
file_diff,
description,
} => v_flex()
.border_color(cx.theme().colors().border)
.border_t_1()
.px_2()
.py_1p5()
// todo! nicer rendering
.child(file_name.clone())
.child(file_diff.clone())
.children(description.clone())
.child(
h_flex()
.justify_end()
.gap_1()
.child(
Button::new(
("always_allow", tool_call_id.as_u64()),
"Always Allow Edits",
)
.icon(IconName::CheckDouble)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Success)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::AlwaysAllow,
cx,
);
}
})),
)
.child(
Button::new(("allow", tool_call_id.as_u64()), "Allow")
.icon(IconName::Check)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Success)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::Allow,
cx,
);
}
})),
)
.child(
Button::new(("reject", tool_call_id.as_u64()), "Reject")
.icon(IconName::X)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Error)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::Reject,
cx,
);
}
})),
),
)
.into_any(),
ToolCallConfirmation::Execute {
command,
root_command,
description,
} => v_flex()
.border_color(cx.theme().colors().border)
.border_t_1()
.px_2()
.py_1p5()
// todo! nicer rendering
.child(command.clone())
.children(description.clone())
.child(
h_flex()
.justify_end()
.gap_1()
.child(
Button::new(
("always_allow", tool_call_id.as_u64()),
format!("Always Allow {root_command}"),
)
.icon(IconName::CheckDouble)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Success)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::AlwaysAllow,
cx,
);
}
})),
)
.child(
Button::new(("allow", tool_call_id.as_u64()), "Allow")
.icon(IconName::Check)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Success)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::Allow,
cx,
);
}
})),
)
.child(
Button::new(("reject", tool_call_id.as_u64()), "Reject")
.icon(IconName::X)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Error)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::Reject,
cx,
);
}
})),
),
)
.into_any(),
ToolCallConfirmation::Mcp {
server_name,
tool_name: _,
tool_display_name,
description,
} => v_flex()
.border_color(cx.theme().colors().border)
.border_t_1()
.px_2()
.py_1p5()
// todo! nicer rendering
.child(format!("{server_name} - {tool_display_name}"))
.children(description.clone())
.child(
h_flex()
.justify_end()
.gap_1()
.child(
Button::new(
("always_allow_server", tool_call_id.as_u64()),
format!("Always Allow {server_name}"),
)
.icon(IconName::CheckDouble)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Success)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer,
cx,
);
}
})),
)
.child(
Button::new(
("always_allow_tool", tool_call_id.as_u64()),
format!("Always Allow {tool_display_name}"),
)
.icon(IconName::CheckDouble)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Success)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::AlwaysAllowTool,
cx,
);
}
})),
)
.child(
Button::new(("allow", tool_call_id.as_u64()), "Allow")
.icon(IconName::Check)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Success)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::Allow,
cx,
);
}
})),
)
.child(
Button::new(("reject", tool_call_id.as_u64()), "Reject")
.icon(IconName::X)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Error)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::Reject,
cx,
);
}
})),
),
)
.into_any(),
ToolCallConfirmation::Fetch { description, urls } => v_flex()
.border_color(cx.theme().colors().border)
.border_t_1()
.px_2()
.py_1p5()
// todo! nicer rendering
.children(urls.clone())
.children(description.clone())
.child(
h_flex()
.justify_end()
.gap_1()
.child(
Button::new(("always_allow", tool_call_id.as_u64()), "Always Allow")
.icon(IconName::CheckDouble)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Success)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::AlwaysAllow,
cx,
);
}
})),
)
.child(
Button::new(("allow", tool_call_id.as_u64()), "Allow")
.icon(IconName::Check)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Success)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::Allow,
cx,
);
}
})),
)
.child(
Button::new(("reject", tool_call_id.as_u64()), "Reject")
.icon(IconName::X)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Error)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::Reject,
cx,
);
}
})),
),
)
.into_any(),
ToolCallConfirmation::Other { description } => v_flex()
.border_color(cx.theme().colors().border)
.border_t_1()
.px_2()
.py_1p5()
// todo! nicer rendering
.child(description.clone())
.child(
h_flex()
.justify_end()
.gap_1()
.child(
Button::new(("always_allow", tool_call_id.as_u64()), "Always Allow")
.icon(IconName::CheckDouble)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Success)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::AlwaysAllow,
cx,
);
}
})),
)
.child(
Button::new(("allow", tool_call_id.as_u64()), "Allow")
.icon(IconName::Check)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Success)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::Allow,
cx,
);
}
})),
)
.child(
Button::new(("reject", tool_call_id.as_u64()), "Reject")
.icon(IconName::X)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Error)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::Reject,
cx,
);
}
})),
),
)
.into_any(),
}
}
}
impl Focusable for AcpThreadView {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.message_editor.focus_handle(cx)
}
}
impl Render for AcpThreadView {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let text = self.message_editor.read(cx).text(cx);
let is_editor_empty = text.is_empty();
let focus_handle = self.message_editor.focus_handle(cx);
v_flex()
.key_context("MessageEditor")
.on_action(cx.listener(Self::chat))
.h_full()
.child(match &self.thread_state {
ThreadState::Loading { .. } => v_flex()
.p_2()
.flex_1()
.justify_end()
.child(Label::new("Connecting to Gemini...")),
ThreadState::LoadError(e) => div()
.p_2()
.flex_1()
.justify_end()
.child(Label::new(format!("Failed to load {e}")).into_any_element()),
ThreadState::Ready { thread, .. } => v_flex()
.flex_1()
.gap_2()
.pb_2()
.child(
list(self.list_state.clone())
.with_sizing_behavior(gpui::ListSizingBehavior::Auto)
.flex_grow(),
)
.child(div().px_3().children(if self.send_task.is_none() {
None
} else {
Label::new(if thread.read(cx).waiting_for_tool_confirmation() {
"Waiting for tool confirmation"
} else {
"Generating..."
})
.color(Color::Muted)
.size(LabelSize::Small)
.into()
})),
})
.child(
v_flex()
.bg(cx.theme().colors().editor_background)
.border_t_1()
.border_color(cx.theme().colors().border)
.p_2()
.gap_2()
.child(self.message_editor.clone())
.child(h_flex().justify_end().child(if self.send_task.is_some() {
IconButton::new("stop-generation", IconName::StopFilled)
.icon_color(Color::Error)
.style(ButtonStyle::Tinted(ui::TintColor::Error))
.tooltip(move |window, cx| {
Tooltip::for_action(
"Stop Generation",
&editor::actions::Cancel,
window,
cx,
)
})
.disabled(is_editor_empty)
.on_click(cx.listener(|this, _event, _, _| this.cancel()))
} else {
IconButton::new("send-message", IconName::Send)
.icon_color(Color::Accent)
.style(ButtonStyle::Filled)
.disabled(is_editor_empty)
.on_click({
let focus_handle = focus_handle.clone();
move |_event, window, cx| {
focus_handle.dispatch_action(&Chat, window, cx);
}
})
.when(!is_editor_empty, |button| {
button.tooltip(move |window, cx| {
Tooltip::for_action("Send", &Chat, window, cx)
})
})
.when(is_editor_empty, |button| {
button.tooltip(Tooltip::text("Type a message to submit"))
})
})),
)
}
}
fn user_message_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
let mut style = default_markdown_style(window, cx);
let mut text_style = window.text_style();
let theme_settings = ThemeSettings::get_global(cx);
let buffer_font = theme_settings.buffer_font.family.clone();
let buffer_font_size = TextSize::Small.rems(cx);
text_style.refine(&TextStyleRefinement {
font_family: Some(buffer_font),
font_size: Some(buffer_font_size.into()),
..Default::default()
});
style.base_text_style = text_style;
style
}
fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
let theme_settings = ThemeSettings::get_global(cx);
let colors = cx.theme().colors();
let ui_font_size = TextSize::Default.rems(cx);
let buffer_font_size = TextSize::Small.rems(cx);
let mut text_style = window.text_style();
let line_height = buffer_font_size * 1.75;
text_style.refine(&TextStyleRefinement {
font_family: Some(theme_settings.ui_font.family.clone()),
font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
font_features: Some(theme_settings.ui_font.features.clone()),
font_size: Some(ui_font_size.into()),
line_height: Some(line_height.into()),
color: Some(cx.theme().colors().text),
..Default::default()
});
MarkdownStyle {
base_text_style: text_style.clone(),
syntax: cx.theme().syntax().clone(),
selection_background_color: cx.theme().colors().element_selection_background,
code_block_overflow_x_scroll: true,
table_overflow_x_scroll: true,
heading_level_styles: Some(HeadingLevelStyles {
h1: Some(TextStyleRefinement {
font_size: Some(rems(1.15).into()),
..Default::default()
}),
h2: Some(TextStyleRefinement {
font_size: Some(rems(1.1).into()),
..Default::default()
}),
h3: Some(TextStyleRefinement {
font_size: Some(rems(1.05).into()),
..Default::default()
}),
h4: Some(TextStyleRefinement {
font_size: Some(rems(1.).into()),
..Default::default()
}),
h5: Some(TextStyleRefinement {
font_size: Some(rems(0.95).into()),
..Default::default()
}),
h6: Some(TextStyleRefinement {
font_size: Some(rems(0.875).into()),
..Default::default()
}),
}),
code_block: StyleRefinement {
padding: EdgesRefinement {
top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
},
background: Some(colors.editor_background.into()),
text: Some(TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
font_features: Some(theme_settings.buffer_font.features.clone()),
font_size: Some(buffer_font_size.into()),
..Default::default()
}),
..Default::default()
},
inline_code: TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
font_features: Some(theme_settings.buffer_font.features.clone()),
font_size: Some(buffer_font_size.into()),
background_color: Some(colors.editor_foreground.opacity(0.08)),
..Default::default()
},
link: TextStyleRefinement {
background_color: Some(colors.editor_foreground.opacity(0.025)),
underline: Some(UnderlineStyle {
color: Some(colors.text_accent.opacity(0.5)),
thickness: px(1.),
..Default::default()
}),
..Default::default()
},
link_callback: Some(Rc::new(move |_url, _cx| {
// todo!()
// if MentionLink::is_valid(url) {
// let colors = cx.theme().colors();
// Some(TextStyleRefinement {
// background_color: Some(colors.element_background),
// ..Default::default()
// })
// } else {
None
// }
})),
..Default::default()
}
}

View File

@@ -31,13 +31,7 @@ use workspace::{StatusItemView, Workspace, item::ItemHandle};
const GIT_OPERATION_DELAY: Duration = Duration::from_millis(0);
actions!(
activity_indicator,
[
/// Displays error messages from language servers in the status bar.
ShowErrorMessage
]
);
actions!(activity_indicator, [ShowErrorMessage]);
pub enum Event {
ShowStatus {

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings};
use assistant_tool::{AnyTool, ToolSource, ToolWorkingSet, UniqueToolName};
use assistant_tool::{Tool, ToolSource, ToolWorkingSet};
use collections::IndexMap;
use convert_case::{Case, Casing};
use fs::Fs;
@@ -72,7 +72,7 @@ impl AgentProfile {
&self.id
}
pub fn enabled_tools(&self, cx: &App) -> Vec<(UniqueToolName, AnyTool)> {
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
return Vec::new();
};
@@ -81,7 +81,7 @@ impl AgentProfile {
.read(cx)
.tools(cx)
.into_iter()
.filter(|(_, tool)| Self::is_enabled(settings, tool.source(), tool.name()))
.filter(|tool| Self::is_enabled(settings, tool.source(), tool.name()))
.collect()
}
@@ -108,7 +108,7 @@ impl AgentProfile {
#[cfg(test)]
mod tests {
use agent_settings::ContextServerPreset;
use assistant_tool::{Tool, ToolRegistry};
use assistant_tool::ToolRegistry;
use collections::IndexMap;
use gpui::SharedString;
use gpui::{AppContext, TestAppContext};
@@ -137,7 +137,7 @@ mod tests {
let mut enabled_tools = cx
.read(|cx| profile.enabled_tools(cx))
.into_iter()
.map(|(_, tool)| tool.name())
.map(|tool| tool.name())
.collect::<Vec<_>>();
enabled_tools.sort();
@@ -174,7 +174,7 @@ mod tests {
let mut enabled_tools = cx
.read(|cx| profile.enabled_tools(cx))
.into_iter()
.map(|(_, tool)| tool.name())
.map(|tool| tool.name())
.collect::<Vec<_>>();
enabled_tools.sort();
@@ -207,7 +207,7 @@ mod tests {
let mut enabled_tools = cx
.read(|cx| profile.enabled_tools(cx))
.into_iter()
.map(|(_, tool)| tool.name())
.map(|tool| tool.name())
.collect::<Vec<_>>();
enabled_tools.sort();
@@ -267,16 +267,10 @@ mod tests {
}
fn default_tool_set(cx: &mut TestAppContext) -> Entity<ToolWorkingSet> {
cx.new(|cx| {
cx.new(|_| {
let mut tool_set = ToolWorkingSet::default();
tool_set.insert(
Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")).into(),
cx,
);
tool_set.insert(
Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")).into(),
cx,
);
tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")));
tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")));
tool_set
})
}
@@ -296,8 +290,6 @@ mod tests {
}
impl Tool for FakeTool {
type Input = ();
fn name(&self) -> String {
self.name.clone()
}
@@ -316,17 +308,17 @@ mod tests {
unimplemented!()
}
fn needs_confirmation(&self, _input: &Self::Input, _cx: &App) -> bool {
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
unimplemented!()
}
fn ui_text(&self, _input: &Self::Input) -> String {
fn ui_text(&self, _input: &serde_json::Value) -> String {
unimplemented!()
}
fn run(
self: Arc<Self>,
_input: Self::Input,
_input: serde_json::Value,
_request: Arc<language_model::LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<assistant_tool::ActionLog>,

View File

@@ -29,8 +29,6 @@ impl ContextServerTool {
}
impl Tool for ContextServerTool {
type Input = serde_json::Value;
fn name(&self) -> String {
self.tool.name.clone()
}
@@ -49,7 +47,7 @@ impl Tool for ContextServerTool {
}
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
true
}
@@ -71,13 +69,13 @@ impl Tool for ContextServerTool {
})
}
fn ui_text(&self, _input: &Self::Input) -> String {
fn ui_text(&self, _input: &serde_json::Value) -> String {
format!("Run MCP tool `{}`", self.tool.name)
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,

View File

@@ -10,10 +10,10 @@ use crate::{
};
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, AnyTool, AnyToolCard, ToolWorkingSet};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use client::{ModelRequestUsage, RequestUsage};
use collections::HashMap;
use collections::{HashMap, HashSet};
use feature_flags::{self, FeatureFlagAppExt};
use futures::{FutureExt, StreamExt as _, future::Shared};
use git::repository::DiffType;
@@ -960,14 +960,13 @@ impl Thread {
model: Arc<dyn LanguageModel>,
) -> Vec<LanguageModelRequestTool> {
if model.supports_tools() {
self.profile
.enabled_tools(cx)
resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
.into_iter()
.filter_map(|(name, tool)| {
// Skip tools that cannot be supported
let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
Some(LanguageModelRequestTool {
name: name.into(),
name,
description: tool.description(),
input_schema,
})
@@ -2387,7 +2386,7 @@ impl Thread {
let tool_list = available_tools
.iter()
.map(|(name, tool)| format!("- {}: {}", name, tool.description()))
.map(|tool| format!("- {}: {}", tool.name(), tool.description()))
.collect::<Vec<_>>()
.join("\n");
@@ -2452,7 +2451,7 @@ impl Thread {
ui_text: impl Into<SharedString>,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
tool: AnyTool,
tool: Arc<dyn Tool>,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>,
@@ -2468,7 +2467,7 @@ impl Thread {
tool_use_id: LanguageModelToolUseId,
request: Arc<LanguageModelRequest>,
input: serde_json::Value,
tool: AnyTool,
tool: Arc<dyn Tool>,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>,
@@ -2607,7 +2606,7 @@ impl Thread {
.profile
.enabled_tools(cx)
.iter()
.map(|(name, _)| name.clone().into())
.map(|tool| tool.name())
.collect();
self.message_feedback.insert(message_id, feedback);
@@ -3145,6 +3144,85 @@ struct PendingCompletion {
_task: Task<()>,
}
/// Resolves tool name conflicts by ensuring all tool names are unique.
///
/// When multiple tools have the same name, this function applies the following rules:
/// 1. Native tools always keep their original name
/// 2. Context server tools get prefixed with their server ID and an underscore
/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
///
/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
let mut tool_name = tool.name();
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
tool_name
}
const MAX_TOOL_NAME_LENGTH: usize = 64;
let mut duplicated_tool_names = HashSet::default();
let mut seen_tool_names = HashSet::default();
for tool in tools {
let tool_name = resolve_tool_name(tool);
if seen_tool_names.contains(&tool_name) {
debug_assert!(
tool.source() != assistant_tool::ToolSource::Native,
"There are two built-in tools with the same name: {}",
tool_name
);
duplicated_tool_names.insert(tool_name);
} else {
seen_tool_names.insert(tool_name);
}
}
if duplicated_tool_names.is_empty() {
return tools
.into_iter()
.map(|tool| (resolve_tool_name(tool), tool.clone()))
.collect();
}
tools
.into_iter()
.filter_map(|tool| {
let mut tool_name = resolve_tool_name(tool);
if !duplicated_tool_names.contains(&tool_name) {
return Some((tool_name, tool.clone()));
}
match tool.source() {
assistant_tool::ToolSource::Native => {
// Built-in tools always keep their original name
Some((tool_name, tool.clone()))
}
assistant_tool::ToolSource::ContextServer { id } => {
// Context server tools are prefixed with the context server ID, and truncated if necessary
tool_name.insert(0, '_');
if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
let mut id = id.to_string();
id.truncate(len);
tool_name.insert_str(0, &id);
} else {
tool_name.insert_str(0, &id);
}
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
if seen_tool_names.contains(&tool_name) {
log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
None
} else {
Some((tool_name, tool.clone()))
}
}
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
@@ -3160,6 +3238,7 @@ mod tests {
use futures::future::BoxFuture;
use futures::stream::BoxStream;
use gpui::TestAppContext;
use icons::IconName;
use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
use language_model::{
LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
@@ -3804,6 +3883,148 @@ fn main() {{
});
}
#[gpui::test]
fn test_resolve_tool_name_conflicts() {
use assistant_tool::{Tool, ToolSource};
assert_resolve_tool_name_conflicts(
vec![
TestTool::new("tool1", ToolSource::Native),
TestTool::new("tool2", ToolSource::Native),
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
],
vec!["tool1", "tool2", "tool3"],
);
assert_resolve_tool_name_conflicts(
vec![
TestTool::new("tool1", ToolSource::Native),
TestTool::new("tool2", ToolSource::Native),
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
],
vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
);
assert_resolve_tool_name_conflicts(
vec![
TestTool::new("tool1", ToolSource::Native),
TestTool::new("tool2", ToolSource::Native),
TestTool::new("tool3", ToolSource::Native),
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
],
vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
);
// Test that tool with very long name is always truncated
assert_resolve_tool_name_conflicts(
vec![TestTool::new(
"tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
ToolSource::Native,
)],
vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
);
// Test deduplication of tools with very long names, in this case the mcp server name should be truncated
assert_resolve_tool_name_conflicts(
vec![
TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
TestTool::new(
"tool-with-very-very-very-long-name",
ToolSource::ContextServer {
id: "mcp-with-very-very-very-long-name".into(),
},
),
],
vec![
"tool-with-very-very-very-long-name",
"mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
],
);
fn assert_resolve_tool_name_conflicts(
tools: Vec<TestTool>,
expected: Vec<impl Into<String>>,
) {
let tools: Vec<Arc<dyn Tool>> = tools
.into_iter()
.map(|t| Arc::new(t) as Arc<dyn Tool>)
.collect();
let tools = resolve_tool_name_conflicts(&tools);
assert_eq!(tools.len(), expected.len());
for (i, expected_name) in expected.into_iter().enumerate() {
let expected_name = expected_name.into();
let actual_name = &tools[i].0;
assert_eq!(
actual_name, &expected_name,
"Expected '{}' got '{}' at index {}",
expected_name, actual_name, i
);
}
}
struct TestTool {
name: String,
source: ToolSource,
}
impl TestTool {
fn new(name: impl Into<String>, source: ToolSource) -> Self {
Self {
name: name.into(),
source,
}
}
}
impl Tool for TestTool {
fn name(&self) -> String {
self.name.clone()
}
fn icon(&self) -> IconName {
IconName::Ai
}
fn may_perform_edits(&self) -> bool {
false
}
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
true
}
fn source(&self) -> ToolSource {
self.source.clone()
}
fn description(&self) -> String {
"Test tool".to_string()
}
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Test tool".to_string()
}
fn run(
self: Arc<Self>,
_input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>,
_window: Option<AnyWindowHandle>,
_cx: &mut App,
) -> assistant_tool::ToolResult {
assistant_tool::ToolResult {
output: Task::ready(Err(anyhow::anyhow!("No content"))),
card: None,
}
}
}
}
// Helper to create a model that returns errors
enum TestError {
Overloaded,

View File

@@ -537,8 +537,8 @@ impl ThreadStore {
}
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.update(cx, |tool_working_set, cx| {
tool_working_set.remove(&tool_ids, cx);
tool_working_set.update(cx, |tool_working_set, _| {
tool_working_set.remove(&tool_ids);
});
}
}
@@ -569,18 +569,19 @@ impl ThreadStore {
.log_err()
{
let tool_ids = tool_working_set
.update(cx, |tool_working_set, cx| {
tool_working_set.extend(
response.tools.into_iter().map(|tool| {
Arc::new(ContextServerTool::new(
.update(cx, |tool_working_set, _| {
response
.tools
.into_iter()
.map(|tool| {
log::info!("registering context server tool: {:?}", tool.name);
tool_working_set.insert(Arc::new(ContextServerTool::new(
context_server_store.clone(),
server.id(),
tool,
))
.into()
}),
cx,
)
)))
})
.collect::<Vec<_>>()
})
.log_err();

View File

@@ -4,7 +4,7 @@ use crate::{
};
use anyhow::Result;
use assistant_tool::{
AnyTool, AnyToolCard, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
};
use collections::HashMap;
use futures::{FutureExt as _, future::Shared};
@@ -378,7 +378,7 @@ impl ToolUseState {
ui_text: impl Into<Arc<str>>,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
tool: AnyTool,
tool: Arc<dyn Tool>,
) {
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
let ui_text = ui_text.into();
@@ -533,7 +533,7 @@ pub struct Confirmation {
pub input: serde_json::Value,
pub ui_text: Arc<str>,
pub request: Arc<LanguageModelRequest>,
pub tool: AnyTool,
pub tool: Arc<dyn Tool>,
}
#[derive(Debug, Clone)]

View File

@@ -13,12 +13,10 @@ path = "src/agent_ui.rs"
doctest = false
[features]
test-support = [
"gpui/test-support",
"language/test-support",
]
test-support = ["gpui/test-support", "language/test-support"]
[dependencies]
acp.workspace = true
agent.workspace = true
agent_settings.workspace = true
anyhow.workspace = true

View File

@@ -1,7 +1,9 @@
use crate::context_picker::{ContextPicker, MentionLink};
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::message_editor::{extract_message_creases, insert_message_creases};
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
use crate::ui::{
AddedContext, AgentNotification, AgentNotificationEvent, AnimatedLabel, ContextPill,
};
use crate::{AgentPanel, ModelUsageContext};
use agent::{
ContextStore, LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, TextThreadStore,
@@ -1024,7 +1026,6 @@ impl ActiveThread {
}
}
ThreadEvent::MessageAdded(message_id) => {
self.clear_last_error();
if let Some(rendered_message) = self.thread.update(cx, |thread, cx| {
thread.message(*message_id).map(|message| {
RenderedMessage::from_segments(
@@ -1041,7 +1042,6 @@ impl ActiveThread {
cx.notify();
}
ThreadEvent::MessageEdited(message_id) => {
self.clear_last_error();
if let Some(index) = self.messages.iter().position(|id| id == message_id) {
if let Some(rendered_message) = self.thread.update(cx, |thread, cx| {
thread.message(*message_id).map(|message| {
@@ -1818,7 +1818,7 @@ impl ActiveThread {
.my_3()
.mx_5()
.when(is_generating_stale || message.is_hidden, |this| {
this.child(LoadingLabel::new("").size(LabelSize::Small))
this.child(AnimatedLabel::new("").size(LabelSize::Small))
})
});
@@ -2584,7 +2584,7 @@ impl ActiveThread {
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(LoadingLabel::new("Thinking").size(LabelSize::Small)),
.child(AnimatedLabel::new("Thinking").size(LabelSize::Small)),
)
.child(
h_flex()
@@ -3153,7 +3153,7 @@ impl ActiveThread {
.border_color(self.tool_card_border_color(cx))
.rounded_b_lg()
.child(
LoadingLabel::new("Waiting for Confirmation").size(LabelSize::Small)
AnimatedLabel::new("Waiting for Confirmation").size(LabelSize::Small)
)
.child(
h_flex()

View File

@@ -26,8 +26,8 @@ use project::{
};
use settings::{Settings, update_settings_file};
use ui::{
ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu,
Scrollbar, ScrollbarState, Switch, SwitchColor, Tooltip, prelude::*,
ContextMenu, Disclosure, ElevationIndex, Indicator, PopoverMenu, Scrollbar, ScrollbarState,
Switch, SwitchColor, Tooltip, prelude::*,
};
use util::ResultExt as _;
use workspace::Workspace;
@@ -172,29 +172,19 @@ impl AgentConfiguration {
.unwrap_or(false);
v_flex()
.when(is_expanded, |this| this.mb_2())
.child(
div()
.opacity(0.6)
.px_2()
.child(Divider::horizontal().color(DividerColor::Border)),
)
.py_2()
.gap_1p5()
.border_t_1()
.border_color(cx.theme().colors().border.opacity(0.6))
.child(
h_flex()
.map(|this| {
if is_expanded {
this.mt_2().mb_1()
} else {
this.my_2()
}
})
.w_full()
.gap_1()
.justify_between()
.child(
h_flex()
.id(provider_id_string.clone())
.cursor_pointer()
.px_2()
.py_0p5()
.w_full()
.justify_between()
@@ -257,16 +247,12 @@ impl AgentConfiguration {
)
}),
)
.child(
div()
.px_2()
.when(is_expanded, |parent| match configuration_view {
Some(configuration_view) => parent.child(configuration_view),
None => parent.child(Label::new(format!(
"No configuration view for {provider_name}",
))),
}),
)
.when(is_expanded, |parent| match configuration_view {
Some(configuration_view) => parent.child(configuration_view),
None => parent.child(Label::new(format!(
"No configuration view for {provider_name}",
))),
})
}
fn render_provider_configuration_section(
@@ -276,11 +262,12 @@ impl AgentConfiguration {
let providers = LanguageModelRegistry::read_global(cx).providers();
v_flex()
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.border_b_1()
.border_color(cx.theme().colors().border)
.child(
v_flex()
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.pb_0()
.mb_2p5()
.gap_0p5()
.child(Headline::new("LLM Providers"))
@@ -289,15 +276,10 @@ impl AgentConfiguration {
.color(Color::Muted),
),
)
.child(
div()
.pl(DynamicSpacing::Base08.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.children(
providers.into_iter().map(|provider| {
self.render_provider_configuration_block(&provider, cx)
}),
),
.children(
providers
.into_iter()
.map(|provider| self.render_provider_configuration_block(&provider, cx)),
)
}

View File

@@ -379,14 +379,6 @@ impl ConfigureContextServerModal {
};
self.state = State::Waiting;
let existing_server = self.context_server_store.read(cx).get_running_server(&id);
if existing_server.is_some() {
self.context_server_store.update(cx, |store, cx| {
store.stop_server(&id, cx).log_err();
});
}
let wait_for_context_server_task =
wait_for_context_server(&self.context_server_store, id.clone(), cx);
cx.spawn({
@@ -407,21 +399,13 @@ impl ConfigureContextServerModal {
})
.detach();
let settings_changed =
ProjectSettings::get_global(cx).context_servers.get(&id.0) != Some(&settings);
if settings_changed {
// When we write the settings to the file, the context server will be restarted.
workspace.update(cx, |workspace, cx| {
let fs = workspace.app_state().fs.clone();
update_settings_file::<ProjectSettings>(fs.clone(), cx, |project_settings, _| {
project_settings.context_servers.insert(id.0, settings);
});
// When we write the settings to the file, the context server will be restarted.
workspace.update(cx, |workspace, cx| {
let fs = workspace.app_state().fs.clone();
update_settings_file::<ProjectSettings>(fs.clone(), cx, |project_settings, _| {
project_settings.context_servers.insert(id.0, settings);
});
} else if let Some(existing_server) = existing_server {
self.context_server_store
.update(cx, |store, cx| store.start_server(existing_server, cx));
}
});
}
fn cancel(&mut self, _: &menu::Cancel, cx: &mut Context<Self>) {

View File

@@ -7,6 +7,7 @@ use std::time::Duration;
use db::kvp::{Dismissable, KEY_VALUE_STORE};
use serde::{Deserialize, Serialize};
use crate::NewGeminiThread;
use crate::language_model_selector::ToggleModelSelector;
use crate::{
AddContextServer, AgentDiffPane, ContinueThread, ContinueWithBurnMode,
@@ -109,6 +110,12 @@ pub fn init(cx: &mut App) {
panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx));
}
})
.register_action(|workspace, _: &NewGeminiThread, window, cx| {
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
workspace.focus_panel::<AgentPanel>(window, cx);
panel.update(cx, |panel, cx| panel.new_gemini_thread(window, cx));
}
})
.register_action(|workspace, action: &OpenRulesLibrary, window, cx| {
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
workspace.focus_panel::<AgentPanel>(window, cx);
@@ -125,6 +132,7 @@ pub fn init(cx: &mut App) {
let thread = thread.read(cx).thread().clone();
AgentDiffPane::deploy_in_workspace(thread, workspace, window, cx);
}
ActiveView::AcpThread { .. } => todo!(),
ActiveView::TextThread { .. }
| ActiveView::History
| ActiveView::Configuration => {}
@@ -188,6 +196,9 @@ enum ActiveView {
message_editor: Entity<MessageEditor>,
_subscriptions: Vec<gpui::Subscription>,
},
AcpThread {
thread_view: Entity<acp::AcpThreadView>,
},
TextThread {
context_editor: Entity<TextThreadEditor>,
title_editor: Entity<Editor>,
@@ -207,7 +218,9 @@ enum WhichFontSize {
impl ActiveView {
pub fn which_font_size_used(&self) -> WhichFontSize {
match self {
ActiveView::Thread { .. } | ActiveView::History => WhichFontSize::AgentFont,
ActiveView::Thread { .. } | ActiveView::AcpThread { .. } | ActiveView::History => {
WhichFontSize::AgentFont
}
ActiveView::TextThread { .. } => WhichFontSize::BufferFont,
ActiveView::Configuration => WhichFontSize::None,
}
@@ -238,6 +251,9 @@ impl ActiveView {
thread.scroll_to_bottom(cx);
});
}
ActiveView::AcpThread { .. } => {
// todo!
}
ActiveView::TextThread { .. }
| ActiveView::History
| ActiveView::Configuration => {}
@@ -653,6 +669,9 @@ impl AgentPanel {
.clone()
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx));
}
ActiveView::AcpThread { .. } => {
// todo!
}
ActiveView::TextThread { .. }
| ActiveView::History
| ActiveView::Configuration => {}
@@ -733,6 +752,9 @@ impl AgentPanel {
ActiveView::Thread { thread, .. } => {
thread.update(cx, |thread, cx| thread.cancel_last_completion(window, cx));
}
ActiveView::AcpThread { thread_view, .. } => {
thread_view.update(cx, |thread_element, _cx| thread_element.cancel());
}
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {}
}
}
@@ -740,6 +762,10 @@ impl AgentPanel {
fn active_message_editor(&self) -> Option<&Entity<MessageEditor>> {
match &self.active_view {
ActiveView::Thread { message_editor, .. } => Some(message_editor),
ActiveView::AcpThread { .. } => {
// todo!
None
}
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => None,
}
}
@@ -862,6 +888,19 @@ impl AgentPanel {
context_editor.focus_handle(cx).focus(window);
}
fn new_gemini_thread(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let project = self.project.clone();
cx.spawn_in(window, async move |this, cx| {
let thread_view =
cx.new_window_entity(|window, cx| acp::AcpThreadView::new(project, window, cx))?;
this.update_in(cx, |this, window, cx| {
this.set_active_view(ActiveView::AcpThread { thread_view }, window, cx);
})
})
.detach();
}
fn deploy_rules_library(
&mut self,
action: &OpenRulesLibrary,
@@ -994,6 +1033,7 @@ impl AgentPanel {
cx,
)
});
let message_editor = cx.new(|cx| {
MessageEditor::new(
self.fs.clone(),
@@ -1018,6 +1058,7 @@ impl AgentPanel {
pub fn go_back(&mut self, _: &workspace::GoBack, window: &mut Window, cx: &mut Context<Self>) {
match self.active_view {
ActiveView::Configuration | ActiveView::History => {
// todo! check go back works correctly
if let Some(previous_view) = self.previous_view.take() {
self.active_view = previous_view;
@@ -1025,6 +1066,9 @@ impl AgentPanel {
ActiveView::Thread { message_editor, .. } => {
message_editor.focus_handle(cx).focus(window);
}
ActiveView::AcpThread { .. } => {
todo!()
}
ActiveView::TextThread { context_editor, .. } => {
context_editor.focus_handle(cx).focus(window);
}
@@ -1144,6 +1188,7 @@ impl AgentPanel {
})
.log_err();
}
ActiveView::AcpThread { .. } => todo!(),
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {}
}
}
@@ -1197,6 +1242,9 @@ impl AgentPanel {
)
.detach_and_log_err(cx);
}
ActiveView::AcpThread { .. } => {
todo!()
}
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {}
}
}
@@ -1231,6 +1279,10 @@ impl AgentPanel {
pub(crate) fn active_thread(&self, cx: &App) -> Option<Entity<Thread>> {
match &self.active_view {
ActiveView::Thread { thread, .. } => Some(thread.read(cx).thread().clone()),
ActiveView::AcpThread { .. } => {
// todo!
None
}
_ => None,
}
}
@@ -1336,6 +1388,9 @@ impl AgentPanel {
});
}
}
ActiveView::AcpThread { .. } => {
// todo!
}
_ => {}
}
@@ -1351,6 +1406,9 @@ impl AgentPanel {
}
})
}
ActiveView::AcpThread { .. } => {
// todo! push history entry
}
_ => {}
}
@@ -1437,6 +1495,7 @@ impl Focusable for AgentPanel {
fn focus_handle(&self, cx: &App) -> FocusHandle {
match &self.active_view {
ActiveView::Thread { message_editor, .. } => message_editor.focus_handle(cx),
ActiveView::AcpThread { thread_view, .. } => thread_view.focus_handle(cx),
ActiveView::History => self.history.focus_handle(cx),
ActiveView::TextThread { context_editor, .. } => context_editor.focus_handle(cx),
ActiveView::Configuration => {
@@ -1593,6 +1652,9 @@ impl AgentPanel {
.into_any_element(),
}
}
ActiveView::AcpThread { thread_view } => Label::new(thread_view.read(cx).title(cx))
.truncate()
.into_any_element(),
ActiveView::TextThread {
title_editor,
context_editor,
@@ -1727,6 +1789,10 @@ impl AgentPanel {
let active_thread = match &self.active_view {
ActiveView::Thread { thread, .. } => Some(thread.read(cx).thread().clone()),
ActiveView::AcpThread { .. } => {
// todo!
None
}
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => None,
};
@@ -1755,6 +1821,7 @@ impl AgentPanel {
menu = menu
.action("New Thread", NewThread::default().boxed_clone())
.action("New Text Thread", NewTextThread.boxed_clone())
.action("New Gemini Thread", NewGeminiThread.boxed_clone())
.when_some(active_thread, |this, active_thread| {
let thread = active_thread.read(cx);
if !thread.is_empty() {
@@ -1893,6 +1960,10 @@ impl AgentPanel {
message_editor,
..
} => (thread.read(cx), message_editor.read(cx)),
ActiveView::AcpThread { .. } => {
// todo!
return None;
}
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {
return None;
}
@@ -2031,6 +2102,10 @@ impl AgentPanel {
return false;
}
}
ActiveView::AcpThread { .. } => {
// todo!
return false;
}
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {
return false;
}
@@ -2615,6 +2690,10 @@ impl AgentPanel {
) -> Option<AnyElement> {
let active_thread = match &self.active_view {
ActiveView::Thread { thread, .. } => thread,
ActiveView::AcpThread { .. } => {
// todo!
return None;
}
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {
return None;
}
@@ -2961,6 +3040,9 @@ impl AgentPanel {
.detach();
});
}
ActiveView::AcpThread { .. } => {
unimplemented!()
}
ActiveView::TextThread { context_editor, .. } => {
context_editor.update(cx, |context_editor, cx| {
TextThreadEditor::insert_dragged_files(
@@ -3034,6 +3116,9 @@ impl Render for AgentPanel {
});
this.continue_conversation(window, cx);
}
ActiveView::AcpThread { .. } => {
todo!()
}
ActiveView::TextThread { .. }
| ActiveView::History
| ActiveView::Configuration => {}
@@ -3075,6 +3160,12 @@ impl Render for AgentPanel {
})
.child(h_flex().child(message_editor.clone()))
.child(self.render_drag_target(cx)),
ActiveView::AcpThread { thread_view, .. } => parent
.relative()
.child(thread_view.clone())
// todo!
// .child(h_flex().child(self.message_editor.clone()))
.child(self.render_drag_target(cx)),
ActiveView::History => parent.child(self.history.clone()),
ActiveView::TextThread {
context_editor,

View File

@@ -54,76 +54,42 @@ pub use ui::preview::{all_agent_previews, get_agent_preview};
actions!(
agent,
[
/// Creates a new text-based conversation thread.
NewTextThread,
/// Toggles the context picker interface for adding files, symbols, or other context.
NewGeminiThread,
ToggleContextPicker,
/// Toggles the navigation menu for switching between threads and views.
ToggleNavigationMenu,
/// Toggles the options menu for agent settings and preferences.
ToggleOptionsMenu,
/// Deletes the recently opened thread from history.
DeleteRecentlyOpenThread,
/// Toggles the profile selector for switching between agent profiles.
ToggleProfileSelector,
/// Removes all added context from the current conversation.
RemoveAllContext,
/// Expands the message editor to full size.
ExpandMessageEditor,
/// Opens the conversation history view.
OpenHistory,
/// Adds a context server to the configuration.
AddContextServer,
/// Removes the currently selected thread.
RemoveSelectedThread,
/// Starts a chat conversation with the agent.
Chat,
/// Starts a chat conversation with follow-up enabled.
ChatWithFollow,
/// Cycles to the next inline assist suggestion.
CycleNextInlineAssist,
/// Cycles to the previous inline assist suggestion.
CyclePreviousInlineAssist,
/// Moves focus up in the interface.
FocusUp,
/// Moves focus down in the interface.
FocusDown,
/// Moves focus left in the interface.
FocusLeft,
/// Moves focus right in the interface.
FocusRight,
/// Removes the currently focused context item.
RemoveFocusedContext,
/// Accepts the suggested context item.
AcceptSuggestedContext,
/// Opens the active thread as a markdown file.
OpenActiveThreadAsMarkdown,
/// Opens the agent diff view to review changes.
OpenAgentDiff,
/// Keeps the current suggestion or change.
Keep,
/// Rejects the current suggestion or change.
Reject,
/// Rejects all suggestions or changes.
RejectAll,
/// Keeps all suggestions or changes.
KeepAll,
/// Follows the agent's suggestions.
Follow,
/// Resets the trial upsell notification.
ResetTrialUpsell,
/// Resets the trial end upsell notification.
ResetTrialEndUpsell,
/// Continues the current thread.
ContinueThread,
/// Continues the thread with burn mode enabled.
ContinueWithBurnMode,
/// Toggles burn mode for faster responses.
ToggleBurnMode,
]
);
/// Creates a new conversation thread, optionally based on an existing thread.
#[derive(Default, Clone, PartialEq, Deserialize, JsonSchema, Action)]
#[action(namespace = agent)]
#[serde(deny_unknown_fields)]
@@ -132,7 +98,6 @@ pub struct NewThread {
from_thread_id: Option<ThreadId>,
}
/// Opens the profile management interface for configuring agent tools and settings.
#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)]
#[action(namespace = agent)]
#[serde(deny_unknown_fields)]

View File

@@ -18,7 +18,6 @@ use ui::{ListItem, ListItemSpacing, prelude::*};
actions!(
agent,
[
/// Toggles the language model selector dropdown.
#[action(deprecated_aliases = ["assistant::ToggleModelSelector", "assistant2::ToggleModelSelector"])]
ToggleModelSelector
]

View File

@@ -47,13 +47,14 @@ use ui::{
};
use util::ResultExt as _;
use workspace::{CollaboratorId, Workspace};
use zed_actions::agent::Chat;
use zed_llm_client::CompletionIntent;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention};
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::profile_selector::ProfileSelector;
use crate::{
ActiveThread, AgentDiffPane, Chat, ChatWithFollow, ExpandMessageEditor, Follow, KeepAll,
ActiveThread, AgentDiffPane, ChatWithFollow, ExpandMessageEditor, Follow, KeepAll,
ModelUsageContext, NewThread, OpenAgentDiff, RejectAll, RemoveAllContext, ToggleBurnMode,
ToggleContextPicker, ToggleProfileSelector, register_agent_preview,
};

View File

@@ -85,24 +85,16 @@ use assistant_context::{
actions!(
assistant,
[
/// Sends the current message to the assistant.
Assist,
/// Confirms and executes the entered slash command.
ConfirmCommand,
/// Copies code from the assistant's response to the clipboard.
CopyCode,
/// Cycles between user and assistant message roles.
CycleMessageRole,
/// Inserts the selected text into the active editor.
InsertIntoEditor,
/// Quotes the current selection in the assistant conversation.
QuoteSelection,
/// Splits the conversation at the current cursor position.
Split,
]
);
/// Inserts files that were dragged and dropped into the assistant conversation.
#[derive(PartialEq, Clone, Action)]
#[action(namespace = assistant, no_json, no_register)]
pub enum InsertDraggedFiles {

View File

@@ -1,5 +1,5 @@
use agent::{Thread, ThreadEvent};
use assistant_tool::{AnyTool, ToolSource};
use assistant_tool::{Tool, ToolSource};
use collections::HashMap;
use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window};
use language_model::{LanguageModel, LanguageModelToolSchemaFormat};
@@ -7,7 +7,7 @@ use std::sync::Arc;
use ui::prelude::*;
pub struct IncompatibleToolsState {
cache: HashMap<LanguageModelToolSchemaFormat, Vec<AnyTool>>,
cache: HashMap<LanguageModelToolSchemaFormat, Vec<Arc<dyn Tool>>>,
thread: Entity<Thread>,
_thread_subscription: Subscription,
}
@@ -29,7 +29,11 @@ impl IncompatibleToolsState {
}
}
pub fn incompatible_tools(&mut self, model: &Arc<dyn LanguageModel>, cx: &App) -> &[AnyTool] {
pub fn incompatible_tools(
&mut self,
model: &Arc<dyn LanguageModel>,
cx: &App,
) -> &[Arc<dyn Tool>] {
self.cache
.entry(model.tool_input_format())
.or_insert_with(|| {
@@ -38,15 +42,15 @@ impl IncompatibleToolsState {
.profile()
.enabled_tools(cx)
.iter()
.filter(|(_, tool)| tool.input_schema(model.tool_input_format()).is_err())
.map(|(_, tool)| tool.clone())
.filter(|tool| tool.input_schema(model.tool_input_format()).is_err())
.cloned()
.collect()
})
}
}
pub struct IncompatibleToolsTooltip {
pub incompatible_tools: Vec<AnyTool>,
pub incompatible_tools: Vec<Arc<dyn Tool>>,
}
impl Render for IncompatibleToolsTooltip {

View File

@@ -1,4 +1,5 @@
mod agent_notification;
mod animated_label;
mod burn_mode_tooltip;
mod context_pill;
mod onboarding_modal;
@@ -6,6 +7,7 @@ pub mod preview;
mod upsell;
pub use agent_notification::*;
pub use animated_label::*;
pub use burn_mode_tooltip::*;
pub use context_pill::*;
pub use onboarding_modal::*;

View File

@@ -1,24 +1,24 @@
use crate::prelude::*;
use gpui::{Animation, AnimationExt, FontWeight, pulsating_between};
use std::time::Duration;
use ui::prelude::*;
#[derive(IntoElement)]
pub struct LoadingLabel {
pub struct AnimatedLabel {
base: Label,
text: SharedString,
}
impl LoadingLabel {
impl AnimatedLabel {
pub fn new(text: impl Into<SharedString>) -> Self {
let text = text.into();
LoadingLabel {
AnimatedLabel {
base: Label::new(text.clone()),
text,
}
}
}
impl LabelCommon for LoadingLabel {
impl LabelCommon for AnimatedLabel {
fn size(mut self, size: LabelSize) -> Self {
self.base = self.base.size(size);
self
@@ -80,14 +80,14 @@ impl LabelCommon for LoadingLabel {
}
}
impl RenderOnce for LoadingLabel {
impl RenderOnce for AnimatedLabel {
fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
let text = self.text.clone();
self.base
.color(Color::Muted)
.with_animations(
"loading_label",
"animated-label",
vec![
Animation::new(Duration::from_secs(1)),
Animation::new(Duration::from_secs(1)).repeat(),

View File

@@ -22,7 +22,6 @@ gpui.workspace = true
icons.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
parking_lot.workspace = true
project.workspace = true
regex.workspace = true

View File

@@ -4,19 +4,25 @@ mod tool_registry;
mod tool_schema;
mod tool_working_set;
use std::{fmt, fmt::Debug, fmt::Formatter, ops::Deref, sync::Arc};
use std::fmt;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::ops::Deref;
use std::sync::Arc;
use anyhow::Result;
use gpui::{
AnyElement, AnyWindowHandle, App, Context, Entity, IntoElement, SharedString, Task, WeakEntity,
Window,
};
use gpui::AnyElement;
use gpui::AnyWindowHandle;
use gpui::Context;
use gpui::IntoElement;
use gpui::Window;
use gpui::{App, Entity, SharedString, Task, WeakEntity};
use icons::IconName;
use language_model::{
LanguageModel, LanguageModelImage, LanguageModelRequest, LanguageModelToolSchemaFormat,
};
use language_model::LanguageModel;
use language_model::LanguageModelImage;
use language_model::LanguageModelRequest;
use language_model::LanguageModelToolSchemaFormat;
use project::Project;
use serde::de::DeserializeOwned;
use workspace::Workspace;
pub use crate::action_log::*;
@@ -193,10 +199,7 @@ pub enum ToolSource {
}
/// A tool that can be used by a language model.
pub trait Tool: Send + Sync + 'static {
/// The input type that is accepted by the tool.
type Input: DeserializeOwned;
pub trait Tool: 'static + Send + Sync {
/// Returns the name of the tool.
fn name(&self) -> String;
@@ -213,7 +216,7 @@ pub trait Tool: Send + Sync + 'static {
/// Returns true if the tool needs the users's confirmation
/// before having permission to run.
fn needs_confirmation(&self, input: &Self::Input, cx: &App) -> bool;
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
/// Returns true if the tool may perform edits.
fn may_perform_edits(&self) -> bool;
@@ -224,18 +227,18 @@ pub trait Tool: Send + Sync + 'static {
}
/// Returns markdown to be displayed in the UI for this tool.
fn ui_text(&self, input: &Self::Input) -> String;
fn ui_text(&self, input: &serde_json::Value) -> String;
/// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
/// (so information may be missing).
fn still_streaming_ui_text(&self, input: &Self::Input) -> String {
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
self.ui_text(input)
}
/// Runs the tool with the provided input.
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
@@ -255,199 +258,7 @@ pub trait Tool: Send + Sync + 'static {
}
}
#[derive(Clone)]
pub struct AnyTool {
inner: Arc<dyn ErasedTool>,
}
/// Copy of `Tool` where the Input type is erased.
trait ErasedTool: Send + Sync {
fn name(&self) -> String;
fn description(&self) -> String;
fn icon(&self) -> IconName;
fn source(&self) -> ToolSource;
fn may_perform_edits(&self) -> bool;
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
fn ui_text(&self, input: &serde_json::Value) -> String;
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String;
fn run(
&self,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult;
fn deserialize_card(
&self,
output: serde_json::Value,
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyToolCard>;
}
struct ErasedToolWrapper<T: Tool> {
tool: Arc<T>,
}
impl<T: Tool> ErasedTool for ErasedToolWrapper<T> {
fn name(&self) -> String {
self.tool.name()
}
fn description(&self) -> String {
self.tool.description()
}
fn icon(&self) -> IconName {
self.tool.icon()
}
fn source(&self) -> ToolSource {
self.tool.source()
}
fn may_perform_edits(&self) -> bool {
self.tool.may_perform_edits()
}
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool {
match serde_json::from_value::<T::Input>(input.clone()) {
Ok(parsed_input) => self.tool.needs_confirmation(&parsed_input, cx),
Err(_) => true,
}
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
self.tool.input_schema(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<T::Input>(input.clone()) {
Ok(parsed_input) => self.tool.ui_text(&parsed_input),
Err(_) => "Invalid input".to_string(),
}
}
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<T::Input>(input.clone()) {
Ok(parsed_input) => self.tool.still_streaming_ui_text(&parsed_input),
Err(_) => "Invalid input".to_string(),
}
}
fn run(
&self,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
match serde_json::from_value::<T::Input>(input) {
Ok(parsed_input) => self.tool.clone().run(
parsed_input,
request,
project,
action_log,
model,
window,
cx,
),
Err(err) => ToolResult::from(Task::ready(Err(err.into()))),
}
}
fn deserialize_card(
&self,
output: serde_json::Value,
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyToolCard> {
self.tool
.clone()
.deserialize_card(output, project, window, cx)
}
}
impl<T: Tool> From<Arc<T>> for AnyTool {
fn from(tool: Arc<T>) -> Self {
Self {
inner: Arc::new(ErasedToolWrapper { tool }),
}
}
}
impl AnyTool {
pub fn name(&self) -> String {
self.inner.name()
}
pub fn description(&self) -> String {
self.inner.description()
}
pub fn icon(&self) -> IconName {
self.inner.icon()
}
pub fn source(&self) -> ToolSource {
self.inner.source()
}
pub fn may_perform_edits(&self) -> bool {
self.inner.may_perform_edits()
}
pub fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool {
self.inner.needs_confirmation(input, cx)
}
pub fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
self.inner.input_schema(format)
}
pub fn ui_text(&self, input: &serde_json::Value) -> String {
self.inner.ui_text(input)
}
pub fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
self.inner.still_streaming_ui_text(input)
}
pub fn run(
&self,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
self.inner
.run(input, request, project, action_log, model, window, cx)
}
pub fn deserialize_card(
&self,
output: serde_json::Value,
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyToolCard> {
self.inner.deserialize_card(output, project, window, cx)
}
}
impl Debug for AnyTool {
impl Debug for dyn Tool {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Tool").field("name", &self.name()).finish()
}

View File

@@ -6,7 +6,7 @@ use gpui::Global;
use gpui::{App, ReadGlobal};
use parking_lot::RwLock;
use crate::{AnyTool, Tool};
use crate::Tool;
#[derive(Default, Deref, DerefMut)]
struct GlobalToolRegistry(Arc<ToolRegistry>);
@@ -15,7 +15,7 @@ impl Global for GlobalToolRegistry {}
#[derive(Default)]
struct ToolRegistryState {
tools: HashMap<Arc<str>, AnyTool>,
tools: HashMap<Arc<str>, Arc<dyn Tool>>,
}
#[derive(Default)]
@@ -48,7 +48,7 @@ impl ToolRegistry {
pub fn register_tool(&self, tool: impl Tool) {
let mut state = self.state.write();
let tool_name: Arc<str> = tool.name().into();
state.tools.insert(tool_name, Arc::new(tool).into());
state.tools.insert(tool_name, Arc::new(tool));
}
/// Unregisters the provided [`Tool`].
@@ -63,12 +63,12 @@ impl ToolRegistry {
}
/// Returns the list of tools in the registry.
pub fn tools(&self) -> Vec<AnyTool> {
pub fn tools(&self) -> Vec<Arc<dyn Tool>> {
self.state.read().tools.values().cloned().collect()
}
/// Returns the [`Tool`] with the given name.
pub fn tool(&self, name: &str) -> Option<AnyTool> {
pub fn tool(&self, name: &str) -> Option<Arc<dyn Tool>> {
self.state.read().tools.get(name).cloned()
}
}

View File

@@ -1,77 +1,39 @@
use std::borrow::Borrow;
use std::sync::Arc;
use crate::{AnyTool, ToolRegistry, ToolSource};
use collections::{HashMap, HashSet, IndexMap};
use gpui::{App, SharedString};
use util::debug_panic;
use collections::{HashMap, IndexMap};
use gpui::App;
use crate::{Tool, ToolRegistry, ToolSource};
#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
pub struct ToolId(usize);
/// A unique identifier for a tool within a working set.
#[derive(Clone, PartialEq, Eq, Hash, Default)]
pub struct UniqueToolName(SharedString);
impl Borrow<str> for UniqueToolName {
fn borrow(&self) -> &str {
&self.0
}
}
impl From<String> for UniqueToolName {
fn from(value: String) -> Self {
UniqueToolName(SharedString::new(value))
}
}
impl Into<String> for UniqueToolName {
fn into(self) -> String {
self.0.into()
}
}
impl std::fmt::Debug for UniqueToolName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::fmt::Display for UniqueToolName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0.as_ref())
}
}
/// A working set of tools for use in one instance of the Assistant Panel.
#[derive(Default)]
pub struct ToolWorkingSet {
context_server_tools_by_id: HashMap<ToolId, AnyTool>,
context_server_tools_by_name: HashMap<UniqueToolName, AnyTool>,
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
next_tool_id: ToolId,
}
impl ToolWorkingSet {
pub fn tool(&self, name: &str, cx: &App) -> Option<AnyTool> {
pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
self.context_server_tools_by_name
.get(name)
.cloned()
.or_else(|| ToolRegistry::global(cx).tool(name))
}
pub fn tools(&self, cx: &App) -> Vec<(UniqueToolName, AnyTool)> {
let mut tools = ToolRegistry::global(cx)
.tools()
.into_iter()
.map(|tool| (UniqueToolName(tool.name().into()), tool))
.collect::<Vec<_>>();
tools.extend(self.context_server_tools_by_name.clone());
pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let mut tools = ToolRegistry::global(cx).tools();
tools.extend(self.context_server_tools_by_id.values().cloned());
tools
}
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<AnyTool>> {
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
let mut tools_by_source = IndexMap::default();
for (_, tool) in self.tools(cx) {
for tool in self.tools(cx) {
tools_by_source
.entry(tool.source())
.or_insert_with(Vec::new)
@@ -87,330 +49,27 @@ impl ToolWorkingSet {
tools_by_source
}
pub fn insert(&mut self, tool: AnyTool, cx: &App) -> ToolId {
let tool_id = self.register_tool(tool);
self.tools_changed(cx);
tool_id
}
pub fn extend(&mut self, tools: impl Iterator<Item = AnyTool>, cx: &App) -> Vec<ToolId> {
let ids = tools.map(|tool| self.register_tool(tool)).collect();
self.tools_changed(cx);
ids
}
pub fn remove(&mut self, tool_ids_to_remove: &[ToolId], cx: &App) {
self.context_server_tools_by_id
.retain(|id, _| !tool_ids_to_remove.contains(id));
self.tools_changed(cx);
}
fn register_tool(&mut self, tool: AnyTool) -> ToolId {
pub fn insert(&mut self, tool: Arc<dyn Tool>) -> ToolId {
let tool_id = self.next_tool_id;
self.next_tool_id.0 += 1;
self.context_server_tools_by_id
.insert(tool_id, tool.clone());
self.tools_changed();
tool_id
}
fn tools_changed(&mut self, cx: &App) {
self.context_server_tools_by_name = resolve_context_server_tool_name_conflicts(
&self
.context_server_tools_by_id
pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) {
self.context_server_tools_by_id
.retain(|id, _| !tool_ids_to_remove.contains(id));
self.tools_changed();
}
fn tools_changed(&mut self) {
self.context_server_tools_by_name.clear();
self.context_server_tools_by_name.extend(
self.context_server_tools_by_id
.values()
.cloned()
.collect::<Vec<_>>(),
&ToolRegistry::global(cx).tools(),
.map(|tool| (tool.name(), tool.clone())),
);
}
}
fn resolve_context_server_tool_name_conflicts(
context_server_tools: &[AnyTool],
native_tools: &[AnyTool],
) -> HashMap<UniqueToolName, AnyTool> {
fn resolve_tool_name(tool: &AnyTool) -> String {
let mut tool_name = tool.name();
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
tool_name
}
const MAX_TOOL_NAME_LENGTH: usize = 64;
let mut duplicated_tool_names = HashSet::default();
let mut seen_tool_names = HashSet::default();
seen_tool_names.extend(native_tools.iter().map(|tool| tool.name()));
for tool in context_server_tools {
let tool_name = resolve_tool_name(tool);
if seen_tool_names.contains(&tool_name) {
debug_assert!(
tool.source() != ToolSource::Native,
"Expected MCP tool but got a native tool: {}",
tool_name
);
duplicated_tool_names.insert(tool_name);
} else {
seen_tool_names.insert(tool_name);
}
}
if duplicated_tool_names.is_empty() {
return context_server_tools
.into_iter()
.map(|tool| (resolve_tool_name(tool).into(), tool.clone()))
.collect();
}
context_server_tools
.into_iter()
.filter_map(|tool| {
let mut tool_name = resolve_tool_name(tool);
if !duplicated_tool_names.contains(&tool_name) {
return Some((tool_name.into(), tool.clone()));
}
match tool.source() {
ToolSource::Native => {
debug_panic!("Expected MCP tool but got a native tool: {}", tool_name);
// Built-in tools always keep their original name
Some((tool_name.into(), tool.clone()))
}
ToolSource::ContextServer { id } => {
// Context server tools are prefixed with the context server ID, and truncated if necessary
tool_name.insert(0, '_');
if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
let mut id = id.to_string();
id.truncate(len);
tool_name.insert_str(0, &id);
} else {
tool_name.insert_str(0, &id);
}
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
if seen_tool_names.contains(&tool_name) {
log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
None
} else {
Some((tool_name.into(), tool.clone()))
}
}
}
})
.collect()
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use gpui::{AnyWindowHandle, Entity, Task, TestAppContext};
use language_model::{LanguageModel, LanguageModelRequest};
use project::Project;
use crate::{ActionLog, Tool, ToolResult};
use super::*;
#[gpui::test]
fn test_unique_tool_names(cx: &mut TestAppContext) {
fn assert_tool(
tool_working_set: &ToolWorkingSet,
unique_name: &str,
expected_name: &str,
expected_source: ToolSource,
cx: &App,
) {
let tool = tool_working_set.tool(unique_name, cx).unwrap();
assert_eq!(tool.name(), expected_name);
assert_eq!(tool.source(), expected_source);
}
let tool_registry = cx.update(ToolRegistry::default_global);
tool_registry.register_tool(TestTool::new("tool1", ToolSource::Native));
tool_registry.register_tool(TestTool::new("tool2", ToolSource::Native));
let mut tool_working_set = ToolWorkingSet::default();
cx.update(|cx| {
tool_working_set.extend(
vec![
Arc::new(TestTool::new(
"tool2",
ToolSource::ContextServer { id: "mcp-1".into() },
))
.into(),
Arc::new(TestTool::new(
"tool2",
ToolSource::ContextServer { id: "mcp-2".into() },
))
.into(),
]
.into_iter(),
cx,
);
});
cx.update(|cx| {
assert_tool(&tool_working_set, "tool1", "tool1", ToolSource::Native, cx);
assert_tool(&tool_working_set, "tool2", "tool2", ToolSource::Native, cx);
assert_tool(
&tool_working_set,
"mcp-1_tool2",
"tool2",
ToolSource::ContextServer { id: "mcp-1".into() },
cx,
);
assert_tool(
&tool_working_set,
"mcp-2_tool2",
"tool2",
ToolSource::ContextServer { id: "mcp-2".into() },
cx,
);
})
}
#[gpui::test]
fn test_resolve_context_server_tool_name_conflicts() {
assert_resolve_context_server_tool_name_conflicts(
vec![
TestTool::new("tool1", ToolSource::Native),
TestTool::new("tool2", ToolSource::Native),
],
vec![TestTool::new(
"tool3",
ToolSource::ContextServer { id: "mcp-1".into() },
)],
vec!["tool3"],
);
assert_resolve_context_server_tool_name_conflicts(
vec![
TestTool::new("tool1", ToolSource::Native),
TestTool::new("tool2", ToolSource::Native),
],
vec![
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
],
vec!["mcp-1_tool3", "mcp-2_tool3"],
);
assert_resolve_context_server_tool_name_conflicts(
vec![
TestTool::new("tool1", ToolSource::Native),
TestTool::new("tool2", ToolSource::Native),
TestTool::new("tool3", ToolSource::Native),
],
vec![
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
],
vec!["mcp-1_tool3", "mcp-2_tool3"],
);
// Test deduplication of tools with very long names, in this case the mcp server name should be truncated
assert_resolve_context_server_tool_name_conflicts(
vec![TestTool::new(
"tool-with-very-very-very-long-name",
ToolSource::Native,
)],
vec![TestTool::new(
"tool-with-very-very-very-long-name",
ToolSource::ContextServer {
id: "mcp-with-very-very-very-long-name".into(),
},
)],
vec!["mcp-with-very-very-very-long-_tool-with-very-very-very-long-name"],
);
fn assert_resolve_context_server_tool_name_conflicts(
builtin_tools: Vec<TestTool>,
context_server_tools: Vec<TestTool>,
expected: Vec<&'static str>,
) {
let context_server_tools: Vec<AnyTool> = context_server_tools
.into_iter()
.map(|t| Arc::new(t).into())
.collect();
let builtin_tools: Vec<AnyTool> = builtin_tools
.into_iter()
.map(|t| Arc::new(t).into())
.collect();
let tools =
resolve_context_server_tool_name_conflicts(&context_server_tools, &builtin_tools);
assert_eq!(tools.len(), expected.len());
for (i, (name, _)) in tools.into_iter().enumerate() {
assert_eq!(
name.0.as_ref(),
expected[i],
"Expected '{}' got '{}' at index {}",
expected[i],
name,
i
);
}
}
}
struct TestTool {
name: String,
source: ToolSource,
}
impl TestTool {
fn new(name: impl Into<String>, source: ToolSource) -> Self {
Self {
name: name.into(),
source,
}
}
}
impl Tool for TestTool {
type Input = ();
fn name(&self) -> String {
self.name.clone()
}
fn icon(&self) -> icons::IconName {
icons::IconName::Ai
}
fn may_perform_edits(&self) -> bool {
false
}
fn needs_confirmation(&self, _input: &Self::Input, _cx: &App) -> bool {
true
}
fn source(&self) -> ToolSource {
self.source.clone()
}
fn description(&self) -> String {
"Test tool".to_string()
}
fn ui_text(&self, _input: &Self::Input) -> String {
"Test tool".to_string()
}
fn run(
self: Arc<Self>,
_input: Self::Input,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>,
_window: Option<AnyWindowHandle>,
_cx: &mut App,
) -> ToolResult {
ToolResult {
output: Task::ready(Err(anyhow::anyhow!("No content"))),
card: None,
}
}
}
}

View File

@@ -40,13 +40,11 @@ pub struct CopyPathToolInput {
pub struct CopyPathTool;
impl Tool for CopyPathTool {
type Input = CopyPathToolInput;
fn name(&self) -> String {
"copy_path".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -66,15 +64,20 @@ impl Tool for CopyPathTool {
json_schema_for::<CopyPathToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
let src = MarkdownInlineCode(&input.source_path);
let dest = MarkdownInlineCode(&input.destination_path);
format!("Copy {src} to {dest}")
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<CopyPathToolInput>(input.clone()) {
Ok(input) => {
let src = MarkdownInlineCode(&input.source_path);
let dest = MarkdownInlineCode(&input.destination_path);
format!("Copy {src} to {dest}")
}
Err(_) => "Copy path".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -82,6 +85,10 @@ impl Tool for CopyPathTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CopyPathToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let copy_task = project.update(cx, |project, cx| {
match project
.find_project_path(&input.source_path, cx)

View File

@@ -29,8 +29,6 @@ pub struct CreateDirectoryToolInput {
pub struct CreateDirectoryTool;
impl Tool for CreateDirectoryTool {
type Input = CreateDirectoryToolInput;
fn name(&self) -> String {
"create_directory".into()
}
@@ -39,7 +37,7 @@ impl Tool for CreateDirectoryTool {
include_str!("./create_directory_tool/description.md").into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -55,13 +53,18 @@ impl Tool for CreateDirectoryTool {
json_schema_for::<CreateDirectoryToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
format!("Create directory {}", MarkdownInlineCode(&input.path))
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<CreateDirectoryToolInput>(input.clone()) {
Ok(input) => {
format!("Create directory {}", MarkdownInlineCode(&input.path))
}
Err(_) => "Create directory".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -69,6 +72,10 @@ impl Tool for CreateDirectoryTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CreateDirectoryToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let project_path = match project.read(cx).find_project_path(&input.path, cx) {
Some(project_path) => project_path,
None => {

View File

@@ -29,13 +29,11 @@ pub struct DeletePathToolInput {
pub struct DeletePathTool;
impl Tool for DeletePathTool {
type Input = DeletePathToolInput;
fn name(&self) -> String {
"delete_path".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -55,13 +53,16 @@ impl Tool for DeletePathTool {
json_schema_for::<DeletePathToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
format!("Delete “`{}`”", input.path)
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<DeletePathToolInput>(input.clone()) {
Ok(input) => format!("Delete “`{}`”", input.path),
Err(_) => "Delete path".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
@@ -69,7 +70,10 @@ impl Tool for DeletePathTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let path_str = input.path;
let path_str = match serde_json::from_value::<DeletePathToolInput>(input) {
Ok(input) => input.path,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(project_path) = project.read(cx).find_project_path(&path_str, cx) else {
return Task::ready(Err(anyhow!(
"Couldn't delete {path_str} because that path isn't in this project."

View File

@@ -42,13 +42,11 @@ where
pub struct DiagnosticsTool;
impl Tool for DiagnosticsTool {
type Input = DiagnosticsToolInput;
fn name(&self) -> String {
"diagnostics".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -68,9 +66,15 @@ impl Tool for DiagnosticsTool {
json_schema_for::<DiagnosticsToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
if let Some(path) = input.path.as_ref().filter(|p| !p.is_empty()) {
format!("Check diagnostics for {}", MarkdownInlineCode(path))
fn ui_text(&self, input: &serde_json::Value) -> String {
if let Some(path) = serde_json::from_value::<DiagnosticsToolInput>(input.clone())
.ok()
.and_then(|input| match input.path {
Some(path) if !path.is_empty() => Some(path),
_ => None,
})
{
format!("Check diagnostics for {}", MarkdownInlineCode(&path))
} else {
"Check project diagnostics".to_string()
}
@@ -78,7 +82,7 @@ impl Tool for DiagnosticsTool {
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
@@ -86,7 +90,10 @@ impl Tool for DiagnosticsTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
match input.path {
match serde_json::from_value::<DiagnosticsToolInput>(input)
.ok()
.and_then(|input| input.path)
{
Some(path) if !path.is_empty() => {
let Some(project_path) = project.read(cx).find_project_path(&path, cx) else {
return Task::ready(Err(anyhow!("Could not find path {path} in project",)))

View File

@@ -9132,7 +9132,7 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) {
self.manipulate_lines(window, cx, |lines| lines.sort())
self.manipulate_immutable_lines(window, cx, |lines| lines.sort())
}
pub fn sort_lines_case_insensitive(
@@ -9141,7 +9141,7 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) {
self.manipulate_lines(window, cx, |lines| {
self.manipulate_immutable_lines(window, cx, |lines| {
lines.sort_by_key(|line| line.to_lowercase())
})
}
@@ -9152,7 +9152,7 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) {
self.manipulate_lines(window, cx, |lines| {
self.manipulate_immutable_lines(window, cx, |lines| {
let mut seen = HashSet::default();
lines.retain(|line| seen.insert(line.to_lowercase()));
})
@@ -9164,7 +9164,7 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) {
self.manipulate_lines(window, cx, |lines| {
self.manipulate_immutable_lines(window, cx, |lines| {
let mut seen = HashSet::default();
lines.retain(|line| seen.insert(*line));
})
@@ -9606,20 +9606,20 @@ impl Editor {
}
pub fn reverse_lines(&mut self, _: &ReverseLines, window: &mut Window, cx: &mut Context<Self>) {
self.manipulate_lines(window, cx, |lines| lines.reverse())
self.manipulate_immutable_lines(window, cx, |lines| lines.reverse())
}
pub fn shuffle_lines(&mut self, _: &ShuffleLines, window: &mut Window, cx: &mut Context<Self>) {
self.manipulate_lines(window, cx, |lines| lines.shuffle(&mut thread_rng()))
self.manipulate_immutable_lines(window, cx, |lines| lines.shuffle(&mut thread_rng()))
}
fn manipulate_lines<Fn>(
fn manipulate_lines<M>(
&mut self,
window: &mut Window,
cx: &mut Context<Self>,
mut callback: Fn,
mut manipulate: M,
) where
Fn: FnMut(&mut Vec<&str>),
M: FnMut(&str) -> LineManipulationResult,
{
self.hide_mouse_cursor(&HideMouseCursorOrigin::TypingAction);
@@ -9652,18 +9652,14 @@ impl Editor {
.text_for_range(start_point..end_point)
.collect::<String>();
let mut lines = text.split('\n').collect_vec();
let LineManipulationResult { new_text, line_count_before, line_count_after} = manipulate(&text);
let lines_before = lines.len();
callback(&mut lines);
let lines_after = lines.len();
edits.push((start_point..end_point, lines.join("\n")));
edits.push((start_point..end_point, new_text));
// Selections must change based on added and removed line count
let start_row =
MultiBufferRow(start_point.row + added_lines as u32 - removed_lines as u32);
let end_row = MultiBufferRow(start_row.0 + lines_after.saturating_sub(1) as u32);
let end_row = MultiBufferRow(start_row.0 + line_count_after.saturating_sub(1) as u32);
new_selections.push(Selection {
id: selection.id,
start: start_row,
@@ -9672,10 +9668,10 @@ impl Editor {
reversed: selection.reversed,
});
if lines_after > lines_before {
added_lines += lines_after - lines_before;
} else if lines_before > lines_after {
removed_lines += lines_before - lines_after;
if line_count_after > line_count_before {
added_lines += line_count_after - line_count_before;
} else if line_count_before > line_count_after {
removed_lines += line_count_before - line_count_after;
}
}
@@ -9720,6 +9716,171 @@ impl Editor {
})
}
fn manipulate_immutable_lines<Fn>(
&mut self,
window: &mut Window,
cx: &mut Context<Self>,
mut callback: Fn,
) where
Fn: FnMut(&mut Vec<&str>),
{
self.manipulate_lines(window, cx, |text| {
let mut lines: Vec<&str> = text.split('\n').collect();
let line_count_before = lines.len();
callback(&mut lines);
LineManipulationResult {
new_text: lines.join("\n"),
line_count_before,
line_count_after: lines.len(),
}
});
}
fn manipulate_mutable_lines<Fn>(
&mut self,
window: &mut Window,
cx: &mut Context<Self>,
mut callback: Fn,
) where
Fn: FnMut(&mut Vec<Cow<'_, str>>),
{
self.manipulate_lines(window, cx, |text| {
let mut lines: Vec<Cow<str>> = text.split('\n').map(Cow::from).collect();
let line_count_before = lines.len();
callback(&mut lines);
LineManipulationResult {
new_text: lines.join("\n"),
line_count_before,
line_count_after: lines.len(),
}
});
}
pub fn convert_indentation_to_spaces(
&mut self,
_: &ConvertIndentationToSpaces,
window: &mut Window,
cx: &mut Context<Self>,
) {
let settings = self.buffer.read(cx).language_settings(cx);
let tab_size = settings.tab_size.get() as usize;
self.manipulate_mutable_lines(window, cx, |lines| {
// Allocates a reasonably sized scratch buffer once for the whole loop
let mut reindented_line = String::with_capacity(MAX_LINE_LEN);
// Avoids recomputing spaces that could be inserted many times
let space_cache: Vec<Vec<char>> = (1..=tab_size)
.map(|n| IndentSize::spaces(n as u32).chars().collect())
.collect();
for line in lines.iter_mut().filter(|line| !line.is_empty()) {
let mut chars = line.as_ref().chars();
let mut col = 0;
let mut changed = false;
while let Some(ch) = chars.next() {
match ch {
' ' => {
reindented_line.push(' ');
col += 1;
}
'\t' => {
// \t are converted to spaces depending on the current column
let spaces_len = tab_size - (col % tab_size);
reindented_line.extend(&space_cache[spaces_len - 1]);
col += spaces_len;
changed = true;
}
_ => {
// If we dont append before break, the character is consumed
reindented_line.push(ch);
break;
}
}
}
if !changed {
reindented_line.clear();
continue;
}
// Append the rest of the line and replace old reference with new one
reindented_line.extend(chars);
*line = Cow::Owned(reindented_line.clone());
reindented_line.clear();
}
});
}
pub fn convert_indentation_to_tabs(
&mut self,
_: &ConvertIndentationToTabs,
window: &mut Window,
cx: &mut Context<Self>,
) {
let settings = self.buffer.read(cx).language_settings(cx);
let tab_size = settings.tab_size.get() as usize;
self.manipulate_mutable_lines(window, cx, |lines| {
// Allocates a reasonably sized buffer once for the whole loop
let mut reindented_line = String::with_capacity(MAX_LINE_LEN);
// Avoids recomputing spaces that could be inserted many times
let space_cache: Vec<Vec<char>> = (1..=tab_size)
.map(|n| IndentSize::spaces(n as u32).chars().collect())
.collect();
for line in lines.iter_mut().filter(|line| !line.is_empty()) {
let mut chars = line.chars();
let mut spaces_count = 0;
let mut first_non_indent_char = None;
let mut changed = false;
while let Some(ch) = chars.next() {
match ch {
' ' => {
// Keep track of spaces. Append \t when we reach tab_size
spaces_count += 1;
changed = true;
if spaces_count == tab_size {
reindented_line.push('\t');
spaces_count = 0;
}
}
'\t' => {
reindented_line.push('\t');
spaces_count = 0;
}
_ => {
// Dont append it yet, we might have remaining spaces
first_non_indent_char = Some(ch);
break;
}
}
}
if !changed {
reindented_line.clear();
continue;
}
// Remaining spaces that didn't make a full tab stop
if spaces_count > 0 {
reindented_line.extend(&space_cache[spaces_count - 1]);
}
// If we consume an extra character that was not indentation, add it back
if let Some(extra_char) = first_non_indent_char {
reindented_line.push(extra_char);
}
// Append the rest of the line and replace old reference with new one
reindented_line.extend(chars);
*line = Cow::Owned(reindented_line.clone());
reindented_line.clear();
}
});
}
pub fn convert_to_upper_case(
&mut self,
_: &ConvertToUpperCase,
@@ -21157,6 +21318,13 @@ pub struct LineHighlight {
pub type_id: Option<TypeId>,
}
struct LineManipulationResult {
pub new_text: String,
pub line_count_before: usize,
pub line_count_after: usize,
}
fn render_diff_hunk_controls(
row: u32,
status: &DiffHunkStatus,

View File

@@ -121,13 +121,11 @@ struct PartialInput {
const DEFAULT_UI_TEXT: &str = "Editing file";
impl Tool for EditFileTool {
type Input = EditFileToolInput;
fn name(&self) -> String {
"edit_file".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -147,20 +145,24 @@ impl Tool for EditFileTool {
json_schema_for::<EditFileToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
input.display_description.clone()
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<EditFileToolInput>(input.clone()) {
Ok(input) => input.display_description,
Err(_) => "Editing file".to_string(),
}
}
fn still_streaming_ui_text(&self, input: &Self::Input) -> String {
let description = input.display_description.trim();
if !description.is_empty() {
return description.to_string();
}
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
let description = input.display_description.trim();
if !description.is_empty() {
return description.to_string();
}
let path = input.path.to_string_lossy();
let path = path.trim();
if !path.is_empty() {
return path.to_string();
let path = input.path.trim();
if !path.is_empty() {
return path.to_string();
}
}
DEFAULT_UI_TEXT.to_string()
@@ -168,7 +170,7 @@ impl Tool for EditFileTool {
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
@@ -176,6 +178,11 @@ impl Tool for EditFileTool {
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<EditFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let project_path = match resolve_path(&input, project.clone(), cx) {
Ok(path) => path,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
@@ -1162,11 +1169,12 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let result = cx
.update(|cx| {
let input = EditFileToolInput {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Some edit".into(),
path: "root/nonexistent_file.txt".into(),
mode: EditFileMode::Edit,
};
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
@@ -1280,22 +1288,24 @@ mod tests {
#[test]
fn still_streaming_ui_text_with_path() {
let input = EditFileToolInput {
path: "src/main.rs".into(),
display_description: "".into(),
mode: EditFileMode::Edit,
};
let input = json!({
"path": "src/main.rs",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(EditFileTool.still_streaming_ui_text(&input), "src/main.rs");
}
#[test]
fn still_streaming_ui_text_with_description() {
let input = EditFileToolInput {
path: "".into(),
display_description: "Fix error handling".into(),
mode: EditFileMode::Edit,
};
let input = json!({
"path": "",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
@@ -1305,11 +1315,12 @@ mod tests {
#[test]
fn still_streaming_ui_text_with_path_and_description() {
let input = EditFileToolInput {
path: "src/main.rs".into(),
display_description: "Fix error handling".into(),
mode: EditFileMode::Edit,
};
let input = json!({
"path": "src/main.rs",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
@@ -1319,11 +1330,12 @@ mod tests {
#[test]
fn still_streaming_ui_text_no_path_or_description() {
let input = EditFileToolInput {
path: "".into(),
display_description: "".into(),
mode: EditFileMode::Edit,
};
let input = json!({
"path": "",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
@@ -1333,11 +1345,7 @@ mod tests {
#[test]
fn still_streaming_ui_text_with_null() {
let input = EditFileToolInput {
path: "".into(),
display_description: "".into(),
mode: EditFileMode::Edit,
};
let input = serde_json::Value::Null;
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
@@ -1449,11 +1457,12 @@ mod tests {
// Have the model stream unformatted content
let edit_result = {
let edit_task = cx.update(|cx| {
let input = EditFileToolInput {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Create main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
@@ -1512,11 +1521,12 @@ mod tests {
// Stream unformatted edits again
let edit_result = {
let edit_task = cx.update(|cx| {
let input = EditFileToolInput {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Update main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
@@ -1590,11 +1600,12 @@ mod tests {
// Have the model stream content that contains trailing whitespace
let edit_result = {
let edit_task = cx.update(|cx| {
let input = EditFileToolInput {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Create main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
@@ -1646,11 +1657,12 @@ mod tests {
// Stream edits again with trailing whitespace
let edit_result = {
let edit_task = cx.update(|cx| {
let input = EditFileToolInput {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Update main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,

View File

@@ -3,10 +3,10 @@ use std::sync::Arc;
use std::{borrow::Cow, cell::RefCell};
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, bail};
use anyhow::{Context as _, Result, anyhow, bail};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::AsyncReadExt as _;
use gpui::{AnyWindowHandle, App, AppContext as _, Entity};
use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task};
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
use http_client::{AsyncBody, HttpClientWithUrl};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
@@ -113,13 +113,11 @@ impl FetchTool {
}
impl Tool for FetchTool {
type Input = FetchToolInput;
fn name(&self) -> String {
"fetch".to_string()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -139,13 +137,16 @@ impl Tool for FetchTool {
json_schema_for::<FetchToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
format!("Fetch {}", MarkdownEscaped(&input.url))
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<FetchToolInput>(input.clone()) {
Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)),
Err(_) => "Fetch URL".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -153,6 +154,11 @@ impl Tool for FetchTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<FetchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let text = cx.background_spawn({
let http_client = self.http_client.clone();
async move { Self::build_message(http_client, &input.url).await }

View File

@@ -51,13 +51,11 @@ const RESULTS_PER_PAGE: usize = 50;
pub struct FindPathTool;
impl Tool for FindPathTool {
type Input = FindPathToolInput;
fn name(&self) -> String {
"find_path".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -77,13 +75,16 @@ impl Tool for FindPathTool {
json_schema_for::<FindPathToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
format!("Find paths matching \"`{}`\"", input.glob)
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<FindPathToolInput>(input.clone()) {
Ok(input) => format!("Find paths matching “`{}`”", input.glob),
Err(_) => "Search paths".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -91,7 +92,10 @@ impl Tool for FindPathTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let (offset, glob) = (input.offset, input.glob);
let (offset, glob) = match serde_json::from_value::<FindPathToolInput>(input) {
Ok(input) => (input.offset, input.glob),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let (sender, receiver) = oneshot::channel();

View File

@@ -53,13 +53,11 @@ const RESULTS_PER_PAGE: u32 = 20;
pub struct GrepTool;
impl Tool for GrepTool {
type Input = GrepToolInput;
fn name(&self) -> String {
"grep".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -79,25 +77,30 @@ impl Tool for GrepTool {
json_schema_for::<GrepToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
let page = input.page();
let regex_str = MarkdownInlineCode(&input.regex);
let case_info = if input.case_sensitive {
" (case-sensitive)"
} else {
""
};
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<GrepToolInput>(input.clone()) {
Ok(input) => {
let page = input.page();
let regex_str = MarkdownInlineCode(&input.regex);
let case_info = if input.case_sensitive {
" (case-sensitive)"
} else {
""
};
if page > 1 {
format!("Get page {page} of search results for regex {regex_str}{case_info}")
} else {
format!("Search files for regex {regex_str}{case_info}")
if page > 1 {
format!("Get page {page} of search results for regex {regex_str}{case_info}")
} else {
format!("Search files for regex {regex_str}{case_info}")
}
}
Err(_) => "Search with regex".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -108,6 +111,13 @@ impl Tool for GrepTool {
const CONTEXT_LINES: u32 = 2;
const MAX_ANCESTOR_LINES: u32 = 10;
let input = match serde_json::from_value::<GrepToolInput>(input) {
Ok(input) => input,
Err(error) => {
return Task::ready(Err(anyhow!("Failed to parse input: {error}"))).into();
}
};
let include_matcher = match PathMatcher::new(
input
.include_pattern
@@ -338,12 +348,13 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
// Test with include pattern for Rust files inside the root of the project
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "println".to_string(),
include_pattern: Some("root/**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
assert!(result.contains("main.rs"), "Should find matches in main.rs");
@@ -357,12 +368,13 @@ mod tests {
);
// Test with include pattern for src directory only
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "fn".to_string(),
include_pattern: Some("root/**/src/**".to_string()),
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
assert!(
@@ -379,12 +391,13 @@ mod tests {
);
// Test with empty include pattern (should default to all files)
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "fn".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
assert!(result.contains("main.rs"), "Should find matches in main.rs");
@@ -415,12 +428,13 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
// Test case-insensitive search (default)
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "uppercase".to_string(),
include_pattern: Some("**/*.txt".to_string()),
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
assert!(
@@ -429,12 +443,13 @@ mod tests {
);
// Test case-sensitive search
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "uppercase".to_string(),
include_pattern: Some("**/*.txt".to_string()),
offset: 0,
case_sensitive: true,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
assert!(
@@ -443,12 +458,13 @@ mod tests {
);
// Test case-sensitive search
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "LOWERCASE".to_string(),
include_pattern: Some("**/*.txt".to_string()),
offset: 0,
case_sensitive: true,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
@@ -458,12 +474,13 @@ mod tests {
);
// Test case-sensitive search for lowercase pattern
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "lowercase".to_string(),
include_pattern: Some("**/*.txt".to_string()),
offset: 0,
case_sensitive: true,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
assert!(
@@ -559,12 +576,13 @@ mod tests {
let project = setup_syntax_test(cx).await;
// Test: Line at the top level of the file
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "This is at the top level".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
@@ -588,12 +606,13 @@ mod tests {
let project = setup_syntax_test(cx).await;
// Test: Line inside a function body
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "Function in nested module".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
@@ -619,12 +638,13 @@ mod tests {
let project = setup_syntax_test(cx).await;
// Test: Line with a function argument
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "second_arg".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
@@ -654,12 +674,13 @@ mod tests {
let project = setup_syntax_test(cx).await;
// Test: Line inside an if block
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "Inside if block".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
@@ -684,12 +705,13 @@ mod tests {
let project = setup_syntax_test(cx).await;
// Test: Line in the middle of a long function - should show message about remaining lines
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "Line 5".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
@@ -724,12 +746,13 @@ mod tests {
let project = setup_syntax_test(cx).await;
// Test: Line in the long function
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "Line 12".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
@@ -751,7 +774,7 @@ mod tests {
}
async fn run_grep_tool(
input: GrepToolInput,
input: serde_json::Value,
project: Entity<Project>,
cx: &mut TestAppContext,
) -> String {
@@ -853,12 +876,9 @@ mod tests {
// Searching for files outside the project worktree should return no results
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "outside_function".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "outside_function"
});
Arc::new(GrepTool)
.run(
input,
@@ -882,12 +902,9 @@ mod tests {
// Searching within the project should succeed
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "main".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "main"
});
Arc::new(GrepTool)
.run(
input,
@@ -911,12 +928,9 @@ mod tests {
// Searching files that match file_scan_exclusions should return no results
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "special_configuration".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "special_configuration"
});
Arc::new(GrepTool)
.run(
input,
@@ -939,12 +953,9 @@ mod tests {
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "custom_metadata".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "custom_metadata"
});
Arc::new(GrepTool)
.run(
input,
@@ -968,12 +979,9 @@ mod tests {
// Searching private files should return no results
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "SECRET_KEY".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "SECRET_KEY"
});
Arc::new(GrepTool)
.run(
input,
@@ -996,12 +1004,9 @@ mod tests {
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "private_key_content".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "private_key_content"
});
Arc::new(GrepTool)
.run(
input,
@@ -1024,12 +1029,9 @@ mod tests {
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "sensitive_data".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "sensitive_data"
});
Arc::new(GrepTool)
.run(
input,
@@ -1053,12 +1055,9 @@ mod tests {
// Searching a normal file should still work, even with private_files configured
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "normal_file_content".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "normal_file_content"
});
Arc::new(GrepTool)
.run(
input,
@@ -1082,12 +1081,10 @@ mod tests {
// Path traversal attempts with .. in include_pattern should not escape project
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "outside_function".to_string(),
include_pattern: Some("../outside_project/**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "outside_function",
"include_pattern": "../outside_project/**/*.rs"
});
Arc::new(GrepTool)
.run(
input,
@@ -1188,12 +1185,10 @@ mod tests {
// Search for "secret" - should exclude files based on worktree-specific settings
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "secret".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "secret",
"case_sensitive": false
});
Arc::new(GrepTool)
.run(
input,
@@ -1255,12 +1250,10 @@ mod tests {
// Test with `include_pattern` specific to one worktree
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "secret".to_string(),
include_pattern: Some("worktree1/**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "secret",
"include_pattern": "worktree1/**/*.rs"
});
Arc::new(GrepTool)
.run(
input,

View File

@@ -41,13 +41,11 @@ pub struct ListDirectoryToolInput {
pub struct ListDirectoryTool;
impl Tool for ListDirectoryTool {
type Input = ListDirectoryToolInput;
fn name(&self) -> String {
"list_directory".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -67,14 +65,19 @@ impl Tool for ListDirectoryTool {
json_schema_for::<ListDirectoryToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
let path = MarkdownInlineCode(&input.path);
format!("List the {path} directory's contents")
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<ListDirectoryToolInput>(input.clone()) {
Ok(input) => {
let path = MarkdownInlineCode(&input.path);
format!("List the {path} directory's contents")
}
Err(_) => "List directory".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -82,6 +85,11 @@ impl Tool for ListDirectoryTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<ListDirectoryToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
// Sometimes models will return these even though we tell it to give a path and not a glob.
// When this happens, just list the root worktree directories.
if matches!(input.path.as_str(), "." | "" | "./" | "*") {
@@ -277,9 +285,9 @@ mod tests {
let tool = Arc::new(ListDirectoryTool);
// Test listing root directory
let input = ListDirectoryToolInput {
path: "project".to_string(),
};
let input = json!({
"path": "project"
});
let result = cx
.update(|cx| {
@@ -312,9 +320,9 @@ mod tests {
);
// Test listing src directory
let input = ListDirectoryToolInput {
path: "project/src".to_string(),
};
let input = json!({
"path": "project/src"
});
let result = cx
.update(|cx| {
@@ -347,9 +355,9 @@ mod tests {
);
// Test listing directory with only files
let input = ListDirectoryToolInput {
path: "project/tests".to_string(),
};
let input = json!({
"path": "project/tests"
});
let result = cx
.update(|cx| {
@@ -391,9 +399,9 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let tool = Arc::new(ListDirectoryTool);
let input = ListDirectoryToolInput {
path: "project/empty_dir".to_string(),
};
let input = json!({
"path": "project/empty_dir"
});
let result = cx
.update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx))
@@ -424,9 +432,9 @@ mod tests {
let tool = Arc::new(ListDirectoryTool);
// Test non-existent path
let input = ListDirectoryToolInput {
path: "project/nonexistent".to_string(),
};
let input = json!({
"path": "project/nonexistent"
});
let result = cx
.update(|cx| {
@@ -447,9 +455,9 @@ mod tests {
assert!(result.unwrap_err().to_string().contains("Path not found"));
// Test trying to list a file instead of directory
let input = ListDirectoryToolInput {
path: "project/file.txt".to_string(),
};
let input = json!({
"path": "project/file.txt"
});
let result = cx
.update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx))
@@ -519,9 +527,9 @@ mod tests {
let tool = Arc::new(ListDirectoryTool);
// Listing root directory should exclude private and excluded files
let input = ListDirectoryToolInput {
path: "project".to_string(),
};
let input = json!({
"path": "project"
});
let result = cx
.update(|cx| {
@@ -560,9 +568,9 @@ mod tests {
);
// Trying to list an excluded directory should fail
let input = ListDirectoryToolInput {
path: "project/.secretdir".to_string(),
};
let input = json!({
"path": "project/.secretdir"
});
let result = cx
.update(|cx| {
@@ -592,9 +600,9 @@ mod tests {
);
// Listing a directory should exclude private files within it
let input = ListDirectoryToolInput {
path: "project/visible_dir".to_string(),
};
let input = json!({
"path": "project/visible_dir"
});
let result = cx
.update(|cx| {
@@ -712,9 +720,9 @@ mod tests {
let tool = Arc::new(ListDirectoryTool);
// Test listing worktree1/src - should exclude secret.rs and config.toml based on local settings
let input = ListDirectoryToolInput {
path: "worktree1/src".to_string(),
};
let input = json!({
"path": "worktree1/src"
});
let result = cx
.update(|cx| {
@@ -744,9 +752,9 @@ mod tests {
);
// Test listing worktree1/tests - should exclude fixture.sql based on local settings
let input = ListDirectoryToolInput {
path: "worktree1/tests".to_string(),
};
let input = json!({
"path": "worktree1/tests"
});
let result = cx
.update(|cx| {
@@ -772,9 +780,9 @@ mod tests {
);
// Test listing worktree2/lib - should exclude private.js and data.json based on local settings
let input = ListDirectoryToolInput {
path: "worktree2/lib".to_string(),
};
let input = json!({
"path": "worktree2/lib"
});
let result = cx
.update(|cx| {
@@ -804,9 +812,9 @@ mod tests {
);
// Test listing worktree2/docs - should exclude internal.md based on local settings
let input = ListDirectoryToolInput {
path: "worktree2/docs".to_string(),
};
let input = json!({
"path": "worktree2/docs"
});
let result = cx
.update(|cx| {
@@ -832,9 +840,9 @@ mod tests {
);
// Test trying to list an excluded directory directly
let input = ListDirectoryToolInput {
path: "worktree1/src/secret.rs".to_string(),
};
let input = json!({
"path": "worktree1/src/secret.rs"
});
let result = cx
.update(|cx| {

View File

@@ -38,13 +38,11 @@ pub struct MovePathToolInput {
pub struct MovePathTool;
impl Tool for MovePathTool {
type Input = MovePathToolInput;
fn name(&self) -> String {
"move_path".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -64,29 +62,34 @@ impl Tool for MovePathTool {
json_schema_for::<MovePathToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
let src = MarkdownInlineCode(&input.source_path);
let dest = MarkdownInlineCode(&input.destination_path);
let src_path = Path::new(&input.source_path);
let dest_path = Path::new(&input.destination_path);
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<MovePathToolInput>(input.clone()) {
Ok(input) => {
let src = MarkdownInlineCode(&input.source_path);
let dest = MarkdownInlineCode(&input.destination_path);
let src_path = Path::new(&input.source_path);
let dest_path = Path::new(&input.destination_path);
match dest_path
.file_name()
.and_then(|os_str| os_str.to_os_string().into_string().ok())
{
Some(filename) if src_path.parent() == dest_path.parent() => {
let filename = MarkdownInlineCode(&filename);
format!("Rename {src} to {filename}")
}
_ => {
format!("Move {src} to {dest}")
match dest_path
.file_name()
.and_then(|os_str| os_str.to_os_string().into_string().ok())
{
Some(filename) if src_path.parent() == dest_path.parent() => {
let filename = MarkdownInlineCode(&filename);
format!("Rename {src} to {filename}")
}
_ => {
format!("Move {src} to {dest}")
}
}
}
Err(_) => "Move path".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -94,6 +97,10 @@ impl Tool for MovePathTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<MovePathToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let rename_task = project.update(cx, |project, cx| {
match project
.find_project_path(&input.source_path, cx)

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::Result;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use chrono::{Local, Utc};
use gpui::{AnyWindowHandle, App, Entity, Task};
@@ -29,13 +29,11 @@ pub struct NowToolInput {
pub struct NowTool;
impl Tool for NowTool {
type Input = NowToolInput;
fn name(&self) -> String {
"now".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -55,13 +53,13 @@ impl Tool for NowTool {
json_schema_for::<NowToolInput>(format)
}
fn ui_text(&self, _input: &Self::Input) -> String {
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Get current time".to_string()
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -69,6 +67,11 @@ impl Tool for NowTool {
_window: Option<AnyWindowHandle>,
_cx: &mut App,
) -> ToolResult {
let input: NowToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let now = match input.timezone {
Timezone::Utc => Utc::now().to_rfc3339(),
Timezone::Local => Local::now().to_rfc3339(),

View File

@@ -1,7 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, AppContext, Entity};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -19,13 +19,11 @@ pub struct OpenToolInput {
pub struct OpenTool;
impl Tool for OpenTool {
type Input = OpenToolInput;
fn name(&self) -> String {
"open".to_string()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
true
}
fn may_perform_edits(&self) -> bool {
@@ -43,13 +41,16 @@ impl Tool for OpenTool {
json_schema_for::<OpenToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
format!("Open `{}`", MarkdownEscaped(&input.path_or_url))
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<OpenToolInput>(input.clone()) {
Ok(input) => format!("Open `{}`", MarkdownEscaped(&input.path_or_url)),
Err(_) => "Open file or URL".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -57,6 +58,11 @@ impl Tool for OpenTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input: OpenToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
// If path_or_url turns out to be a path in the project, make it absolute.
let abs_path = to_absolute_path(&input.path_or_url, project, cx);

View File

@@ -51,13 +51,11 @@ pub struct ReadFileToolInput {
pub struct ReadFileTool;
impl Tool for ReadFileTool {
type Input = ReadFileToolInput;
fn name(&self) -> String {
"read_file".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -77,18 +75,23 @@ impl Tool for ReadFileTool {
json_schema_for::<ReadFileToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
let path = MarkdownInlineCode(&input.path);
match (input.start_line, input.end_line) {
(Some(start), None) => format!("Read file {path} (from line {start})"),
(Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
_ => format!("Read file {path}"),
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<ReadFileToolInput>(input.clone()) {
Ok(input) => {
let path = MarkdownInlineCode(&input.path);
match (input.start_line, input.end_line) {
(Some(start), None) => format!("Read file {path} (from line {start})"),
(Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
_ => format!("Read file {path}"),
}
}
Err(_) => "Read file".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
@@ -96,6 +99,11 @@ impl Tool for ReadFileTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<ReadFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
};
@@ -300,12 +308,9 @@ mod test {
let model = Arc::new(FakeLanguageModel::default());
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/nonexistent_file.txt".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "root/nonexistent_file.txt"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -342,11 +347,9 @@ mod test {
let model = Arc::new(FakeLanguageModel::default());
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/small_file.txt".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "root/small_file.txt"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -386,11 +389,9 @@ mod test {
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/large_file.rs".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "root/large_file.rs"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -420,11 +421,10 @@ mod test {
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/large_file.rs".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "root/large_file.rs",
"offset": 1
});
Arc::new(ReadFileTool)
.run(
input,
@@ -477,11 +477,11 @@ mod test {
let model = Arc::new(FakeLanguageModel::default());
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(2),
end_line: Some(4),
};
let input = json!({
"path": "root/multiline.txt",
"start_line": 2,
"end_line": 4
});
Arc::new(ReadFileTool)
.run(
input,
@@ -520,11 +520,11 @@ mod test {
// start_line of 0 should be treated as 1
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(0),
end_line: Some(2),
};
let input = json!({
"path": "root/multiline.txt",
"start_line": 0,
"end_line": 2
});
Arc::new(ReadFileTool)
.run(
input,
@@ -543,11 +543,11 @@ mod test {
// end_line of 0 should result in at least 1 line
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(1),
end_line: Some(0),
};
let input = json!({
"path": "root/multiline.txt",
"start_line": 1,
"end_line": 0
});
Arc::new(ReadFileTool)
.run(
input,
@@ -566,11 +566,11 @@ mod test {
// when start_line > end_line, should still return at least 1 line
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(3),
end_line: Some(2),
};
let input = json!({
"path": "root/multiline.txt",
"start_line": 3,
"end_line": 2
});
Arc::new(ReadFileTool)
.run(
input,
@@ -694,11 +694,9 @@ mod test {
// Reading a file outside the project worktree should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "/outside_project/sensitive_file.txt".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "/outside_project/sensitive_file.txt"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -720,11 +718,9 @@ mod test {
// Reading a file within the project should succeed
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/allowed_file.txt".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "project_root/allowed_file.txt"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -746,11 +742,9 @@ mod test {
// Reading files that match file_scan_exclusions should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/.secretdir/config".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "project_root/.secretdir/config"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -771,11 +765,9 @@ mod test {
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/.mymetadata".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "project_root/.mymetadata"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -797,11 +789,9 @@ mod test {
// Reading private files should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/secrets/.mysecrets".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "project_root/.mysecrets"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -822,11 +812,9 @@ mod test {
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/subdir/special.privatekey".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "project_root/subdir/special.privatekey"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -847,11 +835,9 @@ mod test {
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/subdir/data.mysensitive".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "project_root/subdir/data.mysensitive"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -873,11 +859,9 @@ mod test {
// Reading a normal file should still work, even with private_files configured
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/subdir/normal_file.txt".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "project_root/subdir/normal_file.txt"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -900,11 +884,9 @@ mod test {
// Path traversal attempts with .. should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/../outside_project/sensitive_file.txt".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "project_root/../outside_project/sensitive_file.txt"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -999,11 +981,9 @@ mod test {
let tool = Arc::new(ReadFileTool);
// Test reading allowed files in worktree1
let input = ReadFileToolInput {
path: "worktree1/src/main.rs".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "worktree1/src/main.rs"
});
let result = cx
.update(|cx| {
@@ -1027,11 +1007,9 @@ mod test {
);
// Test reading private file in worktree1 should fail
let input = ReadFileToolInput {
path: "worktree1/src/secret.rs".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "worktree1/src/secret.rs"
});
let result = cx
.update(|cx| {
@@ -1058,11 +1036,9 @@ mod test {
);
// Test reading excluded file in worktree1 should fail
let input = ReadFileToolInput {
path: "worktree1/tests/fixture.sql".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "worktree1/tests/fixture.sql"
});
let result = cx
.update(|cx| {
@@ -1089,11 +1065,9 @@ mod test {
);
// Test reading allowed files in worktree2
let input = ReadFileToolInput {
path: "worktree2/lib/public.js".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "worktree2/lib/public.js"
});
let result = cx
.update(|cx| {
@@ -1117,11 +1091,9 @@ mod test {
);
// Test reading private file in worktree2 should fail
let input = ReadFileToolInput {
path: "worktree2/lib/private.js".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "worktree2/lib/private.js"
});
let result = cx
.update(|cx| {
@@ -1148,11 +1120,9 @@ mod test {
);
// Test reading excluded file in worktree2 should fail
let input = ReadFileToolInput {
path: "worktree2/docs/internal.md".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "worktree2/docs/internal.md"
});
let result = cx
.update(|cx| {
@@ -1180,11 +1150,9 @@ mod test {
// Test that files allowed in one worktree but not in another are handled correctly
// (e.g., config.toml is private in worktree1 but doesn't exist in worktree2)
let input = ReadFileToolInput {
path: "worktree1/src/config.toml".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "worktree1/src/config.toml"
});
let result = cx
.update(|cx| {

View File

@@ -25,7 +25,9 @@ fn schema_to_json(
fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
let mut generator = match format {
LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(),
LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
// TODO: Gemini docs mention using a subset of OpenAPI 3, so this may benefit from using
// `SchemaSettings::openapi3()`.
LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::draft07()
.with(|settings| {
settings.meta_schema = None;
settings.inline_subschemas = true;

View File

@@ -2,7 +2,7 @@ use crate::{
schema::json_schema_for,
ui::{COLLAPSED_LINES, ToolOutputPreview},
};
use anyhow::{Context as _, Result};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
use futures::{FutureExt as _, future::Shared};
use gpui::{
@@ -72,13 +72,11 @@ impl TerminalTool {
}
impl Tool for TerminalTool {
type Input = TerminalToolInput;
fn name(&self) -> String {
Self::NAME.to_string()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
true
}
@@ -98,24 +96,30 @@ impl Tool for TerminalTool {
json_schema_for::<TerminalToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
let mut lines = input.command.lines();
let first_line = lines.next().unwrap_or_default();
let remaining_line_count = lines.count();
match remaining_line_count {
0 => MarkdownInlineCode(&first_line).to_string(),
1 => MarkdownInlineCode(&format!(
"{} - {} more line",
first_line, remaining_line_count
))
.to_string(),
n => MarkdownInlineCode(&format!("{} - {} more lines", first_line, n)).to_string(),
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<TerminalToolInput>(input.clone()) {
Ok(input) => {
let mut lines = input.command.lines();
let first_line = lines.next().unwrap_or_default();
let remaining_line_count = lines.count();
match remaining_line_count {
0 => MarkdownInlineCode(&first_line).to_string(),
1 => MarkdownInlineCode(&format!(
"{} - {} more line",
first_line, remaining_line_count
))
.to_string(),
n => MarkdownInlineCode(&format!("{} - {} more lines", first_line, n))
.to_string(),
}
}
Err(_) => "Run terminal command".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -123,6 +127,11 @@ impl Tool for TerminalTool {
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input: TerminalToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let working_dir = match working_dir(&input, &project, cx) {
Ok(dir) => dir,
Err(err) => return Task::ready(Err(err)).into(),
@@ -747,7 +756,7 @@ mod tests {
let result = cx.update(|cx| {
TerminalTool::run(
Arc::new(TerminalTool::new(cx)),
input,
serde_json::to_value(input).unwrap(),
Arc::default(),
project.clone(),
action_log.clone(),
@@ -782,7 +791,7 @@ mod tests {
let check = |input, expected, cx: &mut App| {
let headless_result = TerminalTool::run(
Arc::new(TerminalTool::new(cx)),
input,
serde_json::to_value(input).unwrap(),
Arc::default(),
project.clone(),
action_log.clone(),

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::Result;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
@@ -20,13 +20,11 @@ pub struct ThinkingToolInput {
pub struct ThinkingTool;
impl Tool for ThinkingTool {
type Input = ThinkingToolInput;
fn name(&self) -> String {
"thinking".to_string()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -46,13 +44,13 @@ impl Tool for ThinkingTool {
json_schema_for::<ThinkingToolInput>(format)
}
fn ui_text(&self, _input: &Self::Input) -> String {
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Thinking".to_string()
}
fn run(
self: Arc<Self>,
_input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -61,6 +59,10 @@ impl Tool for ThinkingTool {
_cx: &mut App,
) -> ToolResult {
// This tool just "thinks out loud" and doesn't perform any actions.
Task::ready(Ok("Finished thinking.".to_string().into())).into()
Task::ready(match serde_json::from_value::<ThinkingToolInput>(input) {
Ok(_input) => Ok("Finished thinking.".to_string().into()),
Err(err) => Err(anyhow!(err)),
})
.into()
}
}

View File

@@ -28,13 +28,11 @@ pub struct WebSearchToolInput {
pub struct WebSearchTool;
impl Tool for WebSearchTool {
type Input = WebSearchToolInput;
fn name(&self) -> String {
"web_search".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -54,13 +52,13 @@ impl Tool for WebSearchTool {
json_schema_for::<WebSearchToolInput>(format)
}
fn ui_text(&self, _input: &Self::Input) -> String {
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Searching the Web".to_string()
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -68,6 +66,10 @@ impl Tool for WebSearchTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<WebSearchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
return Task::ready(Err(anyhow!("Web search is not available."))).into();
};

View File

@@ -28,17 +28,7 @@ use workspace::Workspace;
const SHOULD_SHOW_UPDATE_NOTIFICATION_KEY: &str = "auto-updater-should-show-updated-notification";
const POLL_INTERVAL: Duration = Duration::from_secs(60 * 60);
actions!(
auto_update,
[
/// Checks for available updates.
Check,
/// Dismisses the update error message.
DismissErrorMessage,
/// Opens the release notes for the current version in a browser.
ViewReleaseNotes,
]
);
actions!(auto_update, [Check, DismissErrorMessage, ViewReleaseNotes,]);
#[derive(Serialize)]
struct UpdateRequestBody {

View File

@@ -12,13 +12,7 @@ use workspace::Workspace;
use workspace::notifications::simple_message_notification::MessageNotification;
use workspace::notifications::{NotificationId, show_app_notification};
actions!(
auto_update,
[
/// Opens the release notes for the current version in a new tab.
ViewReleaseNotesLocally
]
);
actions!(auto_update, [ViewReleaseNotesLocally]);
pub fn init(cx: &mut App) {
notify_if_app_was_updated(cx);

View File

@@ -29,7 +29,7 @@ client.workspace = true
collections.workspace = true
fs.workspace = true
futures.workspace = true
gpui = { workspace = true, features = ["screen-capture"] }
gpui.workspace = true
language.workspace = true
log.workspace = true
postage.workspace = true

View File

@@ -81,17 +81,7 @@ pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10);
pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
actions!(
client,
[
/// Signs in to Zed account.
SignIn,
/// Signs out of Zed account.
SignOut,
/// Reconnects to the collaboration server.
Reconnect
]
);
actions!(client, [SignIn, SignOut, Reconnect]);
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
pub struct ClientSettingsContent {

View File

@@ -35,7 +35,6 @@ dashmap.workspace = true
derive_more.workspace = true
envy = "0.4.2"
futures.workspace = true
gpui = { workspace = true, features = ["screen-capture"] }
hex.workspace = true
http_client.workspace = true
jsonwebtoken.workspace = true

View File

@@ -107,7 +107,7 @@ CREATE INDEX "index_worktree_entries_on_project_id" ON "worktree_entries" ("proj
CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id");
CREATE TABLE "project_repositories" (
"project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE,
"project_id" INTEGER NOT NULL,
"abs_path" VARCHAR,
"id" INTEGER NOT NULL,
"entry_ids" VARCHAR,
@@ -124,7 +124,7 @@ CREATE TABLE "project_repositories" (
CREATE INDEX "index_project_repositories_on_project_id" ON "project_repositories" ("project_id");
CREATE TABLE "project_repository_statuses" (
"project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE,
"project_id" INTEGER NOT NULL,
"repository_id" INTEGER NOT NULL,
"repo_path" VARCHAR NOT NULL,
"status" INT8 NOT NULL,

View File

@@ -1,25 +0,0 @@
DELETE FROM project_repositories
WHERE project_id NOT IN (SELECT id FROM projects);
ALTER TABLE project_repositories
ADD CONSTRAINT fk_project_repositories_project_id
FOREIGN KEY (project_id)
REFERENCES projects (id)
ON DELETE CASCADE
NOT VALID;
ALTER TABLE project_repositories
VALIDATE CONSTRAINT fk_project_repositories_project_id;
DELETE FROM project_repository_statuses
WHERE project_id NOT IN (SELECT id FROM projects);
ALTER TABLE project_repository_statuses
ADD CONSTRAINT fk_project_repository_statuses_project_id
FOREIGN KEY (project_id)
REFERENCES projects (id)
ON DELETE CASCADE
NOT VALID;
ALTER TABLE project_repository_statuses
VALIDATE CONSTRAINT fk_project_repository_statuses_project_id;

View File

@@ -1404,9 +1404,6 @@ async fn sync_model_request_usage_with_stripe(
llm_db: &Arc<LlmDatabase>,
stripe_billing: &Arc<StripeBilling>,
) -> anyhow::Result<()> {
log::info!("Stripe usage sync: Starting");
let started_at = Utc::now();
let staff_users = app.db.get_staff_users().await?;
let staff_user_ids = staff_users
.iter()
@@ -1451,10 +1448,6 @@ async fn sync_model_request_usage_with_stripe(
.find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
.await?;
let usage_meter_count = usage_meters.len();
log::info!("Stripe usage sync: Syncing {usage_meter_count} usage meters");
for (usage_meter, usage) in usage_meters {
maybe!(async {
let Some((billing_customer, billing_subscription)) =
@@ -1511,10 +1504,5 @@ async fn sync_model_request_usage_with_stripe(
.log_err();
}
log::info!(
"Stripe usage sync: Synced {usage_meter_count} usage meters in {:?}",
Utc::now() - started_at
);
Ok(())
}

View File

@@ -4,19 +4,20 @@ mod tables;
#[cfg(test)]
pub mod tests;
use crate::{Error, Result};
use crate::{Error, Result, executor::Executor};
use anyhow::{Context as _, anyhow};
use collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use dashmap::DashMap;
use futures::StreamExt;
use project_repository_statuses::StatusKind;
use rand::{Rng, SeedableRng, prelude::StdRng};
use rpc::ExtensionProvides;
use rpc::{
ConnectionId, ExtensionMetadata,
proto::{self},
};
use sea_orm::{
ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr,
FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
TransactionTrait,
entity::prelude::*,
@@ -32,6 +33,7 @@ use std::{
ops::{Deref, DerefMut},
rc::Rc,
sync::Arc,
time::Duration,
};
use time::PrimitiveDateTime;
use tokio::sync::{Mutex, OwnedMutexGuard};
@@ -56,7 +58,6 @@ pub use tables::*;
#[cfg(test)]
pub struct DatabaseTestOptions {
pub executor: gpui::BackgroundExecutor,
pub runtime: tokio::runtime::Runtime,
pub query_failure_probability: parking_lot::Mutex<f64>,
}
@@ -68,6 +69,8 @@ pub struct Database {
pool: DatabaseConnection,
rooms: DashMap<RoomId, Arc<Mutex<()>>>,
projects: DashMap<ProjectId, Arc<Mutex<()>>>,
rng: Mutex<StdRng>,
executor: Executor,
notification_kinds_by_id: HashMap<NotificationKindId, &'static str>,
notification_kinds_by_name: HashMap<String, NotificationKindId>,
#[cfg(test)]
@@ -78,15 +81,17 @@ pub struct Database {
// separate files in the `queries` folder.
impl Database {
/// Connects to the database with the given options
pub async fn new(options: ConnectOptions) -> Result<Self> {
pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
sqlx::any::install_default_drivers();
Ok(Self {
options: options.clone(),
pool: sea_orm::Database::connect(options).await?,
rooms: DashMap::with_capacity(16384),
projects: DashMap::with_capacity(16384),
rng: Mutex::new(StdRng::seed_from_u64(0)),
notification_kinds_by_id: HashMap::default(),
notification_kinds_by_name: HashMap::default(),
executor,
#[cfg(test)]
test_options: None,
})
@@ -102,13 +107,48 @@ impl Database {
self.projects.clear();
}
/// Transaction runs things in a transaction. If you want to call other methods
/// and pass the transaction around you need to reborrow the transaction at each
/// call site with: `&*tx`.
pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
{
let body = async {
let (tx, result) = self.with_transaction(&f).await?;
let mut i = 0;
loop {
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(result) => match tx.commit().await.map_err(Into::into) {
Ok(()) => return Ok(result),
Err(error) => {
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
},
Err(error) => {
tx.rollback().await?;
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
}
i += 1;
}
};
self.run(body).await
}
pub async fn weak_transaction<F, Fut, T>(&self, f: F) -> Result<T>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
{
let body = async {
let (tx, result) = self.with_weak_transaction(&f).await?;
match result {
Ok(result) => match tx.commit().await.map_err(Into::into) {
Ok(()) => Ok(result),
@@ -134,28 +174,44 @@ impl Database {
Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
{
let body = async {
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(Some((room_id, data))) => {
let lock = self.rooms.entry(room_id).or_default().clone();
let _guard = lock.lock_owned().await;
match tx.commit().await.map_err(Into::into) {
Ok(()) => Ok(Some(TransactionGuard {
data,
_guard,
_not_send: PhantomData,
})),
Err(error) => Err(error),
let mut i = 0;
loop {
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(Some((room_id, data))) => {
let lock = self.rooms.entry(room_id).or_default().clone();
let _guard = lock.lock_owned().await;
match tx.commit().await.map_err(Into::into) {
Ok(()) => {
return Ok(Some(TransactionGuard {
data,
_guard,
_not_send: PhantomData,
}));
}
Err(error) => {
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
}
}
Ok(None) => match tx.commit().await.map_err(Into::into) {
Ok(()) => return Ok(None),
Err(error) => {
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
},
Err(error) => {
tx.rollback().await?;
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
}
Ok(None) => match tx.commit().await.map_err(Into::into) {
Ok(()) => Ok(None),
Err(error) => Err(error),
},
Err(error) => {
tx.rollback().await?;
Err(error)
}
i += 1;
}
};
@@ -173,26 +229,38 @@ impl Database {
{
let room_id = Database::room_id_for_project(self, project_id).await?;
let body = async {
let lock = if let Some(room_id) = room_id {
self.rooms.entry(room_id).or_default().clone()
} else {
self.projects.entry(project_id).or_default().clone()
};
let _guard = lock.lock_owned().await;
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(data) => match tx.commit().await.map_err(Into::into) {
Ok(()) => Ok(TransactionGuard {
data,
_guard,
_not_send: PhantomData,
}),
Err(error) => Err(error),
},
Err(error) => {
tx.rollback().await?;
Err(error)
let mut i = 0;
loop {
let lock = if let Some(room_id) = room_id {
self.rooms.entry(room_id).or_default().clone()
} else {
self.projects.entry(project_id).or_default().clone()
};
let _guard = lock.lock_owned().await;
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(data) => match tx.commit().await.map_err(Into::into) {
Ok(()) => {
return Ok(TransactionGuard {
data,
_guard,
_not_send: PhantomData,
});
}
Err(error) => {
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
},
Err(error) => {
tx.rollback().await?;
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
}
i += 1;
}
};
@@ -212,22 +280,34 @@ impl Database {
Fut: Send + Future<Output = Result<T>>,
{
let body = async {
let lock = self.rooms.entry(room_id).or_default().clone();
let _guard = lock.lock_owned().await;
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(data) => match tx.commit().await.map_err(Into::into) {
Ok(()) => Ok(TransactionGuard {
data,
_guard,
_not_send: PhantomData,
}),
Err(error) => Err(error),
},
Err(error) => {
tx.rollback().await?;
Err(error)
let mut i = 0;
loop {
let lock = self.rooms.entry(room_id).or_default().clone();
let _guard = lock.lock_owned().await;
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(data) => match tx.commit().await.map_err(Into::into) {
Ok(()) => {
return Ok(TransactionGuard {
data,
_guard,
_not_send: PhantomData,
});
}
Err(error) => {
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
},
Err(error) => {
tx.rollback().await?;
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
}
i += 1;
}
};
@@ -235,6 +315,28 @@ impl Database {
}
async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
{
let tx = self
.pool
.begin_with_config(Some(IsolationLevel::Serializable), None)
.await?;
let mut tx = Arc::new(Some(tx));
let result = f(TransactionHandle(tx.clone())).await;
let tx = Arc::get_mut(&mut tx)
.and_then(|tx| tx.take())
.context("couldn't complete transaction because it's still in use")?;
Ok((tx, result))
}
async fn with_weak_transaction<F, Fut, T>(
&self,
f: &F,
) -> Result<(DatabaseTransaction, Result<T>)>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
@@ -259,13 +361,13 @@ impl Database {
{
#[cfg(test)]
{
use rand::prelude::*;
let test_options = self.test_options.as_ref().unwrap();
test_options.executor.simulate_random_delay().await;
let fail_probability = *test_options.query_failure_probability.lock();
if test_options.executor.rng().gen_bool(fail_probability) {
return Err(anyhow!("simulated query failure"))?;
if let Executor::Deterministic(executor) = &self.executor {
executor.simulate_random_delay().await;
let fail_probability = *test_options.query_failure_probability.lock();
if executor.rng().gen_bool(fail_probability) {
return Err(anyhow!("simulated query failure"))?;
}
}
test_options.runtime.block_on(future)
@@ -276,6 +378,46 @@ impl Database {
future.await
}
}
async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: usize) -> bool {
// If the error is due to a failure to serialize concurrent transactions, then retry
// this transaction after a delay. With each subsequent retry, double the delay duration.
// Also vary the delay randomly in order to ensure different database connections retry
// at different times.
const SLEEPS: [f32; 10] = [10., 20., 40., 80., 160., 320., 640., 1280., 2560., 5120.];
if is_serialization_error(error) && prev_attempt_count < SLEEPS.len() {
let base_delay = SLEEPS[prev_attempt_count];
let randomized_delay = base_delay * self.rng.lock().await.gen_range(0.5..=2.0);
log::warn!(
"retrying transaction after serialization error. delay: {} ms.",
randomized_delay
);
self.executor
.sleep(Duration::from_millis(randomized_delay as u64))
.await;
true
} else {
false
}
}
}
fn is_serialization_error(error: &Error) -> bool {
const SERIALIZATION_FAILURE_CODE: &str = "40001";
match error {
Error::Database(
DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
| DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
) if error
.as_database_error()
.and_then(|error| error.code())
.as_deref()
== Some(SERIALIZATION_FAILURE_CODE) =>
{
true
}
_ => false,
}
}
/// A handle to a [`DatabaseTransaction`].

View File

@@ -20,7 +20,7 @@ impl Database {
&self,
params: &CreateBillingCustomerParams,
) -> Result<billing_customer::Model> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let customer = billing_customer::Entity::insert(billing_customer::ActiveModel {
user_id: ActiveValue::set(params.user_id),
stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()),
@@ -40,7 +40,7 @@ impl Database {
id: BillingCustomerId,
params: &UpdateBillingCustomerParams,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
billing_customer::Entity::update(billing_customer::ActiveModel {
id: ActiveValue::set(id),
user_id: params.user_id.clone(),
@@ -61,7 +61,7 @@ impl Database {
&self,
id: BillingCustomerId,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::Id.eq(id))
.one(&*tx)
@@ -75,7 +75,7 @@ impl Database {
&self,
user_id: UserId,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::UserId.eq(user_id))
.one(&*tx)
@@ -89,7 +89,7 @@ impl Database {
&self,
stripe_customer_id: &str,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::StripeCustomerId.eq(stripe_customer_id))
.one(&*tx)

View File

@@ -22,7 +22,7 @@ impl Database {
&self,
user_id: UserId,
) -> Result<Option<billing_preference::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_preference::Entity::find()
.filter(billing_preference::Column::UserId.eq(user_id))
.one(&*tx)
@@ -37,7 +37,7 @@ impl Database {
user_id: UserId,
params: &CreateBillingPreferencesParams,
) -> Result<billing_preference::Model> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let preferences = billing_preference::Entity::insert(billing_preference::ActiveModel {
user_id: ActiveValue::set(user_id),
max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
@@ -65,7 +65,7 @@ impl Database {
user_id: UserId,
params: &UpdateBillingPreferencesParams,
) -> Result<billing_preference::Model> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let preferences = billing_preference::Entity::update_many()
.set(billing_preference::ActiveModel {
max_monthly_llm_usage_spending_in_cents: params

View File

@@ -35,7 +35,7 @@ impl Database {
&self,
params: &CreateBillingSubscriptionParams,
) -> Result<billing_subscription::Model> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
billing_customer_id: ActiveValue::set(params.billing_customer_id),
kind: ActiveValue::set(params.kind),
@@ -64,7 +64,7 @@ impl Database {
id: BillingSubscriptionId,
params: &UpdateBillingSubscriptionParams,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
billing_subscription::Entity::update(billing_subscription::ActiveModel {
id: ActiveValue::set(id),
billing_customer_id: params.billing_customer_id.clone(),
@@ -90,7 +90,7 @@ impl Database {
&self,
id: BillingSubscriptionId,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_subscription::Entity::find_by_id(id)
.one(&*tx)
.await?)
@@ -103,7 +103,7 @@ impl Database {
&self,
stripe_subscription_id: &str,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_subscription::Entity::find()
.filter(
billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id),
@@ -118,7 +118,7 @@ impl Database {
&self,
user_id: UserId,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(billing_customer::Column::UserId.eq(user_id))
@@ -152,7 +152,7 @@ impl Database {
&self,
user_id: UserId,
) -> Result<Vec<billing_subscription::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let subscriptions = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(billing_customer::Column::UserId.eq(user_id))
@@ -169,7 +169,7 @@ impl Database {
&self,
user_ids: HashSet<UserId>,
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| {
self.weak_transaction(|tx| {
let user_ids = user_ids.clone();
async move {
let mut rows = billing_subscription::Entity::find()
@@ -201,7 +201,7 @@ impl Database {
&self,
user_ids: HashSet<UserId>,
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| {
self.weak_transaction(|tx| {
let user_ids = user_ids.clone();
async move {
let mut rows = billing_subscription::Entity::find()
@@ -236,7 +236,7 @@ impl Database {
/// Returns the count of the active billing subscriptions for the user with the specified ID.
pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let count = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(

View File

@@ -501,8 +501,10 @@ impl Database {
/// Returns all channels for the user with the given ID.
pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
self.transaction(|tx| async move { self.get_user_channels(user_id, None, true, &tx).await })
.await
self.weak_transaction(
|tx| async move { self.get_user_channels(user_id, None, true, &tx).await },
)
.await
}
/// Returns all channels for the user with the given ID that are descendants

View File

@@ -15,7 +15,7 @@ impl Database {
user_b_busy: bool,
}
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let user_a_participant = Alias::new("user_a_participant");
let user_b_participant = Alias::new("user_b_participant");
let mut db_contacts = contact::Entity::find()
@@ -91,7 +91,7 @@ impl Database {
/// Returns whether the given user is a busy (on a call).
pub async fn is_user_busy(&self, user_id: UserId) -> Result<bool> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let participant = room_participant::Entity::find()
.filter(room_participant::Column::UserId.eq(user_id))
.one(&*tx)

View File

@@ -9,7 +9,7 @@ pub enum ContributorSelector {
impl Database {
/// Retrieves the GitHub logins of all users who have signed the CLA.
pub async fn get_contributors(&self) -> Result<Vec<String>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryGithubLogin {
GithubLogin,
@@ -32,7 +32,7 @@ impl Database {
&self,
selector: &ContributorSelector,
) -> Result<Option<DateTime>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let condition = match selector {
ContributorSelector::GitHubUserId { github_user_id } => {
user::Column::GithubUserId.eq(*github_user_id)
@@ -69,7 +69,7 @@ impl Database {
github_user_created_at: DateTimeUtc,
initial_channel_id: Option<ChannelId>,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let user = self
.update_or_create_user_by_github_account_tx(
github_login,

View File

@@ -8,7 +8,7 @@ impl Database {
model: &str,
digests: &[Vec<u8>],
) -> Result<HashMap<Vec<u8>, Vec<f32>>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let embeddings = {
let mut db_embeddings = embedding::Entity::find()
.filter(
@@ -52,7 +52,7 @@ impl Database {
model: &str,
embeddings: &HashMap<Vec<u8>, Vec<f32>>,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
embedding::Entity::insert_many(embeddings.iter().map(|(digest, dimensions)| {
let now_offset_datetime = OffsetDateTime::now_utc();
let retrieved_at =
@@ -78,7 +78,7 @@ impl Database {
}
pub async fn purge_old_embeddings(&self) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
embedding::Entity::delete_many()
.filter(
embedding::Column::RetrievedAt

View File

@@ -15,7 +15,7 @@ impl Database {
max_schema_version: i32,
limit: usize,
) -> Result<Vec<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let mut condition = Condition::all()
.add(
extension::Column::LatestVersion
@@ -43,7 +43,7 @@ impl Database {
ids: &[&str],
constraints: Option<&ExtensionVersionConstraints>,
) -> Result<Vec<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let extensions = extension::Entity::find()
.filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
.all(&*tx)
@@ -123,7 +123,7 @@ impl Database {
&self,
extension_id: &str,
) -> Result<Vec<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let condition = extension::Column::ExternalId
.eq(extension_id)
.into_condition();
@@ -162,7 +162,7 @@ impl Database {
extension_id: &str,
constraints: Option<&ExtensionVersionConstraints>,
) -> Result<Option<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let extension = extension::Entity::find()
.filter(extension::Column::ExternalId.eq(extension_id))
.one(&*tx)
@@ -187,7 +187,7 @@ impl Database {
extension_id: &str,
version: &str,
) -> Result<Option<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let extension = extension::Entity::find()
.filter(extension::Column::ExternalId.eq(extension_id))
.filter(extension_version::Column::Version.eq(version))
@@ -204,7 +204,7 @@ impl Database {
}
pub async fn get_known_extension_versions(&self) -> Result<HashMap<String, Vec<String>>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let mut extension_external_ids_by_id = HashMap::default();
let mut rows = extension::Entity::find().stream(&*tx).await?;
@@ -242,7 +242,7 @@ impl Database {
&self,
versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
for (external_id, versions) in versions_by_extension_id {
if versions.is_empty() {
continue;
@@ -349,7 +349,7 @@ impl Database {
}
pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryId {
Id,

View File

@@ -13,7 +13,7 @@ impl Database {
&self,
params: &CreateProcessedStripeEventParams,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
processed_stripe_event::Entity::insert(processed_stripe_event::ActiveModel {
stripe_event_id: ActiveValue::set(params.stripe_event_id.clone()),
stripe_event_type: ActiveValue::set(params.stripe_event_type.clone()),
@@ -35,7 +35,7 @@ impl Database {
&self,
event_id: &str,
) -> Result<Option<processed_stripe_event::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(processed_stripe_event::Entity::find_by_id(event_id)
.one(&*tx)
.await?)
@@ -48,7 +48,7 @@ impl Database {
&self,
event_ids: &[&str],
) -> Result<Vec<processed_stripe_event::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(processed_stripe_event::Entity::find()
.filter(
processed_stripe_event::Column::StripeEventId.is_in(event_ids.iter().copied()),

View File

@@ -112,7 +112,7 @@ impl Database {
}
pub async fn delete_project(&self, project_id: ProjectId) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
project::Entity::delete_by_id(project_id).exec(&*tx).await?;
Ok(())
})

View File

@@ -80,7 +80,7 @@ impl Database {
&self,
user_id: UserId,
) -> Result<Option<proto::IncomingCall>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let pending_participant = room_participant::Entity::find()
.filter(
room_participant::Column::UserId

View File

@@ -142,50 +142,6 @@ impl Database {
}
}
loop {
let delete_query = Query::delete()
.from_table(project_repository_statuses::Entity)
.and_where(
Expr::tuple([Expr::col((
project_repository_statuses::Entity,
project_repository_statuses::Column::ProjectId,
))
.into()])
.in_subquery(
Query::select()
.columns([(
project_repository_statuses::Entity,
project_repository_statuses::Column::ProjectId,
)])
.from(project_repository_statuses::Entity)
.inner_join(
project::Entity,
Expr::col((project::Entity, project::Column::Id)).equals((
project_repository_statuses::Entity,
project_repository_statuses::Column::ProjectId,
)),
)
.and_where(project::Column::HostConnectionServerId.ne(server_id))
.limit(10000)
.to_owned(),
),
)
.to_owned();
let statement = Statement::from_sql_and_values(
tx.get_database_backend(),
delete_query
.to_string(sea_orm::sea_query::PostgresQueryBuilder)
.as_str(),
vec![],
);
let result = tx.execute(statement).await?;
if result.rows_affected() == 0 {
break;
}
}
Ok(())
})
.await

View File

@@ -382,7 +382,7 @@ impl Database {
/// Returns the active flags for the user.
pub async fn get_user_flags(&self, user: UserId) -> Result<Vec<String>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryAs {
Flag,

View File

@@ -17,15 +17,11 @@ use crate::migrations::run_database_migrations;
use super::*;
use gpui::BackgroundExecutor;
use parking_lot::Mutex;
use rand::prelude::*;
use sea_orm::ConnectionTrait;
use sqlx::migrate::MigrateDatabase;
use std::{
sync::{
Arc,
atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
},
time::Duration,
use std::sync::{
Arc,
atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
};
pub struct TestDb {
@@ -45,7 +41,9 @@ impl TestDb {
let mut db = runtime.block_on(async {
let mut options = ConnectOptions::new(url);
options.max_connections(5);
let mut db = Database::new(options).await.unwrap();
let mut db = Database::new(options, Executor::Deterministic(executor.clone()))
.await
.unwrap();
let sql = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/migrations.sqlite/20221109000000_test_schema.sql"
@@ -62,7 +60,6 @@ impl TestDb {
});
db.test_options = Some(DatabaseTestOptions {
executor,
runtime,
query_failure_probability: parking_lot::Mutex::new(0.0),
});
@@ -96,7 +93,9 @@ impl TestDb {
options
.max_connections(5)
.idle_timeout(Duration::from_secs(0));
let mut db = Database::new(options).await.unwrap();
let mut db = Database::new(options, Executor::Deterministic(executor.clone()))
.await
.unwrap();
let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
run_database_migrations(db.options(), migrations_path)
.await
@@ -106,7 +105,6 @@ impl TestDb {
});
db.test_options = Some(DatabaseTestOptions {
executor,
runtime,
query_failure_probability: parking_lot::Mutex::new(0.0),
});

View File

@@ -49,7 +49,7 @@ async fn test_purge_old_embeddings(cx: &mut gpui::TestAppContext) {
db.save_embeddings(model, &embeddings).await.unwrap();
// Reach into the DB and change the retrieved at to be > 60 days
db.transaction(|tx| {
db.weak_transaction(|tx| {
let digest = digest.clone();
async move {
let sixty_days_ago = OffsetDateTime::now_utc().sub(Duration::days(61));

View File

@@ -285,7 +285,7 @@ impl AppState {
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
let mut db_options = db::ConnectOptions::new(config.database_url.clone());
db_options.max_connections(config.database_max_connections);
let mut db = Database::new(db_options).await?;
let mut db = Database::new(db_options, Executor::Production).await?;
db.initialize_notification_kinds().await?;
let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config

View File

@@ -59,7 +59,7 @@ async fn main() -> Result<()> {
let config = envy::from_env::<Config>().expect("error loading config");
let db_options = db::ConnectOptions::new(config.database_url.clone());
let mut db = Database::new(db_options).await?;
let mut db = Database::new(db_options, Executor::Production).await?;
db.initialize_notification_kinds().await?;
collab::seed::seed(&config, &db, false).await?;
@@ -253,7 +253,7 @@ async fn main() -> Result<()> {
async fn setup_app_database(config: &Config) -> Result<()> {
let db_options = db::ConnectOptions::new(config.database_url.clone());
let mut db = Database::new(db_options).await?;
let mut db = Database::new(db_options, Executor::Production).await?;
let migrations_path = config.migrations_path.as_deref().unwrap_or_else(|| {
#[cfg(feature = "sqlite")]

View File

@@ -22,9 +22,7 @@ use gpui::{
use language::{
Diagnostic, DiagnosticEntry, DiagnosticSourceKind, FakeLspAdapter, Language, LanguageConfig,
LanguageMatcher, LineEnding, OffsetRangeExt, Point, Rope,
language_settings::{
AllLanguageSettings, Formatter, FormatterList, PrettierSettings, SelectedFormatter,
},
language_settings::{AllLanguageSettings, Formatter, PrettierSettings, SelectedFormatter},
tree_sitter_rust, tree_sitter_typescript,
};
use lsp::{LanguageServerId, OneOf};
@@ -4591,14 +4589,13 @@ async fn test_formatting_buffer(
cx_a.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |file| {
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList::Single(
Formatter::External {
file.defaults.formatter =
Some(SelectedFormatter::List(vec![Formatter::External {
command: "awk".into(),
arguments: Some(
vec!["{sub(/two/,\"{buffer_path}\")}1".to_string()].into(),
),
},
)));
}]));
});
});
});
@@ -4698,9 +4695,10 @@ async fn test_prettier_formatting_buffer(
cx_b.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |file| {
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList::Single(
Formatter::LanguageServer { name: None },
)));
file.defaults.formatter =
Some(SelectedFormatter::List(vec![Formatter::LanguageServer {
name: None,
}]));
file.defaults.prettier = Some(PrettierSettings {
allowed: true,
..PrettierSettings::default()
@@ -4821,7 +4819,7 @@ async fn test_definition(
);
let definitions_1 = project_b
.update(cx_b, |p, cx| p.definitions(&buffer_b, 23, cx))
.update(cx_b, |p, cx| p.definition(&buffer_b, 23, cx))
.await
.unwrap();
cx_b.read(|cx| {
@@ -4852,7 +4850,7 @@ async fn test_definition(
);
let definitions_2 = project_b
.update(cx_b, |p, cx| p.definitions(&buffer_b, 33, cx))
.update(cx_b, |p, cx| p.definition(&buffer_b, 33, cx))
.await
.unwrap();
cx_b.read(|cx| {
@@ -4889,7 +4887,7 @@ async fn test_definition(
);
let type_definitions = project_b
.update(cx_b, |p, cx| p.type_definitions(&buffer_b, 7, cx))
.update(cx_b, |p, cx| p.type_definition(&buffer_b, 7, cx))
.await
.unwrap();
cx_b.read(|cx| {
@@ -5057,7 +5055,7 @@ async fn test_references(
lsp_response_tx
.unbounded_send(Err(anyhow!("can't find references")))
.unwrap();
assert_eq!(references.await.unwrap(), []);
references.await.unwrap_err();
// User is informed that the request is no longer pending.
executor.run_until_parked();
@@ -5641,7 +5639,7 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it(
let definitions;
let buffer_b2;
if rng.r#gen() {
definitions = project_b.update(cx_b, |p, cx| p.definitions(&buffer_b1, 23, cx));
definitions = project_b.update(cx_b, |p, cx| p.definition(&buffer_b1, 23, cx));
(buffer_b2, _) = project_b
.update(cx_b, |p, cx| {
p.open_buffer_with_lsp((worktree_id, "b.rs"), cx)
@@ -5655,7 +5653,7 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it(
})
.await
.unwrap();
definitions = project_b.update(cx_b, |p, cx| p.definitions(&buffer_b1, 23, cx));
definitions = project_b.update(cx_b, |p, cx| p.definition(&buffer_b1, 23, cx));
}
let definitions = definitions.await.unwrap();

View File

@@ -838,7 +838,7 @@ impl RandomizedTest for ProjectCollaborationTest {
.map(|_| Ok(()))
.boxed(),
LspRequestKind::Definition => project
.definitions(&buffer, offset, cx)
.definition(&buffer, offset, cx)
.map_ok(|_| ())
.boxed(),
LspRequestKind::Highlights => project

View File

@@ -14,8 +14,7 @@ use http_client::BlockedHttpClient;
use language::{
FakeLspAdapter, Language, LanguageConfig, LanguageMatcher, LanguageRegistry,
language_settings::{
AllLanguageSettings, Formatter, FormatterList, PrettierSettings, SelectedFormatter,
language_settings,
AllLanguageSettings, Formatter, PrettierSettings, SelectedFormatter, language_settings,
},
tree_sitter_typescript,
};
@@ -505,9 +504,10 @@ async fn test_ssh_collaboration_formatting_with_prettier(
cx_b.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |file| {
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList::Single(
Formatter::LanguageServer { name: None },
)));
file.defaults.formatter =
Some(SelectedFormatter::List(vec![Formatter::LanguageServer {
name: None,
}]));
file.defaults.prettier = Some(PrettierSettings {
allowed: true,
..PrettierSettings::default()

View File

@@ -30,13 +30,7 @@ use workspace::{
};
use workspace::{item::Dedup, notifications::NotificationId};
actions!(
collab,
[
/// Copies a link to the current position in the channel buffer.
CopyLink
]
);
actions!(collab, [CopyLink]);
pub fn init(cx: &mut App) {
workspace::FollowableViewRegistry::register::<ChannelView>(cx)

View File

@@ -71,13 +71,7 @@ struct SerializedChatPanel {
width: Option<Pixels>,
}
actions!(
chat_panel,
[
/// Toggles focus on the chat panel.
ToggleFocus
]
);
actions!(chat_panel, [ToggleFocus]);
impl ChatPanel {
pub fn new(

View File

@@ -44,25 +44,15 @@ use workspace::{
actions!(
collab_panel,
[
/// Toggles focus on the collaboration panel.
ToggleFocus,
/// Removes the selected channel or contact.
Remove,
/// Opens the context menu for the selected item.
Secondary,
/// Collapses the selected channel in the tree view.
CollapseSelectedChannel,
/// Expands the selected channel in the tree view.
ExpandSelectedChannel,
/// Starts moving a channel to a new location.
StartMoveChannel,
/// Moves the selected item to the current location.
MoveSelected,
/// Inserts a space character in the filter input.
InsertSpace,
/// Moves the selected channel up in the list.
MoveChannelUp,
/// Moves the selected channel down in the list.
MoveChannelDown,
]
);

View File

@@ -17,13 +17,9 @@ use workspace::{ModalView, notifications::DetachAndPromptErr};
actions!(
channel_modal,
[
/// Selects the next control in the channel modal.
SelectNextControl,
/// Toggles between invite members and manage members mode.
ToggleMode,
/// Toggles admin status for the selected member.
ToggleMemberAdmin,
/// Removes the selected member from the channel.
RemoveMember
]
);

View File

@@ -74,13 +74,7 @@ pub struct NotificationPresenter {
pub can_navigate: bool,
}
actions!(
notification_panel,
[
/// Toggles focus on the notification panel.
ToggleFocus
]
);
actions!(notification_panel, [ToggleFocus]);
pub fn init(cx: &mut App) {
cx.observe_new(|workspace: &mut Workspace, _, _| {

View File

@@ -61,7 +61,7 @@ impl RenderOnce for ComponentExample {
12.0,
12.0,
))
.shadow_xs()
.shadow_sm()
.child(self.element),
)
.into_any_element()

View File

@@ -46,17 +46,11 @@ pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in, reinstall_an
actions!(
copilot,
[
/// Requests a code completion suggestion from Copilot.
Suggest,
/// Cycles to the next Copilot suggestion.
NextSuggestion,
/// Cycles to the previous Copilot suggestion.
PreviousSuggestion,
/// Reinstalls the Copilot language server.
Reinstall,
/// Signs in to GitHub Copilot.
SignIn,
/// Signs out of GitHub Copilot.
SignOut
]
);

View File

@@ -79,9 +79,9 @@ impl JsDebugAdapter {
let command = configuration.get("command")?.as_str()?.to_owned();
let mut args = shlex::split(&command)?.into_iter();
let program = args.next()?;
configuration.insert("runtimeExecutable".to_owned(), program.into());
configuration.insert("program".to_owned(), program.into());
configuration.insert(
"runtimeArgs".to_owned(),
"args".to_owned(),
args.map(Value::from).collect::<Vec<_>>().into(),
);
configuration.insert("console".to_owned(), "externalTerminal".into());
@@ -522,11 +522,7 @@ impl DebugAdapter for JsDebugAdapter {
}
fn label_for_child_session(&self, args: &StartDebuggingRequestArguments) -> Option<String> {
let label = args
.configuration
.get("name")?
.as_str()
.filter(|name| !name.is_empty())?;
let label = args.configuration.get("name")?.as_str()?;
Some(label.to_owned())
}
}

View File

@@ -660,15 +660,6 @@ impl DebugAdapter for PythonDebugAdapter {
self.get_installed_binary(delegate, &config, None, user_args, toolchain, false)
.await
}
fn label_for_child_session(&self, args: &StartDebuggingRequestArguments) -> Option<String> {
let label = args
.configuration
.get("name")?
.as_str()
.filter(|label| !label.is_empty())?;
Some(label.to_owned())
}
}
async fn fetch_latest_adapter_version_from_github(

View File

@@ -918,13 +918,7 @@ impl Render for DapLogView {
}
}
actions!(
dev,
[
/// Opens the debug adapter protocol logs viewer.
OpenDebugAdapterLogs
]
);
actions!(dev, [OpenDebugAdapterLogs]);
pub fn init(cx: &mut App) {
let log_store = cx.new(|cx| LogStore::new(cx));

View File

@@ -5,7 +5,7 @@ use crate::session::running::breakpoint_list::BreakpointList;
use crate::{
ClearAllBreakpoints, Continue, CopyDebugAdapterArguments, Detach, FocusBreakpointList,
FocusConsole, FocusFrames, FocusLoadedSources, FocusModules, FocusTerminal, FocusVariables,
NewProcessModal, NewProcessMode, Pause, RerunSession, StepInto, StepOut, StepOver, Stop,
NewProcessModal, NewProcessMode, Pause, Restart, StepInto, StepOut, StepOver, Stop,
ToggleExpandItem, ToggleSessionPicker, ToggleThreadPicker, persistence, spawn_task_or_modal,
};
use anyhow::{Context as _, Result, anyhow};
@@ -25,7 +25,7 @@ use gpui::{
use itertools::Itertools as _;
use language::Buffer;
use project::debugger::session::{Session, SessionStateEvent};
use project::{DebugScenarioContext, Fs, ProjectPath, WorktreeId};
use project::{Fs, ProjectPath, WorktreeId};
use project::{Project, debugger::session::ThreadStatus};
use rpc::proto::{self};
use settings::Settings;
@@ -197,7 +197,6 @@ impl DebugPanel {
.and_then(|buffer| buffer.read(cx).file())
.map(|f| f.worktree_id(cx))
});
let Some(worktree) = worktree
.and_then(|id| self.project.read(cx).worktree_for_id(id, cx))
.or_else(|| self.project.read(cx).visible_worktrees(cx).next())
@@ -205,7 +204,6 @@ impl DebugPanel {
log::debug!("Could not find a worktree to spawn the debug session in");
return;
};
self.debug_scenario_scheduled_last = true;
if let Some(inventory) = self
.project
@@ -216,15 +214,7 @@ impl DebugPanel {
.cloned()
{
inventory.update(cx, |inventory, _| {
inventory.scenario_scheduled(
scenario.clone(),
// todo(debugger): Task context is cloned three times
// once in Session,inventory, and in resolve scenario
// we should wrap it in an RC instead to save some memory
task_context.clone(),
worktree_id,
active_buffer.as_ref().map(|buffer| buffer.downgrade()),
);
inventory.scenario_scheduled(scenario.clone());
})
}
let task = cx.spawn_in(window, {
@@ -235,16 +225,6 @@ impl DebugPanel {
let definition = debug_session
.update_in(cx, |debug_session, window, cx| {
debug_session.running_state().update(cx, |running, cx| {
if scenario.build.is_some() {
running.scenario = Some(scenario.clone());
running.scenario_context = Some(DebugScenarioContext {
active_buffer: active_buffer
.as_ref()
.map(|entity| entity.downgrade()),
task_context: task_context.clone(),
worktree_id: worktree_id,
});
};
running.resolve_scenario(
scenario,
task_context,
@@ -293,8 +273,7 @@ impl DebugPanel {
return;
};
let workspace = self.workspace.clone();
let Some((scenario, context)) = task_inventory.read(cx).last_scheduled_scenario().cloned()
else {
let Some(scenario) = task_inventory.read(cx).last_scheduled_scenario().cloned() else {
window.defer(cx, move |window, cx| {
workspace
.update(cx, |workspace, cx| {
@@ -305,22 +284,28 @@ impl DebugPanel {
return;
};
let DebugScenarioContext {
task_context,
worktree_id,
active_buffer,
} = context;
cx.spawn_in(window, async move |this, cx| {
let task_contexts = workspace
.update_in(cx, |workspace, window, cx| {
tasks_ui::task_contexts(workspace, window, cx)
})?
.await;
let active_buffer = active_buffer.and_then(|buffer| buffer.upgrade());
let task_context = task_contexts.active_context().cloned().unwrap_or_default();
let worktree_id = task_contexts.worktree();
self.start_session(
scenario,
task_context,
active_buffer,
worktree_id,
window,
cx,
);
this.update_in(cx, |this, window, cx| {
this.start_session(
scenario.clone(),
task_context,
None,
worktree_id,
window,
cx,
);
})
})
.detach();
}
pub(crate) async fn register_session(
@@ -773,16 +758,16 @@ impl DebugPanel {
.icon_size(IconSize::XSmall)
.on_click(window.listener_for(
&running_state,
|this, _, window, cx| {
this.rerun_session(window, cx);
|this, _, _window, cx| {
this.restart_session(cx);
},
))
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
"Rerun Session",
&RerunSession,
"Restart",
&Restart,
&focus_handle,
window,
cx,
@@ -1313,13 +1298,11 @@ impl Render for DebugPanel {
}
v_flex()
.when(!self.is_zoomed, |this| {
this.when_else(
self.position(window, cx) == DockPosition::Bottom,
|this| this.max_h(self.size),
|this| this.max_w(self.size),
)
})
.when_else(
self.position(window, cx) == DockPosition::Bottom,
|this| this.max_h(self.size),
|this| this.max_w(self.size),
)
.size_full()
.key_context("DebugPanel")
.child(h_flex().children(self.top_controls_strip(window, cx)))
@@ -1617,13 +1600,12 @@ impl workspace::DebuggerProvider for DebuggerProvider {
definition: DebugScenario,
context: TaskContext,
buffer: Option<Entity<Buffer>>,
worktree_id: Option<WorktreeId>,
window: &mut Window,
cx: &mut App,
) {
self.0.update(cx, |_, cx| {
cx.defer_in(window, move |this, window, cx| {
this.start_session(definition, context, buffer, worktree_id, window, cx);
cx.defer_in(window, |this, window, cx| {
this.start_session(definition, context, buffer, None, window, cx);
})
})
}

View File

@@ -32,67 +32,34 @@ pub mod tests;
actions!(
debugger,
[
/// Starts a new debugging session.
Start,
/// Continues execution until the next breakpoint.
Continue,
/// Detaches the debugger from the running process.
Detach,
/// Pauses the currently running program.
Pause,
/// Restarts the current debugging session.
Restart,
/// Reruns the current debugging session with the same configuration.
RerunSession,
/// Steps into the next function call.
StepInto,
/// Steps over the current line.
StepOver,
/// Steps out of the current function.
StepOut,
/// Steps back to the previous statement.
StepBack,
/// Stops the debugging session.
Stop,
/// Toggles whether to ignore all breakpoints.
ToggleIgnoreBreakpoints,
/// Clears all breakpoints in the project.
ClearAllBreakpoints,
/// Focuses on the debugger console panel.
FocusConsole,
/// Focuses on the variables panel.
FocusVariables,
/// Focuses on the breakpoint list panel.
FocusBreakpointList,
/// Focuses on the call stack frames panel.
FocusFrames,
/// Focuses on the loaded modules panel.
FocusModules,
/// Focuses on the loaded sources panel.
FocusLoadedSources,
/// Focuses on the terminal panel.
FocusTerminal,
/// Shows the stack trace for the current thread.
ShowStackTrace,
/// Toggles the thread picker dropdown.
ToggleThreadPicker,
/// Toggles the session picker dropdown.
ToggleSessionPicker,
/// Reruns the last debugging session.
#[action(deprecated_aliases = ["debugger::RerunLastSession"])]
Rerun,
/// Toggles expansion of the selected item in the debugger UI.
RerunLastSession,
ToggleExpandItem,
]
);
actions!(
dev,
[
/// Copies debug adapter launch arguments to clipboard.
CopyDebugAdapterArguments
]
);
actions!(dev, [CopyDebugAdapterArguments]);
pub fn init(cx: &mut App) {
DebuggerSettings::register(cx);
@@ -107,15 +74,17 @@ pub fn init(cx: &mut App) {
.register_action(|workspace: &mut Workspace, _: &Start, window, cx| {
NewProcessModal::show(workspace, window, NewProcessMode::Debug, None, cx);
})
.register_action(|workspace: &mut Workspace, _: &Rerun, window, cx| {
let Some(debug_panel) = workspace.panel::<DebugPanel>(cx) else {
return;
};
.register_action(
|workspace: &mut Workspace, _: &RerunLastSession, window, cx| {
let Some(debug_panel) = workspace.panel::<DebugPanel>(cx) else {
return;
};
debug_panel.update(cx, |debug_panel, cx| {
debug_panel.rerun_last_session(workspace, window, cx);
})
})
debug_panel.update(cx, |debug_panel, cx| {
debug_panel.rerun_last_session(workspace, window, cx);
})
},
)
.register_action(
|workspace: &mut Workspace, _: &ShutdownDebugAdapters, _window, cx| {
workspace.project().update(cx, |project, cx| {
@@ -241,14 +210,6 @@ pub fn init(cx: &mut App) {
.ok();
}
})
.on_action({
let active_item = active_item.clone();
move |_: &RerunSession, window, cx| {
active_item
.update(cx, |item, cx| item.rerun_session(window, cx))
.ok();
}
})
.on_action({
let active_item = active_item.clone();
move |_: &Stop, _, cx| {

View File

@@ -4,7 +4,6 @@ use collections::HashMap;
use gpui::{Animation, AnimationExt as _, Entity, Transformation, percentage};
use project::debugger::session::{ThreadId, ThreadStatus};
use ui::{ContextMenu, DropdownMenu, DropdownStyle, Indicator, prelude::*};
use util::truncate_and_trailoff;
use crate::{
debugger_panel::DebugPanel,
@@ -13,8 +12,6 @@ use crate::{
impl DebugPanel {
fn dropdown_label(label: impl Into<SharedString>) -> Label {
const MAX_LABEL_CHARS: usize = 50;
let label = truncate_and_trailoff(&label.into(), MAX_LABEL_CHARS);
Label::new(label).size(LabelSize::Small)
}
@@ -173,8 +170,6 @@ impl DebugPanel {
window: &mut Window,
cx: &mut Context<Self>,
) -> Option<DropdownMenu> {
const MAX_LABEL_CHARS: usize = 150;
let running_state = running_state.clone();
let running_state_read = running_state.read(cx);
let thread_id = running_state_read.thread_id();
@@ -207,7 +202,6 @@ impl DebugPanel {
.is_empty()
.then(|| format!("Tid: {}", thread.id))
.unwrap_or_else(|| thread.name);
let entry_name = truncate_and_trailoff(&entry_name, MAX_LABEL_CHARS);
this = this.entry(entry_name, None, move |window, cx| {
running_state.update(cx, |running_state, cx| {

View File

@@ -23,9 +23,7 @@ use gpui::{
};
use itertools::Itertools as _;
use picker::{Picker, PickerDelegate, highlighted_match_with_paths::HighlightedMatch};
use project::{
DebugScenarioContext, ProjectPath, TaskContexts, TaskSourceKind, task_store::TaskStore,
};
use project::{ProjectPath, TaskContexts, TaskSourceKind, task_store::TaskStore};
use settings::{Settings, initial_local_debug_tasks_content};
use task::{DebugScenario, RevealTarget, ZedDebugConfig};
use theme::ThemeSettings;
@@ -94,7 +92,6 @@ impl NewProcessModal {
cx.spawn_in(window, async move |workspace, cx| {
let task_contexts = workspace.update_in(cx, |workspace, window, cx| {
// todo(debugger): get the buffer here (if the active item is an editor) and store it so we can pass it to start_session later
tasks_ui::task_contexts(workspace, window, cx)
})?;
workspace.update_in(cx, |workspace, window, cx| {
@@ -1113,11 +1110,7 @@ pub(super) struct TaskMode {
pub(super) struct DebugDelegate {
task_store: Entity<TaskStore>,
candidates: Vec<(
Option<TaskSourceKind>,
DebugScenario,
Option<DebugScenarioContext>,
)>,
candidates: Vec<(Option<TaskSourceKind>, DebugScenario)>,
selected_index: usize,
matches: Vec<StringMatch>,
prompt: String,
@@ -1215,11 +1208,7 @@ impl DebugDelegate {
this.delegate.candidates = recent
.into_iter()
.map(|(scenario, context)| {
let (kind, scenario) =
Self::get_scenario_kind(&languages, &dap_registry, scenario);
(kind, scenario, Some(context))
})
.map(|scenario| Self::get_scenario_kind(&languages, &dap_registry, scenario))
.chain(
scenarios
.into_iter()
@@ -1234,7 +1223,7 @@ impl DebugDelegate {
.map(|(kind, scenario)| {
let (language, scenario) =
Self::get_scenario_kind(&languages, &dap_registry, scenario);
(language.or(Some(kind)), scenario, None)
(language.or(Some(kind)), scenario)
}),
)
.collect();
@@ -1280,7 +1269,7 @@ impl PickerDelegate for DebugDelegate {
let candidates: Vec<_> = candidates
.into_iter()
.enumerate()
.map(|(index, (_, candidate, _))| {
.map(|(index, (_, candidate))| {
StringMatchCandidate::new(index, candidate.label.as_ref())
})
.collect();
@@ -1445,40 +1434,25 @@ impl PickerDelegate for DebugDelegate {
.get(self.selected_index())
.and_then(|match_candidate| self.candidates.get(match_candidate.candidate_id).cloned());
let Some((_, debug_scenario, context)) = debug_scenario else {
let Some((_, debug_scenario)) = debug_scenario else {
return;
};
let context = context.unwrap_or_else(|| {
self.task_contexts
.as_ref()
.and_then(|task_contexts| {
Some(DebugScenarioContext {
task_context: task_contexts.active_context().cloned()?,
active_buffer: None,
worktree_id: task_contexts.worktree(),
})
})
.unwrap_or_default()
});
let DebugScenarioContext {
task_context,
active_buffer,
worktree_id,
} = context;
let active_buffer = active_buffer.and_then(|buffer| buffer.upgrade());
let (task_context, worktree_id) = self
.task_contexts
.as_ref()
.and_then(|task_contexts| {
Some((
task_contexts.active_context().cloned()?,
task_contexts.worktree(),
))
})
.unwrap_or_default();
send_telemetry(&debug_scenario, TelemetrySpawnLocation::ScenarioList, cx);
self.debug_panel
.update(cx, |panel, cx| {
panel.start_session(
debug_scenario,
task_context,
active_buffer,
worktree_id,
window,
cx,
);
panel.start_session(debug_scenario, task_context, None, worktree_id, window, cx);
})
.ok();

View File

@@ -12,7 +12,6 @@ use rpc::proto;
use running::RunningState;
use std::{cell::OnceCell, sync::OnceLock};
use ui::{Indicator, Tooltip, prelude::*};
use util::truncate_and_trailoff;
use workspace::{
CollaboratorId, FollowableItem, ViewId, Workspace,
item::{self, Item},
@@ -127,10 +126,7 @@ impl DebugSession {
}
pub(crate) fn label_element(&self, depth: usize, cx: &App) -> AnyElement {
const MAX_LABEL_CHARS: usize = 150;
let label = self.label(cx);
let label = truncate_and_trailoff(&label, MAX_LABEL_CHARS);
let is_terminated = self
.running_state

View File

@@ -33,7 +33,7 @@ use language::Buffer;
use loaded_source_list::LoadedSourceList;
use module_list::ModuleList;
use project::{
DebugScenarioContext, Project, WorktreeId,
Project, WorktreeId,
debugger::session::{Session, SessionEvent, ThreadId, ThreadStatus},
terminals::TerminalKind,
};
@@ -79,8 +79,6 @@ pub struct RunningState {
pane_close_subscriptions: HashMap<EntityId, Subscription>,
dock_axis: Axis,
_schedule_serialize: Option<Task<()>>,
pub(crate) scenario: Option<DebugScenario>,
pub(crate) scenario_context: Option<DebugScenarioContext>,
}
impl RunningState {
@@ -833,8 +831,6 @@ impl RunningState {
debug_terminal,
dock_axis,
_schedule_serialize: None,
scenario: None,
scenario_context: None,
}
}
@@ -1043,7 +1039,7 @@ impl RunningState {
let scenario = dap_registry
.adapter(&adapter)
.with_context(|| anyhow!("{}: is not a valid adapter name", &adapter))?.config_from_zed_format(zed_config)
.await?;
.await?;
config = scenario.config;
util::merge_non_null_json_value_into(extra_config, &mut config);
@@ -1529,34 +1525,6 @@ impl RunningState {
});
}
pub fn rerun_session(&mut self, window: &mut Window, cx: &mut Context<Self>) {
if let Some((scenario, context)) = self.scenario.take().zip(self.scenario_context.take())
&& scenario.build.is_some()
{
let DebugScenarioContext {
task_context,
active_buffer,
worktree_id,
} = context;
let active_buffer = active_buffer.and_then(|buffer| buffer.upgrade());
self.workspace
.update(cx, |workspace, cx| {
workspace.start_debug_session(
scenario,
task_context,
active_buffer,
worktree_id,
window,
cx,
)
})
.ok();
} else {
self.restart_session(cx);
}
}
pub fn restart_session(&self, cx: &mut Context<Self>) {
self.session().update(cx, |state, cx| {
state.restart(None, cx);

View File

@@ -33,12 +33,7 @@ use zed_actions::{ToggleEnableBreakpoint, UnsetBreakpoint};
actions!(
debugger,
[
/// Navigates to the previous breakpoint property in the list.
PreviousBreakpointProperty,
/// Navigates to the next breakpoint property in the list.
NextBreakpointProperty
]
[PreviousBreakpointProperty, NextBreakpointProperty]
);
#[derive(Clone, Copy, PartialEq)]
pub(crate) enum SelectedBreakpointKind {

View File

@@ -13,26 +13,17 @@ use gpui::{
Render, Subscription, Task, TextStyle, WeakEntity, actions,
};
use language::{Buffer, CodeLabel, ToOffset};
use menu::{Confirm, SelectNext, SelectPrevious};
use menu::Confirm;
use project::{
Completion, CompletionResponse,
debugger::session::{CompletionsQuery, OutputToken, Session},
search_history::{SearchHistory, SearchHistoryCursor},
debugger::session::{CompletionsQuery, OutputToken, Session, SessionEvent},
};
use settings::Settings;
use std::fmt::Write;
use std::{cell::RefCell, ops::Range, rc::Rc, usize};
use theme::{Theme, ThemeSettings};
use ui::{ContextMenu, Divider, PopoverMenu, SplitButton, Tooltip, prelude::*};
use util::ResultExt;
actions!(
console,
[
/// Adds an expression to the watch list.
WatchExpression
]
);
actions!(console, [WatchExpression]);
pub struct Console {
console: Entity<Editor>,
@@ -42,10 +33,8 @@ pub struct Console {
variable_list: Entity<VariableList>,
stack_frame_list: Entity<StackFrameList>,
last_token: OutputToken,
update_output_task: Option<Task<()>>,
update_output_task: Task<()>,
focus_handle: FocusHandle,
history: SearchHistory,
cursor: SearchHistoryCursor,
}
impl Console {
@@ -94,6 +83,11 @@ impl Console {
let _subscriptions = vec![
cx.subscribe(&stack_frame_list, Self::handle_stack_frame_list_events),
cx.subscribe_in(&session, window, |this, _, event, window, cx| {
if let SessionEvent::ConsoleOutput = event {
this.update_output(window, cx)
}
}),
cx.on_focus(&focus_handle, window, |console, window, cx| {
if console.is_running(cx) {
console.query_bar.focus_handle(cx).focus(window);
@@ -108,14 +102,9 @@ impl Console {
variable_list,
_subscriptions,
stack_frame_list,
update_output_task: None,
update_output_task: Task::ready(()),
last_token: OutputToken(0),
focus_handle,
history: SearchHistory::new(
None,
project::search_history::QueryInsertionBehavior::ReplacePreviousIfContains,
),
cursor: Default::default(),
}
}
@@ -144,116 +133,202 @@ impl Console {
self.session.read(cx).has_new_output(self.last_token)
}
fn add_messages(
pub fn add_messages<'a>(
&mut self,
events: Vec<OutputEvent>,
events: impl Iterator<Item = &'a OutputEvent>,
window: &mut Window,
cx: &mut App,
) -> Task<Result<()>> {
self.console.update(cx, |_, cx| {
cx.spawn_in(window, async move |console, cx| {
let mut len = console.update(cx, |this, cx| this.buffer().read(cx).len(cx))?;
let (output, spans, background_spans) = cx
.background_spawn(async move {
let mut all_spans = Vec::new();
let mut all_background_spans = Vec::new();
let mut to_insert = String::new();
let mut scratch = String::new();
) {
self.console.update(cx, |console, cx| {
console.set_read_only(false);
for event in &events {
scratch.clear();
let mut ansi_handler = ConsoleHandler::default();
let mut ansi_processor =
ansi::Processor::<ansi::StdSyncHandler>::default();
for event in events {
let to_insert = format!("{}\n", event.output.trim_end());
let trimmed_output = event.output.trim_end();
let _ = writeln!(&mut scratch, "{trimmed_output}");
ansi_processor.advance(&mut ansi_handler, scratch.as_bytes());
let output = std::mem::take(&mut ansi_handler.output);
to_insert.extend(output.chars());
let mut spans = std::mem::take(&mut ansi_handler.spans);
let mut background_spans =
std::mem::take(&mut ansi_handler.background_spans);
if ansi_handler.current_range_start < output.len() {
spans.push((
ansi_handler.current_range_start..output.len(),
ansi_handler.current_color,
));
let mut ansi_handler = ConsoleHandler::default();
let mut ansi_processor = ansi::Processor::<ansi::StdSyncHandler>::default();
let len = console.buffer().read(cx).len(cx);
ansi_processor.advance(&mut ansi_handler, to_insert.as_bytes());
let output = std::mem::take(&mut ansi_handler.output);
let mut spans = std::mem::take(&mut ansi_handler.spans);
let mut background_spans = std::mem::take(&mut ansi_handler.background_spans);
if ansi_handler.current_range_start < output.len() {
spans.push((
ansi_handler.current_range_start..output.len(),
ansi_handler.current_color,
));
}
if ansi_handler.current_background_range_start < output.len() {
background_spans.push((
ansi_handler.current_background_range_start..output.len(),
ansi_handler.current_background_color,
));
}
console.move_to_end(&editor::actions::MoveToEnd, window, cx);
console.insert(&output, window, cx);
let buffer = console.buffer().read(cx).snapshot(cx);
struct ConsoleAnsiHighlight;
for (range, color) in spans {
let Some(color) = color else { continue };
let start_offset = len + range.start;
let range = start_offset..len + range.end;
let range = buffer.anchor_after(range.start)..buffer.anchor_before(range.end);
let style = HighlightStyle {
color: Some(terminal_view::terminal_element::convert_color(
&color,
cx.theme(),
)),
..Default::default()
};
console.highlight_text_key::<ConsoleAnsiHighlight>(
start_offset,
vec![range],
style,
cx,
);
}
for (range, color) in background_spans {
let Some(color) = color else { continue };
let start_offset = len + range.start;
let range = start_offset..len + range.end;
let range = buffer.anchor_after(range.start)..buffer.anchor_before(range.end);
let color_fetcher: fn(&Theme) -> Hsla = match color {
// Named and theme defined colors
ansi::Color::Named(n) => match n {
ansi::NamedColor::Black => |theme| theme.colors().terminal_ansi_black,
ansi::NamedColor::Red => |theme| theme.colors().terminal_ansi_red,
ansi::NamedColor::Green => |theme| theme.colors().terminal_ansi_green,
ansi::NamedColor::Yellow => |theme| theme.colors().terminal_ansi_yellow,
ansi::NamedColor::Blue => |theme| theme.colors().terminal_ansi_blue,
ansi::NamedColor::Magenta => {
|theme| theme.colors().terminal_ansi_magenta
}
if ansi_handler.current_background_range_start < output.len() {
background_spans.push((
ansi_handler.current_background_range_start..output.len(),
ansi_handler.current_background_color,
));
ansi::NamedColor::Cyan => |theme| theme.colors().terminal_ansi_cyan,
ansi::NamedColor::White => |theme| theme.colors().terminal_ansi_white,
ansi::NamedColor::BrightBlack => {
|theme| theme.colors().terminal_ansi_bright_black
}
for (range, _) in spans.iter_mut() {
let start_offset = len + range.start;
*range = start_offset..len + range.end;
ansi::NamedColor::BrightRed => {
|theme| theme.colors().terminal_ansi_bright_red
}
for (range, _) in background_spans.iter_mut() {
let start_offset = len + range.start;
*range = start_offset..len + range.end;
ansi::NamedColor::BrightGreen => {
|theme| theme.colors().terminal_ansi_bright_green
}
ansi::NamedColor::BrightYellow => {
|theme| theme.colors().terminal_ansi_bright_yellow
}
ansi::NamedColor::BrightBlue => {
|theme| theme.colors().terminal_ansi_bright_blue
}
ansi::NamedColor::BrightMagenta => {
|theme| theme.colors().terminal_ansi_bright_magenta
}
ansi::NamedColor::BrightCyan => {
|theme| theme.colors().terminal_ansi_bright_cyan
}
ansi::NamedColor::BrightWhite => {
|theme| theme.colors().terminal_ansi_bright_white
}
ansi::NamedColor::Foreground => {
|theme| theme.colors().terminal_foreground
}
ansi::NamedColor::Background => {
|theme| theme.colors().terminal_background
}
ansi::NamedColor::Cursor => |theme| theme.players().local().cursor,
ansi::NamedColor::DimBlack => {
|theme| theme.colors().terminal_ansi_dim_black
}
ansi::NamedColor::DimRed => {
|theme| theme.colors().terminal_ansi_dim_red
}
ansi::NamedColor::DimGreen => {
|theme| theme.colors().terminal_ansi_dim_green
}
ansi::NamedColor::DimYellow => {
|theme| theme.colors().terminal_ansi_dim_yellow
}
ansi::NamedColor::DimBlue => {
|theme| theme.colors().terminal_ansi_dim_blue
}
ansi::NamedColor::DimMagenta => {
|theme| theme.colors().terminal_ansi_dim_magenta
}
ansi::NamedColor::DimCyan => {
|theme| theme.colors().terminal_ansi_dim_cyan
}
ansi::NamedColor::DimWhite => {
|theme| theme.colors().terminal_ansi_dim_white
}
ansi::NamedColor::BrightForeground => {
|theme| theme.colors().terminal_bright_foreground
}
ansi::NamedColor::DimForeground => {
|theme| theme.colors().terminal_dim_foreground
}
},
// 'True' colors
ansi::Color::Spec(_) => |theme| theme.colors().editor_background,
// 8 bit, indexed colors
ansi::Color::Indexed(i) => {
match i {
// 0-15 are the same as the named colors above
0 => |theme| theme.colors().terminal_ansi_black,
1 => |theme| theme.colors().terminal_ansi_red,
2 => |theme| theme.colors().terminal_ansi_green,
3 => |theme| theme.colors().terminal_ansi_yellow,
4 => |theme| theme.colors().terminal_ansi_blue,
5 => |theme| theme.colors().terminal_ansi_magenta,
6 => |theme| theme.colors().terminal_ansi_cyan,
7 => |theme| theme.colors().terminal_ansi_white,
8 => |theme| theme.colors().terminal_ansi_bright_black,
9 => |theme| theme.colors().terminal_ansi_bright_red,
10 => |theme| theme.colors().terminal_ansi_bright_green,
11 => |theme| theme.colors().terminal_ansi_bright_yellow,
12 => |theme| theme.colors().terminal_ansi_bright_blue,
13 => |theme| theme.colors().terminal_ansi_bright_magenta,
14 => |theme| theme.colors().terminal_ansi_bright_cyan,
15 => |theme| theme.colors().terminal_ansi_bright_white,
// 16-231 are a 6x6x6 RGB color cube, mapped to 0-255 using steps defined by XTerm.
// See: https://github.com/xterm-x11/xterm-snapshots/blob/master/256colres.pl
// 16..=231 => {
// let (r, g, b) = rgb_for_index(index as u8);
// rgba_color(
// if r == 0 { 0 } else { r * 40 + 55 },
// if g == 0 { 0 } else { g * 40 + 55 },
// if b == 0 { 0 } else { b * 40 + 55 },
// )
// }
// 232-255 are a 24-step grayscale ramp from (8, 8, 8) to (238, 238, 238).
// 232..=255 => {
// let i = index as u8 - 232; // Align index to 0..24
// let value = i * 10 + 8;
// rgba_color(value, value, value)
// }
// For compatibility with the alacritty::Colors interface
// See: https://github.com/alacritty/alacritty/blob/master/alacritty_terminal/src/term/color.rs
_ => |_| gpui::black(),
}
len += output.len();
all_spans.extend(spans);
all_background_spans.extend(background_spans);
}
(to_insert, all_spans, all_background_spans)
})
.await;
console.update_in(cx, |console, window, cx| {
console.set_read_only(false);
console.move_to_end(&editor::actions::MoveToEnd, window, cx);
console.insert(&output, window, cx);
console.set_read_only(true);
};
struct ConsoleAnsiHighlight;
console.highlight_background_key::<ConsoleAnsiHighlight>(
start_offset,
&[range],
color_fetcher,
cx,
);
}
}
let buffer = console.buffer().read(cx).snapshot(cx);
for (range, color) in spans {
let Some(color) = color else { continue };
let start_offset = range.start;
let range =
buffer.anchor_after(range.start)..buffer.anchor_before(range.end);
let style = HighlightStyle {
color: Some(terminal_view::terminal_element::convert_color(
&color,
cx.theme(),
)),
..Default::default()
};
console.highlight_text_key::<ConsoleAnsiHighlight>(
start_offset,
vec![range],
style,
cx,
);
}
for (range, color) in background_spans {
let Some(color) = color else { continue };
let start_offset = range.start;
let range =
buffer.anchor_after(range.start)..buffer.anchor_before(range.end);
console.highlight_background_key::<ConsoleAnsiHighlight>(
start_offset,
&[range],
color_fetcher(color),
cx,
);
}
cx.notify();
})?;
Ok(())
})
})
console.set_read_only(true);
cx.notify();
});
}
pub fn watch_expression(
@@ -270,8 +345,7 @@ impl Console {
expression
});
self.history.add(&mut self.cursor, expression.clone());
self.cursor.reset();
self.session.update(cx, |session, cx| {
session
.evaluate(
@@ -291,28 +365,7 @@ impl Console {
});
}
fn previous_query(&mut self, _: &SelectPrevious, window: &mut Window, cx: &mut Context<Self>) {
let prev = self.history.previous(&mut self.cursor);
if let Some(prev) = prev {
self.query_bar.update(cx, |editor, cx| {
editor.set_text(prev, window, cx);
});
}
}
fn next_query(&mut self, _: &SelectNext, window: &mut Window, cx: &mut Context<Self>) {
let next = self.history.next(&mut self.cursor);
let query = next.unwrap_or_else(|| {
self.cursor.reset();
""
});
self.query_bar.update(cx, |editor, cx| {
editor.set_text(query, window, cx);
});
}
fn evaluate(&mut self, _: &Confirm, window: &mut Window, cx: &mut Context<Self>) {
pub fn evaluate(&mut self, _: &Confirm, window: &mut Window, cx: &mut Context<Self>) {
let expression = self.query_bar.update(cx, |editor, cx| {
let expression = editor.text(cx);
cx.defer_in(window, |editor, window, cx| {
@@ -322,8 +375,6 @@ impl Console {
expression
});
self.history.add(&mut self.cursor, expression.clone());
self.cursor.reset();
self.session.update(cx, |session, cx| {
session
.evaluate(
@@ -407,50 +458,31 @@ impl Console {
EditorElement::new(&self.query_bar, Self::editor_style(&self.query_bar, cx))
}
pub(crate) fn update_output(&mut self, window: &mut Window, cx: &mut Context<Self>) {
if self.update_output_task.is_some() {
return;
}
fn update_output(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let session = self.session.clone();
let token = self.last_token;
self.update_output_task = Some(cx.spawn_in(window, async move |this, cx| {
let Some((last_processed_token, task)) = session
.update_in(cx, |session, window, cx| {
let (output, last_processed_token) = session.output(token);
this.update(cx, |this, cx| {
if last_processed_token == this.last_token {
return None;
}
Some((
last_processed_token,
this.add_messages(output.cloned().collect(), window, cx),
))
})
.ok()
.flatten()
})
.ok()
.flatten()
else {
_ = this.update(cx, |this, _| {
this.update_output_task.take();
self.update_output_task = cx.spawn_in(window, async move |this, cx| {
_ = session.update_in(cx, move |session, window, cx| {
let (output, last_processed_token) = session.output(token);
_ = this.update(cx, |this, cx| {
if last_processed_token == this.last_token {
return;
}
this.add_messages(output, window, cx);
this.last_token = last_processed_token;
});
return;
};
_ = task.await.log_err();
_ = this.update(cx, |this, _| {
this.last_token = last_processed_token;
this.update_output_task.take();
});
}));
});
}
}
impl Render for Console {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let query_focus_handle = self.query_bar.focus_handle(cx);
self.update_output(window, cx);
v_flex()
.track_focus(&self.focus_handle)
.key_context("DebugConsole")
@@ -461,8 +493,6 @@ impl Render for Console {
.when(self.is_running(cx), |this| {
this.child(Divider::horizontal()).child(
h_flex()
.on_action(cx.listener(Self::previous_query))
.on_action(cx.listener(Self::next_query))
.gap_1()
.bg(cx.theme().colors().editor_background)
.child(self.render_query_bar(cx))
@@ -815,84 +845,3 @@ impl ansi::Handler for ConsoleHandler {
}
}
}
fn color_fetcher(color: ansi::Color) -> fn(&Theme) -> Hsla {
let color_fetcher: fn(&Theme) -> Hsla = match color {
// Named and theme defined colors
ansi::Color::Named(n) => match n {
ansi::NamedColor::Black => |theme| theme.colors().terminal_ansi_black,
ansi::NamedColor::Red => |theme| theme.colors().terminal_ansi_red,
ansi::NamedColor::Green => |theme| theme.colors().terminal_ansi_green,
ansi::NamedColor::Yellow => |theme| theme.colors().terminal_ansi_yellow,
ansi::NamedColor::Blue => |theme| theme.colors().terminal_ansi_blue,
ansi::NamedColor::Magenta => |theme| theme.colors().terminal_ansi_magenta,
ansi::NamedColor::Cyan => |theme| theme.colors().terminal_ansi_cyan,
ansi::NamedColor::White => |theme| theme.colors().terminal_ansi_white,
ansi::NamedColor::BrightBlack => |theme| theme.colors().terminal_ansi_bright_black,
ansi::NamedColor::BrightRed => |theme| theme.colors().terminal_ansi_bright_red,
ansi::NamedColor::BrightGreen => |theme| theme.colors().terminal_ansi_bright_green,
ansi::NamedColor::BrightYellow => |theme| theme.colors().terminal_ansi_bright_yellow,
ansi::NamedColor::BrightBlue => |theme| theme.colors().terminal_ansi_bright_blue,
ansi::NamedColor::BrightMagenta => |theme| theme.colors().terminal_ansi_bright_magenta,
ansi::NamedColor::BrightCyan => |theme| theme.colors().terminal_ansi_bright_cyan,
ansi::NamedColor::BrightWhite => |theme| theme.colors().terminal_ansi_bright_white,
ansi::NamedColor::Foreground => |theme| theme.colors().terminal_foreground,
ansi::NamedColor::Background => |theme| theme.colors().terminal_background,
ansi::NamedColor::Cursor => |theme| theme.players().local().cursor,
ansi::NamedColor::DimBlack => |theme| theme.colors().terminal_ansi_dim_black,
ansi::NamedColor::DimRed => |theme| theme.colors().terminal_ansi_dim_red,
ansi::NamedColor::DimGreen => |theme| theme.colors().terminal_ansi_dim_green,
ansi::NamedColor::DimYellow => |theme| theme.colors().terminal_ansi_dim_yellow,
ansi::NamedColor::DimBlue => |theme| theme.colors().terminal_ansi_dim_blue,
ansi::NamedColor::DimMagenta => |theme| theme.colors().terminal_ansi_dim_magenta,
ansi::NamedColor::DimCyan => |theme| theme.colors().terminal_ansi_dim_cyan,
ansi::NamedColor::DimWhite => |theme| theme.colors().terminal_ansi_dim_white,
ansi::NamedColor::BrightForeground => |theme| theme.colors().terminal_bright_foreground,
ansi::NamedColor::DimForeground => |theme| theme.colors().terminal_dim_foreground,
},
// 'True' colors
ansi::Color::Spec(_) => |theme| theme.colors().editor_background,
// 8 bit, indexed colors
ansi::Color::Indexed(i) => {
match i {
// 0-15 are the same as the named colors above
0 => |theme| theme.colors().terminal_ansi_black,
1 => |theme| theme.colors().terminal_ansi_red,
2 => |theme| theme.colors().terminal_ansi_green,
3 => |theme| theme.colors().terminal_ansi_yellow,
4 => |theme| theme.colors().terminal_ansi_blue,
5 => |theme| theme.colors().terminal_ansi_magenta,
6 => |theme| theme.colors().terminal_ansi_cyan,
7 => |theme| theme.colors().terminal_ansi_white,
8 => |theme| theme.colors().terminal_ansi_bright_black,
9 => |theme| theme.colors().terminal_ansi_bright_red,
10 => |theme| theme.colors().terminal_ansi_bright_green,
11 => |theme| theme.colors().terminal_ansi_bright_yellow,
12 => |theme| theme.colors().terminal_ansi_bright_blue,
13 => |theme| theme.colors().terminal_ansi_bright_magenta,
14 => |theme| theme.colors().terminal_ansi_bright_cyan,
15 => |theme| theme.colors().terminal_ansi_bright_white,
// 16-231 are a 6x6x6 RGB color cube, mapped to 0-255 using steps defined by XTerm.
// See: https://github.com/xterm-x11/xterm-snapshots/blob/master/256colres.pl
// 16..=231 => {
// let (r, g, b) = rgb_for_index(index as u8);
// rgba_color(
// if r == 0 { 0 } else { r * 40 + 55 },
// if g == 0 { 0 } else { g * 40 + 55 },
// if b == 0 { 0 } else { b * 40 + 55 },
// )
// }
// 232-255 are a 24-step grayscale ramp from (8, 8, 8) to (238, 238, 238).
// 232..=255 => {
// let i = index as u8 - 232; // Align index to 0..24
// let value = i * 10 + 8;
// rgba_color(value, value, value)
// }
// For compatibility with the alacritty::Colors interface
// See: https://github.com/alacritty/alacritty/blob/master/alacritty_terminal/src/term/color.rs
_ => |_| gpui::black(),
}
}
};
color_fetcher
}

Some files were not shown because too many files have changed in this diff Show More