Compare commits
52 Commits
keymap_edi
...
follow-age
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ab62666c3 | ||
|
|
35539847a4 | ||
|
|
f619d5f02a | ||
|
|
ba59305510 | ||
|
|
672a1dd553 | ||
|
|
93cc4946d8 | ||
|
|
0c0a4ed866 | ||
|
|
51f1998107 | ||
|
|
1ffedf4a08 | ||
|
|
d25da9728b | ||
|
|
e1e3f2e423 | ||
|
|
92b9ecd7d2 | ||
|
|
758d260cec | ||
|
|
8d4d3badf3 | ||
|
|
7c23d13773 | ||
|
|
ad87c545c7 | ||
|
|
23fbab15ee | ||
|
|
d7e181576e | ||
|
|
9788aff4b1 | ||
|
|
2a319efade | ||
|
|
50ec26c163 | ||
|
|
39dd133b1c | ||
|
|
24eb039752 | ||
|
|
bffa53d706 | ||
|
|
0e5e8f9f8d | ||
|
|
96d785cb45 | ||
|
|
57610c9935 | ||
|
|
5bf1b4f0a8 | ||
|
|
f891dfb358 | ||
|
|
e3a2d52472 | ||
|
|
122af4fd53 | ||
|
|
e07ffe7cf1 | ||
|
|
5e4be013af | ||
|
|
f055dca592 | ||
|
|
5872276511 | ||
|
|
1bf9e15f26 | ||
|
|
f046d70625 | ||
|
|
afeb3d4fd9 | ||
|
|
92dd6b67c7 | ||
|
|
38ede4bae3 | ||
|
|
fc920bf63d | ||
|
|
04c68dc0cf | ||
|
|
399eced884 | ||
|
|
50f705e779 | ||
|
|
8173534ad5 | ||
|
|
8c03934b26 | ||
|
|
84e4891d54 | ||
|
|
d03d8ccec1 | ||
|
|
4d934f2884 | ||
|
|
e697cf9747 | ||
|
|
e4e455ae6b | ||
|
|
ff4900c942 |
2
.github/workflows/eval.yml
vendored
2
.github/workflows/eval.yml
vendored
@@ -69,7 +69,7 @@ jobs:
|
||||
run: cargo build --package=eval
|
||||
|
||||
- name: Run eval
|
||||
run: cargo run --package=eval -- --repetitions=3 --concurrency=1
|
||||
run: cargo run --package=eval -- --repetitions=8 --concurrency=1
|
||||
|
||||
# Even the Linux runner is not stateful, in theory there is no need to do this cleanup.
|
||||
# But, to avoid potential issues in the future if we choose to use a stateful Linux runner and forget to add code
|
||||
|
||||
@@ -46,5 +46,17 @@
|
||||
"formatter": "auto",
|
||||
"remove_trailing_whitespace_on_save": true,
|
||||
"ensure_final_newline_on_save": true,
|
||||
"file_scan_exclusions": ["crates/eval/worktrees/", "crates/eval/repos/"]
|
||||
"file_scan_exclusions": [
|
||||
"crates/eval/worktrees/",
|
||||
"crates/eval/repos/",
|
||||
"**/.git",
|
||||
"**/.svn",
|
||||
"**/.hg",
|
||||
"**/.jj",
|
||||
"**/CVS",
|
||||
"**/.DS_Store",
|
||||
"**/Thumbs.db",
|
||||
"**/.classpath",
|
||||
"**/.settings"
|
||||
]
|
||||
}
|
||||
|
||||
178
Cargo.lock
generated
178
Cargo.lock
generated
@@ -68,6 +68,7 @@ dependencies = [
|
||||
"convert_case 0.8.0",
|
||||
"db",
|
||||
"editor",
|
||||
"extension",
|
||||
"feature_flags",
|
||||
"file_icons",
|
||||
"fs",
|
||||
@@ -81,6 +82,7 @@ dependencies = [
|
||||
"indexmap",
|
||||
"indoc",
|
||||
"itertools 0.14.0",
|
||||
"jsonschema",
|
||||
"language",
|
||||
"language_model",
|
||||
"language_model_selector",
|
||||
@@ -90,6 +92,7 @@ dependencies = [
|
||||
"markdown",
|
||||
"menu",
|
||||
"multi_buffer",
|
||||
"notifications",
|
||||
"ordered-float 2.10.1",
|
||||
"parking_lot",
|
||||
"paths",
|
||||
@@ -106,6 +109,7 @@ dependencies = [
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_json_lenient",
|
||||
"settings",
|
||||
"smallvec",
|
||||
"smol",
|
||||
@@ -148,7 +152,9 @@ checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"const-random",
|
||||
"getrandom 0.2.15",
|
||||
"once_cell",
|
||||
"serde",
|
||||
"version_check",
|
||||
"zerocopy 0.7.35",
|
||||
]
|
||||
@@ -690,6 +696,7 @@ dependencies = [
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
@@ -703,7 +710,9 @@ dependencies = [
|
||||
name = "assistant_tools"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"anyhow",
|
||||
"assistant_settings",
|
||||
"assistant_tool",
|
||||
"buffer_diff",
|
||||
"chrono",
|
||||
@@ -711,26 +720,38 @@ dependencies = [
|
||||
"clock",
|
||||
"collections",
|
||||
"component",
|
||||
"derive_more",
|
||||
"editor",
|
||||
"feature_flags",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"handlebars 4.5.0",
|
||||
"html_to_markdown",
|
||||
"http_client",
|
||||
"indoc",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
"language_model",
|
||||
"language_models",
|
||||
"linkme",
|
||||
"open",
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"reqwest_client",
|
||||
"rust-embed",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smallvec",
|
||||
"streaming_diff",
|
||||
"strsim",
|
||||
"task",
|
||||
"tempfile",
|
||||
"terminal",
|
||||
"terminal_view",
|
||||
"tree-sitter-rust",
|
||||
@@ -2171,6 +2192,12 @@ dependencies = [
|
||||
"piper",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "borrow-or-share"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3eeab4423108c5d7c744f4d234de88d18d636100093ae04caf4825134b9c3a32"
|
||||
|
||||
[[package]]
|
||||
name = "borsh"
|
||||
version = "1.5.7"
|
||||
@@ -2286,6 +2313,12 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bytecount"
|
||||
version = "0.6.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce"
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
version = "1.22.0"
|
||||
@@ -3206,7 +3239,9 @@ dependencies = [
|
||||
name = "component_preview"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"agent",
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"client",
|
||||
"collections",
|
||||
"component",
|
||||
@@ -3216,6 +3251,7 @@ dependencies = [
|
||||
"log",
|
||||
"notifications",
|
||||
"project",
|
||||
"prompt_store",
|
||||
"serde",
|
||||
"ui",
|
||||
"ui_input",
|
||||
@@ -4369,6 +4405,7 @@ dependencies = [
|
||||
"ctor",
|
||||
"editor",
|
||||
"env_logger 0.11.8",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"indoc",
|
||||
"language",
|
||||
@@ -4764,6 +4801,15 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "email_address"
|
||||
version = "0.2.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "embed-resource"
|
||||
version = "3.0.2"
|
||||
@@ -4997,6 +5043,7 @@ dependencies = [
|
||||
"node_runtime",
|
||||
"pathdiff",
|
||||
"paths",
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"prompt_store",
|
||||
"regex",
|
||||
@@ -5178,6 +5225,7 @@ dependencies = [
|
||||
"collections",
|
||||
"db",
|
||||
"editor",
|
||||
"extension",
|
||||
"extension_host",
|
||||
"fs",
|
||||
"fuzzy",
|
||||
@@ -5410,6 +5458,17 @@ version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8bf7cc16383c4b8d58b9905a8509f02926ce3058053c056376248d958c9df1e8"
|
||||
|
||||
[[package]]
|
||||
name = "fluent-uri"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1918b65d96df47d3591bed19c5cca17e3fa5d0707318e4b5ef2eae01764df7e5"
|
||||
dependencies = [
|
||||
"borrow-or-share",
|
||||
"ref-cast",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "flume"
|
||||
version = "0.11.1"
|
||||
@@ -5564,6 +5623,16 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fraction"
|
||||
version = "0.15.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0f158e3ff0a1b334408dc9fb811cd99b446986f4d8b741bb08f9df1604085ae7"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
"num",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "freetype-sys"
|
||||
version = "0.20.1"
|
||||
@@ -6361,6 +6430,7 @@ dependencies = [
|
||||
"log",
|
||||
"pest",
|
||||
"pest_derive",
|
||||
"rust-embed",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 1.0.69",
|
||||
@@ -7566,6 +7636,33 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jsonschema"
|
||||
version = "0.30.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1b46a0365a611fbf1d2143104dcf910aada96fafd295bab16c60b802bf6fa1d"
|
||||
dependencies = [
|
||||
"ahash 0.8.11",
|
||||
"base64 0.22.1",
|
||||
"bytecount",
|
||||
"email_address",
|
||||
"fancy-regex 0.14.0",
|
||||
"fraction",
|
||||
"idna",
|
||||
"itoa",
|
||||
"num-cmp",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"referencing",
|
||||
"regex",
|
||||
"regex-syntax 0.8.5",
|
||||
"reqwest 0.12.15 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"uuid-simd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jsonwebtoken"
|
||||
version = "9.3.1"
|
||||
@@ -8250,7 +8347,7 @@ dependencies = [
|
||||
"prost 0.9.0",
|
||||
"prost-build 0.9.0",
|
||||
"prost-types 0.9.0",
|
||||
"reqwest 0.12.15",
|
||||
"reqwest 0.12.15 (git+https://github.com/zed-industries/reqwest.git?rev=951c770a32f1998d6e999cef3e59e0013e6c4415)",
|
||||
"serde",
|
||||
"workspace-hack",
|
||||
]
|
||||
@@ -9160,6 +9257,12 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-cmp"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "63335b2e2c34fae2fb0aa2cecfd9f0832a1e24b3b32ecec612c3426d46dc8aaa"
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.4.6"
|
||||
@@ -11753,6 +11856,20 @@ dependencies = [
|
||||
"syn 2.0.100",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "referencing"
|
||||
version = "0.30.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c8eff4fa778b5c2a57e85c5f2fe3a709c52f0e60d23146e2151cbef5893f420e"
|
||||
dependencies = [
|
||||
"ahash 0.8.11",
|
||||
"fluent-uri",
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"percent-encoding",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "refineable"
|
||||
version = "0.1.0"
|
||||
@@ -12022,6 +12139,43 @@ dependencies = [
|
||||
"winreg 0.50.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.12.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"bytes 1.10.1",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"http 1.3.1",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"hyper 1.6.0",
|
||||
"hyper-util",
|
||||
"ipnet",
|
||||
"js-sys",
|
||||
"log",
|
||||
"mime",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper 1.0.2",
|
||||
"tokio",
|
||||
"tower 0.5.2",
|
||||
"tower-service",
|
||||
"url",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"web-sys",
|
||||
"windows-registry 0.4.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.12.15"
|
||||
@@ -12082,7 +12236,7 @@ dependencies = [
|
||||
"http_client_tls",
|
||||
"log",
|
||||
"regex",
|
||||
"reqwest 0.12.15",
|
||||
"reqwest 0.12.15 (git+https://github.com/zed-industries/reqwest.git?rev=951c770a32f1998d6e999cef3e59e0013e6c4415)",
|
||||
"serde",
|
||||
"smol",
|
||||
"tokio",
|
||||
@@ -15933,6 +16087,17 @@ dependencies = [
|
||||
"sha1_smol",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uuid-simd"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "23b082222b4f6619906941c17eb2297fff4c2fb96cb60164170522942a200bd8"
|
||||
dependencies = [
|
||||
"outref",
|
||||
"uuid",
|
||||
"vsimd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "v_frame"
|
||||
version = "0.3.8"
|
||||
@@ -16878,18 +17043,22 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"client",
|
||||
"component",
|
||||
"db",
|
||||
"documented",
|
||||
"editor",
|
||||
"fuzzy",
|
||||
"gpui",
|
||||
"install_cli",
|
||||
"language",
|
||||
"linkme",
|
||||
"picker",
|
||||
"project",
|
||||
"schemars",
|
||||
"serde",
|
||||
"settings",
|
||||
"telemetry",
|
||||
"theme",
|
||||
"ui",
|
||||
"util",
|
||||
"vim_mode_setting",
|
||||
@@ -18022,12 +18191,14 @@ dependencies = [
|
||||
"getrandom 0.2.15",
|
||||
"getrandom 0.3.2",
|
||||
"gimli",
|
||||
"handlebars 4.5.0",
|
||||
"hashbrown 0.14.5",
|
||||
"hashbrown 0.15.2",
|
||||
"heck 0.4.1",
|
||||
"hmac",
|
||||
"hyper 0.14.32",
|
||||
"hyper-rustls 0.27.5",
|
||||
"idna",
|
||||
"indexmap",
|
||||
"inout",
|
||||
"itertools 0.12.1",
|
||||
@@ -18051,6 +18222,7 @@ dependencies = [
|
||||
"num-bigint-dig",
|
||||
"num-integer",
|
||||
"num-iter",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"object",
|
||||
"once_cell",
|
||||
@@ -18465,7 +18637,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed"
|
||||
version = "0.185.0"
|
||||
version = "0.186.0"
|
||||
dependencies = [
|
||||
"activity_indicator",
|
||||
"agent",
|
||||
|
||||
@@ -435,6 +435,7 @@ dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "be69a0
|
||||
dashmap = "6.0"
|
||||
derive_more = "0.99.17"
|
||||
dirs = "4.0"
|
||||
documented = "0.9.1"
|
||||
dotenv = "0.15.0"
|
||||
ec4rs = "1.1"
|
||||
emojis = "0.6.1"
|
||||
@@ -461,6 +462,7 @@ indexmap = { version = "2.7.0", features = ["serde"] }
|
||||
indoc = "2"
|
||||
inventory = "0.3.19"
|
||||
itertools = "0.14.0"
|
||||
jsonschema = "0.30.0"
|
||||
jsonwebtoken = "9.3"
|
||||
jupyter-protocol = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
|
||||
jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed" ,rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
|
||||
@@ -797,5 +799,6 @@ ignored = [
|
||||
"serde",
|
||||
"component",
|
||||
"linkme",
|
||||
"documented",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
1
assets/icons/hammer.svg
Normal file
1
assets/icons/hammer.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-hammer-icon lucide-hammer"><path d="m15 12-8.373 8.373a1 1 0 1 1-3-3L12 9"/><path d="m18 15 4-4"/><path d="m21.5 11.5-1.914-1.914A2 2 0 0 1 19 8.172V7l-2.26-2.26a6 6 0 0 0-4.202-1.756L9 2.96l.92.82A6.18 6.18 0 0 1 12 8.4V10l2 2h1.172a2 2 0 0 1 1.414.586L18.5 14.5"/></svg>
|
||||
|
After Width: | Height: | Size: 475 B |
1
assets/icons/user_check.svg
Normal file
1
assets/icons/user_check.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-user-round-check-icon lucide-user-round-check"><path d="M2 21a8 8 0 0 1 13.292-6"/><circle cx="10" cy="8" r="5"/><path d="m16 19 2 2 4-4"/></svg>
|
||||
|
After Width: | Height: | Size: 348 B |
@@ -248,7 +248,6 @@
|
||||
"ctrl-shift-o": "agent::ToggleNavigationMenu",
|
||||
"ctrl-shift-i": "agent::ToggleOptionsMenu",
|
||||
"shift-escape": "agent::ExpandMessageEditor",
|
||||
"ctrl-e": "agent::ChatMode",
|
||||
"ctrl-alt-e": "agent::RemoveAllContext"
|
||||
}
|
||||
},
|
||||
@@ -962,5 +961,20 @@
|
||||
"bindings": {
|
||||
"escape": "menu::Cancel"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "ConfigureContextServerModal > Editor",
|
||||
"bindings": {
|
||||
"escape": "menu::Cancel",
|
||||
"enter": "editor::Newline",
|
||||
"ctrl-enter": "menu::Confirm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Diagnostics",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"ctrl-r": "diagnostics::ToggleDiagnosticsRefresh"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -293,7 +293,6 @@
|
||||
"cmd-shift-o": "agent::ToggleNavigationMenu",
|
||||
"cmd-shift-i": "agent::ToggleOptionsMenu",
|
||||
"shift-escape": "agent::ExpandMessageEditor",
|
||||
"cmd-e": "agent::ChatMode",
|
||||
"cmd-alt-e": "agent::RemoveAllContext"
|
||||
}
|
||||
},
|
||||
@@ -1068,5 +1067,21 @@
|
||||
"bindings": {
|
||||
"escape": "menu::Cancel"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "ConfigureContextServerModal > Editor",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"escape": "menu::Cancel",
|
||||
"enter": "editor::Newline",
|
||||
"cmd-enter": "menu::Confirm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Diagnostics",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"ctrl-r": "diagnostics::ToggleDiagnosticsRefresh"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -3,11 +3,12 @@ You are a highly skilled software engineer with extensive knowledge in many prog
|
||||
## Communication
|
||||
|
||||
1. Be conversational but professional.
|
||||
2. Refer to the USER in the second person and yourself in the first person.
|
||||
2. Refer to the user in the second person and yourself in the first person.
|
||||
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names.
|
||||
4. NEVER lie or make things up.
|
||||
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
|
||||
|
||||
{{#if has_tools}}
|
||||
## Tool Use
|
||||
|
||||
1. Make sure to adhere to the tools schema.
|
||||
@@ -22,6 +23,7 @@ You are a highly skilled software engineer with extensive knowledge in many prog
|
||||
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
|
||||
|
||||
{{! TODO: If there are files, we should mention it but otherwise omit that fact }}
|
||||
{{#if has_tools}}
|
||||
If appropriate, use tool calls to explore the current project, which contains the following root directories:
|
||||
|
||||
{{#each worktrees}}
|
||||
@@ -36,6 +38,14 @@ If appropriate, use tool calls to explore the current project, which contains th
|
||||
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
|
||||
- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file.
|
||||
{{/if}}
|
||||
{{/if}}
|
||||
{{else}}
|
||||
You are being tasked with providing a response, but you have no ability to use tools or to read or write any aspect of the user's system (other than any context the user might have provided to you).
|
||||
|
||||
As such, if you need the user to perform any actions for you, you must request them explicitly. Bias towards giving a response to the best of your ability, and then making requests for the user to take action (e.g. to give you more context) only optionally.
|
||||
|
||||
The one exception to this is if the user references something you don't know about - for example, the name of a source code file, function, type, or other piece of code that you have no awareness of. In this case, you MUST NOT MAKE SOMETHING UP, or assume you know what that thing is or how it works. Instead, you must ask the user for clarification rather than giving a response.
|
||||
{{/if}}
|
||||
|
||||
## Code Block Formatting
|
||||
|
||||
@@ -111,6 +121,8 @@ In Markdown, hash marks signify headings. For example:
|
||||
```
|
||||
</bad_example_do_not_do_this>
|
||||
This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks.
|
||||
|
||||
{{#if has_tools}}
|
||||
## Fixing Diagnostics
|
||||
|
||||
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.
|
||||
@@ -124,10 +136,11 @@ Otherwise, follow debugging best practices:
|
||||
2. Add descriptive logging statements and error messages to track variable and code state.
|
||||
3. Add test functions and statements to isolate the problem.
|
||||
|
||||
{{/if}}
|
||||
## Calling External APIs
|
||||
|
||||
1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission.
|
||||
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data.
|
||||
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file(s). If no such file exists or if the package is not present, use the latest version that is in your training data.
|
||||
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
|
||||
|
||||
## System Information
|
||||
@@ -135,10 +148,10 @@ Otherwise, follow debugging best practices:
|
||||
Operating System: {{os}}
|
||||
Default Shell: {{shell}}
|
||||
|
||||
{{#if (or has_rules has_default_user_rules)}}
|
||||
{{#if (or has_rules has_user_rules)}}
|
||||
## User's Custom Instructions
|
||||
|
||||
The following additional instructions are provided by the user, and should be followed to the best of your ability without interfering with the tool use guidelines.
|
||||
The following additional instructions are provided by the user, and should be followed to the best of your ability{{#if has_tools}} without interfering with the tool use guidelines{{/if}}.
|
||||
|
||||
{{#if has_rules}}
|
||||
There are project rules that apply to these root directories:
|
||||
|
||||
@@ -657,6 +657,8 @@
|
||||
},
|
||||
// When enabled, the agent can run potentially destructive actions without asking for your confirmation.
|
||||
"always_allow_tool_actions": false,
|
||||
// When enabled, the agent will stream edits.
|
||||
"stream_edits": false,
|
||||
"default_profile": "write",
|
||||
"profiles": {
|
||||
"ask": {
|
||||
@@ -671,6 +673,7 @@
|
||||
"now": true,
|
||||
"find_path": true,
|
||||
"read_file": true,
|
||||
"open": true,
|
||||
"grep": true,
|
||||
"thinking": true,
|
||||
"web_search": true
|
||||
@@ -834,7 +837,20 @@
|
||||
// "modal_max_width": "full"
|
||||
//
|
||||
// Default: small
|
||||
"modal_max_width": "small"
|
||||
"modal_max_width": "small",
|
||||
// Determines whether the file finder should skip focus for the active file in search results.
|
||||
// There are 2 possible values:
|
||||
//
|
||||
// 1. true: When searching for files, if the currently active file appears as the first result,
|
||||
// auto-focus will skip it and focus the second result instead.
|
||||
// "skip_focus_for_active_in_search": true
|
||||
//
|
||||
// 2. false: When searching for files, the first result will always receive focus,
|
||||
// even if it's the currently active file.
|
||||
// "skip_focus_for_active_in_search": false
|
||||
//
|
||||
// Default: true
|
||||
"skip_focus_for_active_in_search": true
|
||||
},
|
||||
// Whether or not to remove any trailing whitespace from lines of a buffer
|
||||
// before saving it.
|
||||
@@ -917,6 +933,11 @@
|
||||
// The minimum severity of the diagnostics to show inline.
|
||||
// Shows all diagnostics when not specified.
|
||||
"max_severity": null
|
||||
},
|
||||
"rust": {
|
||||
// When enabled, Zed disables rust-analyzer's check on save and starts to query
|
||||
// Cargo diagnostics separately.
|
||||
"fetch_cargo_diagnostics": false
|
||||
}
|
||||
},
|
||||
// Files or globs of files that will be excluded by Zed entirely. They will be skipped during file
|
||||
|
||||
@@ -35,6 +35,7 @@ context_server.workspace = true
|
||||
convert_case.workspace = true
|
||||
db.workspace = true
|
||||
editor.workspace = true
|
||||
extension.workspace = true
|
||||
feature_flags.workspace = true
|
||||
file_icons.workspace = true
|
||||
fs.workspace = true
|
||||
@@ -47,6 +48,7 @@ html_to_markdown.workspace = true
|
||||
http_client.workspace = true
|
||||
indexmap.workspace = true
|
||||
itertools.workspace = true
|
||||
jsonschema.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
language_model_selector.workspace = true
|
||||
@@ -56,6 +58,7 @@ lsp.workspace = true
|
||||
markdown.workspace = true
|
||||
menu.workspace = true
|
||||
multi_buffer.workspace = true
|
||||
notifications.workspace = true
|
||||
ordered-float.workspace = true
|
||||
parking_lot.workspace = true
|
||||
paths.workspace = true
|
||||
@@ -71,6 +74,7 @@ rope.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
serde_json_lenient.workspace = true
|
||||
settings.workspace = true
|
||||
smallvec.workspace = true
|
||||
smol.workspace = true
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use crate::context::{AgentContextHandle, RULES_ICON};
|
||||
use crate::context_picker::MentionLink;
|
||||
use crate::context_picker::{ContextPicker, MentionLink};
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
|
||||
use crate::thread::{
|
||||
LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent,
|
||||
ThreadFeedback,
|
||||
@@ -14,14 +16,16 @@ use anyhow::Context as _;
|
||||
use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
|
||||
use assistant_tool::ToolUseStatus;
|
||||
use collections::{HashMap, HashSet};
|
||||
use editor::actions::{MoveUp, Paste};
|
||||
use editor::scroll::Autoscroll;
|
||||
use editor::{Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer};
|
||||
use gpui::{
|
||||
AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardItem,
|
||||
DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter, Focusable, Hsla, ListAlignment,
|
||||
ListState, MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription,
|
||||
Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle,
|
||||
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
|
||||
AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardEntry,
|
||||
ClipboardItem, DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter, Focusable, Hsla,
|
||||
ListAlignment, ListState, MouseButton, PlatformDisplay, ScrollHandle, Stateful,
|
||||
StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation,
|
||||
UnderlineStyle, WeakEntity, WindowHandle, linear_color_stop, linear_gradient, list, percentage,
|
||||
pulsating_between,
|
||||
};
|
||||
use language::{Buffer, Language, LanguageRegistry};
|
||||
use language_model::{
|
||||
@@ -41,7 +45,8 @@ use std::time::Duration;
|
||||
use text::ToPoint;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{
|
||||
Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*,
|
||||
Disclosure, IconButton, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState, TextSize,
|
||||
Tooltip, prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use util::markdown::MarkdownCodeBlock;
|
||||
@@ -49,6 +54,7 @@ use workspace::Workspace;
|
||||
use zed_actions::assistant::OpenRulesLibrary;
|
||||
|
||||
pub struct ActiveThread {
|
||||
context_store: Entity<ContextStore>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
thread: Entity<Thread>,
|
||||
@@ -61,7 +67,7 @@ pub struct ActiveThread {
|
||||
hide_scrollbar_task: Option<Task<()>>,
|
||||
rendered_messages_by_id: HashMap<MessageId, RenderedMessage>,
|
||||
rendered_tool_uses: HashMap<LanguageModelToolUseId, RenderedToolUse>,
|
||||
editing_message: Option<(MessageId, EditMessageState)>,
|
||||
editing_message: Option<(MessageId, EditingMessageState)>,
|
||||
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
|
||||
expanded_thinking_segments: HashMap<(MessageId, usize), bool>,
|
||||
expanded_code_blocks: HashMap<(MessageId, usize), bool>,
|
||||
@@ -72,6 +78,7 @@ pub struct ActiveThread {
|
||||
_subscriptions: Vec<Subscription>,
|
||||
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
|
||||
open_feedback_editors: HashMap<MessageId, Entity<Editor>>,
|
||||
_load_edited_message_context_task: Option<Task<()>>,
|
||||
}
|
||||
|
||||
struct RenderedMessage {
|
||||
@@ -725,10 +732,12 @@ fn open_markdown_link(
|
||||
}
|
||||
}
|
||||
|
||||
struct EditMessageState {
|
||||
struct EditingMessageState {
|
||||
editor: Entity<Editor>,
|
||||
context_strip: Entity<ContextStrip>,
|
||||
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
|
||||
last_estimated_token_count: Option<usize>,
|
||||
_subscription: Subscription,
|
||||
_subscriptions: [Subscription; 2],
|
||||
_update_token_count_task: Option<Task<()>>,
|
||||
}
|
||||
|
||||
@@ -736,6 +745,7 @@ impl ActiveThread {
|
||||
pub fn new(
|
||||
thread: Entity<Thread>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
context_store: Entity<ContextStore>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
window: &mut Window,
|
||||
@@ -758,6 +768,7 @@ impl ActiveThread {
|
||||
let mut this = Self {
|
||||
language_registry,
|
||||
thread_store,
|
||||
context_store,
|
||||
thread: thread.clone(),
|
||||
workspace,
|
||||
save_thread_task: None,
|
||||
@@ -779,6 +790,7 @@ impl ActiveThread {
|
||||
_subscriptions: subscriptions,
|
||||
notification_subscriptions: HashMap::default(),
|
||||
open_feedback_editors: HashMap::default(),
|
||||
_load_edited_message_context_task: None,
|
||||
};
|
||||
|
||||
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
|
||||
@@ -1237,33 +1249,49 @@ impl ActiveThread {
|
||||
return;
|
||||
};
|
||||
|
||||
let buffer = cx.new(|cx| {
|
||||
MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
|
||||
});
|
||||
let editor = cx.new(|cx| {
|
||||
let mut editor = Editor::new(
|
||||
editor::EditorMode::AutoHeight { max_lines: 8 },
|
||||
buffer,
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
let editor = crate::message_editor::create_editor(
|
||||
self.workspace.clone(),
|
||||
self.context_store.downgrade(),
|
||||
self.thread_store.downgrade(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
editor.update(cx, |editor, cx| {
|
||||
editor.set_text(message_text.clone(), window, cx);
|
||||
editor.focus_handle(cx).focus(window);
|
||||
editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
|
||||
editor
|
||||
});
|
||||
let subscription = cx.subscribe(&editor, |this, _, event, cx| match event {
|
||||
let buffer_edited_subscription = cx.subscribe(&editor, |this, _, event, cx| match event {
|
||||
EditorEvent::BufferEdited => {
|
||||
this.update_editing_message_token_count(true, cx);
|
||||
}
|
||||
_ => {}
|
||||
});
|
||||
|
||||
let context_picker_menu_handle = PopoverMenuHandle::default();
|
||||
let context_strip = cx.new(|cx| {
|
||||
ContextStrip::new(
|
||||
self.context_store.clone(),
|
||||
self.workspace.clone(),
|
||||
Some(self.thread_store.downgrade()),
|
||||
context_picker_menu_handle.clone(),
|
||||
SuggestContextKind::File,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let context_strip_subscription =
|
||||
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event);
|
||||
|
||||
self.editing_message = Some((
|
||||
message_id,
|
||||
EditMessageState {
|
||||
EditingMessageState {
|
||||
editor: editor.clone(),
|
||||
context_strip,
|
||||
context_picker_menu_handle,
|
||||
last_estimated_token_count: None,
|
||||
_subscription: subscription,
|
||||
_subscriptions: [buffer_edited_subscription, context_strip_subscription],
|
||||
_update_token_count_task: None,
|
||||
},
|
||||
));
|
||||
@@ -1271,6 +1299,26 @@ impl ActiveThread {
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn handle_context_strip_event(
|
||||
&mut self,
|
||||
_context_strip: &Entity<ContextStrip>,
|
||||
event: &ContextStripEvent,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some((_, state)) = self.editing_message.as_ref() {
|
||||
match event {
|
||||
ContextStripEvent::PickerDismissed
|
||||
| ContextStripEvent::BlurredEmpty
|
||||
| ContextStripEvent::BlurredDown => {
|
||||
let editor_focus_handle = state.editor.focus_handle(cx);
|
||||
window.focus(&editor_focus_handle);
|
||||
}
|
||||
ContextStripEvent::BlurredUp => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_editing_message_token_count(&mut self, debounce: bool, cx: &mut Context<Self>) {
|
||||
let Some((message_id, state)) = self.editing_message.as_mut() else {
|
||||
return;
|
||||
@@ -1357,6 +1405,68 @@ impl ActiveThread {
|
||||
}));
|
||||
}
|
||||
|
||||
fn toggle_context_picker(
|
||||
&mut self,
|
||||
_: &crate::ToggleContextPicker,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some((_, state)) = self.editing_message.as_mut() {
|
||||
let handle = state.context_picker_menu_handle.clone();
|
||||
window.defer(cx, move |window, cx| {
|
||||
handle.toggle(window, cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_all_context(
|
||||
&mut self,
|
||||
_: &crate::RemoveAllContext,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.context_store.update(cx, |store, _cx| store.clear());
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn move_up(&mut self, _: &MoveUp, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if let Some((_, state)) = self.editing_message.as_mut() {
|
||||
if state.context_picker_menu_handle.is_deployed() {
|
||||
cx.propagate();
|
||||
} else {
|
||||
state.context_strip.focus_handle(cx).focus(window);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn paste(&mut self, _: &Paste, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
let images = cx
|
||||
.read_from_clipboard()
|
||||
.map(|item| {
|
||||
item.into_entries()
|
||||
.filter_map(|entry| {
|
||||
if let ClipboardEntry::Image(image) = entry {
|
||||
Some(image)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
if images.is_empty() {
|
||||
return;
|
||||
}
|
||||
cx.stop_propagation();
|
||||
|
||||
self.context_store.update(cx, |store, cx| {
|
||||
for image in images {
|
||||
store.add_image_instance(Arc::new(image), cx);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
|
||||
self.editing_message.take();
|
||||
cx.notify();
|
||||
@@ -1371,21 +1481,11 @@ impl ActiveThread {
|
||||
let Some((message_id, state)) = self.editing_message.take() else {
|
||||
return;
|
||||
};
|
||||
let edited_text = state.editor.read(cx).text(cx);
|
||||
let thread_model = self.thread.update(cx, |thread, cx| {
|
||||
thread.edit_message(
|
||||
message_id,
|
||||
Role::User,
|
||||
vec![MessageSegment::Text(edited_text)],
|
||||
cx,
|
||||
);
|
||||
for message_id in self.messages_after(message_id) {
|
||||
thread.delete_message(*message_id, cx);
|
||||
}
|
||||
thread.get_or_init_configured_model(cx)
|
||||
});
|
||||
|
||||
let Some(model) = thread_model else {
|
||||
let Some(model) = self
|
||||
.thread
|
||||
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx))
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -1394,11 +1494,45 @@ impl ActiveThread {
|
||||
return;
|
||||
}
|
||||
|
||||
self.thread.update(cx, |thread, cx| {
|
||||
thread.advance_prompt_id();
|
||||
thread.send_to_model(model.model, Some(window.window_handle()), cx);
|
||||
});
|
||||
cx.notify();
|
||||
let edited_text = state.editor.read(cx).text(cx);
|
||||
|
||||
let new_context = self
|
||||
.context_store
|
||||
.read(cx)
|
||||
.new_context_for_thread(self.thread.read(cx), Some(message_id));
|
||||
|
||||
let project = self.thread.read(cx).project().clone();
|
||||
let prompt_store = self.thread_store.read(cx).prompt_store().clone();
|
||||
|
||||
let load_context_task =
|
||||
crate::context::load_context(new_context, &project, &prompt_store, cx);
|
||||
self._load_edited_message_context_task =
|
||||
Some(cx.spawn_in(window, async move |this, cx| {
|
||||
let context = load_context_task.await;
|
||||
let _ = this
|
||||
.update_in(cx, |this, window, cx| {
|
||||
this.thread.update(cx, |thread, cx| {
|
||||
thread.edit_message(
|
||||
message_id,
|
||||
Role::User,
|
||||
vec![MessageSegment::Text(edited_text)],
|
||||
Some(context.loaded_context),
|
||||
cx,
|
||||
);
|
||||
for message_id in this.messages_after(message_id) {
|
||||
thread.delete_message(*message_id, cx);
|
||||
}
|
||||
});
|
||||
|
||||
this.thread.update(cx, |thread, cx| {
|
||||
thread.advance_prompt_id();
|
||||
thread.send_to_model(model.model, Some(window.window_handle()), cx);
|
||||
});
|
||||
this._load_edited_message_context_task = None;
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
}));
|
||||
}
|
||||
|
||||
fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
|
||||
@@ -1519,6 +1653,53 @@ impl ActiveThread {
|
||||
}
|
||||
}
|
||||
|
||||
fn render_edit_message_editor(
|
||||
&self,
|
||||
state: &EditingMessageState,
|
||||
window: &mut Window,
|
||||
cx: &Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let font_size = TextSize::Small.rems(cx);
|
||||
let line_height = font_size.to_pixels(window.rem_size()) * 1.75;
|
||||
|
||||
let colors = cx.theme().colors();
|
||||
|
||||
let text_style = TextStyle {
|
||||
color: colors.text,
|
||||
font_family: settings.buffer_font.family.clone(),
|
||||
font_fallbacks: settings.buffer_font.fallbacks.clone(),
|
||||
font_features: settings.buffer_font.features.clone(),
|
||||
font_size: font_size.into(),
|
||||
line_height: line_height.into(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
v_flex()
|
||||
.key_context("EditMessageEditor")
|
||||
.on_action(cx.listener(Self::toggle_context_picker))
|
||||
.on_action(cx.listener(Self::remove_all_context))
|
||||
.on_action(cx.listener(Self::move_up))
|
||||
.on_action(cx.listener(Self::cancel_editing_message))
|
||||
.on_action(cx.listener(Self::confirm_editing_message))
|
||||
.capture_action(cx.listener(Self::paste))
|
||||
.min_h_6()
|
||||
.flex_grow()
|
||||
.w_full()
|
||||
.gap_2()
|
||||
.child(EditorElement::new(
|
||||
&state.editor,
|
||||
EditorStyle {
|
||||
background: colors.editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
syntax: cx.theme().syntax().clone(),
|
||||
..Default::default()
|
||||
},
|
||||
))
|
||||
.child(state.context_strip.clone())
|
||||
}
|
||||
|
||||
fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
|
||||
let message_id = self.messages[ix];
|
||||
let Some(message) = self.thread.read(cx).message(message_id) else {
|
||||
@@ -1551,11 +1732,11 @@ impl ActiveThread {
|
||||
let generating_label = (is_generating && is_last_message)
|
||||
.then(|| AnimatedLabel::new("Generating").size(LabelSize::Small));
|
||||
|
||||
let edit_message_editor = self
|
||||
let editing_message_state = self
|
||||
.editing_message
|
||||
.as_ref()
|
||||
.filter(|(id, _)| *id == message_id)
|
||||
.map(|(_, state)| state.editor.clone());
|
||||
.map(|(_, state)| state);
|
||||
|
||||
let colors = cx.theme().colors();
|
||||
let editor_bg_color = colors.editor_background;
|
||||
@@ -1690,77 +1871,43 @@ impl ActiveThread {
|
||||
let has_content = !message_is_empty || !added_context.is_empty();
|
||||
|
||||
let message_content = has_content.then(|| {
|
||||
v_flex()
|
||||
.w_full()
|
||||
.gap_1()
|
||||
.when(!message_is_empty, |parent| {
|
||||
parent.child(
|
||||
if let Some(edit_message_editor) = edit_message_editor.clone() {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let font_size = TextSize::Small.rems(cx);
|
||||
let line_height = font_size.to_pixels(window.rem_size()) * 1.75;
|
||||
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.buffer_font.family.clone(),
|
||||
font_fallbacks: settings.buffer_font.fallbacks.clone(),
|
||||
font_features: settings.buffer_font.features.clone(),
|
||||
font_size: font_size.into(),
|
||||
line_height: line_height.into(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
div()
|
||||
.key_context("EditMessageEditor")
|
||||
.on_action(cx.listener(Self::cancel_editing_message))
|
||||
.on_action(cx.listener(Self::confirm_editing_message))
|
||||
.min_h_6()
|
||||
.flex_grow()
|
||||
.w_full()
|
||||
.child(EditorElement::new(
|
||||
&edit_message_editor,
|
||||
EditorStyle {
|
||||
background: colors.editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
syntax: cx.theme().syntax().clone(),
|
||||
..Default::default()
|
||||
},
|
||||
))
|
||||
.into_any()
|
||||
} else {
|
||||
div()
|
||||
.min_h_6()
|
||||
.child(self.render_message_content(
|
||||
message_id,
|
||||
rendered_message,
|
||||
has_tool_uses,
|
||||
workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
))
|
||||
.into_any()
|
||||
},
|
||||
)
|
||||
})
|
||||
.when(!added_context.is_empty(), |parent| {
|
||||
parent.child(h_flex().flex_wrap().gap_1().children(
|
||||
added_context.into_iter().map(|added_context| {
|
||||
let context = added_context.handle.clone();
|
||||
ContextPill::added(added_context, false, false, None).on_click(Rc::new(
|
||||
cx.listener({
|
||||
let workspace = workspace.clone();
|
||||
move |_, _, window, cx| {
|
||||
if let Some(workspace) = workspace.upgrade() {
|
||||
open_context(&context, workspace, window, cx);
|
||||
cx.notify();
|
||||
if let Some(state) = editing_message_state.as_ref() {
|
||||
self.render_edit_message_editor(state, window, cx)
|
||||
.into_any_element()
|
||||
} else {
|
||||
v_flex()
|
||||
.w_full()
|
||||
.gap_1()
|
||||
.when(!message_is_empty, |parent| {
|
||||
parent.child(div().min_h_6().child(self.render_message_content(
|
||||
message_id,
|
||||
rendered_message,
|
||||
has_tool_uses,
|
||||
workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
)))
|
||||
})
|
||||
.when(!added_context.is_empty(), |parent| {
|
||||
parent.child(h_flex().flex_wrap().gap_1().children(
|
||||
added_context.into_iter().map(|added_context| {
|
||||
let context = added_context.handle.clone();
|
||||
ContextPill::added(added_context, false, false, None).on_click(
|
||||
Rc::new(cx.listener({
|
||||
let workspace = workspace.clone();
|
||||
move |_, _, window, cx| {
|
||||
if let Some(workspace) = workspace.upgrade() {
|
||||
open_context(&context, workspace, window, cx);
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
))
|
||||
}),
|
||||
))
|
||||
})
|
||||
})),
|
||||
)
|
||||
}),
|
||||
))
|
||||
})
|
||||
.into_any_element()
|
||||
}
|
||||
});
|
||||
|
||||
let styled_message = match message.role {
|
||||
@@ -1785,8 +1932,8 @@ impl ActiveThread {
|
||||
.p_2p5()
|
||||
.gap_1()
|
||||
.children(message_content)
|
||||
.when_some(edit_message_editor.clone(), |this, edit_editor| {
|
||||
let edit_editor_clone = edit_editor.clone();
|
||||
.when_some(editing_message_state, |this, state| {
|
||||
let focus_handle = state.editor.focus_handle(cx).clone();
|
||||
this.w_full().justify_between().child(
|
||||
h_flex()
|
||||
.gap_0p5()
|
||||
@@ -1797,16 +1944,17 @@ impl ActiveThread {
|
||||
)
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.icon_color(Color::Error)
|
||||
.tooltip(move |window, cx| {
|
||||
let focus_handle =
|
||||
edit_editor_clone.focus_handle(cx);
|
||||
Tooltip::for_action_in(
|
||||
"Cancel Edit",
|
||||
&menu::Cancel,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.tooltip({
|
||||
let focus_handle = focus_handle.clone();
|
||||
move |window, cx| {
|
||||
Tooltip::for_action_in(
|
||||
"Cancel Edit",
|
||||
&menu::Cancel,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
})
|
||||
.on_click(cx.listener(Self::handle_cancel_click)),
|
||||
)
|
||||
@@ -1815,18 +1963,20 @@ impl ActiveThread {
|
||||
"confirm-edit-message",
|
||||
IconName::Check,
|
||||
)
|
||||
.disabled(edit_editor.read(cx).is_empty(cx))
|
||||
.disabled(state.editor.read(cx).is_empty(cx))
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.icon_color(Color::Success)
|
||||
.tooltip(move |window, cx| {
|
||||
let focus_handle = edit_editor.focus_handle(cx);
|
||||
Tooltip::for_action_in(
|
||||
"Regenerate",
|
||||
&menu::Confirm,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.tooltip({
|
||||
let focus_handle = focus_handle.clone();
|
||||
move |window, cx| {
|
||||
Tooltip::for_action_in(
|
||||
"Regenerate",
|
||||
&menu::Confirm,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
})
|
||||
.on_click(
|
||||
cx.listener(Self::handle_regenerate_click),
|
||||
@@ -1835,7 +1985,7 @@ impl ActiveThread {
|
||||
)
|
||||
}),
|
||||
)
|
||||
.when(edit_message_editor.is_none(), |this| {
|
||||
.when(editing_message_state.is_none(), |this| {
|
||||
this.tooltip(Tooltip::text("Click To Edit"))
|
||||
})
|
||||
.on_click(cx.listener({
|
||||
|
||||
@@ -6,6 +6,7 @@ mod assistant_panel;
|
||||
mod buffer_codegen;
|
||||
mod context;
|
||||
mod context_picker;
|
||||
mod context_server_configuration;
|
||||
mod context_store;
|
||||
mod context_strip;
|
||||
mod history_store;
|
||||
@@ -30,6 +31,7 @@ use command_palette_hooks::CommandPaletteFilter;
|
||||
use feature_flags::{Assistant2FeatureFlag, FeatureFlagAppExt};
|
||||
use fs::Fs;
|
||||
use gpui::{App, actions, impl_actions};
|
||||
use language::LanguageRegistry;
|
||||
use prompt_store::PromptBuilder;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
@@ -44,6 +46,8 @@ pub use crate::inline_assistant::InlineAssistant;
|
||||
pub use crate::thread::{Message, MessageSegment, Thread, ThreadEvent};
|
||||
pub use crate::thread_store::ThreadStore;
|
||||
pub use agent_diff::{AgentDiff, AgentDiffToolbar};
|
||||
pub use context_store::ContextStore;
|
||||
pub use ui::{all_agent_previews, get_agent_preview};
|
||||
|
||||
actions!(
|
||||
agent,
|
||||
@@ -60,7 +64,6 @@ actions!(
|
||||
AddContextServer,
|
||||
RemoveSelectedThread,
|
||||
Chat,
|
||||
ChatMode,
|
||||
CycleNextInlineAssist,
|
||||
CyclePreviousInlineAssist,
|
||||
FocusUp,
|
||||
@@ -107,11 +110,13 @@ pub fn init(
|
||||
fs: Arc<dyn Fs>,
|
||||
client: Arc<Client>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut App,
|
||||
) {
|
||||
AssistantSettings::register(cx);
|
||||
thread_store::init(cx);
|
||||
assistant_panel::init(cx);
|
||||
context_server_configuration::init(language_registry, cx);
|
||||
|
||||
inline_assistant::init(
|
||||
fs.clone(),
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
mod add_context_server_modal;
|
||||
mod configure_context_server_modal;
|
||||
mod manage_profiles_modal;
|
||||
mod tool_picker;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_tool::{ToolSource, ToolWorkingSet};
|
||||
use collections::HashMap;
|
||||
use context_server::manager::ContextServerManager;
|
||||
use context_server::manager::{ContextServer, ContextServerManager, ContextServerStatus};
|
||||
use fs::Fs;
|
||||
use gpui::{
|
||||
Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, ScrollHandle, Subscription,
|
||||
Action, Animation, AnimationExt as _, AnyView, App, Entity, EventEmitter, FocusHandle,
|
||||
Focusable, ScrollHandle, Subscription, pulsating_between,
|
||||
};
|
||||
use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
|
||||
use settings::{Settings, update_settings_file};
|
||||
@@ -22,6 +24,7 @@ use util::ResultExt as _;
|
||||
use zed_actions::ExtensionCategoryFilter;
|
||||
|
||||
pub(crate) use add_context_server_modal::AddContextServerModal;
|
||||
pub(crate) use configure_context_server_modal::ConfigureContextServerModal;
|
||||
pub(crate) use manage_profiles_modal::ManageProfilesModal;
|
||||
|
||||
use crate::AddContextServer;
|
||||
@@ -254,10 +257,12 @@ impl AssistantConfiguration {
|
||||
)
|
||||
}
|
||||
|
||||
fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
fn render_context_servers_section(
|
||||
&mut self,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let context_servers = self.context_server_manager.read(cx).all_servers().clone();
|
||||
let tools_by_source = self.tools.read(cx).tools_by_source(cx);
|
||||
let empty = Vec::new();
|
||||
|
||||
const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
|
||||
|
||||
@@ -272,136 +277,11 @@ impl AssistantConfiguration {
|
||||
.child(Headline::new("Model Context Protocol (MCP) Servers"))
|
||||
.child(Label::new(SUBHEADING).color(Color::Muted)),
|
||||
)
|
||||
.children(context_servers.into_iter().map(|context_server| {
|
||||
let is_running = context_server.client().is_some();
|
||||
let are_tools_expanded = self
|
||||
.expanded_context_server_tools
|
||||
.get(&context_server.id())
|
||||
.copied()
|
||||
.unwrap_or_default();
|
||||
|
||||
let tools = tools_by_source
|
||||
.get(&ToolSource::ContextServer {
|
||||
id: context_server.id().into(),
|
||||
})
|
||||
.unwrap_or_else(|| &empty);
|
||||
let tool_count = tools.len();
|
||||
|
||||
v_flex()
|
||||
.id(SharedString::from(context_server.id()))
|
||||
.border_1()
|
||||
.rounded_md()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.bg(cx.theme().colors().background.opacity(0.25))
|
||||
.child(
|
||||
h_flex()
|
||||
.p_1()
|
||||
.justify_between()
|
||||
.when(are_tools_expanded && tool_count > 1, |element| {
|
||||
element
|
||||
.border_b_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
})
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Disclosure::new("tool-list-disclosure", are_tools_expanded)
|
||||
.disabled(tool_count == 0)
|
||||
.on_click(cx.listener({
|
||||
let context_server_id = context_server.id();
|
||||
move |this, _event, _window, _cx| {
|
||||
let is_open = this
|
||||
.expanded_context_server_tools
|
||||
.entry(context_server_id.clone())
|
||||
.or_insert(false);
|
||||
|
||||
*is_open = !*is_open;
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(Indicator::dot().color(if is_running {
|
||||
Color::Success
|
||||
} else {
|
||||
Color::Error
|
||||
}))
|
||||
.child(Label::new(context_server.id()))
|
||||
.child(
|
||||
Label::new(format!("{tool_count} tools"))
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
Switch::new("context-server-switch", is_running.into())
|
||||
.color(SwitchColor::Accent)
|
||||
.on_click({
|
||||
let context_server_manager =
|
||||
self.context_server_manager.clone();
|
||||
let context_server = context_server.clone();
|
||||
move |state, _window, cx| match state {
|
||||
ToggleState::Unselected
|
||||
| ToggleState::Indeterminate => {
|
||||
context_server_manager.update(cx, |this, cx| {
|
||||
this.stop_server(context_server.clone(), cx)
|
||||
.log_err();
|
||||
});
|
||||
}
|
||||
ToggleState::Selected => {
|
||||
cx.spawn({
|
||||
let context_server_manager =
|
||||
context_server_manager.clone();
|
||||
let context_server = context_server.clone();
|
||||
async move |cx| {
|
||||
if let Some(start_server_task) =
|
||||
context_server_manager
|
||||
.update(cx, |this, cx| {
|
||||
this.start_server(
|
||||
context_server,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.log_err()
|
||||
{
|
||||
start_server_task.await.log_err();
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
}),
|
||||
),
|
||||
)
|
||||
.map(|parent| {
|
||||
if !are_tools_expanded {
|
||||
return parent;
|
||||
}
|
||||
|
||||
parent.child(v_flex().py_1p5().px_1().gap_1().children(
|
||||
tools.into_iter().enumerate().map(|(ix, tool)| {
|
||||
h_flex()
|
||||
.id(("tool-item", ix))
|
||||
.px_1()
|
||||
.gap_2()
|
||||
.justify_between()
|
||||
.hover(|style| style.bg(cx.theme().colors().element_hover))
|
||||
.rounded_sm()
|
||||
.child(
|
||||
Label::new(tool.name())
|
||||
.buffer_font(cx)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
Icon::new(IconName::Info)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Ignored),
|
||||
)
|
||||
.tooltip(Tooltip::text(tool.description()))
|
||||
}),
|
||||
))
|
||||
})
|
||||
}))
|
||||
.children(
|
||||
context_servers
|
||||
.into_iter()
|
||||
.map(|context_server| self.render_context_server(context_server, window, cx)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_between()
|
||||
@@ -429,7 +309,7 @@ impl AssistantConfiguration {
|
||||
.style(ButtonStyle::Filled)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.full_width()
|
||||
.icon(IconName::DatabaseZap)
|
||||
.icon(IconName::Hammer)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_position(IconPosition::Start)
|
||||
.on_click(|_event, window, cx| {
|
||||
@@ -447,10 +327,214 @@ impl AssistantConfiguration {
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_context_server(
|
||||
&self,
|
||||
context_server: Arc<ContextServer>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl use<> + IntoElement {
|
||||
let tools_by_source = self.tools.read(cx).tools_by_source(cx);
|
||||
let server_status = self
|
||||
.context_server_manager
|
||||
.read(cx)
|
||||
.status_for_server(&context_server.id());
|
||||
|
||||
let is_running = matches!(server_status, Some(ContextServerStatus::Running));
|
||||
|
||||
let error = if let Some(ContextServerStatus::Error(error)) = server_status.clone() {
|
||||
Some(error)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let are_tools_expanded = self
|
||||
.expanded_context_server_tools
|
||||
.get(&context_server.id())
|
||||
.copied()
|
||||
.unwrap_or_default();
|
||||
|
||||
let tools = tools_by_source
|
||||
.get(&ToolSource::ContextServer {
|
||||
id: context_server.id().into(),
|
||||
})
|
||||
.map_or([].as_slice(), |tools| tools.as_slice());
|
||||
let tool_count = tools.len();
|
||||
|
||||
let border_color = cx.theme().colors().border.opacity(0.6);
|
||||
|
||||
v_flex()
|
||||
.id(SharedString::from(context_server.id()))
|
||||
.border_1()
|
||||
.rounded_md()
|
||||
.border_color(border_color)
|
||||
.bg(cx.theme().colors().background.opacity(0.2))
|
||||
.overflow_hidden()
|
||||
.child(
|
||||
h_flex()
|
||||
.p_1()
|
||||
.justify_between()
|
||||
.when(
|
||||
error.is_some() || are_tools_expanded && tool_count > 1,
|
||||
|element| element.border_b_1().border_color(border_color),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.child(
|
||||
Disclosure::new(
|
||||
"tool-list-disclosure",
|
||||
are_tools_expanded || error.is_some(),
|
||||
)
|
||||
.disabled(tool_count == 0)
|
||||
.on_click(cx.listener({
|
||||
let context_server_id = context_server.id();
|
||||
move |this, _event, _window, _cx| {
|
||||
let is_open = this
|
||||
.expanded_context_server_tools
|
||||
.entry(context_server_id.clone())
|
||||
.or_insert(false);
|
||||
|
||||
*is_open = !*is_open;
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(match server_status {
|
||||
Some(ContextServerStatus::Starting) => {
|
||||
let color = Color::Success.color(cx);
|
||||
Indicator::dot()
|
||||
.color(Color::Success)
|
||||
.with_animation(
|
||||
SharedString::from(format!(
|
||||
"{}-starting",
|
||||
context_server.id(),
|
||||
)),
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
.with_easing(pulsating_between(0.4, 1.)),
|
||||
move |this, delta| {
|
||||
this.color(color.alpha(delta).into())
|
||||
},
|
||||
)
|
||||
.into_any_element()
|
||||
}
|
||||
Some(ContextServerStatus::Running) => {
|
||||
Indicator::dot().color(Color::Success).into_any_element()
|
||||
}
|
||||
Some(ContextServerStatus::Error(_)) => {
|
||||
Indicator::dot().color(Color::Error).into_any_element()
|
||||
}
|
||||
None => Indicator::dot().color(Color::Muted).into_any_element(),
|
||||
})
|
||||
.child(Label::new(context_server.id()).ml_0p5())
|
||||
.when(is_running, |this| {
|
||||
this.child(
|
||||
Label::new(if tool_count == 1 {
|
||||
SharedString::from("1 tool")
|
||||
} else {
|
||||
SharedString::from(format!("{} tools", tool_count))
|
||||
})
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
Switch::new("context-server-switch", is_running.into())
|
||||
.color(SwitchColor::Accent)
|
||||
.on_click({
|
||||
let context_server_manager = self.context_server_manager.clone();
|
||||
let context_server = context_server.clone();
|
||||
move |state, _window, cx| match state {
|
||||
ToggleState::Unselected | ToggleState::Indeterminate => {
|
||||
context_server_manager.update(cx, |this, cx| {
|
||||
this.stop_server(context_server.clone(), cx).log_err();
|
||||
});
|
||||
}
|
||||
ToggleState::Selected => {
|
||||
cx.spawn({
|
||||
let context_server_manager =
|
||||
context_server_manager.clone();
|
||||
let context_server = context_server.clone();
|
||||
async move |cx| {
|
||||
if let Some(start_server_task) =
|
||||
context_server_manager
|
||||
.update(cx, |this, cx| {
|
||||
this.start_server(context_server, cx)
|
||||
})
|
||||
.log_err()
|
||||
{
|
||||
start_server_task.await.log_err();
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
}),
|
||||
),
|
||||
)
|
||||
.map(|parent| {
|
||||
if let Some(error) = error {
|
||||
return parent.child(
|
||||
h_flex()
|
||||
.p_2()
|
||||
.gap_2()
|
||||
.items_start()
|
||||
.child(
|
||||
h_flex()
|
||||
.flex_none()
|
||||
.h(window.line_height() / 1.6_f32)
|
||||
.justify_center()
|
||||
.child(
|
||||
Icon::new(IconName::XCircle)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Error),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
div().w_full().child(
|
||||
Label::new(error)
|
||||
.buffer_font(cx)
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
if !are_tools_expanded || tools.is_empty() {
|
||||
return parent;
|
||||
}
|
||||
|
||||
parent.child(v_flex().py_1p5().px_1().gap_1().children(
|
||||
tools.into_iter().enumerate().map(|(ix, tool)| {
|
||||
h_flex()
|
||||
.id(("tool-item", ix))
|
||||
.px_1()
|
||||
.gap_2()
|
||||
.justify_between()
|
||||
.hover(|style| style.bg(cx.theme().colors().element_hover))
|
||||
.rounded_sm()
|
||||
.child(
|
||||
Label::new(tool.name())
|
||||
.buffer_font(cx)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
Icon::new(IconName::Info)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Ignored),
|
||||
)
|
||||
.tooltip(Tooltip::text(tool.description()))
|
||||
}),
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AssistantConfiguration {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
v_flex()
|
||||
.id("assistant-configuration")
|
||||
.key_context("AgentConfiguration")
|
||||
@@ -467,7 +551,7 @@ impl Render for AssistantConfiguration {
|
||||
.overflow_y_scroll()
|
||||
.child(self.render_command_permission(cx))
|
||||
.child(Divider::horizontal().color(DividerColor::Border))
|
||||
.child(self.render_context_servers_section(cx))
|
||||
.child(self.render_context_servers_section(window, cx))
|
||||
.child(Divider::horizontal().color(DividerColor::Border))
|
||||
.child(self.render_provider_configuration_section(cx)),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,443 @@
|
||||
use std::{
|
||||
sync::{Arc, Mutex},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use context_server::manager::{ContextServerManager, ContextServerStatus};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use extension::ContextServerConfiguration;
|
||||
use gpui::{
|
||||
Animation, AnimationExt, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Task,
|
||||
TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, percentage,
|
||||
};
|
||||
use language::{Language, LanguageRegistry};
|
||||
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
|
||||
use notifications::status_toast::{StatusToast, ToastIcon};
|
||||
use settings::{Settings as _, update_settings_file};
|
||||
use theme::ThemeSettings;
|
||||
use ui::{KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*};
|
||||
use util::ResultExt;
|
||||
use workspace::{ModalView, Workspace};
|
||||
|
||||
pub(crate) struct ConfigureContextServerModal {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_servers_to_setup: Vec<ConfigureContextServer>,
|
||||
context_server_manager: Entity<ContextServerManager>,
|
||||
}
|
||||
|
||||
struct ConfigureContextServer {
|
||||
id: Arc<str>,
|
||||
installation_instructions: Entity<markdown::Markdown>,
|
||||
settings_validator: Option<jsonschema::Validator>,
|
||||
settings_editor: Entity<Editor>,
|
||||
last_error: Option<SharedString>,
|
||||
waiting_for_context_server: bool,
|
||||
}
|
||||
|
||||
impl ConfigureContextServerModal {
|
||||
pub fn new(
|
||||
configurations: impl Iterator<Item = (Arc<str>, ContextServerConfiguration)>,
|
||||
jsonc_language: Option<Arc<Language>>,
|
||||
context_server_manager: Entity<ContextServerManager>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Option<Self> {
|
||||
let context_servers_to_setup = configurations
|
||||
.map(|(id, manifest)| {
|
||||
let jsonc_language = jsonc_language.clone();
|
||||
let settings_validator = jsonschema::validator_for(&manifest.settings_schema)
|
||||
.context("Failed to load JSON schema for context server settings")
|
||||
.log_err();
|
||||
ConfigureContextServer {
|
||||
id: id.clone(),
|
||||
installation_instructions: cx.new(|cx| {
|
||||
Markdown::new(
|
||||
manifest.installation_instructions.clone().into(),
|
||||
Some(language_registry.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
}),
|
||||
settings_validator,
|
||||
settings_editor: cx.new(|cx| {
|
||||
let mut editor = Editor::auto_height(16, window, cx);
|
||||
editor.set_text(manifest.default_settings.trim(), window, cx);
|
||||
if let Some(buffer) = editor.buffer().read(cx).as_singleton() {
|
||||
buffer.update(cx, |buffer, cx| buffer.set_language(jsonc_language, cx))
|
||||
}
|
||||
editor
|
||||
}),
|
||||
waiting_for_context_server: false,
|
||||
last_error: None,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if context_servers_to_setup.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Self {
|
||||
workspace,
|
||||
context_servers_to_setup,
|
||||
context_server_manager,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ConfigureContextServerModal {
|
||||
pub fn confirm(&mut self, cx: &mut Context<Self>) {
|
||||
if self.context_servers_to_setup.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(workspace) = self.workspace.upgrade() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let configuration = &mut self.context_servers_to_setup[0];
|
||||
if configuration.waiting_for_context_server {
|
||||
return;
|
||||
}
|
||||
|
||||
let settings_value = match serde_json_lenient::from_str::<serde_json::Value>(
|
||||
&configuration.settings_editor.read(cx).text(cx),
|
||||
) {
|
||||
Ok(value) => value,
|
||||
Err(error) => {
|
||||
configuration.last_error = Some(error.to_string().into());
|
||||
cx.notify();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(validator) = configuration.settings_validator.as_ref() {
|
||||
if let Err(error) = validator.validate(&settings_value) {
|
||||
configuration.last_error = Some(error.to_string().into());
|
||||
cx.notify();
|
||||
return;
|
||||
}
|
||||
}
|
||||
let id = configuration.id.clone();
|
||||
|
||||
let settings_changed = context_server::ContextServerSettings::get_global(cx)
|
||||
.context_servers
|
||||
.get(&id)
|
||||
.map_or(true, |config| {
|
||||
config.settings.as_ref() != Some(&settings_value)
|
||||
});
|
||||
|
||||
let is_running = self.context_server_manager.read(cx).status_for_server(&id)
|
||||
== Some(ContextServerStatus::Running);
|
||||
|
||||
if !settings_changed && is_running {
|
||||
self.complete_setup(id, cx);
|
||||
return;
|
||||
}
|
||||
|
||||
configuration.waiting_for_context_server = true;
|
||||
|
||||
let task = wait_for_context_server(&self.context_server_manager, id.clone(), cx);
|
||||
cx.spawn({
|
||||
let id = id.clone();
|
||||
async move |this, cx| {
|
||||
let result = task.await;
|
||||
this.update(cx, |this, cx| match result {
|
||||
Ok(_) => {
|
||||
this.complete_setup(id, cx);
|
||||
}
|
||||
Err(err) => {
|
||||
if let Some(configuration) = this.context_servers_to_setup.get_mut(0) {
|
||||
configuration.last_error = Some(err.into());
|
||||
configuration.waiting_for_context_server = false;
|
||||
} else {
|
||||
this.dismiss(cx);
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
// When we write the settings to the file, the context server will be restarted.
|
||||
update_settings_file::<context_server::ContextServerSettings>(
|
||||
workspace.read(cx).app_state().fs.clone(),
|
||||
cx,
|
||||
{
|
||||
let id = id.clone();
|
||||
|settings, _| {
|
||||
if let Some(server_config) = settings.context_servers.get_mut(&id) {
|
||||
server_config.settings = Some(settings_value);
|
||||
} else {
|
||||
settings.context_servers.insert(
|
||||
id,
|
||||
context_server::ServerConfig {
|
||||
settings: Some(settings_value),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn complete_setup(&mut self, id: Arc<str>, cx: &mut Context<Self>) {
|
||||
self.context_servers_to_setup.remove(0);
|
||||
cx.notify();
|
||||
|
||||
if !self.context_servers_to_setup.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
self.workspace
|
||||
.update(cx, {
|
||||
|workspace, cx| {
|
||||
let status_toast = StatusToast::new(
|
||||
format!("{} configured successfully.", id),
|
||||
cx,
|
||||
|this, _cx| {
|
||||
this.icon(ToastIcon::new(IconName::Hammer).color(Color::Muted))
|
||||
.action("Dismiss", |_, _| {})
|
||||
},
|
||||
);
|
||||
|
||||
workspace.toggle_status_toast(status_toast, cx);
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
|
||||
self.dismiss(cx);
|
||||
}
|
||||
|
||||
fn dismiss(&self, cx: &mut Context<Self>) {
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
}
|
||||
|
||||
fn wait_for_context_server(
|
||||
context_server_manager: &Entity<ContextServerManager>,
|
||||
context_server_id: Arc<str>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<(), Arc<str>>> {
|
||||
let (tx, rx) = futures::channel::oneshot::channel();
|
||||
let tx = Arc::new(Mutex::new(Some(tx)));
|
||||
|
||||
let subscription = cx.subscribe(context_server_manager, move |_, event, _cx| match event {
|
||||
context_server::manager::Event::ServerStatusChanged { server_id, status } => match status {
|
||||
Some(ContextServerStatus::Running) => {
|
||||
if server_id == &context_server_id {
|
||||
if let Some(tx) = tx.lock().unwrap().take() {
|
||||
let _ = tx.send(Ok(()));
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(ContextServerStatus::Error(error)) => {
|
||||
if server_id == &context_server_id {
|
||||
if let Some(tx) = tx.lock().unwrap().take() {
|
||||
let _ = tx.send(Err(error.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
});
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
let result = rx.await.unwrap();
|
||||
drop(subscription);
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
impl Render for ConfigureContextServerModal {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let Some(configuration) = self.context_servers_to_setup.first() else {
|
||||
return div().child("No context servers to setup");
|
||||
};
|
||||
|
||||
let focus_handle = self.focus_handle(cx);
|
||||
|
||||
div()
|
||||
.elevation_3(cx)
|
||||
.w(rems(34.))
|
||||
.key_context("ConfigureContextServerModal")
|
||||
.on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| this.confirm(cx)))
|
||||
.on_action(cx.listener(|this, _: &menu::Cancel, _window, cx| this.dismiss(cx)))
|
||||
.capture_any_mouse_down(cx.listener(|this, _, window, cx| {
|
||||
this.focus_handle(cx).focus(window);
|
||||
}))
|
||||
.child(
|
||||
Modal::new("configure-context-server", None)
|
||||
.header(ModalHeader::new().headline(format!("Configure {}", configuration.id)))
|
||||
.section(
|
||||
Section::new()
|
||||
.child(div().pb_2().text_sm().child(MarkdownElement::new(
|
||||
configuration.installation_instructions.clone(),
|
||||
default_markdown_style(window, cx),
|
||||
)))
|
||||
.child(
|
||||
div()
|
||||
.p_2()
|
||||
.rounded_md()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.gap_1()
|
||||
.child({
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.buffer_font.family.clone(),
|
||||
font_fallbacks: settings.buffer_font.fallbacks.clone(),
|
||||
font_size: settings.buffer_font_size(cx).into(),
|
||||
font_weight: settings.buffer_font.weight,
|
||||
line_height: relative(
|
||||
settings.buffer_line_height.value(),
|
||||
),
|
||||
..Default::default()
|
||||
};
|
||||
EditorElement::new(
|
||||
&configuration.settings_editor,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
syntax: cx.theme().syntax().clone(),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
})
|
||||
.when_some(configuration.last_error.clone(), |this, error| {
|
||||
this.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.child(
|
||||
Icon::new(IconName::Warning)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Warning),
|
||||
)
|
||||
.child(
|
||||
div().w_full().child(
|
||||
Label::new(error)
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
),
|
||||
),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.when(configuration.waiting_for_context_server, |this| {
|
||||
this.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.child(
|
||||
Icon::new(IconName::ArrowCircle)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Info)
|
||||
.with_animation(
|
||||
"arrow-circle",
|
||||
Animation::new(Duration::from_secs(2)).repeat(),
|
||||
|icon, delta| {
|
||||
icon.transform(Transformation::rotate(
|
||||
percentage(delta),
|
||||
))
|
||||
},
|
||||
)
|
||||
.into_any_element(),
|
||||
)
|
||||
.child(
|
||||
Label::new("Waiting for Context Server")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.footer(
|
||||
ModalFooter::new().end_slot(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
Button::new("cancel", "Cancel")
|
||||
.key_binding(
|
||||
KeyBinding::for_action_in(
|
||||
&menu::Cancel,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.map(|kb| kb.size(rems_from_px(12.))),
|
||||
)
|
||||
.on_click(cx.listener(|this, _event, _window, cx| {
|
||||
this.dismiss(cx)
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new("configure-server", "Configure MCP")
|
||||
.disabled(configuration.waiting_for_context_server)
|
||||
.key_binding(
|
||||
KeyBinding::for_action_in(
|
||||
&menu::Confirm,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.map(|kb| kb.size(rems_from_px(12.))),
|
||||
)
|
||||
.on_click(cx.listener(|this, _event, _window, cx| {
|
||||
this.confirm(cx)
|
||||
})),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
|
||||
let theme_settings = ThemeSettings::get_global(cx);
|
||||
let colors = cx.theme().colors();
|
||||
let mut text_style = window.text_style();
|
||||
text_style.refine(&TextStyleRefinement {
|
||||
font_family: Some(theme_settings.ui_font.family.clone()),
|
||||
font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
|
||||
font_features: Some(theme_settings.ui_font.features.clone()),
|
||||
font_size: Some(TextSize::XSmall.rems(cx).into()),
|
||||
color: Some(colors.text_muted),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
MarkdownStyle {
|
||||
base_text_style: text_style.clone(),
|
||||
selection_background_color: cx.theme().players().local().selection,
|
||||
link: TextStyleRefinement {
|
||||
background_color: Some(colors.editor_foreground.opacity(0.025)),
|
||||
underline: Some(UnderlineStyle {
|
||||
color: Some(colors.text_accent.opacity(0.5)),
|
||||
thickness: px(1.),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
impl ModalView for ConfigureContextServerModal {}
|
||||
impl EventEmitter<DismissEvent> for ConfigureContextServerModal {}
|
||||
impl Focusable for ConfigureContextServerModal {
|
||||
fn focus_handle(&self, cx: &App) -> FocusHandle {
|
||||
if let Some(current) = self.context_servers_to_setup.first() {
|
||||
current.settings_editor.read(cx).focus_handle(cx)
|
||||
} else {
|
||||
cx.focus_handle()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -401,11 +401,12 @@ impl AssistantPanel {
|
||||
}
|
||||
});
|
||||
|
||||
let thread_id = thread.read(cx).id().clone();
|
||||
let history_store = cx.new(|cx| {
|
||||
HistoryStore::new(
|
||||
thread_store.clone(),
|
||||
context_store.clone(),
|
||||
[RecentEntry::Thread(thread.clone())],
|
||||
[RecentEntry::Thread(thread_id, thread.clone())],
|
||||
cx,
|
||||
)
|
||||
});
|
||||
@@ -423,6 +424,7 @@ impl AssistantPanel {
|
||||
ActiveThread::new(
|
||||
thread.clone(),
|
||||
thread_store.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
language_registry.clone(),
|
||||
workspace.clone(),
|
||||
window,
|
||||
@@ -456,7 +458,8 @@ impl AssistantPanel {
|
||||
|
||||
for entry in recently_opened.iter() {
|
||||
let summary = entry.summary(cx);
|
||||
menu = menu.entry_with_end_slot(
|
||||
|
||||
menu = menu.entry_with_end_slot_on_hover(
|
||||
summary,
|
||||
None,
|
||||
{
|
||||
@@ -467,7 +470,7 @@ impl AssistantPanel {
|
||||
.update(cx, {
|
||||
let entry = entry.clone();
|
||||
move |this, cx| match entry {
|
||||
RecentEntry::Thread(thread) => {
|
||||
RecentEntry::Thread(_, thread) => {
|
||||
this.open_thread(thread, window, cx)
|
||||
}
|
||||
RecentEntry::Context(context) => {
|
||||
@@ -625,7 +628,7 @@ impl AssistantPanel {
|
||||
let thread_view = ActiveView::thread(thread.clone(), window, cx);
|
||||
self.set_active_view(thread_view, window, cx);
|
||||
|
||||
let message_editor_context_store = cx.new(|_cx| {
|
||||
let context_store = cx.new(|_cx| {
|
||||
crate::context_store::ContextStore::new(
|
||||
self.project.downgrade(),
|
||||
Some(self.thread_store.downgrade()),
|
||||
@@ -638,7 +641,7 @@ impl AssistantPanel {
|
||||
.update(cx, |this, cx| this.open_thread(&other_thread_id, cx));
|
||||
|
||||
cx.spawn({
|
||||
let context_store = message_editor_context_store.clone();
|
||||
let context_store = context_store.clone();
|
||||
|
||||
async move |_panel, cx| {
|
||||
let other_thread = other_thread_task.await?;
|
||||
@@ -663,6 +666,7 @@ impl AssistantPanel {
|
||||
ActiveThread::new(
|
||||
thread.clone(),
|
||||
self.thread_store.clone(),
|
||||
context_store.clone(),
|
||||
self.language_registry.clone(),
|
||||
self.workspace.clone(),
|
||||
window,
|
||||
@@ -681,7 +685,7 @@ impl AssistantPanel {
|
||||
MessageEditor::new(
|
||||
self.fs.clone(),
|
||||
self.workspace.clone(),
|
||||
message_editor_context_store,
|
||||
context_store,
|
||||
self.prompt_store.clone(),
|
||||
self.thread_store.downgrade(),
|
||||
thread,
|
||||
@@ -842,7 +846,7 @@ impl AssistantPanel {
|
||||
) {
|
||||
let thread_view = ActiveView::thread(thread.clone(), window, cx);
|
||||
self.set_active_view(thread_view, window, cx);
|
||||
let message_editor_context_store = cx.new(|_cx| {
|
||||
let context_store = cx.new(|_cx| {
|
||||
crate::context_store::ContextStore::new(
|
||||
self.project.downgrade(),
|
||||
Some(self.thread_store.downgrade()),
|
||||
@@ -859,6 +863,7 @@ impl AssistantPanel {
|
||||
ActiveThread::new(
|
||||
thread.clone(),
|
||||
self.thread_store.clone(),
|
||||
context_store.clone(),
|
||||
self.language_registry.clone(),
|
||||
self.workspace.clone(),
|
||||
window,
|
||||
@@ -877,7 +882,7 @@ impl AssistantPanel {
|
||||
MessageEditor::new(
|
||||
self.fs.clone(),
|
||||
self.workspace.clone(),
|
||||
message_editor_context_store,
|
||||
context_store,
|
||||
self.prompt_store.clone(),
|
||||
self.thread_store.downgrade(),
|
||||
thread,
|
||||
@@ -1099,7 +1104,8 @@ impl AssistantPanel {
|
||||
ActiveView::Thread { thread, .. } => self.history_store.update(cx, |store, cx| {
|
||||
if let Some(thread) = thread.upgrade() {
|
||||
if thread.read(cx).is_empty() {
|
||||
store.remove_recently_opened_entry(&RecentEntry::Thread(thread), cx);
|
||||
let id = thread.read(cx).id().clone();
|
||||
store.remove_recently_opened_thread(id, cx);
|
||||
}
|
||||
}
|
||||
}),
|
||||
@@ -1109,7 +1115,8 @@ impl AssistantPanel {
|
||||
match &new_view {
|
||||
ActiveView::Thread { thread, .. } => self.history_store.update(cx, |store, cx| {
|
||||
if let Some(thread) = thread.upgrade() {
|
||||
store.push_recently_opened_entry(RecentEntry::Thread(thread), cx);
|
||||
let id = thread.read(cx).id().clone();
|
||||
store.push_recently_opened_entry(RecentEntry::Thread(id, thread), cx);
|
||||
}
|
||||
}),
|
||||
ActiveView::PromptEditor { context_editor, .. } => {
|
||||
|
||||
@@ -3,11 +3,12 @@ use std::hash::{Hash, Hasher};
|
||||
use std::path::PathBuf;
|
||||
use std::{ops::Range, path::Path, sync::Arc};
|
||||
|
||||
use assistant_tool::outline;
|
||||
use collections::HashSet;
|
||||
use futures::future;
|
||||
use futures::{FutureExt, future::Shared};
|
||||
use gpui::{App, AppContext as _, Entity, SharedString, Task};
|
||||
use language::Buffer;
|
||||
use language::{Buffer, ParseStatus};
|
||||
use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
|
||||
use project::{Project, ProjectEntryId, ProjectPath, Worktree};
|
||||
use prompt_store::{PromptStore, UserPromptId};
|
||||
@@ -152,6 +153,7 @@ pub struct FileContext {
|
||||
pub handle: FileContextHandle,
|
||||
pub full_path: Arc<Path>,
|
||||
pub text: SharedString,
|
||||
pub is_outline: bool,
|
||||
}
|
||||
|
||||
impl FileContextHandle {
|
||||
@@ -177,14 +179,51 @@ impl FileContextHandle {
|
||||
log::error!("file context missing path");
|
||||
return Task::ready(None);
|
||||
};
|
||||
let full_path = file.full_path(cx);
|
||||
let full_path: Arc<Path> = file.full_path(cx).into();
|
||||
let rope = buffer_ref.as_rope().clone();
|
||||
let buffer = self.buffer.clone();
|
||||
cx.background_spawn(async move {
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// For large files, use outline instead of full content
|
||||
if rope.len() > outline::AUTO_OUTLINE_SIZE {
|
||||
// Wait until the buffer has been fully parsed, so we can read its outline
|
||||
if let Ok(mut parse_status) =
|
||||
buffer.read_with(cx, |buffer, _| buffer.parse_status())
|
||||
{
|
||||
while *parse_status.borrow() != ParseStatus::Idle {
|
||||
parse_status.changed().await.log_err();
|
||||
}
|
||||
|
||||
if let Ok(snapshot) = buffer.read_with(cx, |buffer, _| buffer.snapshot()) {
|
||||
if let Some(outline) = snapshot.outline(None) {
|
||||
let items = outline
|
||||
.items
|
||||
.into_iter()
|
||||
.map(|item| item.to_point(&snapshot));
|
||||
|
||||
if let Ok(outline_text) =
|
||||
outline::render_outline(items, None, 0, usize::MAX).await
|
||||
{
|
||||
let context = AgentContext::File(FileContext {
|
||||
handle: self,
|
||||
full_path,
|
||||
text: outline_text.into(),
|
||||
is_outline: true,
|
||||
});
|
||||
return Some((context, vec![buffer]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to full content if we couldn't build an outline
|
||||
// (or didn't need to because the file was small enough)
|
||||
let context = AgentContext::File(FileContext {
|
||||
handle: self,
|
||||
full_path: full_path.into(),
|
||||
full_path,
|
||||
text: rope.to_string().into(),
|
||||
is_outline: false,
|
||||
});
|
||||
Some((context, vec![buffer]))
|
||||
})
|
||||
@@ -804,7 +843,7 @@ pub fn load_context(
|
||||
text.push_str(
|
||||
"\n<context>\n\
|
||||
The following items were attached by the user. \
|
||||
You don't need to use other tools to read them.\n\n",
|
||||
They are up-to-date and don't need to be re-read.\n\n",
|
||||
);
|
||||
|
||||
if !file_context.is_empty() {
|
||||
@@ -996,3 +1035,115 @@ impl Hash for AgentContextKey {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
|
||||
fn init_test_settings(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
Project::init_settings(cx);
|
||||
});
|
||||
}
|
||||
|
||||
// Helper to create a test project with test files
|
||||
async fn create_test_project(
|
||||
cx: &mut TestAppContext,
|
||||
files: serde_json::Value,
|
||||
) -> Entity<Project> {
|
||||
let fs = FakeFs::new(cx.background_executor.clone());
|
||||
fs.insert_tree(path!("/test"), files).await;
|
||||
Project::test(fs, [path!("/test").as_ref()], cx).await
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_large_file_uses_outline(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
||||
// Create a large file that exceeds AUTO_OUTLINE_SIZE
|
||||
const LINE: &str = "Line with some text\n";
|
||||
let large_content = LINE.repeat(2 * (outline::AUTO_OUTLINE_SIZE / LINE.len()));
|
||||
let content_len = large_content.len();
|
||||
|
||||
assert!(content_len > outline::AUTO_OUTLINE_SIZE);
|
||||
|
||||
let file_context = file_context_for(large_content, cx).await;
|
||||
|
||||
assert!(
|
||||
file_context.is_outline,
|
||||
"Large file should use outline format"
|
||||
);
|
||||
|
||||
assert!(
|
||||
file_context.text.len() < content_len,
|
||||
"Outline should be smaller than original content"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_small_file_uses_full_content(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
||||
let small_content = "This is a small file.\n";
|
||||
let content_len = small_content.len();
|
||||
|
||||
assert!(content_len < outline::AUTO_OUTLINE_SIZE);
|
||||
|
||||
let file_context = file_context_for(small_content.to_string(), cx).await;
|
||||
|
||||
assert!(
|
||||
!file_context.is_outline,
|
||||
"Small files should not get an outline"
|
||||
);
|
||||
|
||||
assert_eq!(file_context.text, small_content);
|
||||
}
|
||||
|
||||
async fn file_context_for(content: String, cx: &mut TestAppContext) -> FileContext {
|
||||
// Create a test project with the file
|
||||
let project = create_test_project(
|
||||
cx,
|
||||
json!({
|
||||
"file.txt": content,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Open the buffer
|
||||
let buffer_path = project
|
||||
.read_with(cx, |project, cx| project.find_project_path("file.txt", cx))
|
||||
.unwrap();
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(buffer_path, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let context_handle = AgentContextHandle::File(FileContextHandle {
|
||||
buffer: buffer.clone(),
|
||||
context_id: ContextId::zero(),
|
||||
});
|
||||
|
||||
cx.update(|cx| load_context(vec![context_handle], &project, &None, cx))
|
||||
.await
|
||||
.loaded_context
|
||||
.contexts
|
||||
.into_iter()
|
||||
.find_map(|ctx| {
|
||||
if let AgentContext::File(file_ctx) = ctx {
|
||||
Some(file_ctx)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.expect("Should have found a file context")
|
||||
}
|
||||
}
|
||||
|
||||
120
crates/agent/src/context_server_configuration.rs
Normal file
120
crates/agent/src/context_server_configuration.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use context_server::ContextServerDescriptorRegistry;
|
||||
use extension::ExtensionManifest;
|
||||
use language::LanguageRegistry;
|
||||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::{AssistantPanel, assistant_configuration::ConfigureContextServerModal};
|
||||
|
||||
pub(crate) fn init(language_registry: Arc<LanguageRegistry>, cx: &mut App) {
|
||||
cx.observe_new(move |_: &mut Workspace, window, cx| {
|
||||
let Some(window) = window else {
|
||||
return;
|
||||
};
|
||||
|
||||
if let Some(extension_events) = extension::ExtensionEvents::try_global(cx).as_ref() {
|
||||
cx.subscribe_in(extension_events, window, {
|
||||
let language_registry = language_registry.clone();
|
||||
move |workspace, _, event, window, cx| match event {
|
||||
extension::Event::ExtensionInstalled(manifest) => {
|
||||
show_configure_mcp_modal(
|
||||
language_registry.clone(),
|
||||
manifest,
|
||||
workspace,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
extension::Event::ConfigureExtensionRequested(manifest) => {
|
||||
if !manifest.context_servers.is_empty() {
|
||||
show_configure_mcp_modal(
|
||||
language_registry.clone(),
|
||||
manifest,
|
||||
workspace,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
} else {
|
||||
log::info!(
|
||||
"No extension events global found. Skipping context server configuration wizard"
|
||||
);
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn show_configure_mcp_modal(
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
manifest: &Arc<ExtensionManifest>,
|
||||
workspace: &mut Workspace,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<'_, Workspace>,
|
||||
) {
|
||||
let Some(context_server_manager) = workspace.panel::<AssistantPanel>(cx).map(|panel| {
|
||||
panel
|
||||
.read(cx)
|
||||
.thread_store()
|
||||
.read(cx)
|
||||
.context_server_manager()
|
||||
}) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let registry = ContextServerDescriptorRegistry::global(cx).read(cx);
|
||||
let project = workspace.project().clone();
|
||||
let configuration_tasks = manifest
|
||||
.context_servers
|
||||
.keys()
|
||||
.cloned()
|
||||
.filter_map({
|
||||
|key| {
|
||||
let descriptor = registry.context_server_descriptor(&key)?;
|
||||
Some(cx.spawn({
|
||||
let project = project.clone();
|
||||
async move |_, cx| {
|
||||
descriptor
|
||||
.configuration(project, &cx)
|
||||
.await
|
||||
.context("Failed to resolve context server configuration")
|
||||
.log_err()
|
||||
.flatten()
|
||||
.map(|config| (key, config))
|
||||
}
|
||||
}))
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let jsonc_language = language_registry.language_for_name("jsonc");
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let descriptors = futures::future::join_all(configuration_tasks).await;
|
||||
let jsonc_language = jsonc_language.await.ok();
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
let modal = ConfigureContextServerModal::new(
|
||||
descriptors.into_iter().flatten(),
|
||||
jsonc_language,
|
||||
context_server_manager,
|
||||
language_registry,
|
||||
cx.entity().downgrade(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
if let Some(modal) = modal {
|
||||
this.toggle_modal(window, cx, |_, _| modal);
|
||||
}
|
||||
})
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
@@ -21,7 +21,7 @@ use crate::context::{
|
||||
SymbolContextHandle, ThreadContextHandle,
|
||||
};
|
||||
use crate::context_strip::SuggestedContext;
|
||||
use crate::thread::{Thread, ThreadId};
|
||||
use crate::thread::{MessageId, Thread, ThreadId};
|
||||
|
||||
pub struct ContextStore {
|
||||
project: WeakEntity<Project>,
|
||||
@@ -54,9 +54,14 @@ impl ContextStore {
|
||||
self.context_thread_ids.clear();
|
||||
}
|
||||
|
||||
pub fn new_context_for_thread(&self, thread: &Thread) -> Vec<AgentContextHandle> {
|
||||
pub fn new_context_for_thread(
|
||||
&self,
|
||||
thread: &Thread,
|
||||
exclude_messages_from_id: Option<MessageId>,
|
||||
) -> Vec<AgentContextHandle> {
|
||||
let existing_context = thread
|
||||
.messages()
|
||||
.take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
|
||||
.flat_map(|message| {
|
||||
message
|
||||
.loaded_context
|
||||
|
||||
@@ -36,16 +36,28 @@ impl HistoryEntry {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) enum RecentEntry {
|
||||
Thread(Entity<Thread>),
|
||||
Thread(ThreadId, Entity<Thread>),
|
||||
Context(Entity<AssistantContext>),
|
||||
}
|
||||
|
||||
impl PartialEq for RecentEntry {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(Self::Thread(l0, _), Self::Thread(r0, _)) => l0 == r0,
|
||||
(Self::Context(l0), Self::Context(r0)) => l0 == r0,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for RecentEntry {}
|
||||
|
||||
impl RecentEntry {
|
||||
pub(crate) fn summary(&self, cx: &App) -> SharedString {
|
||||
match self {
|
||||
RecentEntry::Thread(thread) => thread.read(cx).summary_or_default(),
|
||||
RecentEntry::Thread(_, thread) => thread.read(cx).summary_or_default(),
|
||||
RecentEntry::Context(context) => context.read(cx).summary_or_default(),
|
||||
}
|
||||
}
|
||||
@@ -93,9 +105,10 @@ impl HistoryStore {
|
||||
.map(|serialized| match serialized {
|
||||
SerializedRecentEntry::Thread(id) => thread_store
|
||||
.update(cx, |thread_store, cx| {
|
||||
let thread_id = ThreadId::from(id.as_str());
|
||||
thread_store
|
||||
.open_thread(&ThreadId::from(id.as_str()), cx)
|
||||
.map_ok(RecentEntry::Thread)
|
||||
.open_thread(&thread_id, cx)
|
||||
.map_ok(|thread| RecentEntry::Thread(thread_id, thread))
|
||||
.boxed()
|
||||
})
|
||||
.unwrap_or_else(|_| async { Err(anyhow!("no thread store")) }.boxed()),
|
||||
@@ -170,9 +183,7 @@ impl HistoryStore {
|
||||
RecentEntry::Context(context) => Some(SerializedRecentEntry::Context(
|
||||
context.read(cx).path()?.to_str()?.to_owned(),
|
||||
)),
|
||||
RecentEntry::Thread(thread) => Some(SerializedRecentEntry::Thread(
|
||||
thread.read(cx).id().to_string(),
|
||||
)),
|
||||
RecentEntry::Thread(id, _) => Some(SerializedRecentEntry::Thread(id.to_string())),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
@@ -200,6 +211,14 @@ impl HistoryStore {
|
||||
self.save_recently_opened_entries(cx);
|
||||
}
|
||||
|
||||
pub fn remove_recently_opened_thread(&mut self, id: ThreadId, cx: &mut Context<Self>) {
|
||||
self.recently_opened_entries.retain(|entry| match entry {
|
||||
RecentEntry::Thread(thread_id, _) if thread_id == &id => false,
|
||||
_ => true,
|
||||
});
|
||||
self.save_recently_opened_entries(cx);
|
||||
}
|
||||
|
||||
pub fn remove_recently_opened_entry(&mut self, entry: &RecentEntry, cx: &mut Context<Self>) {
|
||||
self.recently_opened_entries
|
||||
.retain(|old_entry| old_entry != entry);
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::sync::Arc;
|
||||
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
|
||||
use crate::context::{ContextLoadResult, load_context};
|
||||
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
|
||||
use crate::ui::AnimatedLabel;
|
||||
use crate::ui::{AgentPreview, AnimatedLabel};
|
||||
use buffer_diff::BufferDiff;
|
||||
use collections::HashSet;
|
||||
use editor::actions::{MoveUp, Paste};
|
||||
@@ -42,10 +42,11 @@ use crate::profile_selector::ProfileSelector;
|
||||
use crate::thread::{Thread, TokenUsageRatio};
|
||||
use crate::thread_store::ThreadStore;
|
||||
use crate::{
|
||||
AgentDiff, Chat, ChatMode, ExpandMessageEditor, NewThread, OpenAgentDiff, RemoveAllContext,
|
||||
ToggleContextPicker, ToggleProfileSelector,
|
||||
ActiveThread, AgentDiff, Chat, ExpandMessageEditor, NewThread, OpenAgentDiff, RemoveAllContext,
|
||||
ToggleContextPicker, ToggleProfileSelector, register_agent_preview,
|
||||
};
|
||||
|
||||
#[derive(RegisterComponent)]
|
||||
pub struct MessageEditor {
|
||||
thread: Entity<Thread>,
|
||||
incompatible_tools_state: Entity<IncompatibleToolsState>,
|
||||
@@ -69,6 +70,56 @@ pub struct MessageEditor {
|
||||
|
||||
const MAX_EDITOR_LINES: usize = 8;
|
||||
|
||||
pub(crate) fn create_editor(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Entity<Editor> {
|
||||
let language = Language::new(
|
||||
language::LanguageConfig {
|
||||
completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']),
|
||||
..Default::default()
|
||||
},
|
||||
None,
|
||||
);
|
||||
|
||||
let editor = cx.new(|cx| {
|
||||
let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
|
||||
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let mut editor = Editor::new(
|
||||
editor::EditorMode::AutoHeight {
|
||||
max_lines: MAX_EDITOR_LINES,
|
||||
},
|
||||
buffer,
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
editor.set_placeholder_text("Ask anything, @ to mention, ↑ to select", cx);
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_soft_wrap();
|
||||
editor.set_context_menu_options(ContextMenuOptions {
|
||||
min_entries_visible: 12,
|
||||
max_entries_visible: 12,
|
||||
placement: Some(ContextMenuPlacement::Above),
|
||||
});
|
||||
editor
|
||||
});
|
||||
|
||||
let editor_entity = editor.downgrade();
|
||||
editor.update(cx, |editor, _| {
|
||||
editor.set_completion_provider(Some(Box::new(ContextPickerCompletionProvider::new(
|
||||
workspace,
|
||||
context_store,
|
||||
Some(thread_store),
|
||||
editor_entity,
|
||||
))));
|
||||
});
|
||||
editor
|
||||
}
|
||||
|
||||
impl MessageEditor {
|
||||
pub fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
@@ -83,47 +134,14 @@ impl MessageEditor {
|
||||
let context_picker_menu_handle = PopoverMenuHandle::default();
|
||||
let model_selector_menu_handle = PopoverMenuHandle::default();
|
||||
|
||||
let language = Language::new(
|
||||
language::LanguageConfig {
|
||||
completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']),
|
||||
..Default::default()
|
||||
},
|
||||
None,
|
||||
let editor = create_editor(
|
||||
workspace.clone(),
|
||||
context_store.downgrade(),
|
||||
thread_store.clone(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
let editor = cx.new(|cx| {
|
||||
let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
|
||||
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let mut editor = Editor::new(
|
||||
editor::EditorMode::AutoHeight {
|
||||
max_lines: MAX_EDITOR_LINES,
|
||||
},
|
||||
buffer,
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
editor.set_placeholder_text("Ask anything, @ to mention, ↑ to select", cx);
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_soft_wrap();
|
||||
editor.set_context_menu_options(ContextMenuOptions {
|
||||
min_entries_visible: 12,
|
||||
max_entries_visible: 12,
|
||||
placement: Some(ContextMenuPlacement::Above),
|
||||
});
|
||||
editor
|
||||
});
|
||||
|
||||
let editor_entity = editor.downgrade();
|
||||
editor.update(cx, |editor, _| {
|
||||
editor.set_completion_provider(Some(Box::new(ContextPickerCompletionProvider::new(
|
||||
workspace.clone(),
|
||||
context_store.downgrade(),
|
||||
Some(thread_store.clone()),
|
||||
editor_entity,
|
||||
))));
|
||||
});
|
||||
|
||||
let context_strip = cx.new(|cx| {
|
||||
ContextStrip::new(
|
||||
context_store.clone(),
|
||||
@@ -189,10 +207,6 @@ impl MessageEditor {
|
||||
&self.context_store
|
||||
}
|
||||
|
||||
fn toggle_chat_mode(&mut self, _: &ChatMode, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn expand_message_editor(
|
||||
&mut self,
|
||||
_: &ExpandMessageEditor,
|
||||
@@ -415,12 +429,13 @@ impl MessageEditor {
|
||||
Some(
|
||||
IconButton::new("max-mode", IconName::ZedMaxMode)
|
||||
.icon_size(IconSize::Small)
|
||||
.toggle_state(active_completion_mode == Some(CompletionMode::Max))
|
||||
.icon_color(Color::Muted)
|
||||
.toggle_state(active_completion_mode == CompletionMode::Max)
|
||||
.on_click(cx.listener(move |this, _event, _window, cx| {
|
||||
this.thread.update(cx, |thread, _cx| {
|
||||
thread.set_completion_mode(match active_completion_mode {
|
||||
Some(CompletionMode::Max) => Some(CompletionMode::Normal),
|
||||
Some(CompletionMode::Normal) | None => Some(CompletionMode::Max),
|
||||
CompletionMode::Max => CompletionMode::Normal,
|
||||
CompletionMode::Normal => CompletionMode::Max,
|
||||
});
|
||||
});
|
||||
}))
|
||||
@@ -482,7 +497,6 @@ impl MessageEditor {
|
||||
.on_action(cx.listener(Self::toggle_context_picker))
|
||||
.on_action(cx.listener(Self::remove_all_context))
|
||||
.on_action(cx.listener(Self::move_up))
|
||||
.on_action(cx.listener(Self::toggle_chat_mode))
|
||||
.on_action(cx.listener(Self::expand_message_editor))
|
||||
.capture_action(cx.listener(Self::paste))
|
||||
.gap_2()
|
||||
@@ -1041,7 +1055,7 @@ impl MessageEditor {
|
||||
let load_task = cx.spawn(async move |this, cx| {
|
||||
let Ok(load_task) = this.update(cx, |this, cx| {
|
||||
let new_context = this.context_store.read_with(cx, |context_store, cx| {
|
||||
context_store.new_context_for_thread(this.thread.read(cx))
|
||||
context_store.new_context_for_thread(this.thread.read(cx), None)
|
||||
});
|
||||
load_context(new_context, &this.project, &this.prompt_store, cx)
|
||||
}) else {
|
||||
@@ -1189,3 +1203,53 @@ impl Render for MessageEditor {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Component for MessageEditor {
|
||||
fn scope() -> ComponentScope {
|
||||
ComponentScope::Agent
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentPreview for MessageEditor {
|
||||
fn create_preview(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
active_thread: Entity<ActiveThread>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Option<AnyElement> {
|
||||
if let Some(workspace_entity) = workspace.upgrade() {
|
||||
let fs = workspace_entity.read(cx).app_state().fs.clone();
|
||||
let weak_project = workspace_entity.read(cx).project().clone().downgrade();
|
||||
let context_store = cx.new(|_cx| ContextStore::new(weak_project, None));
|
||||
let thread = active_thread.read(cx).thread().clone();
|
||||
|
||||
let example_message_editor = cx.new(|cx| {
|
||||
MessageEditor::new(
|
||||
fs,
|
||||
workspace,
|
||||
context_store,
|
||||
None,
|
||||
thread_store,
|
||||
thread,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
Some(
|
||||
v_flex()
|
||||
.gap_4()
|
||||
.children(vec![single_example(
|
||||
"Default",
|
||||
example_message_editor.clone().into_any_element(),
|
||||
)])
|
||||
.into_any_element(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
register_agent_preview!(MessageEditor);
|
||||
|
||||
@@ -301,6 +301,14 @@ pub enum TokenUsageRatio {
|
||||
Exceeded,
|
||||
}
|
||||
|
||||
fn default_completion_mode(cx: &App) -> CompletionMode {
|
||||
if cx.is_staff() {
|
||||
CompletionMode::Max
|
||||
} else {
|
||||
CompletionMode::Normal
|
||||
}
|
||||
}
|
||||
|
||||
/// A thread of conversation with the LLM.
|
||||
pub struct Thread {
|
||||
id: ThreadId,
|
||||
@@ -310,7 +318,7 @@ pub struct Thread {
|
||||
detailed_summary_task: Task<Option<()>>,
|
||||
detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
|
||||
detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
|
||||
completion_mode: Option<CompletionMode>,
|
||||
completion_mode: CompletionMode,
|
||||
messages: Vec<Message>,
|
||||
next_message_id: MessageId,
|
||||
last_prompt_id: PromptId,
|
||||
@@ -366,7 +374,7 @@ impl Thread {
|
||||
detailed_summary_task: Task::ready(None),
|
||||
detailed_summary_tx,
|
||||
detailed_summary_rx,
|
||||
completion_mode: None,
|
||||
completion_mode: default_completion_mode(cx),
|
||||
messages: Vec::new(),
|
||||
next_message_id: MessageId(0),
|
||||
last_prompt_id: PromptId::new(),
|
||||
@@ -440,7 +448,7 @@ impl Thread {
|
||||
detailed_summary_task: Task::ready(None),
|
||||
detailed_summary_tx,
|
||||
detailed_summary_rx,
|
||||
completion_mode: None,
|
||||
completion_mode: default_completion_mode(cx),
|
||||
messages: serialized
|
||||
.messages
|
||||
.into_iter()
|
||||
@@ -569,11 +577,11 @@ impl Thread {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn completion_mode(&self) -> Option<CompletionMode> {
|
||||
pub fn completion_mode(&self) -> CompletionMode {
|
||||
self.completion_mode
|
||||
}
|
||||
|
||||
pub fn set_completion_mode(&mut self, mode: Option<CompletionMode>) {
|
||||
pub fn set_completion_mode(&mut self, mode: CompletionMode) {
|
||||
self.completion_mode = mode;
|
||||
}
|
||||
|
||||
@@ -879,6 +887,7 @@ impl Thread {
|
||||
id: MessageId,
|
||||
new_role: Role,
|
||||
new_segments: Vec<MessageSegment>,
|
||||
loaded_context: Option<LoadedContext>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> bool {
|
||||
let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
|
||||
@@ -886,6 +895,9 @@ impl Thread {
|
||||
};
|
||||
message.role = new_role;
|
||||
message.segments = new_segments;
|
||||
if let Some(context) = loaded_context {
|
||||
message.loaded_context = context;
|
||||
}
|
||||
self.touch_updated_at();
|
||||
cx.emit(ThreadEvent::MessageEdited(id));
|
||||
true
|
||||
@@ -1148,9 +1160,9 @@ impl Thread {
|
||||
|
||||
request.tools = available_tools;
|
||||
request.mode = if model.supports_max_mode() {
|
||||
self.completion_mode
|
||||
Some(self.completion_mode)
|
||||
} else {
|
||||
None
|
||||
Some(CompletionMode::Normal)
|
||||
};
|
||||
|
||||
request
|
||||
@@ -2106,7 +2118,7 @@ impl Thread {
|
||||
.map(|repo| {
|
||||
repo.update(cx, |repo, _| {
|
||||
let current_branch =
|
||||
repo.branch.as_ref().map(|branch| branch.name.to_string());
|
||||
repo.branch.as_ref().map(|branch| branch.name().to_owned());
|
||||
repo.send_job(None, |state, _| async move {
|
||||
let RepositoryState::Local { backend, .. } = state else {
|
||||
return GitState {
|
||||
@@ -2505,7 +2517,7 @@ mod tests {
|
||||
let expected_context = format!(
|
||||
r#"
|
||||
<context>
|
||||
The following items were attached by the user. You don't need to use other tools to read them.
|
||||
The following items were attached by the user. They are up-to-date and don't need to be re-read.
|
||||
|
||||
<files>
|
||||
```rs {path_part}
|
||||
@@ -2546,6 +2558,7 @@ fn main() {{
|
||||
"file1.rs": "fn function1() {}\n",
|
||||
"file2.rs": "fn function2() {}\n",
|
||||
"file3.rs": "fn function3() {}\n",
|
||||
"file4.rs": "fn function4() {}\n",
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
@@ -2558,7 +2571,7 @@ fn main() {{
|
||||
.await
|
||||
.unwrap();
|
||||
let new_contexts = context_store.update(cx, |store, cx| {
|
||||
store.new_context_for_thread(thread.read(cx))
|
||||
store.new_context_for_thread(thread.read(cx), None)
|
||||
});
|
||||
assert_eq!(new_contexts.len(), 1);
|
||||
let loaded_context = cx
|
||||
@@ -2573,7 +2586,7 @@ fn main() {{
|
||||
.await
|
||||
.unwrap();
|
||||
let new_contexts = context_store.update(cx, |store, cx| {
|
||||
store.new_context_for_thread(thread.read(cx))
|
||||
store.new_context_for_thread(thread.read(cx), None)
|
||||
});
|
||||
assert_eq!(new_contexts.len(), 1);
|
||||
let loaded_context = cx
|
||||
@@ -2589,7 +2602,7 @@ fn main() {{
|
||||
.await
|
||||
.unwrap();
|
||||
let new_contexts = context_store.update(cx, |store, cx| {
|
||||
store.new_context_for_thread(thread.read(cx))
|
||||
store.new_context_for_thread(thread.read(cx), None)
|
||||
});
|
||||
assert_eq!(new_contexts.len(), 1);
|
||||
let loaded_context = cx
|
||||
@@ -2640,6 +2653,55 @@ fn main() {{
|
||||
assert!(!request.messages[3].string_contents().contains("file1.rs"));
|
||||
assert!(!request.messages[3].string_contents().contains("file2.rs"));
|
||||
assert!(request.messages[3].string_contents().contains("file3.rs"));
|
||||
|
||||
add_file_to_context(&project, &context_store, "test/file4.rs", cx)
|
||||
.await
|
||||
.unwrap();
|
||||
let new_contexts = context_store.update(cx, |store, cx| {
|
||||
store.new_context_for_thread(thread.read(cx), Some(message2_id))
|
||||
});
|
||||
assert_eq!(new_contexts.len(), 3);
|
||||
let loaded_context = cx
|
||||
.update(|cx| load_context(new_contexts, &project, &None, cx))
|
||||
.await
|
||||
.loaded_context;
|
||||
|
||||
assert!(!loaded_context.text.contains("file1.rs"));
|
||||
assert!(loaded_context.text.contains("file2.rs"));
|
||||
assert!(loaded_context.text.contains("file3.rs"));
|
||||
assert!(loaded_context.text.contains("file4.rs"));
|
||||
|
||||
let new_contexts = context_store.update(cx, |store, cx| {
|
||||
// Remove file4.rs
|
||||
store.remove_context(&loaded_context.contexts[2].handle(), cx);
|
||||
store.new_context_for_thread(thread.read(cx), Some(message2_id))
|
||||
});
|
||||
assert_eq!(new_contexts.len(), 2);
|
||||
let loaded_context = cx
|
||||
.update(|cx| load_context(new_contexts, &project, &None, cx))
|
||||
.await
|
||||
.loaded_context;
|
||||
|
||||
assert!(!loaded_context.text.contains("file1.rs"));
|
||||
assert!(loaded_context.text.contains("file2.rs"));
|
||||
assert!(loaded_context.text.contains("file3.rs"));
|
||||
assert!(!loaded_context.text.contains("file4.rs"));
|
||||
|
||||
let new_contexts = context_store.update(cx, |store, cx| {
|
||||
// Remove file3.rs
|
||||
store.remove_context(&loaded_context.contexts[1].handle(), cx);
|
||||
store.new_context_for_thread(thread.read(cx), Some(message2_id))
|
||||
});
|
||||
assert_eq!(new_contexts.len(), 1);
|
||||
let loaded_context = cx
|
||||
.update(|cx| load_context(new_contexts, &project, &None, cx))
|
||||
.await
|
||||
.loaded_context;
|
||||
|
||||
assert!(!loaded_context.text.contains("file1.rs"));
|
||||
assert!(loaded_context.text.contains("file2.rs"));
|
||||
assert!(!loaded_context.text.contains("file3.rs"));
|
||||
assert!(!loaded_context.text.contains("file4.rs"));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
|
||||
@@ -9,8 +9,8 @@ use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
|
||||
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::HashMap;
|
||||
use context_server::manager::ContextServerManager;
|
||||
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
|
||||
use context_server::manager::{ContextServerManager, ContextServerStatus};
|
||||
use context_server::{ContextServerDescriptorRegistry, ContextServerTool};
|
||||
use futures::channel::{mpsc, oneshot};
|
||||
use futures::future::{self, BoxFuture, Shared};
|
||||
use futures::{FutureExt as _, StreamExt as _};
|
||||
@@ -108,7 +108,7 @@ impl ThreadStore {
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> (Self, oneshot::Receiver<()>) {
|
||||
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
|
||||
let context_server_factory_registry = ContextServerDescriptorRegistry::default_global(cx);
|
||||
let context_server_manager = cx.new(|cx| {
|
||||
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
|
||||
});
|
||||
@@ -555,62 +555,68 @@ impl ThreadStore {
|
||||
) {
|
||||
let tool_working_set = self.tools.clone();
|
||||
match event {
|
||||
context_server::manager::Event::ServerStarted { server_id } => {
|
||||
if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
|
||||
let context_server_manager = context_server_manager.clone();
|
||||
cx.spawn({
|
||||
let server = server.clone();
|
||||
let server_id = server_id.clone();
|
||||
async move |this, cx| {
|
||||
let Some(protocol) = server.client() else {
|
||||
return;
|
||||
};
|
||||
context_server::manager::Event::ServerStatusChanged { server_id, status } => {
|
||||
match status {
|
||||
Some(ContextServerStatus::Running) => {
|
||||
if let Some(server) = context_server_manager.read(cx).get_server(server_id)
|
||||
{
|
||||
let context_server_manager = context_server_manager.clone();
|
||||
cx.spawn({
|
||||
let server = server.clone();
|
||||
let server_id = server_id.clone();
|
||||
async move |this, cx| {
|
||||
let Some(protocol) = server.client() else {
|
||||
return;
|
||||
};
|
||||
|
||||
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
|
||||
if let Some(tools) = protocol.list_tools().await.log_err() {
|
||||
let tool_ids = tool_working_set
|
||||
.update(cx, |tool_working_set, _| {
|
||||
tools
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| {
|
||||
log::info!(
|
||||
"registering context server tool: {:?}",
|
||||
tool.name
|
||||
);
|
||||
tool_working_set.insert(Arc::new(
|
||||
ContextServerTool::new(
|
||||
context_server_manager.clone(),
|
||||
server.id(),
|
||||
tool,
|
||||
),
|
||||
))
|
||||
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
|
||||
if let Some(tools) = protocol.list_tools().await.log_err() {
|
||||
let tool_ids = tool_working_set
|
||||
.update(cx, |tool_working_set, _| {
|
||||
tools
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| {
|
||||
log::info!(
|
||||
"registering context server tool: {:?}",
|
||||
tool.name
|
||||
);
|
||||
tool_working_set.insert(Arc::new(
|
||||
ContextServerTool::new(
|
||||
context_server_manager.clone(),
|
||||
server.id(),
|
||||
tool,
|
||||
),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.log_err();
|
||||
.log_err();
|
||||
|
||||
if let Some(tool_ids) = tool_ids {
|
||||
this.update(cx, |this, cx| {
|
||||
this.context_server_tool_ids
|
||||
.insert(server_id, tool_ids);
|
||||
this.load_default_profile(cx);
|
||||
})
|
||||
.log_err();
|
||||
if let Some(tool_ids) = tool_ids {
|
||||
this.update(cx, |this, cx| {
|
||||
this.context_server_tool_ids
|
||||
.insert(server_id, tool_ids);
|
||||
this.load_default_profile(cx);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
context_server::manager::Event::ServerStopped { server_id } => {
|
||||
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
|
||||
tool_working_set.update(cx, |tool_working_set, _| {
|
||||
tool_working_set.remove(&tool_ids);
|
||||
});
|
||||
self.load_default_profile(cx);
|
||||
}
|
||||
None => {
|
||||
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
|
||||
tool_working_set.update(cx, |tool_working_set, _| {
|
||||
tool_working_set.remove(&tool_ids);
|
||||
});
|
||||
self.load_default_profile(cx);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
mod agent_notification;
|
||||
pub mod agent_preview;
|
||||
mod animated_label;
|
||||
mod context_pill;
|
||||
mod upsell;
|
||||
mod usage_banner;
|
||||
|
||||
pub use agent_notification::*;
|
||||
pub use agent_preview::*;
|
||||
pub use animated_label::*;
|
||||
pub use context_pill::*;
|
||||
pub use usage_banner::*;
|
||||
|
||||
99
crates/agent/src/ui/agent_preview.rs
Normal file
99
crates/agent/src/ui/agent_preview.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
use collections::HashMap;
|
||||
use component::ComponentId;
|
||||
use gpui::{App, Entity, WeakEntity};
|
||||
use linkme::distributed_slice;
|
||||
use std::sync::OnceLock;
|
||||
use ui::{AnyElement, Component, Window};
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::{ActiveThread, ThreadStore};
|
||||
|
||||
/// Function type for creating agent component previews
|
||||
pub type PreviewFn = fn(
|
||||
WeakEntity<Workspace>,
|
||||
Entity<ActiveThread>,
|
||||
WeakEntity<ThreadStore>,
|
||||
&mut Window,
|
||||
&mut App,
|
||||
) -> Option<AnyElement>;
|
||||
|
||||
/// Distributed slice for preview registration functions
|
||||
#[distributed_slice]
|
||||
pub static __ALL_AGENT_PREVIEWS: [fn() -> (ComponentId, PreviewFn)] = [..];
|
||||
|
||||
/// Trait that must be implemented by components that provide agent previews.
|
||||
pub trait AgentPreview: Component {
|
||||
/// Get the ID for this component
|
||||
///
|
||||
/// Eventually this will move to the component trait.
|
||||
fn id() -> ComponentId
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
ComponentId(Self::name())
|
||||
}
|
||||
|
||||
/// Static method to create a preview for this component type
|
||||
fn create_preview(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
active_thread: Entity<ActiveThread>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Option<AnyElement>
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
/// Register an agent preview for the given component type
|
||||
#[macro_export]
|
||||
macro_rules! register_agent_preview {
|
||||
($type:ty) => {
|
||||
#[linkme::distributed_slice($crate::ui::agent_preview::__ALL_AGENT_PREVIEWS)]
|
||||
static __REGISTER_AGENT_PREVIEW: fn() -> (
|
||||
component::ComponentId,
|
||||
$crate::ui::agent_preview::PreviewFn,
|
||||
) = || {
|
||||
(
|
||||
<$type as $crate::ui::agent_preview::AgentPreview>::id(),
|
||||
<$type as $crate::ui::agent_preview::AgentPreview>::create_preview,
|
||||
)
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
/// Lazy initialized registry of preview functions
|
||||
static AGENT_PREVIEW_REGISTRY: OnceLock<HashMap<ComponentId, PreviewFn>> = OnceLock::new();
|
||||
|
||||
/// Initialize the agent preview registry if needed
|
||||
fn get_or_init_registry() -> &'static HashMap<ComponentId, PreviewFn> {
|
||||
AGENT_PREVIEW_REGISTRY.get_or_init(|| {
|
||||
let mut map = HashMap::default();
|
||||
for register_fn in __ALL_AGENT_PREVIEWS.iter() {
|
||||
let (id, preview_fn) = register_fn();
|
||||
map.insert(id, preview_fn);
|
||||
}
|
||||
map
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a specific agent preview by component ID.
|
||||
pub fn get_agent_preview(
|
||||
id: &ComponentId,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
active_thread: Entity<ActiveThread>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Option<AnyElement> {
|
||||
let registry = get_or_init_registry();
|
||||
registry
|
||||
.get(id)
|
||||
.and_then(|preview_fn| preview_fn(workspace, active_thread, thread_store, window, cx))
|
||||
}
|
||||
|
||||
/// Get all registered agent previews.
|
||||
pub fn all_agent_previews() -> Vec<ComponentId> {
|
||||
let registry = get_or_init_registry();
|
||||
registry.keys().cloned().collect()
|
||||
}
|
||||
@@ -216,9 +216,10 @@ impl RenderOnce for ContextPill {
|
||||
})
|
||||
.when_some(on_click.as_ref(), |element, on_click| {
|
||||
let on_click = on_click.clone();
|
||||
element
|
||||
.cursor_pointer()
|
||||
.on_click(move |event, window, cx| on_click(event, window, cx))
|
||||
element.cursor_pointer().on_click(move |event, window, cx| {
|
||||
on_click(event, window, cx);
|
||||
cx.stop_propagation();
|
||||
})
|
||||
})
|
||||
.into_any_element()
|
||||
}
|
||||
@@ -254,7 +255,10 @@ impl RenderOnce for ContextPill {
|
||||
})
|
||||
.when_some(on_click.as_ref(), |element, on_click| {
|
||||
let on_click = on_click.clone();
|
||||
element.on_click(move |event, window, cx| on_click(event, window, cx))
|
||||
element.on_click(move |event, window, cx| {
|
||||
on_click(event, window, cx);
|
||||
cx.stop_propagation();
|
||||
})
|
||||
})
|
||||
.into_any(),
|
||||
}
|
||||
|
||||
163
crates/agent/src/ui/upsell.rs
Normal file
163
crates/agent/src/ui/upsell.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
use component::{Component, ComponentScope, single_example};
|
||||
use gpui::{
|
||||
AnyElement, App, ClickEvent, IntoElement, ParentElement, RenderOnce, SharedString, Styled,
|
||||
Window,
|
||||
};
|
||||
use theme::ActiveTheme;
|
||||
use ui::{
|
||||
Button, ButtonCommon, ButtonStyle, Checkbox, Clickable, Color, Label, LabelCommon,
|
||||
RegisterComponent, ToggleState, h_flex, v_flex,
|
||||
};
|
||||
|
||||
/// A component that displays an upsell message with a call-to-action button
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// let upsell = Upsell::new(
|
||||
/// "Upgrade to Zed Pro",
|
||||
/// "Get unlimited access to AI features and more",
|
||||
/// "Upgrade Now",
|
||||
/// Box::new(|_, _window, cx| {
|
||||
/// cx.open_url("https://zed.dev/pricing");
|
||||
/// }),
|
||||
/// Box::new(|_, _window, cx| {
|
||||
/// // Handle dismiss
|
||||
/// }),
|
||||
/// Box::new(|checked, window, cx| {
|
||||
/// // Handle don't show again
|
||||
/// }),
|
||||
/// );
|
||||
/// ```
|
||||
#[derive(IntoElement, RegisterComponent)]
|
||||
pub struct Upsell {
|
||||
title: SharedString,
|
||||
message: SharedString,
|
||||
cta_text: SharedString,
|
||||
on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App)>,
|
||||
on_dismiss: Box<dyn Fn(&ClickEvent, &mut Window, &mut App)>,
|
||||
on_dont_show_again: Box<dyn Fn(bool, &mut Window, &mut App)>,
|
||||
}
|
||||
|
||||
impl Upsell {
|
||||
/// Create a new upsell component
|
||||
pub fn new(
|
||||
title: impl Into<SharedString>,
|
||||
message: impl Into<SharedString>,
|
||||
cta_text: impl Into<SharedString>,
|
||||
on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App)>,
|
||||
on_dismiss: Box<dyn Fn(&ClickEvent, &mut Window, &mut App)>,
|
||||
on_dont_show_again: Box<dyn Fn(bool, &mut Window, &mut App)>,
|
||||
) -> Self {
|
||||
Self {
|
||||
title: title.into(),
|
||||
message: message.into(),
|
||||
cta_text: cta_text.into(),
|
||||
on_click,
|
||||
on_dismiss,
|
||||
on_dont_show_again,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for Upsell {
|
||||
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
v_flex()
|
||||
.w_full()
|
||||
.p_4()
|
||||
.gap_3()
|
||||
.bg(cx.theme().colors().surface_background)
|
||||
.rounded_md()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
Label::new(self.title)
|
||||
.size(ui::LabelSize::Large)
|
||||
.weight(gpui::FontWeight::BOLD),
|
||||
)
|
||||
.child(Label::new(self.message).color(Color::Muted)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.justify_between()
|
||||
.items_center()
|
||||
.child(
|
||||
h_flex()
|
||||
.items_center()
|
||||
.gap_1()
|
||||
.child(
|
||||
Checkbox::new("dont-show-again", ToggleState::Unselected).on_click(
|
||||
move |_, window, cx| {
|
||||
(self.on_dont_show_again)(true, window, cx);
|
||||
},
|
||||
),
|
||||
)
|
||||
.child(
|
||||
Label::new("Don't show again")
|
||||
.color(Color::Muted)
|
||||
.size(ui::LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Button::new("dismiss-button", "Dismiss")
|
||||
.style(ButtonStyle::Subtle)
|
||||
.on_click(self.on_dismiss),
|
||||
)
|
||||
.child(
|
||||
Button::new("cta-button", self.cta_text)
|
||||
.style(ButtonStyle::Filled)
|
||||
.on_click(self.on_click),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Component for Upsell {
|
||||
fn scope() -> ComponentScope {
|
||||
ComponentScope::Agent
|
||||
}
|
||||
|
||||
fn name() -> &'static str {
|
||||
"Upsell"
|
||||
}
|
||||
|
||||
fn description() -> Option<&'static str> {
|
||||
Some("A promotional component that displays a message with a call-to-action.")
|
||||
}
|
||||
|
||||
fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
|
||||
let examples = vec![
|
||||
single_example(
|
||||
"Default",
|
||||
Upsell::new(
|
||||
"Upgrade to Zed Pro",
|
||||
"Get unlimited access to AI features and more with Zed Pro. Unlock advanced AI capabilities and other premium features.",
|
||||
"Upgrade Now",
|
||||
Box::new(|_, _, _| {}),
|
||||
Box::new(|_, _, _| {}),
|
||||
Box::new(|_, _, _| {}),
|
||||
).render(window, cx).into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Short Message",
|
||||
Upsell::new(
|
||||
"Try Zed Pro for free",
|
||||
"Start your 7-day trial today.",
|
||||
"Start Trial",
|
||||
Box::new(|_, _, _| {}),
|
||||
Box::new(|_, _, _| {}),
|
||||
Box::new(|_, _, _| {}),
|
||||
).render(window, cx).into_any_element(),
|
||||
),
|
||||
];
|
||||
|
||||
Some(v_flex().gap_4().children(examples).into_any_element())
|
||||
}
|
||||
}
|
||||
@@ -98,6 +98,10 @@ impl RenderOnce for UsageBanner {
|
||||
}
|
||||
|
||||
impl Component for UsageBanner {
|
||||
fn scope() -> ComponentScope {
|
||||
ComponentScope::Agent
|
||||
}
|
||||
|
||||
fn sort_name() -> &'static str {
|
||||
"AgentUsageBanner"
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@ use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
|
||||
use client::{Client, TypedEnvelope, proto, telemetry::Telemetry};
|
||||
use clock::ReplicaId;
|
||||
use collections::HashMap;
|
||||
use context_server::ContextServerFactoryRegistry;
|
||||
use context_server::manager::ContextServerManager;
|
||||
use context_server::ContextServerDescriptorRegistry;
|
||||
use context_server::manager::{ContextServerManager, ContextServerStatus};
|
||||
use fs::{Fs, RemoveOptions};
|
||||
use futures::StreamExt;
|
||||
use fuzzy::StringMatchCandidate;
|
||||
@@ -99,7 +99,7 @@ impl ContextStore {
|
||||
|
||||
let this = cx.new(|cx: &mut Context<Self>| {
|
||||
let context_server_factory_registry =
|
||||
ContextServerFactoryRegistry::default_global(cx);
|
||||
ContextServerDescriptorRegistry::default_global(cx);
|
||||
let context_server_manager = cx.new(|cx| {
|
||||
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
|
||||
});
|
||||
@@ -831,54 +831,60 @@ impl ContextStore {
|
||||
) {
|
||||
let slash_command_working_set = self.slash_commands.clone();
|
||||
match event {
|
||||
context_server::manager::Event::ServerStarted { server_id } => {
|
||||
if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
|
||||
let context_server_manager = context_server_manager.clone();
|
||||
cx.spawn({
|
||||
let server = server.clone();
|
||||
let server_id = server_id.clone();
|
||||
async move |this, cx| {
|
||||
let Some(protocol) = server.client() else {
|
||||
return;
|
||||
};
|
||||
context_server::manager::Event::ServerStatusChanged { server_id, status } => {
|
||||
match status {
|
||||
Some(ContextServerStatus::Running) => {
|
||||
if let Some(server) = context_server_manager.read(cx).get_server(server_id)
|
||||
{
|
||||
let context_server_manager = context_server_manager.clone();
|
||||
cx.spawn({
|
||||
let server = server.clone();
|
||||
let server_id = server_id.clone();
|
||||
async move |this, cx| {
|
||||
let Some(protocol) = server.client() else {
|
||||
return;
|
||||
};
|
||||
|
||||
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
|
||||
if let Some(prompts) = protocol.list_prompts().await.log_err() {
|
||||
let slash_command_ids = prompts
|
||||
.into_iter()
|
||||
.filter(assistant_slash_commands::acceptable_prompt)
|
||||
.map(|prompt| {
|
||||
log::info!(
|
||||
"registering context server command: {:?}",
|
||||
prompt.name
|
||||
);
|
||||
slash_command_working_set.insert(Arc::new(
|
||||
assistant_slash_commands::ContextServerSlashCommand::new(
|
||||
context_server_manager.clone(),
|
||||
&server,
|
||||
prompt,
|
||||
),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
|
||||
if let Some(prompts) = protocol.list_prompts().await.log_err() {
|
||||
let slash_command_ids = prompts
|
||||
.into_iter()
|
||||
.filter(assistant_slash_commands::acceptable_prompt)
|
||||
.map(|prompt| {
|
||||
log::info!(
|
||||
"registering context server command: {:?}",
|
||||
prompt.name
|
||||
);
|
||||
slash_command_working_set.insert(Arc::new(
|
||||
assistant_slash_commands::ContextServerSlashCommand::new(
|
||||
context_server_manager.clone(),
|
||||
&server,
|
||||
prompt,
|
||||
),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
this.update( cx, |this, _cx| {
|
||||
this.context_server_slash_command_ids
|
||||
.insert(server_id.clone(), slash_command_ids);
|
||||
})
|
||||
.log_err();
|
||||
this.update( cx, |this, _cx| {
|
||||
this.context_server_slash_command_ids
|
||||
.insert(server_id.clone(), slash_command_ids);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
context_server::manager::Event::ServerStopped { server_id } => {
|
||||
if let Some(slash_command_ids) =
|
||||
self.context_server_slash_command_ids.remove(server_id)
|
||||
{
|
||||
slash_command_working_set.remove(&slash_command_ids);
|
||||
}
|
||||
None => {
|
||||
if let Some(slash_command_ids) =
|
||||
self.context_server_slash_command_ids.remove(server_id)
|
||||
{
|
||||
slash_command_working_set.remove(&slash_command_ids);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use ::open_ai::Model as OpenAiModel;
|
||||
use anthropic::Model as AnthropicModel;
|
||||
use anyhow::{Result, bail};
|
||||
use deepseek::Model as DeepseekModel;
|
||||
use feature_flags::{Assistant2FeatureFlag, FeatureFlagAppExt};
|
||||
use feature_flags::{AgentStreamEditsFeatureFlag, Assistant2FeatureFlag, FeatureFlagAppExt};
|
||||
use gpui::{App, Pixels};
|
||||
use indexmap::IndexMap;
|
||||
use language_model::{CloudModel, LanguageModel};
|
||||
@@ -87,9 +87,14 @@ pub struct AssistantSettings {
|
||||
pub profiles: IndexMap<AgentProfileId, AgentProfile>,
|
||||
pub always_allow_tool_actions: bool,
|
||||
pub notify_when_agent_waiting: NotifyWhenAgentWaiting,
|
||||
pub stream_edits: bool,
|
||||
}
|
||||
|
||||
impl AssistantSettings {
|
||||
pub fn stream_edits(&self, cx: &App) -> bool {
|
||||
cx.has_flag::<AgentStreamEditsFeatureFlag>() || self.stream_edits
|
||||
}
|
||||
|
||||
pub fn are_live_diffs_enabled(&self, cx: &App) -> bool {
|
||||
if cx.has_flag::<Assistant2FeatureFlag>() {
|
||||
return false;
|
||||
@@ -218,6 +223,7 @@ impl AssistantSettingsContent {
|
||||
profiles: None,
|
||||
always_allow_tool_actions: None,
|
||||
notify_when_agent_waiting: None,
|
||||
stream_edits: None,
|
||||
},
|
||||
VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(),
|
||||
},
|
||||
@@ -245,6 +251,7 @@ impl AssistantSettingsContent {
|
||||
profiles: None,
|
||||
always_allow_tool_actions: None,
|
||||
notify_when_agent_waiting: None,
|
||||
stream_edits: None,
|
||||
},
|
||||
None => AssistantSettingsContentV2::default(),
|
||||
}
|
||||
@@ -495,6 +502,7 @@ impl Default for VersionedAssistantSettingsContent {
|
||||
profiles: None,
|
||||
always_allow_tool_actions: None,
|
||||
notify_when_agent_waiting: None,
|
||||
stream_edits: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -550,6 +558,10 @@ pub struct AssistantSettingsContentV2 {
|
||||
///
|
||||
/// Default: "primary_screen"
|
||||
notify_when_agent_waiting: Option<NotifyWhenAgentWaiting>,
|
||||
/// Whether to stream edits from the agent as they are received.
|
||||
///
|
||||
/// Default: false
|
||||
stream_edits: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||
@@ -712,6 +724,7 @@ impl Settings for AssistantSettings {
|
||||
&mut settings.notify_when_agent_waiting,
|
||||
value.notify_when_agent_waiting,
|
||||
);
|
||||
merge(&mut settings.stream_edits, value.stream_edits);
|
||||
merge(&mut settings.default_profile, value.default_profile);
|
||||
|
||||
if let Some(profiles) = value.profiles {
|
||||
@@ -843,6 +856,7 @@ mod tests {
|
||||
profiles: None,
|
||||
always_allow_tool_actions: None,
|
||||
notify_when_agent_waiting: None,
|
||||
stream_edits: None,
|
||||
},
|
||||
)),
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ language.workspace = true
|
||||
language_model.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
text.workspace = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
mod action_log;
|
||||
pub mod outline;
|
||||
mod tool_registry;
|
||||
mod tool_schema;
|
||||
mod tool_working_set;
|
||||
|
||||
132
crates/assistant_tool/src/outline.rs
Normal file
132
crates/assistant_tool/src/outline.rs
Normal file
@@ -0,0 +1,132 @@
|
||||
use crate::ActionLog;
|
||||
use anyhow::{Result, anyhow};
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use language::{OutlineItem, ParseStatus};
|
||||
use project::Project;
|
||||
use regex::Regex;
|
||||
use std::fmt::Write;
|
||||
use text::Point;
|
||||
|
||||
/// For files over this size, instead of reading them (or including them in context),
|
||||
/// we automatically provide the file's symbol outline instead, with line numbers.
|
||||
pub const AUTO_OUTLINE_SIZE: usize = 16384;
|
||||
|
||||
pub async fn file_outline(
|
||||
project: Entity<Project>,
|
||||
path: String,
|
||||
action_log: Entity<ActionLog>,
|
||||
regex: Option<Regex>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> anyhow::Result<String> {
|
||||
let buffer = {
|
||||
let project_path = project.read_with(cx, |project, cx| {
|
||||
project
|
||||
.find_project_path(&path, cx)
|
||||
.ok_or_else(|| anyhow!("Path {path} not found in project"))
|
||||
})??;
|
||||
|
||||
project
|
||||
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
|
||||
.await?
|
||||
};
|
||||
|
||||
action_log.update(cx, |action_log, cx| {
|
||||
action_log.track_buffer(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
// Wait until the buffer has been fully parsed, so that we can read its outline.
|
||||
let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
|
||||
while *parse_status.borrow() != ParseStatus::Idle {
|
||||
parse_status.changed().await?;
|
||||
}
|
||||
|
||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||
let Some(outline) = snapshot.outline(None) else {
|
||||
return Err(anyhow!("No outline information available for this file."));
|
||||
};
|
||||
|
||||
render_outline(
|
||||
outline
|
||||
.items
|
||||
.into_iter()
|
||||
.map(|item| item.to_point(&snapshot)),
|
||||
regex,
|
||||
0,
|
||||
usize::MAX,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn render_outline(
|
||||
items: impl IntoIterator<Item = OutlineItem<Point>>,
|
||||
regex: Option<Regex>,
|
||||
offset: usize,
|
||||
results_per_page: usize,
|
||||
) -> Result<String> {
|
||||
let mut items = items.into_iter().skip(offset);
|
||||
|
||||
let entries = items
|
||||
.by_ref()
|
||||
.filter(|item| {
|
||||
regex
|
||||
.as_ref()
|
||||
.is_none_or(|regex| regex.is_match(&item.text))
|
||||
})
|
||||
.take(results_per_page)
|
||||
.collect::<Vec<_>>();
|
||||
let has_more = items.next().is_some();
|
||||
|
||||
let mut output = String::new();
|
||||
let entries_rendered = render_entries(&mut output, entries);
|
||||
|
||||
// Calculate pagination information
|
||||
let page_start = offset + 1;
|
||||
let page_end = offset + entries_rendered;
|
||||
let total_symbols = if has_more {
|
||||
format!("more than {}", page_end)
|
||||
} else {
|
||||
page_end.to_string()
|
||||
};
|
||||
|
||||
// Add pagination information
|
||||
if has_more {
|
||||
writeln!(&mut output, "\nShowing symbols {page_start}-{page_end} (there were more symbols found; use offset: {page_end} to see next page)",
|
||||
)
|
||||
} else {
|
||||
writeln!(
|
||||
&mut output,
|
||||
"\nShowing symbols {page_start}-{page_end} (total symbols: {total_symbols})",
|
||||
)
|
||||
}
|
||||
.ok();
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn render_entries(
|
||||
output: &mut String,
|
||||
items: impl IntoIterator<Item = OutlineItem<Point>>,
|
||||
) -> usize {
|
||||
let mut entries_rendered = 0;
|
||||
|
||||
for item in items {
|
||||
// Indent based on depth ("" for level 0, " " for level 1, etc.)
|
||||
for _ in 0..item.depth {
|
||||
output.push(' ');
|
||||
}
|
||||
output.push_str(&item.text);
|
||||
|
||||
// Add position information - convert to 1-based line numbers for display
|
||||
let start_line = item.range.start.row + 1;
|
||||
let end_line = item.range.end.row + 1;
|
||||
|
||||
if start_line == end_line {
|
||||
writeln!(output, " [L{}]", start_line).ok();
|
||||
} else {
|
||||
writeln!(output, " [L{}-{}]", start_line, end_line).ok();
|
||||
}
|
||||
entries_rendered += 1;
|
||||
}
|
||||
|
||||
entries_rendered
|
||||
}
|
||||
@@ -11,16 +11,24 @@ workspace = true
|
||||
[lib]
|
||||
path = "src/assistant_tools.rs"
|
||||
|
||||
[features]
|
||||
eval = []
|
||||
|
||||
[dependencies]
|
||||
aho-corasick.workspace = true
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
assistant_settings.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
chrono.workspace = true
|
||||
collections.workspace = true
|
||||
component.workspace = true
|
||||
editor.workspace = true
|
||||
derive_more.workspace = true
|
||||
feature_flags.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
handlebars = { workspace = true, features = ["rust-embed"] }
|
||||
html_to_markdown.workspace = true
|
||||
http_client.workspace = true
|
||||
indoc.workspace = true
|
||||
@@ -31,9 +39,14 @@ linkme.workspace = true
|
||||
open.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
rust-embed.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smallvec.workspace = true
|
||||
streaming_diff.workspace = true
|
||||
strsim.workspace = true
|
||||
task.workspace = true
|
||||
terminal.workspace = true
|
||||
terminal_view.workspace = true
|
||||
@@ -49,12 +62,18 @@ client = { workspace = true, features = ["test-support"] }
|
||||
clock = { workspace = true, features = ["test-support"] }
|
||||
collections = { workspace = true, features = ["test-support"] }
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
gpui_tokio.workspace = true
|
||||
fs = { workspace = true, features = ["test-support"] }
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
language_model = { workspace = true, features = ["test-support"] }
|
||||
language_models.workspace = true
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
rand.workspace = true
|
||||
pretty_assertions.workspace = true
|
||||
reqwest_client.workspace = true
|
||||
settings = { workspace = true, features = ["test-support"] }
|
||||
task = { workspace = true, features = ["test-support"]}
|
||||
tempfile.workspace = true
|
||||
tree-sitter-rust.workspace = true
|
||||
workspace = { workspace = true, features = ["test-support"] }
|
||||
unindent.workspace = true
|
||||
|
||||
@@ -7,6 +7,7 @@ mod create_directory_tool;
|
||||
mod create_file_tool;
|
||||
mod delete_path_tool;
|
||||
mod diagnostics_tool;
|
||||
mod edit_agent;
|
||||
mod edit_file_tool;
|
||||
mod fetch_tool;
|
||||
mod find_path_tool;
|
||||
@@ -19,7 +20,9 @@ mod read_file_tool;
|
||||
mod rename_tool;
|
||||
mod replace;
|
||||
mod schema;
|
||||
mod streaming_edit_file_tool;
|
||||
mod symbol_info_tool;
|
||||
mod templates;
|
||||
mod terminal_tool;
|
||||
mod thinking_tool;
|
||||
mod ui;
|
||||
@@ -27,14 +30,19 @@ mod web_search_tool;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_tool::ToolRegistry;
|
||||
use copy_path_tool::CopyPathTool;
|
||||
use feature_flags::{AgentStreamEditsFeatureFlag, FeatureFlagAppExt};
|
||||
use gpui::App;
|
||||
use http_client::HttpClientWithUrl;
|
||||
use language_model::LanguageModelRegistry;
|
||||
use move_path_tool::MovePathTool;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use web_search_tool::WebSearchTool;
|
||||
|
||||
pub(crate) use templates::*;
|
||||
|
||||
use crate::batch_tool::BatchTool;
|
||||
use crate::code_action_tool::CodeActionTool;
|
||||
use crate::code_symbols_tool::CodeSymbolsTool;
|
||||
@@ -52,6 +60,7 @@ use crate::now_tool::NowTool;
|
||||
use crate::open_tool::OpenTool;
|
||||
use crate::read_file_tool::ReadFileTool;
|
||||
use crate::rename_tool::RenameTool;
|
||||
use crate::streaming_edit_file_tool::StreamingEditFileTool;
|
||||
use crate::symbol_info_tool::SymbolInfoTool;
|
||||
use crate::terminal_tool::TerminalTool;
|
||||
use crate::thinking_tool::ThinkingTool;
|
||||
@@ -68,10 +77,8 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
registry.register_tool(TerminalTool);
|
||||
registry.register_tool(BatchTool);
|
||||
registry.register_tool(CreateDirectoryTool);
|
||||
registry.register_tool(CreateFileTool);
|
||||
registry.register_tool(CopyPathTool);
|
||||
registry.register_tool(DeletePathTool);
|
||||
registry.register_tool(EditFileTool);
|
||||
registry.register_tool(SymbolInfoTool);
|
||||
registry.register_tool(CodeActionTool);
|
||||
registry.register_tool(MovePathTool);
|
||||
@@ -88,6 +95,12 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
registry.register_tool(ThinkingTool);
|
||||
registry.register_tool(FetchTool::new(http_client));
|
||||
|
||||
register_edit_file_tool(cx);
|
||||
cx.observe_flag::<AgentStreamEditsFeatureFlag, _>(|_, cx| register_edit_file_tool(cx))
|
||||
.detach();
|
||||
cx.observe_global::<SettingsStore>(register_edit_file_tool)
|
||||
.detach();
|
||||
|
||||
cx.subscribe(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
move |registry, event, cx| match event {
|
||||
@@ -108,6 +121,21 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn register_edit_file_tool(cx: &mut App) {
|
||||
let registry = ToolRegistry::global(cx);
|
||||
|
||||
registry.unregister_tool(CreateFileTool);
|
||||
registry.unregister_tool(EditFileTool);
|
||||
registry.unregister_tool(StreamingEditFileTool);
|
||||
|
||||
if AssistantSettings::get_global(cx).stream_edits(cx) {
|
||||
registry.register_tool(StreamingEditFileTool);
|
||||
} else {
|
||||
registry.register_tool(CreateFileTool);
|
||||
registry.register_tool(EditFileTool);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -146,6 +174,7 @@ mod tests {
|
||||
#[gpui::test]
|
||||
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
|
||||
settings::init(cx);
|
||||
AssistantSettings::register(cx);
|
||||
|
||||
let client = Client::new(
|
||||
Arc::new(FakeSystemClock::new()),
|
||||
|
||||
@@ -4,10 +4,10 @@ use std::sync::Arc;
|
||||
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::outline;
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use collections::IndexMap;
|
||||
use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
|
||||
use language::{OutlineItem, ParseStatus, Point};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::{Project, Symbol};
|
||||
use regex::{Regex, RegexBuilder};
|
||||
@@ -148,59 +148,13 @@ impl Tool for CodeSymbolsTool {
|
||||
};
|
||||
|
||||
cx.spawn(async move |cx| match input.path {
|
||||
Some(path) => file_outline(project, path, action_log, regex, cx).await,
|
||||
Some(path) => outline::file_outline(project, path, action_log, regex, cx).await,
|
||||
None => project_symbols(project, regex, input.offset, cx).await,
|
||||
})
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn file_outline(
|
||||
project: Entity<Project>,
|
||||
path: String,
|
||||
action_log: Entity<ActionLog>,
|
||||
regex: Option<Regex>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> anyhow::Result<String> {
|
||||
let buffer = {
|
||||
let project_path = project.read_with(cx, |project, cx| {
|
||||
project
|
||||
.find_project_path(&path, cx)
|
||||
.ok_or_else(|| anyhow!("Path {path} not found in project"))
|
||||
})??;
|
||||
|
||||
project
|
||||
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
|
||||
.await?
|
||||
};
|
||||
|
||||
action_log.update(cx, |action_log, cx| {
|
||||
action_log.track_buffer(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
// Wait until the buffer has been fully parsed, so that we can read its outline.
|
||||
let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
|
||||
while *parse_status.borrow() != ParseStatus::Idle {
|
||||
parse_status.changed().await?;
|
||||
}
|
||||
|
||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||
let Some(outline) = snapshot.outline(None) else {
|
||||
return Err(anyhow!("No outline information available for this file."));
|
||||
};
|
||||
|
||||
render_outline(
|
||||
outline
|
||||
.items
|
||||
.into_iter()
|
||||
.map(|item| item.to_point(&snapshot)),
|
||||
regex,
|
||||
0,
|
||||
usize::MAX,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn project_symbols(
|
||||
project: Entity<Project>,
|
||||
regex: Option<Regex>,
|
||||
@@ -291,77 +245,3 @@ async fn project_symbols(
|
||||
output
|
||||
})
|
||||
}
|
||||
|
||||
async fn render_outline(
|
||||
items: impl IntoIterator<Item = OutlineItem<Point>>,
|
||||
regex: Option<Regex>,
|
||||
offset: usize,
|
||||
results_per_page: usize,
|
||||
) -> Result<String> {
|
||||
let mut items = items.into_iter().skip(offset);
|
||||
|
||||
let entries = items
|
||||
.by_ref()
|
||||
.filter(|item| {
|
||||
regex
|
||||
.as_ref()
|
||||
.is_none_or(|regex| regex.is_match(&item.text))
|
||||
})
|
||||
.take(results_per_page)
|
||||
.collect::<Vec<_>>();
|
||||
let has_more = items.next().is_some();
|
||||
|
||||
let mut output = String::new();
|
||||
let entries_rendered = render_entries(&mut output, entries);
|
||||
|
||||
// Calculate pagination information
|
||||
let page_start = offset + 1;
|
||||
let page_end = offset + entries_rendered;
|
||||
let total_symbols = if has_more {
|
||||
format!("more than {}", page_end)
|
||||
} else {
|
||||
page_end.to_string()
|
||||
};
|
||||
|
||||
// Add pagination information
|
||||
if has_more {
|
||||
writeln!(&mut output, "\nShowing symbols {page_start}-{page_end} (there were more symbols found; use offset: {page_end} to see next page)",
|
||||
)
|
||||
} else {
|
||||
writeln!(
|
||||
&mut output,
|
||||
"\nShowing symbols {page_start}-{page_end} (total symbols: {total_symbols})",
|
||||
)
|
||||
}
|
||||
.ok();
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn render_entries(
|
||||
output: &mut String,
|
||||
items: impl IntoIterator<Item = OutlineItem<Point>>,
|
||||
) -> usize {
|
||||
let mut entries_rendered = 0;
|
||||
|
||||
for item in items {
|
||||
// Indent based on depth ("" for level 0, " " for level 1, etc.)
|
||||
for _ in 0..item.depth {
|
||||
output.push(' ');
|
||||
}
|
||||
output.push_str(&item.text);
|
||||
|
||||
// Add position information - convert to 1-based line numbers for display
|
||||
let start_line = item.range.start.row + 1;
|
||||
let end_line = item.range.end.row + 1;
|
||||
|
||||
if start_line == end_line {
|
||||
writeln!(output, " [L{}]", start_line).ok();
|
||||
} else {
|
||||
writeln!(output, " [L{}-{}]", start_line, end_line).ok();
|
||||
}
|
||||
entries_rendered += 1;
|
||||
}
|
||||
|
||||
entries_rendered
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult, outline};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use itertools::Itertools;
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
@@ -14,10 +14,6 @@ use ui::IconName;
|
||||
use util::markdown::MarkdownInlineCode;
|
||||
|
||||
/// If the model requests to read a file whose size exceeds this, then
|
||||
/// the tool will return the file's symbol outline instead of its contents,
|
||||
/// and suggest trying again using line ranges from the outline.
|
||||
const MAX_FILE_SIZE_TO_READ: usize = 16384;
|
||||
|
||||
/// If the model requests to list the entries in a directory with more
|
||||
/// entries than this, then the tool will return a subset of the entries
|
||||
/// and suggest trying again.
|
||||
@@ -218,7 +214,7 @@ impl Tool for ContentsTool {
|
||||
// No line ranges specified, so check file size to see if it's too big.
|
||||
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
|
||||
|
||||
if file_size <= MAX_FILE_SIZE_TO_READ {
|
||||
if file_size <= outline::AUTO_OUTLINE_SIZE {
|
||||
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
|
||||
|
||||
action_log.update(cx, |log, cx| {
|
||||
@@ -229,7 +225,7 @@ impl Tool for ContentsTool {
|
||||
} else {
|
||||
// File is too big, so return its outline and a suggestion to
|
||||
// read again with a line number range specified.
|
||||
let outline = file_outline(project, file_path, action_log, None, cx).await?;
|
||||
let outline = outline::file_outline(project, file_path, action_log, None, cx).await?;
|
||||
|
||||
Ok(format!("This file was too big to read all at once. Here is an outline of its symbols:\n\n{outline}\n\nUsing the line numbers in this outline, you can call this tool again while specifying the start and end fields to see the implementations of symbols in the outline."))
|
||||
}
|
||||
|
||||
1210
crates/assistant_tools/src/edit_agent.rs
Normal file
1210
crates/assistant_tools/src/edit_agent.rs
Normal file
File diff suppressed because it is too large
Load Diff
408
crates/assistant_tools/src/edit_agent/edit_parser.rs
Normal file
408
crates/assistant_tools/src/edit_agent/edit_parser.rs
Normal file
@@ -0,0 +1,408 @@
|
||||
use derive_more::{Add, AddAssign};
|
||||
use smallvec::SmallVec;
|
||||
use std::{cmp, mem, ops::Range};
|
||||
|
||||
const OLD_TEXT_END_TAG: &str = "</old_text>";
|
||||
const NEW_TEXT_END_TAG: &str = "</new_text>";
|
||||
const END_TAG_LEN: usize = OLD_TEXT_END_TAG.len();
|
||||
const _: () = debug_assert!(OLD_TEXT_END_TAG.len() == NEW_TEXT_END_TAG.len());
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum EditParserEvent {
|
||||
OldText(String),
|
||||
NewTextChunk { chunk: String, done: bool },
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq, Add, AddAssign)]
|
||||
pub struct EditParserMetrics {
|
||||
pub tags: usize,
|
||||
pub mismatched_tags: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EditParser {
|
||||
state: EditParserState,
|
||||
buffer: String,
|
||||
metrics: EditParserMetrics,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
enum EditParserState {
|
||||
Pending,
|
||||
WithinOldText,
|
||||
AfterOldText,
|
||||
WithinNewText { start: bool },
|
||||
}
|
||||
|
||||
impl EditParser {
|
||||
pub fn new() -> Self {
|
||||
EditParser {
|
||||
state: EditParserState::Pending,
|
||||
buffer: String::new(),
|
||||
metrics: EditParserMetrics::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(&mut self, chunk: &str) -> SmallVec<[EditParserEvent; 1]> {
|
||||
self.buffer.push_str(chunk);
|
||||
|
||||
let mut edit_events = SmallVec::new();
|
||||
loop {
|
||||
match &mut self.state {
|
||||
EditParserState::Pending => {
|
||||
if let Some(start) = self.buffer.find("<old_text>") {
|
||||
self.buffer.drain(..start + "<old_text>".len());
|
||||
self.state = EditParserState::WithinOldText;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
EditParserState::WithinOldText => {
|
||||
if let Some(tag_range) = self.find_end_tag() {
|
||||
let mut start = 0;
|
||||
if self.buffer.starts_with('\n') {
|
||||
start = 1;
|
||||
}
|
||||
let mut old_text = self.buffer[start..tag_range.start].to_string();
|
||||
if old_text.ends_with('\n') {
|
||||
old_text.pop();
|
||||
}
|
||||
|
||||
self.metrics.tags += 1;
|
||||
if &self.buffer[tag_range.clone()] != OLD_TEXT_END_TAG {
|
||||
self.metrics.mismatched_tags += 1;
|
||||
}
|
||||
|
||||
self.buffer.drain(..tag_range.end);
|
||||
self.state = EditParserState::AfterOldText;
|
||||
edit_events.push(EditParserEvent::OldText(old_text));
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
EditParserState::AfterOldText => {
|
||||
if let Some(start) = self.buffer.find("<new_text>") {
|
||||
self.buffer.drain(..start + "<new_text>".len());
|
||||
self.state = EditParserState::WithinNewText { start: true };
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
EditParserState::WithinNewText { start } => {
|
||||
if !self.buffer.is_empty() {
|
||||
if *start && self.buffer.starts_with('\n') {
|
||||
self.buffer.remove(0);
|
||||
}
|
||||
*start = false;
|
||||
}
|
||||
|
||||
if let Some(tag_range) = self.find_end_tag() {
|
||||
let mut chunk = self.buffer[..tag_range.start].to_string();
|
||||
if chunk.ends_with('\n') {
|
||||
chunk.pop();
|
||||
}
|
||||
|
||||
self.metrics.tags += 1;
|
||||
if &self.buffer[tag_range.clone()] != NEW_TEXT_END_TAG {
|
||||
self.metrics.mismatched_tags += 1;
|
||||
}
|
||||
|
||||
self.buffer.drain(..tag_range.end);
|
||||
self.state = EditParserState::Pending;
|
||||
edit_events.push(EditParserEvent::NewTextChunk { chunk, done: true });
|
||||
} else {
|
||||
let mut end_prefixes = (1..END_TAG_LEN)
|
||||
.flat_map(|i| [&NEW_TEXT_END_TAG[..i], &OLD_TEXT_END_TAG[..i]])
|
||||
.chain(["\n"]);
|
||||
if end_prefixes.all(|prefix| !self.buffer.ends_with(&prefix)) {
|
||||
edit_events.push(EditParserEvent::NewTextChunk {
|
||||
chunk: mem::take(&mut self.buffer),
|
||||
done: false,
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
edit_events
|
||||
}
|
||||
|
||||
fn find_end_tag(&self) -> Option<Range<usize>> {
|
||||
let old_text_end_tag_ix = self.buffer.find(OLD_TEXT_END_TAG);
|
||||
let new_text_end_tag_ix = self.buffer.find(NEW_TEXT_END_TAG);
|
||||
let start_ix = if let Some((old_text_ix, new_text_ix)) =
|
||||
old_text_end_tag_ix.zip(new_text_end_tag_ix)
|
||||
{
|
||||
cmp::min(old_text_ix, new_text_ix)
|
||||
} else {
|
||||
old_text_end_tag_ix.or(new_text_end_tag_ix)?
|
||||
};
|
||||
Some(start_ix..start_ix + END_TAG_LEN)
|
||||
}
|
||||
|
||||
pub fn finish(self) -> EditParserMetrics {
|
||||
self.metrics
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use indoc::indoc;
|
||||
use rand::prelude::*;
|
||||
use std::cmp;
|
||||
|
||||
#[gpui::test(iterations = 1000)]
|
||||
fn test_single_edit(mut rng: StdRng) {
|
||||
let mut parser = EditParser::new();
|
||||
assert_eq!(
|
||||
parse_random_chunks(
|
||||
"<old_text>original</old_text><new_text>updated</new_text>",
|
||||
&mut parser,
|
||||
&mut rng
|
||||
),
|
||||
vec![Edit {
|
||||
old_text: "original".to_string(),
|
||||
new_text: "updated".to_string(),
|
||||
}]
|
||||
);
|
||||
assert_eq!(
|
||||
parser.finish(),
|
||||
EditParserMetrics {
|
||||
tags: 2,
|
||||
mismatched_tags: 0
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 1000)]
|
||||
fn test_multiple_edits(mut rng: StdRng) {
|
||||
let mut parser = EditParser::new();
|
||||
assert_eq!(
|
||||
parse_random_chunks(
|
||||
indoc! {"
|
||||
<old_text>
|
||||
first old
|
||||
</old_text><new_text>first new</new_text>
|
||||
<old_text>second old</old_text><new_text>
|
||||
second new
|
||||
</new_text>
|
||||
"},
|
||||
&mut parser,
|
||||
&mut rng
|
||||
),
|
||||
vec![
|
||||
Edit {
|
||||
old_text: "first old".to_string(),
|
||||
new_text: "first new".to_string(),
|
||||
},
|
||||
Edit {
|
||||
old_text: "second old".to_string(),
|
||||
new_text: "second new".to_string(),
|
||||
},
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
parser.finish(),
|
||||
EditParserMetrics {
|
||||
tags: 4,
|
||||
mismatched_tags: 0
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 1000)]
|
||||
fn test_edits_with_extra_text(mut rng: StdRng) {
|
||||
let mut parser = EditParser::new();
|
||||
assert_eq!(
|
||||
parse_random_chunks(
|
||||
indoc! {"
|
||||
ignore this <old_text>
|
||||
content</old_text>extra stuff<new_text>updated content</new_text>trailing data
|
||||
more text <old_text>second item
|
||||
</old_text>middle text<new_text>modified second item</new_text>end
|
||||
<old_text>third case</old_text><new_text>improved third case</new_text> with trailing text
|
||||
"},
|
||||
&mut parser,
|
||||
&mut rng
|
||||
),
|
||||
vec![
|
||||
Edit {
|
||||
old_text: "content".to_string(),
|
||||
new_text: "updated content".to_string(),
|
||||
},
|
||||
Edit {
|
||||
old_text: "second item".to_string(),
|
||||
new_text: "modified second item".to_string(),
|
||||
},
|
||||
Edit {
|
||||
old_text: "third case".to_string(),
|
||||
new_text: "improved third case".to_string(),
|
||||
},
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
parser.finish(),
|
||||
EditParserMetrics {
|
||||
tags: 6,
|
||||
mismatched_tags: 0
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 1000)]
|
||||
fn test_nested_tags(mut rng: StdRng) {
|
||||
let mut parser = EditParser::new();
|
||||
assert_eq!(
|
||||
parse_random_chunks(
|
||||
"<old_text>code with <tag>nested</tag> elements</old_text><new_text>new <code>content</code></new_text>",
|
||||
&mut parser,
|
||||
&mut rng
|
||||
),
|
||||
vec![Edit {
|
||||
old_text: "code with <tag>nested</tag> elements".to_string(),
|
||||
new_text: "new <code>content</code>".to_string(),
|
||||
}]
|
||||
);
|
||||
assert_eq!(
|
||||
parser.finish(),
|
||||
EditParserMetrics {
|
||||
tags: 2,
|
||||
mismatched_tags: 0
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 1000)]
|
||||
fn test_empty_old_and_new_text(mut rng: StdRng) {
|
||||
let mut parser = EditParser::new();
|
||||
assert_eq!(
|
||||
parse_random_chunks(
|
||||
"<old_text></old_text><new_text></new_text>",
|
||||
&mut parser,
|
||||
&mut rng
|
||||
),
|
||||
vec![Edit {
|
||||
old_text: "".to_string(),
|
||||
new_text: "".to_string(),
|
||||
}]
|
||||
);
|
||||
assert_eq!(
|
||||
parser.finish(),
|
||||
EditParserMetrics {
|
||||
tags: 2,
|
||||
mismatched_tags: 0
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 100)]
|
||||
fn test_multiline_content(mut rng: StdRng) {
|
||||
let mut parser = EditParser::new();
|
||||
assert_eq!(
|
||||
parse_random_chunks(
|
||||
"<old_text>line1\nline2\nline3</old_text><new_text>line1\nmodified line2\nline3</new_text>",
|
||||
&mut parser,
|
||||
&mut rng
|
||||
),
|
||||
vec![Edit {
|
||||
old_text: "line1\nline2\nline3".to_string(),
|
||||
new_text: "line1\nmodified line2\nline3".to_string(),
|
||||
}]
|
||||
);
|
||||
assert_eq!(
|
||||
parser.finish(),
|
||||
EditParserMetrics {
|
||||
tags: 2,
|
||||
mismatched_tags: 0
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 1000)]
|
||||
fn test_mismatched_tags(mut rng: StdRng) {
|
||||
let mut parser = EditParser::new();
|
||||
assert_eq!(
|
||||
parse_random_chunks(
|
||||
// Reduced from an actual Sonnet 3.7 output
|
||||
indoc! {"
|
||||
<old_text>
|
||||
a
|
||||
b
|
||||
c
|
||||
</new_text>
|
||||
<new_text>
|
||||
a
|
||||
B
|
||||
c
|
||||
</old_text>
|
||||
<old_text>
|
||||
d
|
||||
e
|
||||
f
|
||||
</new_text>
|
||||
<new_text>
|
||||
D
|
||||
e
|
||||
F
|
||||
</old_text>
|
||||
"},
|
||||
&mut parser,
|
||||
&mut rng
|
||||
),
|
||||
vec![
|
||||
Edit {
|
||||
old_text: "a\nb\nc".to_string(),
|
||||
new_text: "a\nB\nc".to_string(),
|
||||
},
|
||||
Edit {
|
||||
old_text: "d\ne\nf".to_string(),
|
||||
new_text: "D\ne\nF".to_string(),
|
||||
}
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
parser.finish(),
|
||||
EditParserMetrics {
|
||||
tags: 4,
|
||||
mismatched_tags: 4
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, PartialEq, Eq)]
|
||||
struct Edit {
|
||||
old_text: String,
|
||||
new_text: String,
|
||||
}
|
||||
|
||||
fn parse_random_chunks(input: &str, parser: &mut EditParser, rng: &mut StdRng) -> Vec<Edit> {
|
||||
let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
|
||||
let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
|
||||
chunk_indices.sort();
|
||||
chunk_indices.push(input.len());
|
||||
|
||||
let mut pending_edit = Edit::default();
|
||||
let mut edits = Vec::new();
|
||||
let mut last_ix = 0;
|
||||
for chunk_ix in chunk_indices {
|
||||
for event in parser.push(&input[last_ix..chunk_ix]) {
|
||||
match event {
|
||||
EditParserEvent::OldText(old_text) => {
|
||||
pending_edit.old_text = old_text;
|
||||
}
|
||||
EditParserEvent::NewTextChunk { chunk, done } => {
|
||||
pending_edit.new_text.push_str(&chunk);
|
||||
if done {
|
||||
edits.push(pending_edit);
|
||||
pending_edit = Edit::default();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
last_ix = chunk_ix;
|
||||
}
|
||||
edits
|
||||
}
|
||||
}
|
||||
1064
crates/assistant_tools/src/edit_agent/evals.rs
Normal file
1064
crates/assistant_tools/src/edit_agent/evals.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,328 @@
|
||||
use crate::commit::get_messages;
|
||||
use crate::{GitRemote, Oid};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::AsyncWriteExt;
|
||||
use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Stdio;
|
||||
use std::{ops::Range, path::Path};
|
||||
use text::Rope;
|
||||
use time::OffsetDateTime;
|
||||
use time::UtcOffset;
|
||||
use time::macros::format_description;
|
||||
|
||||
pub use git2 as libgit;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Blame {
|
||||
pub entries: Vec<BlameEntry>,
|
||||
pub messages: HashMap<Oid, String>,
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ParsedCommitMessage {
|
||||
pub message: SharedString,
|
||||
pub permalink: Option<url::Url>,
|
||||
pub pull_request: Option<crate::hosting_provider::PullRequest>,
|
||||
pub remote: Option<GitRemote>,
|
||||
}
|
||||
|
||||
impl Blame {
|
||||
pub async fn for_path(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
content: &Rope,
|
||||
remote_url: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let output = run_git_blame(git_binary, working_directory, path, content).await?;
|
||||
let mut entries = parse_git_blame(&output)?;
|
||||
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
|
||||
|
||||
let mut unique_shas = HashSet::default();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
unique_shas.insert(entry.sha);
|
||||
}
|
||||
|
||||
let shas = unique_shas.into_iter().collect::<Vec<_>>();
|
||||
let messages = get_messages(working_directory, &shas)
|
||||
.await
|
||||
.context("failed to get commit messages")?;
|
||||
|
||||
Ok(Self {
|
||||
entries,
|
||||
messages,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
|
||||
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BlameEntry {
|
||||
pub sha: Oid,
|
||||
|
||||
pub range: Range<u32>,
|
||||
|
||||
pub original_line_number: u32,
|
||||
|
||||
pub author: Option<String>,
|
||||
pub author_mail: Option<String>,
|
||||
pub author_time: Option<i64>,
|
||||
pub author_tz: Option<String>,
|
||||
|
||||
pub committer_name: Option<String>,
|
||||
pub committer_email: Option<String>,
|
||||
pub committer_time: Option<i64>,
|
||||
pub committer_tz: Option<String>,
|
||||
|
||||
pub summary: Option<String>,
|
||||
|
||||
pub previous: Option<String>,
|
||||
pub filename: String,
|
||||
}
|
||||
|
||||
impl BlameEntry {
|
||||
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
|
||||
// entry. The line MUST have this format:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
|
||||
let mut parts = line.split_whitespace();
|
||||
|
||||
let sha = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<Oid>().ok())
|
||||
.ok_or_else(|| anyhow!("failed to parse sha"))?;
|
||||
|
||||
let original_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
|
||||
let final_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let line_count = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let start_line = final_line_number.saturating_sub(1);
|
||||
let end_line = start_line + line_count;
|
||||
let range = start_line..end_line;
|
||||
|
||||
Ok(Self {
|
||||
sha,
|
||||
range,
|
||||
original_line_number,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
|
||||
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
|
||||
let format = format_description!("[offset_hour][offset_minute]");
|
||||
let offset = UtcOffset::parse(author_tz, &format)?;
|
||||
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
|
||||
|
||||
Ok(date_time_utc.to_offset(offset))
|
||||
} else {
|
||||
// Directly return current time in UTC if there's no committer time or timezone
|
||||
Ok(time::OffsetDateTime::now_utc())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parse_git_blame parses the output of `git blame --incremental`, which returns
|
||||
// all the blame-entries for a given path incrementally, as it finds them.
|
||||
//
|
||||
// Each entry *always* starts with:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
//
|
||||
// Each entry *always* ends with:
|
||||
//
|
||||
// filename <whitespace-quoted-filename-goes-here>
|
||||
//
|
||||
// Line numbers are 1-indexed.
|
||||
//
|
||||
// A `git blame --incremental` entry looks like this:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
|
||||
// author Joe Schmoe
|
||||
// author-mail <joe.schmoe@example.com>
|
||||
// author-time 1709741400
|
||||
// author-tz +0100
|
||||
// committer Joe Schmoe
|
||||
// committer-mail <joe.schmoe@example.com>
|
||||
// committer-time 1709741400
|
||||
// committer-tz +0100
|
||||
// summary Joe's cool commit
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// If the entry has the same SHA as an entry that was already printed then no
|
||||
// signature information is printed:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
|
||||
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
|
||||
let mut entries: Vec<BlameEntry> = Vec::new();
|
||||
let mut index: HashMap<Oid, usize> = HashMap::default();
|
||||
|
||||
let mut current_entry: Option<BlameEntry> = None;
|
||||
|
||||
for line in output.lines() {
|
||||
let mut done = false;
|
||||
|
||||
match &mut current_entry {
|
||||
None => {
|
||||
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
|
||||
|
||||
if let Some(existing_entry) = index
|
||||
.get(&new_entry.sha)
|
||||
.and_then(|slot| entries.get(*slot))
|
||||
{
|
||||
new_entry.author.clone_from(&existing_entry.author);
|
||||
new_entry
|
||||
.author_mail
|
||||
.clone_from(&existing_entry.author_mail);
|
||||
new_entry.author_time = existing_entry.author_time;
|
||||
new_entry.author_tz.clone_from(&existing_entry.author_tz);
|
||||
new_entry
|
||||
.committer_name
|
||||
.clone_from(&existing_entry.committer_name);
|
||||
new_entry
|
||||
.committer_email
|
||||
.clone_from(&existing_entry.committer_email);
|
||||
new_entry.committer_time = existing_entry.committer_time;
|
||||
new_entry
|
||||
.committer_tz
|
||||
.clone_from(&existing_entry.committer_tz);
|
||||
new_entry.summary.clone_from(&existing_entry.summary);
|
||||
}
|
||||
|
||||
current_entry.replace(new_entry);
|
||||
}
|
||||
Some(entry) => {
|
||||
let Some((key, value)) = line.split_once(' ') else {
|
||||
continue;
|
||||
};
|
||||
let is_committed = !entry.sha.is_zero();
|
||||
match key {
|
||||
"filename" => {
|
||||
entry.filename = value.into();
|
||||
done = true;
|
||||
}
|
||||
"previous" => entry.previous = Some(value.into()),
|
||||
|
||||
"summary" if is_committed => entry.summary = Some(value.into()),
|
||||
"author" if is_committed => entry.author = Some(value.into()),
|
||||
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
|
||||
"author-time" if is_committed => {
|
||||
entry.author_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
|
||||
|
||||
"committer" if is_committed => entry.committer_name = Some(value.into()),
|
||||
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
|
||||
"committer-time" if is_committed => {
|
||||
entry.committer_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if done {
|
||||
if let Some(entry) = current_entry.take() {
|
||||
index.insert(entry.sha, entries.len());
|
||||
|
||||
// We only want annotations that have a commit.
|
||||
if !entry.sha.is_zero() {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::BlameEntry;
|
||||
use super::parse_git_blame;
|
||||
|
||||
fn read_test_data(filename: &str) -> String {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push(filename);
|
||||
|
||||
std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
|
||||
}
|
||||
|
||||
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push("golden");
|
||||
path.push(format!("{}.json", golden_filename));
|
||||
|
||||
let mut have_json =
|
||||
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
|
||||
// We always want to save with a trailing newline.
|
||||
have_json.push('\n');
|
||||
|
||||
let update = std::env::var("UPDATE_GOLDEN")
|
||||
.map(|val| val.eq_ignore_ascii_case("true"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if update {
|
||||
std::fs::create_dir_all(path.parent().unwrap())
|
||||
.expect("could not create golden test data directory");
|
||||
std::fs::write(&path, have_json).expect("could not write out golden data");
|
||||
} else {
|
||||
let want_json =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| {
|
||||
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
|
||||
}).replace("\r\n", "\n");
|
||||
|
||||
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_not_committed() {
|
||||
let output = read_test_data("blame_incremental_not_committed");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_not_committed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_simple() {
|
||||
let output = read_test_data("blame_incremental_simple");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_simple");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_complex() {
|
||||
let output = read_test_data("blame_incremental_complex");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_complex");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,374 @@
|
||||
use crate::commit::get_messages;
|
||||
use crate::{GitRemote, Oid};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::AsyncWriteExt;
|
||||
use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Stdio;
|
||||
use std::{ops::Range, path::Path};
|
||||
use text::Rope;
|
||||
use time::OffsetDateTime;
|
||||
use time::UtcOffset;
|
||||
use time::macros::format_description;
|
||||
|
||||
pub use git2 as libgit;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Blame {
|
||||
pub entries: Vec<BlameEntry>,
|
||||
pub messages: HashMap<Oid, String>,
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ParsedCommitMessage {
|
||||
pub message: SharedString,
|
||||
pub permalink: Option<url::Url>,
|
||||
pub pull_request: Option<crate::hosting_provider::PullRequest>,
|
||||
pub remote: Option<GitRemote>,
|
||||
}
|
||||
|
||||
impl Blame {
|
||||
pub async fn for_path(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
content: &Rope,
|
||||
remote_url: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let output = run_git_blame(git_binary, working_directory, path, content).await?;
|
||||
let mut entries = parse_git_blame(&output)?;
|
||||
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
|
||||
|
||||
let mut unique_shas = HashSet::default();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
unique_shas.insert(entry.sha);
|
||||
}
|
||||
|
||||
let shas = unique_shas.into_iter().collect::<Vec<_>>();
|
||||
let messages = get_messages(working_directory, &shas)
|
||||
.await
|
||||
.context("failed to get commit messages")?;
|
||||
|
||||
Ok(Self {
|
||||
entries,
|
||||
messages,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
|
||||
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
|
||||
|
||||
async fn run_git_blame(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
contents: &Rope,
|
||||
) -> Result<String> {
|
||||
let mut child = util::command::new_smol_command(git_binary)
|
||||
.current_dir(working_directory)
|
||||
.arg("blame")
|
||||
.arg("--incremental")
|
||||
.arg("--contents")
|
||||
.arg("-")
|
||||
.arg(path.as_os_str())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.as_mut()
|
||||
.context("failed to get pipe to stdin of git blame command")?;
|
||||
|
||||
for chunk in contents.chunks() {
|
||||
stdin.write_all(chunk.as_bytes()).await?;
|
||||
}
|
||||
stdin.flush().await?;
|
||||
|
||||
let output = child
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let trimmed = stderr.trim();
|
||||
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
|
||||
return Ok(String::new());
|
||||
}
|
||||
return Err(anyhow!("git blame process failed: {}", stderr));
|
||||
}
|
||||
|
||||
Ok(String::from_utf8(output.stdout)?)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BlameEntry {
|
||||
pub sha: Oid,
|
||||
|
||||
pub range: Range<u32>,
|
||||
|
||||
pub original_line_number: u32,
|
||||
|
||||
pub author: Option<String>,
|
||||
pub author_mail: Option<String>,
|
||||
pub author_time: Option<i64>,
|
||||
pub author_tz: Option<String>,
|
||||
|
||||
pub committer_name: Option<String>,
|
||||
pub committer_email: Option<String>,
|
||||
pub committer_time: Option<i64>,
|
||||
pub committer_tz: Option<String>,
|
||||
|
||||
pub summary: Option<String>,
|
||||
|
||||
pub previous: Option<String>,
|
||||
pub filename: String,
|
||||
}
|
||||
|
||||
impl BlameEntry {
|
||||
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
|
||||
// entry. The line MUST have this format:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
|
||||
let mut parts = line.split_whitespace();
|
||||
|
||||
let sha = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<Oid>().ok())
|
||||
.ok_or_else(|| anyhow!("failed to parse sha"))?;
|
||||
|
||||
let original_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
|
||||
let final_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let line_count = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let start_line = final_line_number.saturating_sub(1);
|
||||
let end_line = start_line + line_count;
|
||||
let range = start_line..end_line;
|
||||
|
||||
Ok(Self {
|
||||
sha,
|
||||
range,
|
||||
original_line_number,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
|
||||
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
|
||||
let format = format_description!("[offset_hour][offset_minute]");
|
||||
let offset = UtcOffset::parse(author_tz, &format)?;
|
||||
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
|
||||
|
||||
Ok(date_time_utc.to_offset(offset))
|
||||
} else {
|
||||
// Directly return current time in UTC if there's no committer time or timezone
|
||||
Ok(time::OffsetDateTime::now_utc())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parse_git_blame parses the output of `git blame --incremental`, which returns
|
||||
// all the blame-entries for a given path incrementally, as it finds them.
|
||||
//
|
||||
// Each entry *always* starts with:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
//
|
||||
// Each entry *always* ends with:
|
||||
//
|
||||
// filename <whitespace-quoted-filename-goes-here>
|
||||
//
|
||||
// Line numbers are 1-indexed.
|
||||
//
|
||||
// A `git blame --incremental` entry looks like this:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
|
||||
// author Joe Schmoe
|
||||
// author-mail <joe.schmoe@example.com>
|
||||
// author-time 1709741400
|
||||
// author-tz +0100
|
||||
// committer Joe Schmoe
|
||||
// committer-mail <joe.schmoe@example.com>
|
||||
// committer-time 1709741400
|
||||
// committer-tz +0100
|
||||
// summary Joe's cool commit
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// If the entry has the same SHA as an entry that was already printed then no
|
||||
// signature information is printed:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
|
||||
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
|
||||
let mut entries: Vec<BlameEntry> = Vec::new();
|
||||
let mut index: HashMap<Oid, usize> = HashMap::default();
|
||||
|
||||
let mut current_entry: Option<BlameEntry> = None;
|
||||
|
||||
for line in output.lines() {
|
||||
let mut done = false;
|
||||
|
||||
match &mut current_entry {
|
||||
None => {
|
||||
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
|
||||
|
||||
if let Some(existing_entry) = index
|
||||
.get(&new_entry.sha)
|
||||
.and_then(|slot| entries.get(*slot))
|
||||
{
|
||||
new_entry.author.clone_from(&existing_entry.author);
|
||||
new_entry
|
||||
.author_mail
|
||||
.clone_from(&existing_entry.author_mail);
|
||||
new_entry.author_time = existing_entry.author_time;
|
||||
new_entry.author_tz.clone_from(&existing_entry.author_tz);
|
||||
new_entry
|
||||
.committer_name
|
||||
.clone_from(&existing_entry.committer_name);
|
||||
new_entry
|
||||
.committer_email
|
||||
.clone_from(&existing_entry.committer_email);
|
||||
new_entry.committer_time = existing_entry.committer_time;
|
||||
new_entry
|
||||
.committer_tz
|
||||
.clone_from(&existing_entry.committer_tz);
|
||||
new_entry.summary.clone_from(&existing_entry.summary);
|
||||
}
|
||||
|
||||
current_entry.replace(new_entry);
|
||||
}
|
||||
Some(entry) => {
|
||||
let Some((key, value)) = line.split_once(' ') else {
|
||||
continue;
|
||||
};
|
||||
let is_committed = !entry.sha.is_zero();
|
||||
match key {
|
||||
"filename" => {
|
||||
entry.filename = value.into();
|
||||
done = true;
|
||||
}
|
||||
"previous" => entry.previous = Some(value.into()),
|
||||
|
||||
"summary" if is_committed => entry.summary = Some(value.into()),
|
||||
"author" if is_committed => entry.author = Some(value.into()),
|
||||
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
|
||||
"author-time" if is_committed => {
|
||||
entry.author_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
|
||||
|
||||
"committer" if is_committed => entry.committer_name = Some(value.into()),
|
||||
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
|
||||
"committer-time" if is_committed => {
|
||||
entry.committer_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if done {
|
||||
if let Some(entry) = current_entry.take() {
|
||||
index.insert(entry.sha, entries.len());
|
||||
|
||||
// We only want annotations that have a commit.
|
||||
if !entry.sha.is_zero() {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::BlameEntry;
|
||||
use super::parse_git_blame;
|
||||
|
||||
fn read_test_data(filename: &str) -> String {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push(filename);
|
||||
|
||||
std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
|
||||
}
|
||||
|
||||
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push("golden");
|
||||
path.push(format!("{}.json", golden_filename));
|
||||
|
||||
let mut have_json =
|
||||
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
|
||||
// We always want to save with a trailing newline.
|
||||
have_json.push('\n');
|
||||
|
||||
let update = std::env::var("UPDATE_GOLDEN")
|
||||
.map(|val| val.eq_ignore_ascii_case("true"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if update {
|
||||
std::fs::create_dir_all(path.parent().unwrap())
|
||||
.expect("could not create golden test data directory");
|
||||
std::fs::write(&path, have_json).expect("could not write out golden data");
|
||||
} else {
|
||||
let want_json =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| {
|
||||
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
|
||||
}).replace("\r\n", "\n");
|
||||
|
||||
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_not_committed() {
|
||||
let output = read_test_data("blame_incremental_not_committed");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_not_committed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_simple() {
|
||||
let output = read_test_data("blame_incremental_simple");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_simple");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_complex() {
|
||||
let output = read_test_data("blame_incremental_complex");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_complex");
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,378 @@
|
||||
use crate::commit::get_messages;
|
||||
use crate::{GitRemote, Oid};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::AsyncWriteExt;
|
||||
use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Stdio;
|
||||
use std::{ops::Range, path::Path};
|
||||
use text::Rope;
|
||||
use time::OffsetDateTime;
|
||||
use time::UtcOffset;
|
||||
use time::macros::format_description;
|
||||
|
||||
pub use git2 as libgit;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Blame {
|
||||
pub entries: Vec<BlameEntry>,
|
||||
pub messages: HashMap<Oid, String>,
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ParsedCommitMessage {
|
||||
pub message: SharedString,
|
||||
pub permalink: Option<url::Url>,
|
||||
pub pull_request: Option<crate::hosting_provider::PullRequest>,
|
||||
pub remote: Option<GitRemote>,
|
||||
}
|
||||
|
||||
impl Blame {
|
||||
pub async fn for_path(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
content: &Rope,
|
||||
remote_url: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let output = run_git_blame(git_binary, working_directory, path, content).await?;
|
||||
let mut entries = parse_git_blame(&output)?;
|
||||
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
|
||||
|
||||
let mut unique_shas = HashSet::default();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
unique_shas.insert(entry.sha);
|
||||
}
|
||||
|
||||
let shas = unique_shas.into_iter().collect::<Vec<_>>();
|
||||
let messages = get_messages(working_directory, &shas)
|
||||
.await
|
||||
.context("failed to get commit messages")?;
|
||||
|
||||
Ok(Self {
|
||||
entries,
|
||||
messages,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
|
||||
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
|
||||
|
||||
async fn run_git_blame(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
contents: &Rope,
|
||||
) -> Result<String> {
|
||||
let mut child = util::command::new_smol_command(git_binary)
|
||||
.current_dir(working_directory)
|
||||
.arg("blame")
|
||||
.arg("--incremental")
|
||||
.arg("--contents")
|
||||
.arg("-")
|
||||
.arg(path.as_os_str())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.as_mut()
|
||||
.context("failed to get pipe to stdin of git blame command")?;
|
||||
|
||||
for chunk in contents.chunks() {
|
||||
stdin.write_all(chunk.as_bytes()).await?;
|
||||
}
|
||||
stdin.flush().await?;
|
||||
|
||||
let output = child
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
|
||||
|
||||
handle_command_output(output)
|
||||
}
|
||||
|
||||
fn handle_command_output(output: std::process::Output) -> Result<String> {
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let trimmed = stderr.trim();
|
||||
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
|
||||
return Ok(String::new());
|
||||
}
|
||||
return Err(anyhow!("git blame process failed: {}", stderr));
|
||||
}
|
||||
|
||||
Ok(String::from_utf8(output.stdout)?)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BlameEntry {
|
||||
pub sha: Oid,
|
||||
|
||||
pub range: Range<u32>,
|
||||
|
||||
pub original_line_number: u32,
|
||||
|
||||
pub author: Option<String>,
|
||||
pub author_mail: Option<String>,
|
||||
pub author_time: Option<i64>,
|
||||
pub author_tz: Option<String>,
|
||||
|
||||
pub committer_name: Option<String>,
|
||||
pub committer_email: Option<String>,
|
||||
pub committer_time: Option<i64>,
|
||||
pub committer_tz: Option<String>,
|
||||
|
||||
pub summary: Option<String>,
|
||||
|
||||
pub previous: Option<String>,
|
||||
pub filename: String,
|
||||
}
|
||||
|
||||
impl BlameEntry {
|
||||
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
|
||||
// entry. The line MUST have this format:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
|
||||
let mut parts = line.split_whitespace();
|
||||
|
||||
let sha = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<Oid>().ok())
|
||||
.ok_or_else(|| anyhow!("failed to parse sha"))?;
|
||||
|
||||
let original_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
|
||||
let final_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let line_count = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let start_line = final_line_number.saturating_sub(1);
|
||||
let end_line = start_line + line_count;
|
||||
let range = start_line..end_line;
|
||||
|
||||
Ok(Self {
|
||||
sha,
|
||||
range,
|
||||
original_line_number,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
|
||||
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
|
||||
let format = format_description!("[offset_hour][offset_minute]");
|
||||
let offset = UtcOffset::parse(author_tz, &format)?;
|
||||
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
|
||||
|
||||
Ok(date_time_utc.to_offset(offset))
|
||||
} else {
|
||||
// Directly return current time in UTC if there's no committer time or timezone
|
||||
Ok(time::OffsetDateTime::now_utc())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parse_git_blame parses the output of `git blame --incremental`, which returns
|
||||
// all the blame-entries for a given path incrementally, as it finds them.
|
||||
//
|
||||
// Each entry *always* starts with:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
//
|
||||
// Each entry *always* ends with:
|
||||
//
|
||||
// filename <whitespace-quoted-filename-goes-here>
|
||||
//
|
||||
// Line numbers are 1-indexed.
|
||||
//
|
||||
// A `git blame --incremental` entry looks like this:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
|
||||
// author Joe Schmoe
|
||||
// author-mail <joe.schmoe@example.com>
|
||||
// author-time 1709741400
|
||||
// author-tz +0100
|
||||
// committer Joe Schmoe
|
||||
// committer-mail <joe.schmoe@example.com>
|
||||
// committer-time 1709741400
|
||||
// committer-tz +0100
|
||||
// summary Joe's cool commit
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// If the entry has the same SHA as an entry that was already printed then no
|
||||
// signature information is printed:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
|
||||
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
|
||||
let mut entries: Vec<BlameEntry> = Vec::new();
|
||||
let mut index: HashMap<Oid, usize> = HashMap::default();
|
||||
|
||||
let mut current_entry: Option<BlameEntry> = None;
|
||||
|
||||
for line in output.lines() {
|
||||
let mut done = false;
|
||||
|
||||
match &mut current_entry {
|
||||
None => {
|
||||
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
|
||||
|
||||
if let Some(existing_entry) = index
|
||||
.get(&new_entry.sha)
|
||||
.and_then(|slot| entries.get(*slot))
|
||||
{
|
||||
new_entry.author.clone_from(&existing_entry.author);
|
||||
new_entry
|
||||
.author_mail
|
||||
.clone_from(&existing_entry.author_mail);
|
||||
new_entry.author_time = existing_entry.author_time;
|
||||
new_entry.author_tz.clone_from(&existing_entry.author_tz);
|
||||
new_entry
|
||||
.committer_name
|
||||
.clone_from(&existing_entry.committer_name);
|
||||
new_entry
|
||||
.committer_email
|
||||
.clone_from(&existing_entry.committer_email);
|
||||
new_entry.committer_time = existing_entry.committer_time;
|
||||
new_entry
|
||||
.committer_tz
|
||||
.clone_from(&existing_entry.committer_tz);
|
||||
new_entry.summary.clone_from(&existing_entry.summary);
|
||||
}
|
||||
|
||||
current_entry.replace(new_entry);
|
||||
}
|
||||
Some(entry) => {
|
||||
let Some((key, value)) = line.split_once(' ') else {
|
||||
continue;
|
||||
};
|
||||
let is_committed = !entry.sha.is_zero();
|
||||
match key {
|
||||
"filename" => {
|
||||
entry.filename = value.into();
|
||||
done = true;
|
||||
}
|
||||
"previous" => entry.previous = Some(value.into()),
|
||||
|
||||
"summary" if is_committed => entry.summary = Some(value.into()),
|
||||
"author" if is_committed => entry.author = Some(value.into()),
|
||||
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
|
||||
"author-time" if is_committed => {
|
||||
entry.author_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
|
||||
|
||||
"committer" if is_committed => entry.committer_name = Some(value.into()),
|
||||
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
|
||||
"committer-time" if is_committed => {
|
||||
entry.committer_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if done {
|
||||
if let Some(entry) = current_entry.take() {
|
||||
index.insert(entry.sha, entries.len());
|
||||
|
||||
// We only want annotations that have a commit.
|
||||
if !entry.sha.is_zero() {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::BlameEntry;
|
||||
use super::parse_git_blame;
|
||||
|
||||
fn read_test_data(filename: &str) -> String {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push(filename);
|
||||
|
||||
std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
|
||||
}
|
||||
|
||||
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push("golden");
|
||||
path.push(format!("{}.json", golden_filename));
|
||||
|
||||
let mut have_json =
|
||||
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
|
||||
// We always want to save with a trailing newline.
|
||||
have_json.push('\n');
|
||||
|
||||
let update = std::env::var("UPDATE_GOLDEN")
|
||||
.map(|val| val.eq_ignore_ascii_case("true"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if update {
|
||||
std::fs::create_dir_all(path.parent().unwrap())
|
||||
.expect("could not create golden test data directory");
|
||||
std::fs::write(&path, have_json).expect("could not write out golden data");
|
||||
} else {
|
||||
let want_json =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| {
|
||||
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
|
||||
}).replace("\r\n", "\n");
|
||||
|
||||
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_not_committed() {
|
||||
let output = read_test_data("blame_incremental_not_committed");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_not_committed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_simple() {
|
||||
let output = read_test_data("blame_incremental_simple");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_simple");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_complex() {
|
||||
let output = read_test_data("blame_incremental_complex");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_complex");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,374 @@
|
||||
use crate::commit::get_messages;
|
||||
use crate::{GitRemote, Oid};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::AsyncWriteExt;
|
||||
use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Stdio;
|
||||
use std::{ops::Range, path::Path};
|
||||
use text::Rope;
|
||||
use time::OffsetDateTime;
|
||||
use time::UtcOffset;
|
||||
use time::macros::format_description;
|
||||
|
||||
pub use git2 as libgit;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Blame {
|
||||
pub entries: Vec<BlameEntry>,
|
||||
pub messages: HashMap<Oid, String>,
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ParsedCommitMessage {
|
||||
pub message: SharedString,
|
||||
pub permalink: Option<url::Url>,
|
||||
pub pull_request: Option<crate::hosting_provider::PullRequest>,
|
||||
pub remote: Option<GitRemote>,
|
||||
}
|
||||
|
||||
impl Blame {
|
||||
pub async fn for_path(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
content: &Rope,
|
||||
remote_url: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let output = run_git_blame(git_binary, working_directory, path, content).await?;
|
||||
let mut entries = parse_git_blame(&output)?;
|
||||
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
|
||||
|
||||
let mut unique_shas = HashSet::default();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
unique_shas.insert(entry.sha);
|
||||
}
|
||||
|
||||
let shas = unique_shas.into_iter().collect::<Vec<_>>();
|
||||
let messages = get_messages(working_directory, &shas)
|
||||
.await
|
||||
.context("failed to get commit messages")?;
|
||||
|
||||
Ok(Self {
|
||||
entries,
|
||||
messages,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
|
||||
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
|
||||
|
||||
async fn run_git_blame(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
contents: &Rope,
|
||||
) -> Result<String> {
|
||||
let mut child = util::command::new_smol_command(git_binary)
|
||||
.current_dir(working_directory)
|
||||
.arg("blame")
|
||||
.arg("--incremental")
|
||||
.arg("--contents")
|
||||
.arg("-")
|
||||
.arg(path.as_os_str())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.as_mut()
|
||||
.context("failed to get pipe to stdin of git blame command")?;
|
||||
|
||||
for chunk in contents.chunks() {
|
||||
stdin.write_all(chunk.as_bytes()).await?;
|
||||
}
|
||||
stdin.flush().await?;
|
||||
|
||||
let output = child
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let trimmed = stderr.trim();
|
||||
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
|
||||
return Ok(String::new());
|
||||
}
|
||||
return Err(anyhow!("git blame process failed: {}", stderr));
|
||||
}
|
||||
|
||||
Ok(String::from_utf8(output.stdout)?)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BlameEntry {
|
||||
pub sha: Oid,
|
||||
|
||||
pub range: Range<u32>,
|
||||
|
||||
pub original_line_number: u32,
|
||||
|
||||
pub author: Option<String>,
|
||||
pub author_mail: Option<String>,
|
||||
pub author_time: Option<i64>,
|
||||
pub author_tz: Option<String>,
|
||||
|
||||
pub committer_name: Option<String>,
|
||||
pub committer_email: Option<String>,
|
||||
pub committer_time: Option<i64>,
|
||||
pub committer_tz: Option<String>,
|
||||
|
||||
pub summary: Option<String>,
|
||||
|
||||
pub previous: Option<String>,
|
||||
pub filename: String,
|
||||
}
|
||||
|
||||
impl BlameEntry {
|
||||
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
|
||||
// entry. The line MUST have this format:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
|
||||
let mut parts = line.split_whitespace();
|
||||
|
||||
let sha = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<Oid>().ok())
|
||||
.ok_or_else(|| anyhow!("failed to parse sha"))?;
|
||||
|
||||
let original_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
|
||||
let final_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let line_count = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let start_line = final_line_number.saturating_sub(1);
|
||||
let end_line = start_line + line_count;
|
||||
let range = start_line..end_line;
|
||||
|
||||
Ok(Self {
|
||||
sha,
|
||||
range,
|
||||
original_line_number,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
|
||||
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
|
||||
let format = format_description!("[offset_hour][offset_minute]");
|
||||
let offset = UtcOffset::parse(author_tz, &format)?;
|
||||
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
|
||||
|
||||
Ok(date_time_utc.to_offset(offset))
|
||||
} else {
|
||||
// Directly return current time in UTC if there's no committer time or timezone
|
||||
Ok(time::OffsetDateTime::now_utc())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parse_git_blame parses the output of `git blame --incremental`, which returns
|
||||
// all the blame-entries for a given path incrementally, as it finds them.
|
||||
//
|
||||
// Each entry *always* starts with:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
//
|
||||
// Each entry *always* ends with:
|
||||
//
|
||||
// filename <whitespace-quoted-filename-goes-here>
|
||||
//
|
||||
// Line numbers are 1-indexed.
|
||||
//
|
||||
// A `git blame --incremental` entry looks like this:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
|
||||
// author Joe Schmoe
|
||||
// author-mail <joe.schmoe@example.com>
|
||||
// author-time 1709741400
|
||||
// author-tz +0100
|
||||
// committer Joe Schmoe
|
||||
// committer-mail <joe.schmoe@example.com>
|
||||
// committer-time 1709741400
|
||||
// committer-tz +0100
|
||||
// summary Joe's cool commit
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// If the entry has the same SHA as an entry that was already printed then no
|
||||
// signature information is printed:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
|
||||
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
|
||||
let mut entries: Vec<BlameEntry> = Vec::new();
|
||||
let mut index: HashMap<Oid, usize> = HashMap::default();
|
||||
|
||||
let mut current_entry: Option<BlameEntry> = None;
|
||||
|
||||
for line in output.lines() {
|
||||
let mut done = false;
|
||||
|
||||
match &mut current_entry {
|
||||
None => {
|
||||
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
|
||||
|
||||
if let Some(existing_entry) = index
|
||||
.get(&new_entry.sha)
|
||||
.and_then(|slot| entries.get(*slot))
|
||||
{
|
||||
new_entry.author.clone_from(&existing_entry.author);
|
||||
new_entry
|
||||
.author_mail
|
||||
.clone_from(&existing_entry.author_mail);
|
||||
new_entry.author_time = existing_entry.author_time;
|
||||
new_entry.author_tz.clone_from(&existing_entry.author_tz);
|
||||
new_entry
|
||||
.committer_name
|
||||
.clone_from(&existing_entry.committer_name);
|
||||
new_entry
|
||||
.committer_email
|
||||
.clone_from(&existing_entry.committer_email);
|
||||
new_entry.committer_time = existing_entry.committer_time;
|
||||
new_entry
|
||||
.committer_tz
|
||||
.clone_from(&existing_entry.committer_tz);
|
||||
new_entry.summary.clone_from(&existing_entry.summary);
|
||||
}
|
||||
|
||||
current_entry.replace(new_entry);
|
||||
}
|
||||
Some(entry) => {
|
||||
let Some((key, value)) = line.split_once(' ') else {
|
||||
continue;
|
||||
};
|
||||
let is_committed = !entry.sha.is_zero();
|
||||
match key {
|
||||
"filename" => {
|
||||
entry.filename = value.into();
|
||||
done = true;
|
||||
}
|
||||
"previous" => entry.previous = Some(value.into()),
|
||||
|
||||
"summary" if is_committed => entry.summary = Some(value.into()),
|
||||
"author" if is_committed => entry.author = Some(value.into()),
|
||||
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
|
||||
"author-time" if is_committed => {
|
||||
entry.author_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
|
||||
|
||||
"committer" if is_committed => entry.committer_name = Some(value.into()),
|
||||
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
|
||||
"committer-time" if is_committed => {
|
||||
entry.committer_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if done {
|
||||
if let Some(entry) = current_entry.take() {
|
||||
index.insert(entry.sha, entries.len());
|
||||
|
||||
// We only want annotations that have a commit.
|
||||
if !entry.sha.is_zero() {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::BlameEntry;
|
||||
use super::parse_git_blame;
|
||||
|
||||
fn read_test_data(filename: &str) -> String {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push(filename);
|
||||
|
||||
std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
|
||||
}
|
||||
|
||||
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push("golden");
|
||||
path.push(format!("{}.json", golden_filename));
|
||||
|
||||
let mut have_json =
|
||||
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
|
||||
// We always want to save with a trailing newline.
|
||||
have_json.push('\n');
|
||||
|
||||
let update = std::env::var("UPDATE_GOLDEN")
|
||||
.map(|val| val.eq_ignore_ascii_case("true"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if update {
|
||||
std::fs::create_dir_all(path.parent().unwrap())
|
||||
.expect("could not create golden test data directory");
|
||||
std::fs::write(&path, have_json).expect("could not write out golden data");
|
||||
} else {
|
||||
let want_json =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| {
|
||||
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
|
||||
}).replace("\r\n", "\n");
|
||||
|
||||
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_not_committed() {
|
||||
let output = read_test_data("blame_incremental_not_committed");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_not_committed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_simple() {
|
||||
let output = read_test_data("blame_incremental_simple");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_simple");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_complex() {
|
||||
let output = read_test_data("blame_incremental_complex");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_complex");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,339 @@
|
||||
// font-kit/src/canvas.rs
|
||||
//
|
||||
// Copyright © 2018 The Pathfinder Project Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
|
||||
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
|
||||
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
|
||||
// option. This file may not be copied, modified, or distributed
|
||||
// except according to those terms.
|
||||
|
||||
//! An in-memory bitmap surface for glyph rasterization.
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
use pathfinder_geometry::rect::RectI;
|
||||
use pathfinder_geometry::vector::Vector2I;
|
||||
use std::cmp;
|
||||
use std::fmt;
|
||||
|
||||
use crate::utils;
|
||||
|
||||
lazy_static! {
|
||||
static ref BITMAP_1BPP_TO_8BPP_LUT: [[u8; 8]; 256] = {
|
||||
let mut lut = [[0; 8]; 256];
|
||||
for byte in 0..0x100 {
|
||||
let mut value = [0; 8];
|
||||
for bit in 0..8 {
|
||||
if (byte & (0x80 >> bit)) != 0 {
|
||||
value[bit] = 0xff;
|
||||
}
|
||||
}
|
||||
lut[byte] = value
|
||||
}
|
||||
lut
|
||||
};
|
||||
}
|
||||
|
||||
/// An in-memory bitmap surface for glyph rasterization.
|
||||
pub struct Canvas {
|
||||
/// The raw pixel data.
|
||||
pub pixels: Vec<u8>,
|
||||
/// The size of the buffer, in pixels.
|
||||
pub size: Vector2I,
|
||||
/// The number of *bytes* between successive rows.
|
||||
pub stride: usize,
|
||||
/// The image format of the canvas.
|
||||
pub format: Format,
|
||||
}
|
||||
|
||||
impl Canvas {
|
||||
/// Creates a new blank canvas with the given pixel size and format.
|
||||
///
|
||||
/// Stride is automatically calculated from width.
|
||||
///
|
||||
/// The canvas is initialized with transparent black (all values 0).
|
||||
#[inline]
|
||||
pub fn new(size: Vector2I, format: Format) -> Canvas {
|
||||
Canvas::with_stride(
|
||||
size,
|
||||
size.x() as usize * format.bytes_per_pixel() as usize,
|
||||
format,
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a new blank canvas with the given pixel size, stride (number of bytes between
|
||||
/// successive rows), and format.
|
||||
///
|
||||
/// The canvas is initialized with transparent black (all values 0).
|
||||
pub fn with_stride(size: Vector2I, stride: usize, format: Format) -> Canvas {
|
||||
Canvas {
|
||||
pixels: vec![0; stride * size.y() as usize],
|
||||
size,
|
||||
stride,
|
||||
format,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn blit_from_canvas(&mut self, src: &Canvas) {
|
||||
self.blit_from(
|
||||
Vector2I::default(),
|
||||
&src.pixels,
|
||||
src.size,
|
||||
src.stride,
|
||||
src.format,
|
||||
)
|
||||
}
|
||||
|
||||
/// Blits to a rectangle with origin at `dst_point` and size according to `src_size`.
|
||||
/// If the target area overlaps the boundaries of the canvas, only the drawable region is blitted.
|
||||
/// `dst_point` and `src_size` are specified in pixels. `src_stride` is specified in bytes.
|
||||
/// `src_stride` must be equal or larger than the actual data length.
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn blit_from(
|
||||
&mut self,
|
||||
dst_point: Vector2I,
|
||||
src_bytes: &[u8],
|
||||
src_size: Vector2I,
|
||||
src_stride: usize,
|
||||
src_format: Format,
|
||||
) {
|
||||
assert_eq!(
|
||||
src_stride * src_size.y() as usize,
|
||||
src_bytes.len(),
|
||||
"Number of pixels in src_bytes does not match stride and size."
|
||||
);
|
||||
assert!(
|
||||
src_stride >= src_size.x() as usize * src_format.bytes_per_pixel() as usize,
|
||||
"src_stride must be >= than src_size.x()"
|
||||
);
|
||||
|
||||
let dst_rect = RectI::new(dst_point, src_size);
|
||||
let dst_rect = dst_rect.intersection(RectI::new(Vector2I::default(), self.size));
|
||||
let dst_rect = match dst_rect {
|
||||
Some(dst_rect) => dst_rect,
|
||||
None => return,
|
||||
};
|
||||
|
||||
match (self.format, src_format) {
|
||||
(Format::A8, Format::A8)
|
||||
| (Format::Rgb24, Format::Rgb24)
|
||||
| (Format::Rgba32, Format::Rgba32) => {
|
||||
self.blit_from_with::<BlitMemcpy>(dst_rect, src_bytes, src_stride, src_format)
|
||||
}
|
||||
(Format::A8, Format::Rgb24) => {
|
||||
self.blit_from_with::<BlitRgb24ToA8>(dst_rect, src_bytes, src_stride, src_format)
|
||||
}
|
||||
(Format::Rgb24, Format::A8) => {
|
||||
self.blit_from_with::<BlitA8ToRgb24>(dst_rect, src_bytes, src_stride, src_format)
|
||||
}
|
||||
(Format::Rgb24, Format::Rgba32) => self
|
||||
.blit_from_with::<BlitRgba32ToRgb24>(dst_rect, src_bytes, src_stride, src_format),
|
||||
(Format::Rgba32, Format::Rgb24) => self
|
||||
.blit_from_with::<BlitRgb24ToRgba32>(dst_rect, src_bytes, src_stride, src_format),
|
||||
(Format::Rgba32, Format::A8) | (Format::A8, Format::Rgba32) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn blit_from_bitmap_1bpp(
|
||||
&mut self,
|
||||
dst_point: Vector2I,
|
||||
src_bytes: &[u8],
|
||||
src_size: Vector2I,
|
||||
src_stride: usize,
|
||||
) {
|
||||
if self.format != Format::A8 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
let dst_rect = RectI::new(dst_point, src_size);
|
||||
let dst_rect = dst_rect.intersection(RectI::new(Vector2I::default(), self.size));
|
||||
let dst_rect = match dst_rect {
|
||||
Some(dst_rect) => dst_rect,
|
||||
None => return,
|
||||
};
|
||||
|
||||
let size = dst_rect.size();
|
||||
|
||||
let dest_bytes_per_pixel = self.format.bytes_per_pixel() as usize;
|
||||
let dest_row_stride = size.x() as usize * dest_bytes_per_pixel;
|
||||
let src_row_stride = utils::div_round_up(size.x() as usize, 8);
|
||||
|
||||
for y in 0..size.y() {
|
||||
let (dest_row_start, src_row_start) = (
|
||||
(y + dst_rect.origin_y()) as usize * self.stride
|
||||
+ dst_rect.origin_x() as usize * dest_bytes_per_pixel,
|
||||
y as usize * src_stride,
|
||||
);
|
||||
let dest_row_end = dest_row_start + dest_row_stride;
|
||||
let src_row_end = src_row_start + src_row_stride;
|
||||
let dest_row_pixels = &mut self.pixels[dest_row_start..dest_row_end];
|
||||
let src_row_pixels = &src_bytes[src_row_start..src_row_end];
|
||||
for x in 0..src_row_stride {
|
||||
let pattern = &BITMAP_1BPP_TO_8BPP_LUT[src_row_pixels[x] as usize];
|
||||
let dest_start = x * 8;
|
||||
let dest_end = cmp::min(dest_start + 8, dest_row_stride);
|
||||
let src = &pattern[0..(dest_end - dest_start)];
|
||||
dest_row_pixels[dest_start..dest_end].clone_from_slice(src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Blits to area `rect` using the data given in the buffer `src_bytes`.
|
||||
/// `src_stride` must be specified in bytes.
|
||||
/// The dimensions of `rect` must be in pixels.
|
||||
fn blit_from_with<B: Blit>(
|
||||
&mut self,
|
||||
rect: RectI,
|
||||
src_bytes: &[u8],
|
||||
src_stride: usize,
|
||||
src_format: Format,
|
||||
) {
|
||||
let src_bytes_per_pixel = src_format.bytes_per_pixel() as usize;
|
||||
let dest_bytes_per_pixel = self.format.bytes_per_pixel() as usize;
|
||||
|
||||
for y in 0..rect.height() {
|
||||
let (dest_row_start, src_row_start) = (
|
||||
(y + rect.origin_y()) as usize * self.stride
|
||||
+ rect.origin_x() as usize * dest_bytes_per_pixel,
|
||||
y as usize * src_stride,
|
||||
);
|
||||
let dest_row_end = dest_row_start + rect.width() as usize * dest_bytes_per_pixel;
|
||||
let src_row_end = src_row_start + rect.width() as usize * src_bytes_per_pixel;
|
||||
let dest_row_pixels = &mut self.pixels[dest_row_start..dest_row_end];
|
||||
let src_row_pixels = &src_bytes[src_row_start..src_row_end];
|
||||
B::blit(dest_row_pixels, src_row_pixels)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Canvas {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.debug_struct("Canvas")
|
||||
.field("pixels", &self.pixels.len()) // Do not dump a vector content.
|
||||
.field("size", &self.size)
|
||||
.field("stride", &self.stride)
|
||||
.field("format", &self.format)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// The image format for the canvas.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
pub enum Format {
|
||||
/// Premultiplied R8G8B8A8, little-endian.
|
||||
Rgba32,
|
||||
/// R8G8B8, little-endian.
|
||||
Rgb24,
|
||||
/// A8.
|
||||
A8,
|
||||
}
|
||||
|
||||
impl Format {
|
||||
/// Returns the number of bits per pixel that this image format corresponds to.
|
||||
#[inline]
|
||||
pub fn bits_per_pixel(self) -> u8 {
|
||||
match self {
|
||||
Format::Rgba32 => 32,
|
||||
Format::Rgb24 => 24,
|
||||
Format::A8 => 8,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of color channels per pixel that this image format corresponds to.
|
||||
#[inline]
|
||||
pub fn components_per_pixel(self) -> u8 {
|
||||
match self {
|
||||
Format::Rgba32 => 4,
|
||||
Format::Rgb24 => 3,
|
||||
Format::A8 => 1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of bits per color channel that this image format contains.
|
||||
#[inline]
|
||||
pub fn bits_per_component(self) -> u8 {
|
||||
self.bits_per_pixel() / self.components_per_pixel()
|
||||
}
|
||||
|
||||
/// Returns the number of bytes per pixel that this image format corresponds to.
|
||||
#[inline]
|
||||
pub fn bytes_per_pixel(self) -> u8 {
|
||||
self.bits_per_pixel() / 8
|
||||
}
|
||||
}
|
||||
|
||||
/// The antialiasing strategy that should be used when rasterizing glyphs.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
pub enum RasterizationOptions {
|
||||
/// "Black-and-white" rendering. Each pixel is either entirely on or off.
|
||||
Bilevel,
|
||||
/// Grayscale antialiasing. Only one channel is used.
|
||||
GrayscaleAa,
|
||||
/// Subpixel RGB antialiasing, for LCD screens.
|
||||
SubpixelAa,
|
||||
}
|
||||
|
||||
trait Blit {
|
||||
fn blit(dest: &mut [u8], src: &[u8]);
|
||||
}
|
||||
|
||||
struct BlitMemcpy;
|
||||
|
||||
impl Blit for BlitMemcpy {
|
||||
#[inline]
|
||||
fn blit(dest: &mut [u8], src: &[u8]) {
|
||||
dest.clone_from_slice(src)
|
||||
}
|
||||
}
|
||||
|
||||
struct BlitRgb24ToA8;
|
||||
|
||||
impl Blit for BlitRgb24ToA8 {
|
||||
#[inline]
|
||||
fn blit(dest: &mut [u8], src: &[u8]) {
|
||||
// TODO(pcwalton): SIMD.
|
||||
for (dest, src) in dest.iter_mut().zip(src.chunks(3)) {
|
||||
*dest = src[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct BlitA8ToRgb24;
|
||||
|
||||
impl Blit for BlitA8ToRgb24 {
|
||||
#[inline]
|
||||
fn blit(dest: &mut [u8], src: &[u8]) {
|
||||
for (dest, src) in dest.chunks_mut(3).zip(src.iter()) {
|
||||
dest[0] = *src;
|
||||
dest[1] = *src;
|
||||
dest[2] = *src;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct BlitRgba32ToRgb24;
|
||||
|
||||
impl Blit for BlitRgba32ToRgb24 {
|
||||
#[inline]
|
||||
fn blit(dest: &mut [u8], src: &[u8]) {
|
||||
// TODO(pcwalton): SIMD.
|
||||
for (dest, src) in dest.chunks_mut(3).zip(src.chunks(4)) {
|
||||
dest.copy_from_slice(&src[0..3])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct BlitRgb24ToRgba32;
|
||||
|
||||
impl Blit for BlitRgb24ToRgba32 {
|
||||
fn blit(dest: &mut [u8], src: &[u8]) {
|
||||
for (dest, src) in dest.chunks_mut(4).zip(src.chunks(3)) {
|
||||
dest[0] = src[0];
|
||||
dest[1] = src[1];
|
||||
dest[2] = src[2];
|
||||
dest[3] = 255;
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
2193
crates/assistant_tools/src/edit_agent/evals/fixtures/zode/prompt.md
Normal file
2193
crates/assistant_tools/src/edit_agent/evals/fixtures/zode/prompt.md
Normal file
File diff suppressed because one or more lines are too long
@@ -0,0 +1,14 @@
|
||||
class InputCell:
|
||||
def __init__(self, initial_value):
|
||||
self.value = None
|
||||
|
||||
|
||||
class ComputeCell:
|
||||
def __init__(self, inputs, compute_function):
|
||||
self.value = None
|
||||
|
||||
def add_callback(self, callback):
|
||||
pass
|
||||
|
||||
def remove_callback(self, callback):
|
||||
pass
|
||||
@@ -0,0 +1,271 @@
|
||||
# These tests are auto-generated with test data from:
|
||||
# https://github.com/exercism/problem-specifications/tree/main/exercises/react/canonical-data.json
|
||||
# File last updated on 2023-07-19
|
||||
|
||||
from functools import partial
|
||||
import unittest
|
||||
|
||||
from react import (
|
||||
InputCell,
|
||||
ComputeCell,
|
||||
)
|
||||
|
||||
|
||||
class ReactTest(unittest.TestCase):
|
||||
def test_input_cells_have_a_value(self):
|
||||
input = InputCell(10)
|
||||
self.assertEqual(input.value, 10)
|
||||
|
||||
def test_an_input_cell_s_value_can_be_set(self):
|
||||
input = InputCell(4)
|
||||
input.value = 20
|
||||
self.assertEqual(input.value, 20)
|
||||
|
||||
def test_compute_cells_calculate_initial_value(self):
|
||||
input = InputCell(1)
|
||||
output = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] + 1,
|
||||
)
|
||||
self.assertEqual(output.value, 2)
|
||||
|
||||
def test_compute_cells_take_inputs_in_the_right_order(self):
|
||||
one = InputCell(1)
|
||||
two = InputCell(2)
|
||||
output = ComputeCell(
|
||||
[
|
||||
one,
|
||||
two,
|
||||
],
|
||||
lambda inputs: inputs[0] + inputs[1] * 10,
|
||||
)
|
||||
self.assertEqual(output.value, 21)
|
||||
|
||||
def test_compute_cells_update_value_when_dependencies_are_changed(self):
|
||||
input = InputCell(1)
|
||||
output = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] + 1,
|
||||
)
|
||||
input.value = 3
|
||||
self.assertEqual(output.value, 4)
|
||||
|
||||
def test_compute_cells_can_depend_on_other_compute_cells(self):
|
||||
input = InputCell(1)
|
||||
times_two = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] * 2,
|
||||
)
|
||||
times_thirty = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] * 30,
|
||||
)
|
||||
output = ComputeCell(
|
||||
[
|
||||
times_two,
|
||||
times_thirty,
|
||||
],
|
||||
lambda inputs: inputs[0] + inputs[1],
|
||||
)
|
||||
self.assertEqual(output.value, 32)
|
||||
input.value = 3
|
||||
self.assertEqual(output.value, 96)
|
||||
|
||||
def test_compute_cells_fire_callbacks(self):
|
||||
input = InputCell(1)
|
||||
output = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] + 1,
|
||||
)
|
||||
cb1_observer = []
|
||||
callback1 = self.callback_factory(cb1_observer)
|
||||
output.add_callback(callback1)
|
||||
input.value = 3
|
||||
self.assertEqual(cb1_observer[-1], 4)
|
||||
|
||||
def test_callback_cells_only_fire_on_change(self):
|
||||
input = InputCell(1)
|
||||
output = ComputeCell([input], lambda inputs: 111 if inputs[0] < 3 else 222)
|
||||
cb1_observer = []
|
||||
callback1 = self.callback_factory(cb1_observer)
|
||||
output.add_callback(callback1)
|
||||
input.value = 2
|
||||
self.assertEqual(cb1_observer, [])
|
||||
input.value = 4
|
||||
self.assertEqual(cb1_observer[-1], 222)
|
||||
|
||||
def test_callbacks_do_not_report_already_reported_values(self):
|
||||
input = InputCell(1)
|
||||
output = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] + 1,
|
||||
)
|
||||
cb1_observer = []
|
||||
callback1 = self.callback_factory(cb1_observer)
|
||||
output.add_callback(callback1)
|
||||
input.value = 2
|
||||
self.assertEqual(cb1_observer[-1], 3)
|
||||
input.value = 3
|
||||
self.assertEqual(cb1_observer[-1], 4)
|
||||
|
||||
def test_callbacks_can_fire_from_multiple_cells(self):
|
||||
input = InputCell(1)
|
||||
plus_one = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] + 1,
|
||||
)
|
||||
minus_one = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] - 1,
|
||||
)
|
||||
cb1_observer = []
|
||||
cb2_observer = []
|
||||
callback1 = self.callback_factory(cb1_observer)
|
||||
callback2 = self.callback_factory(cb2_observer)
|
||||
plus_one.add_callback(callback1)
|
||||
minus_one.add_callback(callback2)
|
||||
input.value = 10
|
||||
self.assertEqual(cb1_observer[-1], 11)
|
||||
self.assertEqual(cb2_observer[-1], 9)
|
||||
|
||||
def test_callbacks_can_be_added_and_removed(self):
|
||||
input = InputCell(11)
|
||||
output = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] + 1,
|
||||
)
|
||||
cb1_observer = []
|
||||
cb2_observer = []
|
||||
cb3_observer = []
|
||||
callback1 = self.callback_factory(cb1_observer)
|
||||
callback2 = self.callback_factory(cb2_observer)
|
||||
callback3 = self.callback_factory(cb3_observer)
|
||||
output.add_callback(callback1)
|
||||
output.add_callback(callback2)
|
||||
input.value = 31
|
||||
self.assertEqual(cb1_observer[-1], 32)
|
||||
self.assertEqual(cb2_observer[-1], 32)
|
||||
output.remove_callback(callback1)
|
||||
output.add_callback(callback3)
|
||||
input.value = 41
|
||||
self.assertEqual(len(cb1_observer), 1)
|
||||
self.assertEqual(cb2_observer[-1], 42)
|
||||
self.assertEqual(cb3_observer[-1], 42)
|
||||
|
||||
def test_removing_a_callback_multiple_times_doesn_t_interfere_with_other_callbacks(
|
||||
self,
|
||||
):
|
||||
input = InputCell(1)
|
||||
output = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] + 1,
|
||||
)
|
||||
cb1_observer = []
|
||||
cb2_observer = []
|
||||
callback1 = self.callback_factory(cb1_observer)
|
||||
callback2 = self.callback_factory(cb2_observer)
|
||||
output.add_callback(callback1)
|
||||
output.add_callback(callback2)
|
||||
output.remove_callback(callback1)
|
||||
output.remove_callback(callback1)
|
||||
output.remove_callback(callback1)
|
||||
input.value = 2
|
||||
self.assertEqual(cb1_observer, [])
|
||||
self.assertEqual(cb2_observer[-1], 3)
|
||||
|
||||
def test_callbacks_should_only_be_called_once_even_if_multiple_dependencies_change(
|
||||
self,
|
||||
):
|
||||
input = InputCell(1)
|
||||
plus_one = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] + 1,
|
||||
)
|
||||
minus_one1 = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] - 1,
|
||||
)
|
||||
minus_one2 = ComputeCell(
|
||||
[
|
||||
minus_one1,
|
||||
],
|
||||
lambda inputs: inputs[0] - 1,
|
||||
)
|
||||
output = ComputeCell(
|
||||
[
|
||||
plus_one,
|
||||
minus_one2,
|
||||
],
|
||||
lambda inputs: inputs[0] * inputs[1],
|
||||
)
|
||||
cb1_observer = []
|
||||
callback1 = self.callback_factory(cb1_observer)
|
||||
output.add_callback(callback1)
|
||||
input.value = 4
|
||||
self.assertEqual(cb1_observer[-1], 10)
|
||||
|
||||
def test_callbacks_should_not_be_called_if_dependencies_change_but_output_value_doesn_t_change(
|
||||
self,
|
||||
):
|
||||
input = InputCell(1)
|
||||
plus_one = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] + 1,
|
||||
)
|
||||
minus_one = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] - 1,
|
||||
)
|
||||
always_two = ComputeCell(
|
||||
[
|
||||
plus_one,
|
||||
minus_one,
|
||||
],
|
||||
lambda inputs: inputs[0] - inputs[1],
|
||||
)
|
||||
cb1_observer = []
|
||||
callback1 = self.callback_factory(cb1_observer)
|
||||
always_two.add_callback(callback1)
|
||||
input.value = 2
|
||||
self.assertEqual(cb1_observer, [])
|
||||
input.value = 3
|
||||
self.assertEqual(cb1_observer, [])
|
||||
input.value = 4
|
||||
self.assertEqual(cb1_observer, [])
|
||||
input.value = 5
|
||||
self.assertEqual(cb1_observer, [])
|
||||
|
||||
# Utility functions.
|
||||
def callback_factory(self, observer):
|
||||
def callback(observer, value):
|
||||
observer.append(value)
|
||||
|
||||
return partial(callback, observer)
|
||||
@@ -282,7 +282,7 @@ pub struct EditFileToolCard {
|
||||
}
|
||||
|
||||
impl EditFileToolCard {
|
||||
fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
|
||||
pub fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
|
||||
let multibuffer = cx.new(|_| MultiBuffer::without_headers(Capability::ReadOnly));
|
||||
let editor = cx.new(|cx| {
|
||||
let mut editor = Editor::new(
|
||||
@@ -323,7 +323,7 @@ impl EditFileToolCard {
|
||||
}
|
||||
}
|
||||
|
||||
fn set_diff(
|
||||
pub fn set_diff(
|
||||
&mut self,
|
||||
path: Arc<Path>,
|
||||
old_text: String,
|
||||
@@ -343,6 +343,7 @@ impl EditFileToolCard {
|
||||
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
|
||||
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
|
||||
.collect::<Vec<_>>();
|
||||
multibuffer.clear(cx);
|
||||
let (_, is_newly_added) = multibuffer.set_excerpts_for_path(
|
||||
PathKey::for_buffer(&buffer, cx),
|
||||
buffer,
|
||||
@@ -576,13 +577,10 @@ impl ToolCard for EditFileToolCard {
|
||||
card.child(
|
||||
v_flex()
|
||||
.relative()
|
||||
.map(|editor_container| {
|
||||
if self.full_height_expanded {
|
||||
editor_container.h_full()
|
||||
} else {
|
||||
editor_container
|
||||
.h(DEFAULT_COLLAPSED_LINES as f32 * editor_line_height)
|
||||
}
|
||||
.h_full()
|
||||
.when(!self.full_height_expanded, |editor_container| {
|
||||
editor_container
|
||||
.max_h(DEFAULT_COLLAPSED_LINES as f32 * editor_line_height)
|
||||
})
|
||||
.overflow_hidden()
|
||||
.border_t_1()
|
||||
|
||||
@@ -6,7 +6,7 @@ use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use ui::IconName;
|
||||
use util::markdown::MarkdownEscaped;
|
||||
|
||||
@@ -50,7 +50,7 @@ impl Tool for OpenTool {
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
@@ -60,11 +60,107 @@ impl Tool for OpenTool {
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
|
||||
// If path_or_url turns out to be a path in the project, make it absolute.
|
||||
let abs_path = to_absolute_path(&input.path_or_url, project, cx);
|
||||
|
||||
cx.background_spawn(async move {
|
||||
open::that(&input.path_or_url).context("Failed to open URL or file path")?;
|
||||
match abs_path {
|
||||
Some(path) => open::that(path),
|
||||
None => open::that(&input.path_or_url),
|
||||
}
|
||||
.context("Failed to open URL or file path")?;
|
||||
|
||||
Ok(format!("Successfully opened {}", input.path_or_url))
|
||||
})
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
fn to_absolute_path(
|
||||
potential_path: &str,
|
||||
project: Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Option<PathBuf> {
|
||||
let project = project.read(cx);
|
||||
project
|
||||
.find_project_path(PathBuf::from(potential_path), cx)
|
||||
.and_then(|project_path| project.absolute_path(&project_path, cx))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use project::{FakeFs, Project};
|
||||
use settings::SettingsStore;
|
||||
use std::path::Path;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_to_absolute_path(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let temp_dir = TempDir::new().expect("Failed to create temp directory");
|
||||
let temp_path = temp_dir.path().to_string_lossy().to_string();
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
&temp_path,
|
||||
serde_json::json!({
|
||||
"src": {
|
||||
"main.rs": "fn main() {}",
|
||||
"lib.rs": "pub fn lib_fn() {}"
|
||||
},
|
||||
"docs": {
|
||||
"readme.md": "# Project Documentation"
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Use the temp_path as the root directory, not just its filename
|
||||
let project = Project::test(fs.clone(), [temp_dir.path()], cx).await;
|
||||
|
||||
// Test cases where the function should return Some
|
||||
cx.update(|cx| {
|
||||
// Project-relative paths should return Some
|
||||
// Create paths using the last segment of the temp path to simulate a project-relative path
|
||||
let root_dir_name = Path::new(&temp_path)
|
||||
.file_name()
|
||||
.unwrap_or_else(|| std::ffi::OsStr::new("temp"))
|
||||
.to_string_lossy();
|
||||
|
||||
assert!(
|
||||
to_absolute_path(&format!("{root_dir_name}/src/main.rs"), project.clone(), cx)
|
||||
.is_some(),
|
||||
"Failed to resolve main.rs path"
|
||||
);
|
||||
|
||||
assert!(
|
||||
to_absolute_path(
|
||||
&format!("{root_dir_name}/docs/readme.md",),
|
||||
project.clone(),
|
||||
cx,
|
||||
)
|
||||
.is_some(),
|
||||
"Failed to resolve readme.md path"
|
||||
);
|
||||
|
||||
// External URL should return None
|
||||
let result = to_absolute_path("https://example.com", project.clone(), cx);
|
||||
assert_eq!(result, None, "External URLs should return None");
|
||||
|
||||
// Path outside project
|
||||
let result = to_absolute_path("../invalid/path", project.clone(), cx);
|
||||
assert_eq!(result, None, "Paths outside the project should return None");
|
||||
});
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
Project::init_settings(cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::outline;
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
|
||||
@@ -14,10 +15,6 @@ use ui::IconName;
|
||||
use util::markdown::MarkdownInlineCode;
|
||||
|
||||
/// If the model requests to read a file whose size exceeds this, then
|
||||
/// the tool will return an error along with the model's symbol outline,
|
||||
/// and suggest trying again using line ranges from the outline.
|
||||
const MAX_FILE_SIZE_TO_READ: usize = 16384;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ReadFileToolInput {
|
||||
/// The relative path of the file to read.
|
||||
@@ -125,7 +122,8 @@ impl Tool for ReadFileTool {
|
||||
if input.start_line.is_some() || input.end_line.is_some() {
|
||||
let result = buffer.read_with(cx, |buffer, _cx| {
|
||||
let text = buffer.text();
|
||||
let start = input.start_line.unwrap_or(1);
|
||||
// .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0.
|
||||
let start = input.start_line.unwrap_or(1).max(1);
|
||||
let lines = text.split('\n').skip(start - 1);
|
||||
if let Some(end) = input.end_line {
|
||||
let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
|
||||
@@ -144,7 +142,7 @@ impl Tool for ReadFileTool {
|
||||
// No line ranges specified, so check file size to see if it's too big.
|
||||
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
|
||||
|
||||
if file_size <= MAX_FILE_SIZE_TO_READ {
|
||||
if file_size <= outline::AUTO_OUTLINE_SIZE {
|
||||
// File is small enough, so return its contents.
|
||||
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
|
||||
|
||||
@@ -154,9 +152,9 @@ impl Tool for ReadFileTool {
|
||||
|
||||
Ok(result)
|
||||
} else {
|
||||
// File is too big, so return an error with the outline
|
||||
// File is too big, so return the outline
|
||||
// and a suggestion to read again with line numbers.
|
||||
let outline = file_outline(project, file_path, action_log, None, cx).await?;
|
||||
let outline = outline::file_outline(project, file_path, action_log, None, cx).await?;
|
||||
Ok(formatdoc! {"
|
||||
This file was too big to read all at once. Here is an outline of its symbols:
|
||||
|
||||
@@ -332,6 +330,67 @@ mod test {
|
||||
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_read_file_line_range_edge_cases(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
json!({
|
||||
"multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
|
||||
// start_line of 0 should be treated as 1
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = json!({
|
||||
"path": "root/multiline.txt",
|
||||
"start_line": 0,
|
||||
"end_line": 2
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(input, &[], project.clone(), action_log.clone(), None, cx)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "Line 1\nLine 2");
|
||||
|
||||
// end_line of 0 should result in at least 1 line
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = json!({
|
||||
"path": "root/multiline.txt",
|
||||
"start_line": 1,
|
||||
"end_line": 0
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(input, &[], project.clone(), action_log.clone(), None, cx)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "Line 1");
|
||||
|
||||
// when start_line > end_line, should still return at least 1 line
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = json!({
|
||||
"path": "root/multiline.txt",
|
||||
"start_line": 3,
|
||||
"end_line": 2
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(input, &[], project.clone(), action_log, None, cx)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "Line 3");
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
|
||||
352
crates/assistant_tools/src/streaming_edit_file_tool.rs
Normal file
352
crates/assistant_tools/src/streaming_edit_file_tool.rs
Normal file
@@ -0,0 +1,352 @@
|
||||
use crate::{
|
||||
Templates,
|
||||
edit_agent::{EditAgent, EditAgentOutputEvent},
|
||||
edit_file_tool::EditFileToolCard,
|
||||
schema::json_schema_for,
|
||||
};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult};
|
||||
use futures::StreamExt;
|
||||
use gpui::{AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task};
|
||||
use indoc::formatdoc;
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolSchemaFormat,
|
||||
};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
|
||||
pub struct StreamingEditFileTool;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct StreamingEditFileToolInput {
|
||||
/// A one-line, user-friendly markdown description of the edit. This will be
|
||||
/// shown in the UI and also passed to another model to perform the edit.
|
||||
///
|
||||
/// Be terse, but also descriptive in what you want to achieve with this
|
||||
/// edit. Avoid generic instructions.
|
||||
///
|
||||
/// NEVER mention the file path in this description.
|
||||
///
|
||||
/// <example>Fix API endpoint URLs</example>
|
||||
/// <example>Update copyright year in `page_footer`</example>
|
||||
///
|
||||
/// Make sure to include this field before all the others in the input object
|
||||
/// so that we can display it immediately.
|
||||
pub display_description: String,
|
||||
|
||||
/// The full path of the file to create or modify in the project.
|
||||
///
|
||||
/// WARNING: When specifying which file path need changing, you MUST
|
||||
/// start each path with one of the project's root directories.
|
||||
///
|
||||
/// The following examples assume we have two root directories in the project:
|
||||
/// - backend
|
||||
/// - frontend
|
||||
///
|
||||
/// <example>
|
||||
/// `backend/src/main.rs`
|
||||
///
|
||||
/// Notice how the file path starts with root-1. Without that, the path
|
||||
/// would be ambiguous and the call would fail!
|
||||
/// </example>
|
||||
///
|
||||
/// <example>
|
||||
/// `frontend/db.js`
|
||||
/// </example>
|
||||
pub path: PathBuf,
|
||||
|
||||
/// If true, this tool will recreate the file from scratch.
|
||||
/// If false, this tool will produce granular edits to an existing file.
|
||||
pub create_or_overwrite: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
struct PartialInput {
|
||||
#[serde(default)]
|
||||
path: String,
|
||||
#[serde(default)]
|
||||
display_description: String,
|
||||
}
|
||||
|
||||
const DEFAULT_UI_TEXT: &str = "Editing file";
|
||||
|
||||
impl Tool for StreamingEditFileTool {
|
||||
fn name(&self) -> String {
|
||||
"edit_file".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
include_str!("streaming_edit_file_tool/description.md").to_string()
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::Pencil
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<StreamingEditFileToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<StreamingEditFileToolInput>(input.clone()) {
|
||||
Ok(input) => input.display_description,
|
||||
Err(_) => "Editing file".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
|
||||
if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
|
||||
let description = input.display_description.trim();
|
||||
if !description.is_empty() {
|
||||
return description.to_string();
|
||||
}
|
||||
|
||||
let path = input.path.trim();
|
||||
if !path.is_empty() {
|
||||
return path.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
DEFAULT_UI_TEXT.to_string()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<StreamingEditFileToolInput>(input) {
|
||||
Ok(input) => input,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
|
||||
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
|
||||
return Task::ready(Err(anyhow!(
|
||||
"Path {} not found in project",
|
||||
input.path.display()
|
||||
)))
|
||||
.into();
|
||||
};
|
||||
let Some(worktree) = project
|
||||
.read(cx)
|
||||
.worktree_for_id(project_path.worktree_id, cx)
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
|
||||
};
|
||||
let exists = worktree.update(cx, |worktree, cx| {
|
||||
worktree.file_exists(&project_path.path, cx)
|
||||
});
|
||||
|
||||
let card = window.and_then(|window| {
|
||||
window
|
||||
.update(cx, |_, window, cx| {
|
||||
cx.new(|cx| {
|
||||
EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
|
||||
})
|
||||
})
|
||||
.ok()
|
||||
});
|
||||
|
||||
let card_clone = card.clone();
|
||||
let messages = messages.to_vec();
|
||||
let task = cx.spawn(async move |cx: &mut AsyncApp| {
|
||||
if !input.create_or_overwrite && !exists.await? {
|
||||
return Err(anyhow!("{} not found", input.path.display()));
|
||||
}
|
||||
|
||||
let model = cx
|
||||
.update(|cx| LanguageModelRegistry::read_global(cx).default_model())?
|
||||
.context("default model not set")?
|
||||
.model;
|
||||
let edit_agent = EditAgent::new(model, action_log, Templates::new());
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(project_path.clone(), cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
let old_text = cx
|
||||
.background_spawn({
|
||||
let old_snapshot = old_snapshot.clone();
|
||||
async move { old_snapshot.text() }
|
||||
})
|
||||
.await;
|
||||
|
||||
let (output, mut events) = if input.create_or_overwrite {
|
||||
edit_agent.overwrite(
|
||||
buffer.clone(),
|
||||
input.display_description.clone(),
|
||||
messages,
|
||||
cx,
|
||||
)
|
||||
} else {
|
||||
edit_agent.edit(
|
||||
buffer.clone(),
|
||||
input.display_description.clone(),
|
||||
messages,
|
||||
cx,
|
||||
)
|
||||
};
|
||||
|
||||
let mut hallucinated_old_text = false;
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
EditAgentOutputEvent::Edited { position } => {
|
||||
if let Some(card) = card_clone.as_ref() {
|
||||
let new_snapshot =
|
||||
buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
let new_text = cx
|
||||
.background_spawn({
|
||||
let new_snapshot = new_snapshot.clone();
|
||||
async move { new_snapshot.text() }
|
||||
})
|
||||
.await;
|
||||
card.update(cx, |card, cx| {
|
||||
card.set_diff(
|
||||
project_path.path.clone(),
|
||||
old_text.clone(),
|
||||
new_text,
|
||||
cx,
|
||||
);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
EditAgentOutputEvent::OldTextNotFound(_) => hallucinated_old_text = true,
|
||||
}
|
||||
}
|
||||
output.await?;
|
||||
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
|
||||
.await?;
|
||||
|
||||
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
let new_text = cx.background_spawn({
|
||||
let new_snapshot = new_snapshot.clone();
|
||||
async move { new_snapshot.text() }
|
||||
});
|
||||
let diff = cx.background_spawn(async move {
|
||||
language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
|
||||
});
|
||||
let (new_text, diff) = futures::join!(new_text, diff);
|
||||
|
||||
if let Some(card) = card_clone {
|
||||
card.update(cx, |card, cx| {
|
||||
card.set_diff(project_path.path.clone(), old_text, new_text, cx);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
||||
let input_path = input.path.display();
|
||||
if diff.is_empty() {
|
||||
if hallucinated_old_text {
|
||||
Err(anyhow!(formatdoc! {"
|
||||
Some edits were produced but none of them could be applied.
|
||||
Read the relevant sections of {input_path} again so that
|
||||
I can perform the requested edits.
|
||||
"}))
|
||||
} else {
|
||||
Ok("No edits were made.".to_string())
|
||||
}
|
||||
} else {
|
||||
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff))
|
||||
}
|
||||
});
|
||||
|
||||
ToolResult {
|
||||
output: task,
|
||||
card: card.map(AnyToolCard::from),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_with_path() {
|
||||
let input = json!({
|
||||
"path": "src/main.rs",
|
||||
"display_description": "",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
StreamingEditFileTool.still_streaming_ui_text(&input),
|
||||
"src/main.rs"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_with_description() {
|
||||
let input = json!({
|
||||
"path": "",
|
||||
"display_description": "Fix error handling",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
StreamingEditFileTool.still_streaming_ui_text(&input),
|
||||
"Fix error handling",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_with_path_and_description() {
|
||||
let input = json!({
|
||||
"path": "src/main.rs",
|
||||
"display_description": "Fix error handling",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
StreamingEditFileTool.still_streaming_ui_text(&input),
|
||||
"Fix error handling",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_no_path_or_description() {
|
||||
let input = json!({
|
||||
"path": "",
|
||||
"display_description": "",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
StreamingEditFileTool.still_streaming_ui_text(&input),
|
||||
DEFAULT_UI_TEXT,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_with_null() {
|
||||
let input = serde_json::Value::Null;
|
||||
|
||||
assert_eq!(
|
||||
StreamingEditFileTool.still_streaming_ui_text(&input),
|
||||
DEFAULT_UI_TEXT,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
|
||||
|
||||
Before using this tool:
|
||||
|
||||
1. Use the `read_file` tool to understand the file's contents and context
|
||||
|
||||
2. Verify the directory path is correct (only applicable when creating new files):
|
||||
- Use the `list_directory` tool to verify the parent directory exists and is the correct location
|
||||
32
crates/assistant_tools/src/templates.rs
Normal file
32
crates/assistant_tools/src/templates.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use anyhow::Result;
|
||||
use handlebars::Handlebars;
|
||||
use rust_embed::RustEmbed;
|
||||
use serde::Serialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(RustEmbed)]
|
||||
#[folder = "src/templates"]
|
||||
#[include = "*.hbs"]
|
||||
struct Assets;
|
||||
|
||||
pub struct Templates(Handlebars<'static>);
|
||||
|
||||
impl Templates {
|
||||
pub fn new() -> Arc<Self> {
|
||||
let mut handlebars = Handlebars::new();
|
||||
handlebars.register_embed_templates::<Assets>().unwrap();
|
||||
handlebars.register_escape_fn(|text| text.into());
|
||||
Arc::new(Self(handlebars))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Template: Sized {
|
||||
const TEMPLATE_NAME: &'static str;
|
||||
|
||||
fn render(&self, templates: &Templates) -> Result<String>
|
||||
where
|
||||
Self: Serialize + Sized,
|
||||
{
|
||||
Ok(templates.0.render(Self::TEMPLATE_NAME, self)?)
|
||||
}
|
||||
}
|
||||
12
crates/assistant_tools/src/templates/create_file_prompt.hbs
Normal file
12
crates/assistant_tools/src/templates/create_file_prompt.hbs
Normal file
@@ -0,0 +1,12 @@
|
||||
You are an expert engineer and your task is to write a new file from scratch.
|
||||
|
||||
<file_to_edit>
|
||||
{{path}}
|
||||
</file_to_edit>
|
||||
|
||||
<edit_description>
|
||||
{{edit_description}}
|
||||
</edit_description>
|
||||
|
||||
You MUST respond directly with the file's content, without explanations, additional text or triple backticks.
|
||||
The text you output will be saved verbatim as the content of the file.
|
||||
23
crates/assistant_tools/src/templates/diff_judge.hbs
Normal file
23
crates/assistant_tools/src/templates/diff_judge.hbs
Normal file
@@ -0,0 +1,23 @@
|
||||
You are an expert coder, and have been tasked with looking at the following diff:
|
||||
|
||||
<diff>
|
||||
{{diff}}
|
||||
</diff>
|
||||
|
||||
Evaluate the following assertions:
|
||||
|
||||
<assertions>
|
||||
{{assertions}}
|
||||
</assertions>
|
||||
|
||||
You must respond with a short analysis and a score between 0 and 100, where:
|
||||
- 0 means no assertions pass
|
||||
- 100 means all the assertions pass perfectly
|
||||
|
||||
<analysis>
|
||||
- Assertion 1: one line describing why the first assertion passes or fails (even partially)
|
||||
- Assertion 2: one line describing why the second assertion passes or fails (even partially)
|
||||
- ...
|
||||
- Assertion N: one line describing why the Nth assertion passes or fails (even partially)
|
||||
</analysis>
|
||||
<score>YOUR FINAL SCORE HERE</score>
|
||||
49
crates/assistant_tools/src/templates/edit_file_prompt.hbs
Normal file
49
crates/assistant_tools/src/templates/edit_file_prompt.hbs
Normal file
@@ -0,0 +1,49 @@
|
||||
You are an expert text editor and your task is to produce a series of edits to a file given a description of the changes you need to make.
|
||||
|
||||
You MUST respond with a series of edits to that one file in the following format:
|
||||
|
||||
```
|
||||
<edits>
|
||||
|
||||
<old_text>
|
||||
OLD TEXT 1 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 1 HERE
|
||||
</new_text>
|
||||
|
||||
<old_text>
|
||||
OLD TEXT 2 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 2 HERE
|
||||
</new_text>
|
||||
|
||||
<old_text>
|
||||
OLD TEXT 3 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 3 HERE
|
||||
</new_text>
|
||||
|
||||
</edits>
|
||||
```
|
||||
|
||||
Rules for editing:
|
||||
|
||||
- `old_text` represents lines in the input file that will be replaced with `new_text`. `old_text` MUST exactly match the existing file content, character for character, including indentation.
|
||||
- Always include enough context around the lines you want to replace in `old_text` such that it's impossible to mistake them for other lines.
|
||||
- If you want to replace many occurrences of the same text, repeat the same `old_text`/`new_text` pair multiple times and I will apply them sequentially, one occurrence at a time.
|
||||
- When reporting multiple edits, each edit assumes the previous one has already been applied! Therefore, you must ensure `old_text` doesn't reference text that has already been modified by a previous edit.
|
||||
- Don't explain the edits, just report them.
|
||||
- Only edit the file specified in `<file_to_edit>` and NEVER include edits to other files!
|
||||
- If you open an <old_text> tag, you MUST close it using </old_text>
|
||||
- If you open an <new_text> tag, you MUST close it using </new_text>
|
||||
|
||||
<file_to_edit>
|
||||
{{path}}
|
||||
</file_to_edit>
|
||||
|
||||
<edit_description>
|
||||
{{edit_description}}
|
||||
</edit_description>
|
||||
@@ -27,7 +27,9 @@ use crate::db::billing_subscription::{
|
||||
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
|
||||
};
|
||||
use crate::llm::db::subscription_usage_meter::CompletionMode;
|
||||
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
|
||||
use crate::llm::{
|
||||
AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
};
|
||||
use crate::rpc::{ResultExt as _, Server};
|
||||
use crate::{AppState, Cents, Error, Result};
|
||||
use crate::{db::UserId, llm::db::LlmDatabase};
|
||||
@@ -54,6 +56,10 @@ pub fn router() -> Router {
|
||||
"/billing/subscriptions/manage",
|
||||
post(manage_billing_subscription),
|
||||
)
|
||||
.route(
|
||||
"/billing/subscriptions/migrate",
|
||||
post(migrate_to_new_billing),
|
||||
)
|
||||
.route("/billing/monthly_spend", get(get_monthly_spend))
|
||||
.route("/billing/usage", get(get_current_usage))
|
||||
}
|
||||
@@ -256,6 +262,7 @@ async fn list_billing_subscriptions(
|
||||
enum ProductCode {
|
||||
ZedPro,
|
||||
ZedProTrial,
|
||||
ZedFree,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -362,12 +369,7 @@ async fn create_billing_subscription(
|
||||
let checkout_session_url = match body.product {
|
||||
Some(ProductCode::ZedPro) => {
|
||||
stripe_billing
|
||||
.checkout_with_price(
|
||||
app.config.zed_pro_price_id()?,
|
||||
customer_id,
|
||||
&user.github_login,
|
||||
&success_url,
|
||||
)
|
||||
.checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
|
||||
.await?
|
||||
}
|
||||
Some(ProductCode::ZedProTrial) => {
|
||||
@@ -384,7 +386,6 @@ async fn create_billing_subscription(
|
||||
|
||||
stripe_billing
|
||||
.checkout_with_zed_pro_trial(
|
||||
app.config.zed_pro_price_id()?,
|
||||
customer_id,
|
||||
&user.github_login,
|
||||
feature_flags,
|
||||
@@ -392,6 +393,11 @@ async fn create_billing_subscription(
|
||||
)
|
||||
.await?
|
||||
}
|
||||
Some(ProductCode::ZedFree) => {
|
||||
stripe_billing
|
||||
.checkout_with_zed_free(customer_id, &user.github_login, &success_url)
|
||||
.await?
|
||||
}
|
||||
None => {
|
||||
let default_model = llm_db.model(
|
||||
zed_llm_client::LanguageModelProvider::Anthropic,
|
||||
@@ -458,6 +464,14 @@ async fn manage_billing_subscription(
|
||||
))?
|
||||
};
|
||||
|
||||
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
||||
log::error!("failed to retrieve Stripe billing object");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
|
||||
let customer = app
|
||||
.db
|
||||
.get_billing_customer_by_user_id(user.id)
|
||||
@@ -508,8 +522,8 @@ async fn manage_billing_subscription(
|
||||
let flow = match body.intent {
|
||||
ManageSubscriptionIntent::ManageSubscription => None,
|
||||
ManageSubscriptionIntent::UpgradeToPro => {
|
||||
let zed_pro_price_id = app.config.zed_pro_price_id()?;
|
||||
let zed_free_price_id = app.config.zed_free_price_id()?;
|
||||
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
|
||||
let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
|
||||
|
||||
let stripe_subscription =
|
||||
Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
|
||||
@@ -602,6 +616,85 @@ async fn manage_billing_subscription(
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MigrateToNewBillingBody {
|
||||
github_user_id: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct MigrateToNewBillingResponse {
|
||||
/// The ID of the subscription that was canceled.
|
||||
canceled_subscription_id: String,
|
||||
}
|
||||
|
||||
async fn migrate_to_new_billing(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
extract::Json(body): extract::Json<MigrateToNewBillingBody>,
|
||||
) -> Result<Json<MigrateToNewBillingResponse>> {
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::error!("failed to retrieve Stripe client");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
|
||||
let user = app
|
||||
.db
|
||||
.get_user_by_github_user_id(body.github_user_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let old_billing_subscriptions_by_user = app
|
||||
.db
|
||||
.get_active_billing_subscriptions(HashSet::from_iter([user.id]))
|
||||
.await?;
|
||||
|
||||
let Some((_billing_customer, billing_subscription)) =
|
||||
old_billing_subscriptions_by_user.get(&user.id)
|
||||
else {
|
||||
return Err(Error::http(
|
||||
StatusCode::NOT_FOUND,
|
||||
"No active billing subscriptions to migrate".into(),
|
||||
));
|
||||
};
|
||||
|
||||
let stripe_subscription_id = billing_subscription
|
||||
.stripe_subscription_id
|
||||
.parse::<stripe::SubscriptionId>()
|
||||
.context("failed to parse Stripe subscription ID from database")?;
|
||||
|
||||
Subscription::cancel(
|
||||
&stripe_client,
|
||||
&stripe_subscription_id,
|
||||
stripe::CancelSubscription {
|
||||
invoice_now: Some(true),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let feature_flags = app.db.list_feature_flags().await?;
|
||||
|
||||
for feature_flag in ["new-billing", "assistant2"] {
|
||||
let already_in_feature_flag = feature_flags.iter().any(|flag| flag.flag == feature_flag);
|
||||
if already_in_feature_flag {
|
||||
continue;
|
||||
}
|
||||
|
||||
let feature_flag = feature_flags
|
||||
.iter()
|
||||
.find(|flag| flag.flag == feature_flag)
|
||||
.context("failed to find feature flag: {feature_flag:?}")?;
|
||||
|
||||
app.db.add_user_flag(user.id, feature_flag.id).await?;
|
||||
}
|
||||
|
||||
Ok(Json(MigrateToNewBillingResponse {
|
||||
canceled_subscription_id: stripe_subscription_id.to_string(),
|
||||
}))
|
||||
}
|
||||
|
||||
/// The amount of time we wait in between each poll of Stripe events.
|
||||
///
|
||||
/// This value should strike a balance between:
|
||||
@@ -856,9 +949,11 @@ async fn handle_customer_subscription_event(
|
||||
|
||||
log::info!("handling Stripe {} event: {}", event.type_, event.id);
|
||||
|
||||
let subscription_kind = maybe!({
|
||||
let zed_pro_price_id = app.config.zed_pro_price_id().ok()?;
|
||||
let zed_free_price_id = app.config.zed_free_price_id().ok()?;
|
||||
let subscription_kind = maybe!(async {
|
||||
let stripe_billing = app.stripe_billing.clone()?;
|
||||
|
||||
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.ok()?;
|
||||
let zed_free_price_id = stripe_billing.zed_free_price_id().await.ok()?;
|
||||
|
||||
subscription.items.data.iter().find_map(|item| {
|
||||
let price = item.price.as_ref()?;
|
||||
@@ -875,7 +970,8 @@ async fn handle_customer_subscription_event(
|
||||
None
|
||||
}
|
||||
})
|
||||
});
|
||||
})
|
||||
.await;
|
||||
|
||||
let billing_customer =
|
||||
find_or_create_billing_customer(app, stripe_client, subscription.customer)
|
||||
@@ -1090,9 +1186,17 @@ struct UsageCounts {
|
||||
pub remaining: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ModelRequestUsage {
|
||||
pub model: String,
|
||||
pub mode: CompletionMode,
|
||||
pub requests: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct GetCurrentUsageResponse {
|
||||
pub model_requests: UsageCounts,
|
||||
pub model_request_usage: Vec<ModelRequestUsage>,
|
||||
pub edit_predictions: UsageCounts,
|
||||
}
|
||||
|
||||
@@ -1119,6 +1223,7 @@ async fn get_current_usage(
|
||||
limit: Some(0),
|
||||
remaining: Some(0),
|
||||
},
|
||||
model_request_usage: Vec::new(),
|
||||
edit_predictions: UsageCounts {
|
||||
used: 0,
|
||||
limit: Some(0),
|
||||
@@ -1154,8 +1259,21 @@ async fn get_current_usage(
|
||||
SubscriptionKind::ZedFree => zed_llm_client::Plan::Free,
|
||||
};
|
||||
|
||||
let feature_flags = app.db.get_user_flags(user.id).await?;
|
||||
let has_extended_trial = feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
|
||||
|
||||
let model_requests_limit = match plan.model_requests_limit() {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
|
||||
1_000
|
||||
} else {
|
||||
limit
|
||||
};
|
||||
|
||||
Some(limit)
|
||||
}
|
||||
zed_llm_client::UsageLimit::Unlimited => None,
|
||||
};
|
||||
let edit_prediction_limit = match plan.edit_predictions_limit() {
|
||||
@@ -1163,12 +1281,30 @@ async fn get_current_usage(
|
||||
zed_llm_client::UsageLimit::Unlimited => None,
|
||||
};
|
||||
|
||||
let subscription_usage_meters = llm_db
|
||||
.get_current_subscription_usage_meters_for_user(user.id, Utc::now())
|
||||
.await?;
|
||||
|
||||
let model_request_usage = subscription_usage_meters
|
||||
.into_iter()
|
||||
.filter_map(|(usage_meter, _usage)| {
|
||||
let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
|
||||
|
||||
Some(ModelRequestUsage {
|
||||
model: model.name.clone(),
|
||||
mode: usage_meter.mode,
|
||||
requests: usage_meter.requests,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Ok(Json(GetCurrentUsageResponse {
|
||||
model_requests: UsageCounts {
|
||||
used: usage.model_requests,
|
||||
limit: model_requests_limit,
|
||||
remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
|
||||
},
|
||||
model_request_usage,
|
||||
edit_predictions: UsageCounts {
|
||||
used: usage.edit_predictions,
|
||||
limit: edit_prediction_limit,
|
||||
@@ -1402,12 +1538,12 @@ async fn sync_model_request_usage_with_stripe(
|
||||
|
||||
let model = llm_db.model_by_id(usage_meter.model_id)?;
|
||||
|
||||
let (price_id, meter_event_name) = match model.name.as_str() {
|
||||
"claude-3-5-sonnet" => (&claude_3_5_sonnet.id, "claude_3_5_sonnet/requests"),
|
||||
let (price, meter_event_name) = match model.name.as_str() {
|
||||
"claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
|
||||
"claude-3-7-sonnet" => match usage_meter.mode {
|
||||
CompletionMode::Normal => (&claude_3_7_sonnet.id, "claude_3_7_sonnet/requests"),
|
||||
CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
|
||||
CompletionMode::Max => {
|
||||
(&claude_3_7_sonnet_max.id, "claude_3_7_sonnet/requests/max")
|
||||
(&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
|
||||
}
|
||||
},
|
||||
model_name => {
|
||||
@@ -1416,7 +1552,7 @@ async fn sync_model_request_usage_with_stripe(
|
||||
};
|
||||
|
||||
stripe_billing
|
||||
.subscribe_to_price(&stripe_subscription_id, price_id)
|
||||
.subscribe_to_price(&stripe_subscription_id, price)
|
||||
.await?;
|
||||
stripe_billing
|
||||
.bill_model_request_usage(
|
||||
|
||||
@@ -180,9 +180,6 @@ pub struct Config {
|
||||
pub slack_panics_webhook: Option<String>,
|
||||
pub auto_join_channel_id: Option<ChannelId>,
|
||||
pub stripe_api_key: Option<String>,
|
||||
pub stripe_zed_pro_price_id: Option<String>,
|
||||
pub stripe_zed_pro_trial_price_id: Option<String>,
|
||||
pub stripe_zed_free_price_id: Option<String>,
|
||||
pub supermaven_admin_api_key: Option<Arc<str>>,
|
||||
pub user_backfiller_github_access_token: Option<Arc<str>>,
|
||||
}
|
||||
@@ -201,22 +198,6 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn zed_pro_price_id(&self) -> anyhow::Result<stripe::PriceId> {
|
||||
Self::parse_stripe_price_id("Zed Pro", self.stripe_zed_pro_price_id.as_deref())
|
||||
}
|
||||
|
||||
pub fn zed_free_price_id(&self) -> anyhow::Result<stripe::PriceId> {
|
||||
Self::parse_stripe_price_id("Zed Free", self.stripe_zed_pro_price_id.as_deref())
|
||||
}
|
||||
|
||||
fn parse_stripe_price_id(name: &str, value: Option<&str>) -> anyhow::Result<stripe::PriceId> {
|
||||
use std::str::FromStr as _;
|
||||
|
||||
let price_id = value.ok_or_else(|| anyhow!("{name} price ID not set"))?;
|
||||
|
||||
Ok(stripe::PriceId::from_str(price_id)?)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn test() -> Self {
|
||||
Self {
|
||||
@@ -254,9 +235,6 @@ impl Config {
|
||||
migrations_path: None,
|
||||
seed_path: None,
|
||||
stripe_api_key: None,
|
||||
stripe_zed_pro_price_id: None,
|
||||
stripe_zed_pro_trial_price_id: None,
|
||||
stripe_zed_free_price_id: None,
|
||||
supermaven_admin_api_key: None,
|
||||
user_backfiller_github_access_token: None,
|
||||
kinesis_region: None,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::db::UserId;
|
||||
use crate::llm::db::queries::subscription_usages::convert_chrono_to_time;
|
||||
|
||||
use super::*;
|
||||
@@ -34,4 +35,38 @@ impl LlmDatabase {
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns all current subscription usage meters for the given user as of the given timestamp.
|
||||
pub async fn get_current_subscription_usage_meters_for_user(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<Vec<(subscription_usage_meter::Model, subscription_usage::Model)>> {
|
||||
let now = convert_chrono_to_time(now)?;
|
||||
|
||||
self.transaction(|tx| async move {
|
||||
let result = subscription_usage_meter::Entity::find()
|
||||
.inner_join(subscription_usage::Entity)
|
||||
.filter(subscription_usage::Column::UserId.eq(user_id))
|
||||
.filter(
|
||||
subscription_usage::Column::PeriodStartAt
|
||||
.lte(now)
|
||||
.and(subscription_usage::Column::PeriodEndAt.gte(now)),
|
||||
)
|
||||
.select_also(subscription_usage::Entity)
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
let result = result
|
||||
.into_iter()
|
||||
.filter_map(|(meter, usage)| {
|
||||
let usage = usage?;
|
||||
Some((meter, usage))
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(result)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,6 +32,8 @@ pub struct LlmTokenClaims {
|
||||
pub has_llm_subscription: bool,
|
||||
pub max_monthly_spend_in_cents: u32,
|
||||
pub custom_llm_monthly_allowance_in_cents: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub use_new_billing: bool,
|
||||
pub plan: Plan,
|
||||
#[serde(default)]
|
||||
pub has_extended_trial: bool,
|
||||
@@ -90,6 +92,7 @@ impl LlmTokenClaims {
|
||||
custom_llm_monthly_allowance_in_cents: user
|
||||
.custom_llm_monthly_allowance_in_cents
|
||||
.map(|allowance| allowance as u32),
|
||||
use_new_billing: feature_flags.iter().any(|flag| flag == "new-billing"),
|
||||
plan: subscription
|
||||
.as_ref()
|
||||
.and_then(|subscription| subscription.kind)
|
||||
|
||||
@@ -327,6 +327,10 @@ impl Server {
|
||||
.add_request_handler(
|
||||
forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
|
||||
)
|
||||
.add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
|
||||
.add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
|
||||
.add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
|
||||
.add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
|
||||
.add_request_handler(
|
||||
forward_read_only_project_request::<proto::LanguageServerIdForName>,
|
||||
)
|
||||
|
||||
@@ -81,6 +81,24 @@ impl StripeBilling {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
|
||||
self.find_price_id_by_lookup_key("zed-pro").await
|
||||
}
|
||||
|
||||
pub async fn zed_free_price_id(&self) -> Result<PriceId> {
|
||||
self.find_price_id_by_lookup_key("zed-free").await
|
||||
}
|
||||
|
||||
pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
|
||||
self.state
|
||||
.read()
|
||||
.await
|
||||
.prices_by_lookup_key
|
||||
.get(lookup_key)
|
||||
.map(|price| price.id.clone())
|
||||
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
|
||||
}
|
||||
|
||||
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
|
||||
self.state
|
||||
.read()
|
||||
@@ -88,7 +106,7 @@ impl StripeBilling {
|
||||
.prices_by_lookup_key
|
||||
.get(lookup_key)
|
||||
.cloned()
|
||||
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
|
||||
.ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
|
||||
}
|
||||
|
||||
pub async fn register_model_for_token_based_usage(
|
||||
@@ -230,21 +248,29 @@ impl StripeBilling {
|
||||
pub async fn subscribe_to_price(
|
||||
&self,
|
||||
subscription_id: &stripe::SubscriptionId,
|
||||
price_id: &stripe::PriceId,
|
||||
price: &stripe::Price,
|
||||
) -> Result<()> {
|
||||
let subscription =
|
||||
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
|
||||
|
||||
if subscription_contains_price(&subscription, price_id) {
|
||||
if subscription_contains_price(&subscription, &price.id) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
|
||||
|
||||
let price_per_unit = price.unit_amount.unwrap_or_default();
|
||||
let units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
|
||||
|
||||
stripe::Subscription::update(
|
||||
&self.client,
|
||||
subscription_id,
|
||||
stripe::UpdateSubscription {
|
||||
items: Some(vec![stripe::UpdateSubscriptionItems {
|
||||
price: Some(price_id.to_string()),
|
||||
price: Some(price.id.to_string()),
|
||||
billing_thresholds: Some(stripe::SubscriptionItemBillingThresholds {
|
||||
usage_gte: Some(units_for_billing_threshold),
|
||||
}),
|
||||
..Default::default()
|
||||
}]),
|
||||
trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
|
||||
@@ -463,19 +489,20 @@ impl StripeBilling {
|
||||
Ok(session.url.context("no checkout session URL")?)
|
||||
}
|
||||
|
||||
pub async fn checkout_with_price(
|
||||
pub async fn checkout_with_zed_pro(
|
||||
&self,
|
||||
price_id: PriceId,
|
||||
customer_id: stripe::CustomerId,
|
||||
github_login: &str,
|
||||
success_url: &str,
|
||||
) -> Result<String> {
|
||||
let zed_pro_price_id = self.zed_pro_price_id().await?;
|
||||
|
||||
let mut params = stripe::CreateCheckoutSession::new();
|
||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
||||
params.customer = Some(customer_id);
|
||||
params.client_reference_id = Some(github_login);
|
||||
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
|
||||
price: Some(price_id.to_string()),
|
||||
price: Some(zed_pro_price_id.to_string()),
|
||||
quantity: Some(1),
|
||||
..Default::default()
|
||||
}]);
|
||||
@@ -487,12 +514,13 @@ impl StripeBilling {
|
||||
|
||||
pub async fn checkout_with_zed_pro_trial(
|
||||
&self,
|
||||
zed_pro_price_id: PriceId,
|
||||
customer_id: stripe::CustomerId,
|
||||
github_login: &str,
|
||||
feature_flags: Vec<String>,
|
||||
success_url: &str,
|
||||
) -> Result<String> {
|
||||
let zed_pro_price_id = self.zed_pro_price_id().await?;
|
||||
|
||||
let eligible_for_extended_trial = feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
|
||||
@@ -537,6 +565,29 @@ impl StripeBilling {
|
||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
||||
Ok(session.url.context("no checkout session URL")?)
|
||||
}
|
||||
|
||||
pub async fn checkout_with_zed_free(
|
||||
&self,
|
||||
customer_id: stripe::CustomerId,
|
||||
github_login: &str,
|
||||
success_url: &str,
|
||||
) -> Result<String> {
|
||||
let zed_free_price_id = self.zed_free_price_id().await?;
|
||||
|
||||
let mut params = stripe::CreateCheckoutSession::new();
|
||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
||||
params.customer = Some(customer_id);
|
||||
params.client_reference_id = Some(github_login);
|
||||
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
|
||||
price: Some(zed_free_price_id.to_string()),
|
||||
quantity: Some(1),
|
||||
..Default::default()
|
||||
}]);
|
||||
params.success_url = Some(success_url);
|
||||
|
||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
||||
Ok(session.url.context("no checkout session URL")?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
||||
@@ -25,7 +25,7 @@ use language::{
|
||||
use project::{
|
||||
ProjectPath, SERVER_PROGRESS_THROTTLE_TIMEOUT,
|
||||
lsp_store::{
|
||||
lsp_ext_command::{ExpandedMacro, LspExpandMacro},
|
||||
lsp_ext_command::{ExpandedMacro, LspExtExpandMacro},
|
||||
rust_analyzer_ext::RUST_ANALYZER_NAME,
|
||||
},
|
||||
project_settings::{InlineBlameSettings, ProjectSettings},
|
||||
@@ -2704,8 +2704,8 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes
|
||||
let fake_language_server = fake_language_servers.next().await.unwrap();
|
||||
|
||||
// host
|
||||
let mut expand_request_a =
|
||||
fake_language_server.set_request_handler::<LspExpandMacro, _, _>(|params, _| async move {
|
||||
let mut expand_request_a = fake_language_server.set_request_handler::<LspExtExpandMacro, _, _>(
|
||||
|params, _| async move {
|
||||
assert_eq!(
|
||||
params.text_document.uri,
|
||||
lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(),
|
||||
@@ -2715,7 +2715,8 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes
|
||||
name: "test_macro_name".to_string(),
|
||||
expansion: "test_macro_expansion on the host".to_string(),
|
||||
}))
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
editor_a.update_in(cx_a, |editor, window, cx| {
|
||||
expand_macro_recursively(editor, &ExpandMacroRecursively, window, cx)
|
||||
@@ -2738,8 +2739,8 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes
|
||||
});
|
||||
|
||||
// client
|
||||
let mut expand_request_b =
|
||||
fake_language_server.set_request_handler::<LspExpandMacro, _, _>(|params, _| async move {
|
||||
let mut expand_request_b = fake_language_server.set_request_handler::<LspExtExpandMacro, _, _>(
|
||||
|params, _| async move {
|
||||
assert_eq!(
|
||||
params.text_document.uri,
|
||||
lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(),
|
||||
@@ -2749,7 +2750,8 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes
|
||||
name: "test_macro_name".to_string(),
|
||||
expansion: "test_macro_expansion on the client".to_string(),
|
||||
}))
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
editor_b.update_in(cx_b, |editor, window, cx| {
|
||||
expand_macro_recursively(editor, &ExpandMacroRecursively, window, cx)
|
||||
|
||||
@@ -2902,7 +2902,7 @@ async fn test_git_branch_name(
|
||||
.read(cx)
|
||||
.branch
|
||||
.as_ref()
|
||||
.map(|branch| branch.name.to_string()),
|
||||
.map(|branch| branch.name().to_owned()),
|
||||
branch_name
|
||||
)
|
||||
}
|
||||
@@ -6864,7 +6864,7 @@ async fn test_remote_git_branches(
|
||||
|
||||
let branches_b = branches_b
|
||||
.into_iter()
|
||||
.map(|branch| branch.name.to_string())
|
||||
.map(|branch| branch.name().to_string())
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
assert_eq!(branches_b, branches_set);
|
||||
@@ -6895,7 +6895,7 @@ async fn test_remote_git_branches(
|
||||
})
|
||||
});
|
||||
|
||||
assert_eq!(host_branch.name, branches[2]);
|
||||
assert_eq!(host_branch.name(), branches[2]);
|
||||
|
||||
// Also try creating a new branch
|
||||
cx_b.update(|cx| {
|
||||
@@ -6933,5 +6933,5 @@ async fn test_remote_git_branches(
|
||||
})
|
||||
});
|
||||
|
||||
assert_eq!(host_branch.name, "totally-new-branch");
|
||||
assert_eq!(host_branch.name(), "totally-new-branch");
|
||||
}
|
||||
|
||||
@@ -293,7 +293,7 @@ async fn test_ssh_collaboration_git_branches(
|
||||
|
||||
let branches_b = branches_b
|
||||
.into_iter()
|
||||
.map(|branch| branch.name.to_string())
|
||||
.map(|branch| branch.name().to_string())
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
assert_eq!(&branches_b, &branches_set);
|
||||
@@ -326,7 +326,7 @@ async fn test_ssh_collaboration_git_branches(
|
||||
})
|
||||
});
|
||||
|
||||
assert_eq!(server_branch.name, branches[2]);
|
||||
assert_eq!(server_branch.name(), branches[2]);
|
||||
|
||||
// Also try creating a new branch
|
||||
cx_b.update(|cx| {
|
||||
@@ -366,7 +366,7 @@ async fn test_ssh_collaboration_git_branches(
|
||||
})
|
||||
});
|
||||
|
||||
assert_eq!(server_branch.name, "totally-new-branch");
|
||||
assert_eq!(server_branch.name(), "totally-new-branch");
|
||||
|
||||
// Remove the git repository and check that all participants get the update.
|
||||
remote_fs
|
||||
|
||||
@@ -554,9 +554,6 @@ impl TestServer {
|
||||
migrations_path: None,
|
||||
seed_path: None,
|
||||
stripe_api_key: None,
|
||||
stripe_zed_pro_price_id: None,
|
||||
stripe_zed_pro_trial_price_id: None,
|
||||
stripe_zed_free_price_id: None,
|
||||
supermaven_admin_api_key: None,
|
||||
user_backfiller_github_access_token: None,
|
||||
kinesis_region: None,
|
||||
|
||||
@@ -15,18 +15,21 @@ path = "src/component_preview.rs"
|
||||
default = []
|
||||
|
||||
[dependencies]
|
||||
agent.workspace = true
|
||||
anyhow.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
component.workspace = true
|
||||
db.workspace = true
|
||||
gpui.workspace = true
|
||||
languages.workspace = true
|
||||
notifications.workspace = true
|
||||
log.workspace = true
|
||||
notifications.workspace = true
|
||||
project.workspace = true
|
||||
prompt_store.workspace = true
|
||||
serde.workspace = true
|
||||
ui.workspace = true
|
||||
ui_input.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
workspace.workspace = true
|
||||
db.workspace = true
|
||||
anyhow.workspace = true
|
||||
serde.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
|
||||
@@ -3,10 +3,12 @@
|
||||
//! A view for exploring Zed components.
|
||||
|
||||
mod persistence;
|
||||
mod preview_support;
|
||||
|
||||
use std::iter::Iterator;
|
||||
use std::sync::Arc;
|
||||
|
||||
use agent::{ActiveThread, ThreadStore};
|
||||
use client::UserStore;
|
||||
use component::{ComponentId, ComponentMetadata, components};
|
||||
use gpui::{
|
||||
@@ -19,6 +21,7 @@ use gpui::{ListState, ScrollHandle, ScrollStrategy, UniformListScrollHandle};
|
||||
use languages::LanguageRegistry;
|
||||
use notifications::status_toast::{StatusToast, ToastIcon};
|
||||
use persistence::COMPONENT_PREVIEW_DB;
|
||||
use preview_support::active_thread::{load_preview_thread_store, static_active_thread};
|
||||
use project::Project;
|
||||
use ui::{Divider, HighlightedLabel, ListItem, ListSubHeader, prelude::*};
|
||||
|
||||
@@ -33,6 +36,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
|
||||
|
||||
cx.observe_new(move |workspace: &mut Workspace, _window, cx| {
|
||||
let app_state = app_state.clone();
|
||||
let project = workspace.project().clone();
|
||||
let weak_workspace = cx.entity().downgrade();
|
||||
|
||||
workspace.register_action(
|
||||
@@ -45,6 +49,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
|
||||
let component_preview = cx.new(|cx| {
|
||||
ComponentPreview::new(
|
||||
weak_workspace.clone(),
|
||||
project.clone(),
|
||||
language_registry,
|
||||
user_store,
|
||||
None,
|
||||
@@ -52,6 +57,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.expect("Failed to create component preview")
|
||||
});
|
||||
|
||||
workspace.add_item_to_active_pane(
|
||||
@@ -69,6 +75,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
|
||||
|
||||
enum PreviewEntry {
|
||||
AllComponents,
|
||||
ActiveThread,
|
||||
Separator,
|
||||
Component(ComponentMetadata, Option<Vec<usize>>),
|
||||
SectionHeader(SharedString),
|
||||
@@ -91,6 +98,7 @@ enum PreviewPage {
|
||||
#[default]
|
||||
AllComponents,
|
||||
Component(ComponentId),
|
||||
ActiveThread,
|
||||
}
|
||||
|
||||
struct ComponentPreview {
|
||||
@@ -102,24 +110,63 @@ struct ComponentPreview {
|
||||
active_page: PreviewPage,
|
||||
components: Vec<ComponentMetadata>,
|
||||
component_list: ListState,
|
||||
agent_previews: Vec<
|
||||
Box<
|
||||
dyn Fn(
|
||||
&Self,
|
||||
WeakEntity<Workspace>,
|
||||
Entity<ActiveThread>,
|
||||
WeakEntity<ThreadStore>,
|
||||
&mut Window,
|
||||
&mut App,
|
||||
) -> Option<AnyElement>,
|
||||
>,
|
||||
>,
|
||||
cursor_index: usize,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
user_store: Entity<UserStore>,
|
||||
filter_editor: Entity<SingleLineInput>,
|
||||
filter_text: String,
|
||||
|
||||
// preview support
|
||||
thread_store: Option<Entity<ThreadStore>>,
|
||||
active_thread: Option<Entity<ActiveThread>>,
|
||||
}
|
||||
|
||||
impl ComponentPreview {
|
||||
pub fn new(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
user_store: Entity<UserStore>,
|
||||
selected_index: impl Into<Option<usize>>,
|
||||
active_page: Option<PreviewPage>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
) -> anyhow::Result<Self> {
|
||||
let workspace_clone = workspace.clone();
|
||||
let project_clone = project.clone();
|
||||
|
||||
let entity = cx.weak_entity();
|
||||
window
|
||||
.spawn(cx, async move |cx| {
|
||||
let thread_store_task =
|
||||
load_preview_thread_store(workspace_clone.clone(), project_clone.clone(), cx)
|
||||
.await;
|
||||
|
||||
if let Ok(thread_store) = thread_store_task.await {
|
||||
entity
|
||||
.update_in(cx, |this, window, cx| {
|
||||
this.thread_store = Some(thread_store.clone());
|
||||
this.create_active_thread(window, cx);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
let sorted_components = components().all_sorted();
|
||||
let selected_index = selected_index.into().unwrap_or(0);
|
||||
let active_page = active_page.unwrap_or(PreviewPage::AllComponents);
|
||||
@@ -143,6 +190,40 @@ impl ComponentPreview {
|
||||
},
|
||||
);
|
||||
|
||||
// Initialize agent previews
|
||||
let agent_previews = agent::all_agent_previews()
|
||||
.into_iter()
|
||||
.map(|id| {
|
||||
Box::new(
|
||||
move |_self: &ComponentPreview,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
active_thread: Entity<ActiveThread>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut App| {
|
||||
agent::get_agent_preview(
|
||||
&id,
|
||||
workspace,
|
||||
active_thread,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
},
|
||||
)
|
||||
as Box<
|
||||
dyn Fn(
|
||||
&ComponentPreview,
|
||||
WeakEntity<Workspace>,
|
||||
Entity<ActiveThread>,
|
||||
WeakEntity<ThreadStore>,
|
||||
&mut Window,
|
||||
&mut App,
|
||||
) -> Option<AnyElement>,
|
||||
>
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut component_preview = Self {
|
||||
workspace_id: None,
|
||||
focus_handle: cx.focus_handle(),
|
||||
@@ -151,13 +232,17 @@ impl ComponentPreview {
|
||||
language_registry,
|
||||
user_store,
|
||||
workspace,
|
||||
project,
|
||||
active_page,
|
||||
component_map: components().0,
|
||||
components: sorted_components,
|
||||
component_list,
|
||||
agent_previews,
|
||||
cursor_index: selected_index,
|
||||
filter_editor,
|
||||
filter_text: String::new(),
|
||||
thread_store: None,
|
||||
active_thread: None,
|
||||
};
|
||||
|
||||
if component_preview.cursor_index > 0 {
|
||||
@@ -169,13 +254,41 @@ impl ComponentPreview {
|
||||
let focus_handle = component_preview.filter_editor.read(cx).focus_handle(cx);
|
||||
window.focus(&focus_handle);
|
||||
|
||||
component_preview
|
||||
Ok(component_preview)
|
||||
}
|
||||
|
||||
pub fn create_active_thread(
|
||||
&mut self,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> &mut Self {
|
||||
let workspace = self.workspace.clone();
|
||||
let language_registry = self.language_registry.clone();
|
||||
let weak_handle = self.workspace.clone();
|
||||
if let Some(workspace) = workspace.upgrade() {
|
||||
let project = workspace.read(cx).project().clone();
|
||||
if let Some(thread_store) = self.thread_store.clone() {
|
||||
let active_thread = static_active_thread(
|
||||
weak_handle,
|
||||
project,
|
||||
language_registry,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
self.active_thread = Some(active_thread);
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn active_page_id(&self, _cx: &App) -> ActivePageId {
|
||||
match &self.active_page {
|
||||
PreviewPage::AllComponents => ActivePageId::default(),
|
||||
PreviewPage::Component(component_id) => ActivePageId(component_id.0.to_string()),
|
||||
PreviewPage::ActiveThread => ActivePageId("active_thread".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -289,6 +402,7 @@ impl ComponentPreview {
|
||||
|
||||
// Always show all components first
|
||||
entries.push(PreviewEntry::AllComponents);
|
||||
entries.push(PreviewEntry::ActiveThread);
|
||||
entries.push(PreviewEntry::Separator);
|
||||
|
||||
let mut scopes: Vec<_> = scope_groups
|
||||
@@ -389,6 +503,19 @@ impl ComponentPreview {
|
||||
}))
|
||||
.into_any_element()
|
||||
}
|
||||
PreviewEntry::ActiveThread => {
|
||||
let selected = self.active_page == PreviewPage::ActiveThread;
|
||||
|
||||
ListItem::new(ix)
|
||||
.child(Label::new("Active Thread").color(Color::Default))
|
||||
.selectable(true)
|
||||
.toggle_state(selected)
|
||||
.inset(true)
|
||||
.on_click(cx.listener(move |this, _, _, cx| {
|
||||
this.set_active_page(PreviewPage::ActiveThread, cx);
|
||||
}))
|
||||
.into_any_element()
|
||||
}
|
||||
PreviewEntry::Separator => ListItem::new(ix)
|
||||
.child(
|
||||
h_flex()
|
||||
@@ -471,6 +598,7 @@ impl ComponentPreview {
|
||||
.render_scope_header(ix, shared_string.clone(), window, cx)
|
||||
.into_any_element(),
|
||||
PreviewEntry::AllComponents => div().w_full().h_0().into_any_element(),
|
||||
PreviewEntry::ActiveThread => div().w_full().h_0().into_any_element(),
|
||||
PreviewEntry::Separator => div().w_full().h_0().into_any_element(),
|
||||
})
|
||||
.unwrap()
|
||||
@@ -595,6 +723,41 @@ impl ComponentPreview {
|
||||
}
|
||||
}
|
||||
|
||||
fn render_active_thread(
|
||||
&self,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
v_flex()
|
||||
.id("render-active-thread")
|
||||
.size_full()
|
||||
.child(
|
||||
v_flex().children(self.agent_previews.iter().filter_map(|preview_fn| {
|
||||
if let (Some(thread_store), Some(active_thread)) = (
|
||||
self.thread_store.as_ref().map(|ts| ts.downgrade()),
|
||||
self.active_thread.clone(),
|
||||
) {
|
||||
preview_fn(
|
||||
self,
|
||||
self.workspace.clone(),
|
||||
active_thread,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.map(|element| div().child(element))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})),
|
||||
)
|
||||
.children(self.active_thread.clone().map(|thread| thread.clone()))
|
||||
.when_none(&self.active_thread.clone(), |this| {
|
||||
this.child("No active thread")
|
||||
})
|
||||
.into_any_element()
|
||||
}
|
||||
|
||||
fn test_status_toast(&self, cx: &mut Context<Self>) {
|
||||
if let Some(workspace) = self.workspace.upgrade() {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
@@ -704,6 +867,9 @@ impl Render for ComponentPreview {
|
||||
PreviewPage::Component(id) => self
|
||||
.render_component_page(&id, window, cx)
|
||||
.into_any_element(),
|
||||
PreviewPage::ActiveThread => {
|
||||
self.render_active_thread(window, cx).into_any_element()
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
@@ -759,20 +925,28 @@ impl Item for ComponentPreview {
|
||||
let language_registry = self.language_registry.clone();
|
||||
let user_store = self.user_store.clone();
|
||||
let weak_workspace = self.workspace.clone();
|
||||
let project = self.project.clone();
|
||||
let selected_index = self.cursor_index;
|
||||
let active_page = self.active_page.clone();
|
||||
|
||||
Some(cx.new(|cx| {
|
||||
Self::new(
|
||||
weak_workspace,
|
||||
language_registry,
|
||||
user_store,
|
||||
selected_index,
|
||||
Some(active_page),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}))
|
||||
let self_result = Self::new(
|
||||
weak_workspace,
|
||||
project,
|
||||
language_registry,
|
||||
user_store,
|
||||
selected_index,
|
||||
Some(active_page),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
match self_result {
|
||||
Ok(preview) => Some(cx.new(|_cx| preview)),
|
||||
Err(e) => {
|
||||
log::error!("Failed to clone component preview: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn to_item_events(event: &Self::Event, mut f: impl FnMut(workspace::item::ItemEvent)) {
|
||||
@@ -838,10 +1012,12 @@ impl SerializableItem for ComponentPreview {
|
||||
let user_store = user_store.clone();
|
||||
let language_registry = language_registry.clone();
|
||||
let weak_workspace = workspace.clone();
|
||||
let project = project.clone();
|
||||
cx.update(move |window, cx| {
|
||||
Ok(cx.new(|cx| {
|
||||
ComponentPreview::new(
|
||||
weak_workspace,
|
||||
project,
|
||||
language_registry,
|
||||
user_store,
|
||||
None,
|
||||
@@ -849,6 +1025,7 @@ impl SerializableItem for ComponentPreview {
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.expect("Failed to create component preview")
|
||||
}))
|
||||
})?
|
||||
})
|
||||
|
||||
1
crates/component_preview/src/preview_support.rs
Normal file
1
crates/component_preview/src/preview_support.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod active_thread;
|
||||
@@ -0,0 +1,69 @@
|
||||
use languages::LanguageRegistry;
|
||||
use project::Project;
|
||||
use std::sync::Arc;
|
||||
|
||||
use agent::{ActiveThread, ContextStore, MessageSegment, ThreadStore};
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
use gpui::{AppContext, AsyncApp, Entity, Task, WeakEntity};
|
||||
use prompt_store::PromptBuilder;
|
||||
use ui::{App, Window};
|
||||
use workspace::Workspace;
|
||||
|
||||
pub async fn load_preview_thread_store(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<anyhow::Result<Entity<ThreadStore>>> {
|
||||
cx.spawn(async move |cx| {
|
||||
workspace
|
||||
.update(cx, |_, cx| {
|
||||
ThreadStore::load(
|
||||
project.clone(),
|
||||
cx.new(|_| ToolWorkingSet::default()),
|
||||
None,
|
||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
pub fn static_active_thread(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Entity<ActiveThread> {
|
||||
let context_store =
|
||||
cx.new(|_| ContextStore::new(project.downgrade(), Some(thread_store.downgrade())));
|
||||
|
||||
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.insert_assistant_message(vec![
|
||||
MessageSegment::Text("I'll help you fix the lifetime error in your `cx.spawn` call. When working with async operations in GPUI, there are specific patterns to follow for proper lifetime management.".to_string()),
|
||||
MessageSegment::Text("\n\nLet's look at what's happening in your code:".to_string()),
|
||||
MessageSegment::Text("\n\n---\n\nLet's check the current state of the active_thread.rs file to understand what might have changed:".to_string()),
|
||||
MessageSegment::Text("\n\n---\n\nLooking at the implementation of `load_preview_thread_store` and understanding GPUI's async patterns, here's the issue:".to_string()),
|
||||
MessageSegment::Text("\n\n1. `load_preview_thread_store` returns a `Task<anyhow::Result<Entity<ThreadStore>>>`, which means it's already a task".to_string()),
|
||||
MessageSegment::Text("\n2. When you call this function inside another `spawn` call, you're nesting tasks incorrectly".to_string()),
|
||||
MessageSegment::Text("\n3. The `this` parameter you're trying to use in your closure has the wrong context".to_string()),
|
||||
MessageSegment::Text("\n\nHere's the correct way to implement this:".to_string()),
|
||||
MessageSegment::Text("\n\n---\n\nThe problem is in how you're setting up the async closure and trying to reference variables like `window` and `language_registry` that aren't accessible in that scope.".to_string()),
|
||||
MessageSegment::Text("\n\nHere's how to fix it:".to_string()),
|
||||
], cx);
|
||||
});
|
||||
cx.new(|cx| {
|
||||
ActiveThread::new(
|
||||
thread,
|
||||
thread_store,
|
||||
context_store,
|
||||
language_registry,
|
||||
workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
}
|
||||
@@ -34,3 +34,7 @@ smol.workspace = true
|
||||
url = { workspace = true, features = ["serde"] }
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -140,7 +140,7 @@ impl Client {
|
||||
/// This function initializes a new Client by spawning a child process for the context server,
|
||||
/// setting up communication channels, and initializing handlers for input/output operations.
|
||||
/// It takes a server ID, binary information, and an async app context as input.
|
||||
pub fn new(
|
||||
pub fn stdio(
|
||||
server_id: ContextServerId,
|
||||
binary: ModelContextServerBinary,
|
||||
cx: AsyncApp,
|
||||
@@ -158,7 +158,16 @@ impl Client {
|
||||
.unwrap_or_else(String::new);
|
||||
|
||||
let transport = Arc::new(StdioTransport::new(binary, &cx)?);
|
||||
Self::new(server_id, server_name.into(), transport, cx)
|
||||
}
|
||||
|
||||
/// Creates a new Client instance for a context server.
|
||||
pub fn new(
|
||||
server_id: ContextServerId,
|
||||
server_name: Arc<str>,
|
||||
transport: Arc<dyn Transport>,
|
||||
cx: AsyncApp,
|
||||
) -> Result<Self> {
|
||||
let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
|
||||
let (output_done_tx, output_done_rx) = barrier::channel();
|
||||
|
||||
@@ -167,7 +176,7 @@ impl Client {
|
||||
let response_handlers =
|
||||
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
|
||||
|
||||
let stdout_input_task = cx.spawn({
|
||||
let receive_input_task = cx.spawn({
|
||||
let notification_handlers = notification_handlers.clone();
|
||||
let response_handlers = response_handlers.clone();
|
||||
let transport = transport.clone();
|
||||
@@ -177,13 +186,13 @@ impl Client {
|
||||
.await
|
||||
}
|
||||
});
|
||||
let stderr_input_task = cx.spawn({
|
||||
let receive_err_task = cx.spawn({
|
||||
let transport = transport.clone();
|
||||
async move |_| Self::handle_stderr(transport).log_err().await
|
||||
async move |_| Self::handle_err(transport).log_err().await
|
||||
});
|
||||
let input_task = cx.spawn(async move |_| {
|
||||
let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
|
||||
stdout.or(stderr)
|
||||
let (input, err) = futures::join!(receive_input_task, receive_err_task);
|
||||
input.or(err)
|
||||
});
|
||||
|
||||
let output_task = cx.background_spawn({
|
||||
@@ -201,7 +210,7 @@ impl Client {
|
||||
server_id,
|
||||
notification_handlers,
|
||||
response_handlers,
|
||||
name: server_name.into(),
|
||||
name: server_name,
|
||||
next_id: Default::default(),
|
||||
outbound_tx,
|
||||
executor: cx.background_executor().clone(),
|
||||
@@ -247,7 +256,7 @@ impl Client {
|
||||
|
||||
/// Handles the stderr output from the context server.
|
||||
/// Continuously reads and logs any error messages from the server.
|
||||
async fn handle_stderr(transport: Arc<dyn Transport>) -> anyhow::Result<()> {
|
||||
async fn handle_err(transport: Arc<dyn Transport>) -> anyhow::Result<()> {
|
||||
while let Some(err) = transport.receive_err().next().await {
|
||||
log::warn!("context server stderr: {}", err.trim());
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ pub use context_server_settings::{ContextServerSettings, ServerCommand, ServerCo
|
||||
use gpui::{App, actions};
|
||||
|
||||
pub use crate::context_server_tool::ContextServerTool;
|
||||
pub use crate::registry::ContextServerFactoryRegistry;
|
||||
pub use crate::registry::ContextServerDescriptorRegistry;
|
||||
|
||||
actions!(context_servers, [Restart]);
|
||||
|
||||
@@ -21,7 +21,7 @@ pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers";
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
context_server_settings::init(cx);
|
||||
ContextServerFactoryRegistry::default_global(cx);
|
||||
ContextServerDescriptorRegistry::default_global(cx);
|
||||
extension_context_server::init(cx);
|
||||
|
||||
CommandPaletteFilter::update_global(cx, |filter, _cx| {
|
||||
|
||||
@@ -1,9 +1,21 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use extension::{Extension, ExtensionContextServerProxy, ExtensionHostProxy, ProjectDelegate};
|
||||
use gpui::{App, Entity};
|
||||
use anyhow::Result;
|
||||
use extension::{
|
||||
ContextServerConfiguration, Extension, ExtensionContextServerProxy, ExtensionHostProxy,
|
||||
ProjectDelegate,
|
||||
};
|
||||
use gpui::{App, AsyncApp, Entity, Task};
|
||||
use project::Project;
|
||||
|
||||
use crate::{ContextServerFactoryRegistry, ServerCommand};
|
||||
use crate::{ContextServerDescriptorRegistry, ServerCommand, registry};
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
let proxy = ExtensionHostProxy::default_global(cx);
|
||||
proxy.register_context_server_proxy(ContextServerDescriptorRegistryProxy {
|
||||
context_server_factory_registry: ContextServerDescriptorRegistry::global(cx),
|
||||
});
|
||||
}
|
||||
|
||||
struct ExtensionProject {
|
||||
worktree_ids: Vec<u64>,
|
||||
@@ -15,60 +27,78 @@ impl ProjectDelegate for ExtensionProject {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
let proxy = ExtensionHostProxy::default_global(cx);
|
||||
proxy.register_context_server_proxy(ContextServerFactoryRegistryProxy {
|
||||
context_server_factory_registry: ContextServerFactoryRegistry::global(cx),
|
||||
});
|
||||
struct ContextServerDescriptor {
|
||||
id: Arc<str>,
|
||||
extension: Arc<dyn Extension>,
|
||||
}
|
||||
|
||||
struct ContextServerFactoryRegistryProxy {
|
||||
context_server_factory_registry: Entity<ContextServerFactoryRegistry>,
|
||||
fn extension_project(project: Entity<Project>, cx: &mut AsyncApp) -> Result<Arc<ExtensionProject>> {
|
||||
project.update(cx, |project, cx| {
|
||||
Arc::new(ExtensionProject {
|
||||
worktree_ids: project
|
||||
.visible_worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).id().to_proto())
|
||||
.collect(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
impl ExtensionContextServerProxy for ContextServerFactoryRegistryProxy {
|
||||
impl registry::ContextServerDescriptor for ContextServerDescriptor {
|
||||
fn command(&self, project: Entity<Project>, cx: &AsyncApp) -> Task<Result<ServerCommand>> {
|
||||
let id = self.id.clone();
|
||||
let extension = self.extension.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let extension_project = extension_project(project, cx)?;
|
||||
let mut command = extension
|
||||
.context_server_command(id.clone(), extension_project.clone())
|
||||
.await?;
|
||||
command.command = extension
|
||||
.path_from_extension(command.command.as_ref())
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
log::info!("loaded command for context server {id}: {command:?}");
|
||||
|
||||
Ok(ServerCommand {
|
||||
path: command.command,
|
||||
args: command.args,
|
||||
env: Some(command.env.into_iter().collect()),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn configuration(
|
||||
&self,
|
||||
project: Entity<Project>,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<Option<ContextServerConfiguration>>> {
|
||||
let id = self.id.clone();
|
||||
let extension = self.extension.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let extension_project = extension_project(project, cx)?;
|
||||
let configuration = extension
|
||||
.context_server_configuration(id.clone(), extension_project)
|
||||
.await?;
|
||||
|
||||
log::debug!("loaded configuration for context server {id}: {configuration:?}");
|
||||
|
||||
Ok(configuration)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct ContextServerDescriptorRegistryProxy {
|
||||
context_server_factory_registry: Entity<ContextServerDescriptorRegistry>,
|
||||
}
|
||||
|
||||
impl ExtensionContextServerProxy for ContextServerDescriptorRegistryProxy {
|
||||
fn register_context_server(&self, extension: Arc<dyn Extension>, id: Arc<str>, cx: &mut App) {
|
||||
self.context_server_factory_registry
|
||||
.update(cx, |registry, _| {
|
||||
registry.register_server_factory(
|
||||
registry.register_context_server_descriptor(
|
||||
id.clone(),
|
||||
Arc::new({
|
||||
move |project, cx| {
|
||||
log::info!(
|
||||
"loading command for context server {id} from extension {}",
|
||||
extension.manifest().id
|
||||
);
|
||||
|
||||
let id = id.clone();
|
||||
let extension = extension.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let extension_project = project.update(cx, |project, cx| {
|
||||
Arc::new(ExtensionProject {
|
||||
worktree_ids: project
|
||||
.visible_worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).id().to_proto())
|
||||
.collect(),
|
||||
})
|
||||
})?;
|
||||
|
||||
let mut command = extension
|
||||
.context_server_command(id.clone(), extension_project)
|
||||
.await?;
|
||||
command.command = extension
|
||||
.path_from_extension(command.command.as_ref())
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
log::info!("loaded command for context server {id}: {command:?}");
|
||||
|
||||
Ok(ServerCommand {
|
||||
path: command.command,
|
||||
args: command.args,
|
||||
env: Some(command.env.into_iter().collect()),
|
||||
})
|
||||
})
|
||||
}
|
||||
}),
|
||||
Arc::new(ContextServerDescriptor { id, extension })
|
||||
as Arc<dyn registry::ContextServerDescriptor>,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
@@ -27,18 +27,27 @@ use project::Project;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::transport::Transport;
|
||||
use crate::{ContextServerSettings, ServerConfig};
|
||||
|
||||
use crate::{
|
||||
CONTEXT_SERVERS_NAMESPACE, ContextServerFactoryRegistry,
|
||||
CONTEXT_SERVERS_NAMESPACE, ContextServerDescriptorRegistry,
|
||||
client::{self, Client},
|
||||
types,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum ContextServerStatus {
|
||||
Starting,
|
||||
Running,
|
||||
Error(Arc<str>),
|
||||
}
|
||||
|
||||
pub struct ContextServer {
|
||||
pub id: Arc<str>,
|
||||
pub config: Arc<ServerConfig>,
|
||||
pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
|
||||
transport: Option<Arc<dyn Transport>>,
|
||||
}
|
||||
|
||||
impl ContextServer {
|
||||
@@ -47,9 +56,20 @@ impl ContextServer {
|
||||
id,
|
||||
config,
|
||||
client: RwLock::new(None),
|
||||
transport: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn test(id: Arc<str>, transport: Arc<dyn crate::transport::Transport>) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
id,
|
||||
client: RwLock::new(None),
|
||||
config: Arc::new(ServerConfig::default()),
|
||||
transport: Some(transport),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn id(&self) -> Arc<str> {
|
||||
self.id.clone()
|
||||
}
|
||||
@@ -63,20 +83,32 @@ impl ContextServer {
|
||||
}
|
||||
|
||||
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
|
||||
log::info!("starting context server {}", self.id);
|
||||
let Some(command) = &self.config.command else {
|
||||
bail!("no command specified for server {}", self.id);
|
||||
let client = if let Some(transport) = self.transport.clone() {
|
||||
Client::new(
|
||||
client::ContextServerId(self.id.clone()),
|
||||
self.id(),
|
||||
transport,
|
||||
cx.clone(),
|
||||
)?
|
||||
} else {
|
||||
let Some(command) = &self.config.command else {
|
||||
bail!("no command specified for server {}", self.id);
|
||||
};
|
||||
Client::stdio(
|
||||
client::ContextServerId(self.id.clone()),
|
||||
client::ModelContextServerBinary {
|
||||
executable: Path::new(&command.path).to_path_buf(),
|
||||
args: command.args.clone(),
|
||||
env: command.env.clone(),
|
||||
},
|
||||
cx.clone(),
|
||||
)?
|
||||
};
|
||||
let client = Client::new(
|
||||
client::ContextServerId(self.id.clone()),
|
||||
client::ModelContextServerBinary {
|
||||
executable: Path::new(&command.path).to_path_buf(),
|
||||
args: command.args.clone(),
|
||||
env: command.env.clone(),
|
||||
},
|
||||
cx.clone(),
|
||||
)?;
|
||||
self.initialize(client).await
|
||||
}
|
||||
|
||||
async fn initialize(&self, client: Client) -> Result<()> {
|
||||
log::info!("starting context server {}", self.id);
|
||||
let protocol = crate::protocol::ModelContextProtocol::new(client);
|
||||
let client_info = types::Implementation {
|
||||
name: "Zed".to_string(),
|
||||
@@ -105,23 +137,26 @@ impl ContextServer {
|
||||
|
||||
pub struct ContextServerManager {
|
||||
servers: HashMap<Arc<str>, Arc<ContextServer>>,
|
||||
server_status: HashMap<Arc<str>, ContextServerStatus>,
|
||||
project: Entity<Project>,
|
||||
registry: Entity<ContextServerFactoryRegistry>,
|
||||
registry: Entity<ContextServerDescriptorRegistry>,
|
||||
update_servers_task: Option<Task<Result<()>>>,
|
||||
needs_server_update: bool,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
pub enum Event {
|
||||
ServerStarted { server_id: Arc<str> },
|
||||
ServerStopped { server_id: Arc<str> },
|
||||
ServerStatusChanged {
|
||||
server_id: Arc<str>,
|
||||
status: Option<ContextServerStatus>,
|
||||
},
|
||||
}
|
||||
|
||||
impl EventEmitter<Event> for ContextServerManager {}
|
||||
|
||||
impl ContextServerManager {
|
||||
pub fn new(
|
||||
registry: Entity<ContextServerFactoryRegistry>,
|
||||
registry: Entity<ContextServerDescriptorRegistry>,
|
||||
project: Entity<Project>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
@@ -138,6 +173,7 @@ impl ContextServerManager {
|
||||
registry,
|
||||
needs_server_update: false,
|
||||
servers: HashMap::default(),
|
||||
server_status: HashMap::default(),
|
||||
update_servers_task: None,
|
||||
};
|
||||
this.available_context_servers_changed(cx);
|
||||
@@ -153,7 +189,9 @@ impl ContextServerManager {
|
||||
this.needs_server_update = false;
|
||||
})?;
|
||||
|
||||
Self::maintain_servers(this.clone(), cx).await?;
|
||||
if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
|
||||
log::error!("Error maintaining context servers: {}", err);
|
||||
}
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
let has_any_context_servers = !this.running_servers().is_empty();
|
||||
@@ -181,52 +219,37 @@ impl ContextServerManager {
|
||||
.cloned()
|
||||
}
|
||||
|
||||
pub fn status_for_server(&self, id: &str) -> Option<ContextServerStatus> {
|
||||
self.server_status.get(id).cloned()
|
||||
}
|
||||
|
||||
pub fn start_server(
|
||||
&self,
|
||||
server: Arc<ContextServer>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<anyhow::Result<()>> {
|
||||
cx.spawn(async move |this, cx| {
|
||||
let id = server.id.clone();
|
||||
server.start(&cx).await?;
|
||||
this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?;
|
||||
Ok(())
|
||||
})
|
||||
) -> Task<Result<()>> {
|
||||
cx.spawn(async move |this, cx| Self::run_server(this, server, cx).await)
|
||||
}
|
||||
|
||||
pub fn stop_server(
|
||||
&self,
|
||||
&mut self,
|
||||
server: Arc<ContextServer>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> anyhow::Result<()> {
|
||||
server.stop()?;
|
||||
cx.emit(Event::ServerStopped {
|
||||
server_id: server.id(),
|
||||
});
|
||||
) -> Result<()> {
|
||||
server.stop().log_err();
|
||||
self.update_server_status(server.id().clone(), None, cx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn restart_server(
|
||||
&mut self,
|
||||
id: &Arc<str>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<anyhow::Result<()>> {
|
||||
pub fn restart_server(&mut self, id: &Arc<str>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let id = id.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
if let Some(server) = this.update(cx, |this, _cx| this.servers.remove(&id))? {
|
||||
server.stop()?;
|
||||
let config = server.config();
|
||||
|
||||
this.update(cx, |this, cx| this.stop_server(server, cx))??;
|
||||
let new_server = Arc::new(ContextServer::new(id.clone(), config));
|
||||
new_server.clone().start(&cx).await?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.servers.insert(id.clone(), new_server);
|
||||
cx.emit(Event::ServerStopped {
|
||||
server_id: id.clone(),
|
||||
});
|
||||
cx.emit(Event::ServerStarted {
|
||||
server_id: id.clone(),
|
||||
});
|
||||
})?;
|
||||
Self::run_server(this, new_server, cx).await?;
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
@@ -263,12 +286,14 @@ impl ContextServerManager {
|
||||
(this.registry.clone(), this.project.clone())
|
||||
})?;
|
||||
|
||||
for (id, factory) in
|
||||
registry.read_with(cx, |registry, _| registry.context_server_factories())?
|
||||
for (id, descriptor) in
|
||||
registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
|
||||
{
|
||||
let config = desired_servers.entry(id).or_default();
|
||||
if config.command.is_none() {
|
||||
if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
|
||||
if let Some(extension_command) =
|
||||
descriptor.command(project.clone(), &cx).await.log_err()
|
||||
{
|
||||
config.command = Some(extension_command);
|
||||
}
|
||||
}
|
||||
@@ -290,28 +315,270 @@ impl ContextServerManager {
|
||||
for (id, config) in desired_servers {
|
||||
let existing_config = this.servers.get(&id).map(|server| server.config());
|
||||
if existing_config.as_deref() != Some(&config) {
|
||||
let config = Arc::new(config);
|
||||
let server = Arc::new(ContextServer::new(id.clone(), config));
|
||||
let server = Arc::new(ContextServer::new(id.clone(), Arc::new(config)));
|
||||
servers_to_start.insert(id.clone(), server.clone());
|
||||
let old_server = this.servers.insert(id.clone(), server);
|
||||
if let Some(old_server) = old_server {
|
||||
if let Some(old_server) = this.servers.remove(&id) {
|
||||
servers_to_stop.insert(id, old_server);
|
||||
}
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
for (id, server) in servers_to_stop {
|
||||
server.stop().log_err();
|
||||
this.update(cx, |_, cx| cx.emit(Event::ServerStopped { server_id: id }))?;
|
||||
for (_, server) in servers_to_stop {
|
||||
this.update(cx, |this, cx| this.stop_server(server, cx).ok())?;
|
||||
}
|
||||
|
||||
for (id, server) in servers_to_start {
|
||||
if server.start(&cx).await.log_err().is_some() {
|
||||
this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?;
|
||||
}
|
||||
for (_, server) in servers_to_start {
|
||||
Self::run_server(this.clone(), server, cx).await.ok();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_server(
|
||||
this: WeakEntity<Self>,
|
||||
server: Arc<ContextServer>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<()> {
|
||||
let id = server.id();
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_server_status(id.clone(), Some(ContextServerStatus::Starting), cx);
|
||||
this.servers.insert(id.clone(), server.clone());
|
||||
})?;
|
||||
|
||||
match server.start(&cx).await {
|
||||
Ok(_) => {
|
||||
log::debug!("`{}` context server started", id);
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_server_status(id.clone(), Some(ContextServerStatus::Running), cx)
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("`{}` context server failed to start\n{}", id, err);
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_server_status(
|
||||
id.clone(),
|
||||
Some(ContextServerStatus::Error(err.to_string().into())),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_server_status(
|
||||
&mut self,
|
||||
id: Arc<str>,
|
||||
status: Option<ContextServerStatus>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(status) = status.clone() {
|
||||
self.server_status.insert(id.clone(), status);
|
||||
} else {
|
||||
self.server_status.remove(&id);
|
||||
}
|
||||
|
||||
cx.emit(Event::ServerStatusChanged {
|
||||
server_id: id,
|
||||
status,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::pin::Pin;
|
||||
|
||||
use crate::types::{
|
||||
Implementation, InitializeResponse, ProtocolVersion, RequestType, ServerCapabilities,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use futures::{Stream, StreamExt as _, lock::Mutex};
|
||||
use gpui::{AppContext as _, TestAppContext};
|
||||
use project::FakeFs;
|
||||
use serde_json::json;
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_context_server_status(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
let project = create_test_project(cx, json!({"code.rs": ""})).await;
|
||||
|
||||
let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
|
||||
let manager = cx.new(|cx| ContextServerManager::new(registry.clone(), project, cx));
|
||||
|
||||
let server_1_id: Arc<str> = "mcp-1".into();
|
||||
let server_2_id: Arc<str> = "mcp-2".into();
|
||||
|
||||
let transport_1 = Arc::new(FakeTransport::new(
|
||||
|_, request_type, _| match request_type {
|
||||
Some(RequestType::Initialize) => {
|
||||
Some(create_initialize_response("mcp-1".to_string()))
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
));
|
||||
|
||||
let transport_2 = Arc::new(FakeTransport::new(
|
||||
|_, request_type, _| match request_type {
|
||||
Some(RequestType::Initialize) => {
|
||||
Some(create_initialize_response("mcp-2".to_string()))
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
));
|
||||
|
||||
let server_1 = ContextServer::test(server_1_id.clone(), transport_1.clone());
|
||||
let server_2 = ContextServer::test(server_2_id.clone(), transport_2.clone());
|
||||
|
||||
manager
|
||||
.update(cx, |manager, cx| manager.start_server(server_1, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
manager.read(cx).status_for_server(&server_1_id),
|
||||
Some(ContextServerStatus::Running)
|
||||
);
|
||||
assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
|
||||
});
|
||||
|
||||
manager
|
||||
.update(cx, |manager, cx| manager.start_server(server_2.clone(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
manager.read(cx).status_for_server(&server_1_id),
|
||||
Some(ContextServerStatus::Running)
|
||||
);
|
||||
assert_eq!(
|
||||
manager.read(cx).status_for_server(&server_2_id),
|
||||
Some(ContextServerStatus::Running)
|
||||
);
|
||||
});
|
||||
|
||||
manager
|
||||
.update(cx, |manager, cx| manager.stop_server(server_2, cx))
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
manager.read(cx).status_for_server(&server_1_id),
|
||||
Some(ContextServerStatus::Running)
|
||||
);
|
||||
assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
|
||||
});
|
||||
}
|
||||
|
||||
async fn create_test_project(
|
||||
cx: &mut TestAppContext,
|
||||
files: serde_json::Value,
|
||||
) -> Entity<Project> {
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/test"), files).await;
|
||||
Project::test(fs, [path!("/test").as_ref()], cx).await
|
||||
}
|
||||
|
||||
fn init_test_settings(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
ContextServerSettings::register(cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn create_initialize_response(server_name: String) -> serde_json::Value {
|
||||
serde_json::to_value(&InitializeResponse {
|
||||
protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
|
||||
server_info: Implementation {
|
||||
name: server_name,
|
||||
version: "1.0.0".to_string(),
|
||||
},
|
||||
capabilities: ServerCapabilities::default(),
|
||||
meta: None,
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
struct FakeTransport {
|
||||
on_request: Arc<
|
||||
dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
|
||||
+ Send
|
||||
+ Sync,
|
||||
>,
|
||||
tx: futures::channel::mpsc::UnboundedSender<String>,
|
||||
rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
|
||||
}
|
||||
|
||||
impl FakeTransport {
|
||||
fn new(
|
||||
on_request: impl Fn(
|
||||
u64,
|
||||
Option<RequestType>,
|
||||
serde_json::Value,
|
||||
) -> Option<serde_json::Value>
|
||||
+ 'static
|
||||
+ Send
|
||||
+ Sync,
|
||||
) -> Self {
|
||||
let (tx, rx) = futures::channel::mpsc::unbounded();
|
||||
Self {
|
||||
on_request: Arc::new(on_request),
|
||||
tx,
|
||||
rx: Arc::new(Mutex::new(rx)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for FakeTransport {
|
||||
async fn send(&self, message: String) -> Result<()> {
|
||||
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
|
||||
let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
|
||||
|
||||
if let Some(method) = msg.get("method") {
|
||||
let request_type = method
|
||||
.as_str()
|
||||
.and_then(|method| types::RequestType::try_from(method).ok());
|
||||
if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
|
||||
let response = serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": payload
|
||||
});
|
||||
|
||||
self.tx
|
||||
.unbounded_send(response.to_string())
|
||||
.map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
|
||||
let rx = self.rx.clone();
|
||||
Box::pin(futures::stream::unfold(rx, |rx| async move {
|
||||
let mut rx_guard = rx.lock().await;
|
||||
if let Some(message) = rx_guard.next().await {
|
||||
drop(rx_guard);
|
||||
Some((message, rx))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
|
||||
Box::pin(futures::stream::empty())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,38 +2,47 @@ use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use extension::ContextServerConfiguration;
|
||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Global, ReadGlobal, Task};
|
||||
use project::Project;
|
||||
|
||||
use crate::ServerCommand;
|
||||
|
||||
pub type ContextServerFactory =
|
||||
Arc<dyn Fn(Entity<Project>, &AsyncApp) -> Task<Result<ServerCommand>> + Send + Sync + 'static>;
|
||||
|
||||
struct GlobalContextServerFactoryRegistry(Entity<ContextServerFactoryRegistry>);
|
||||
|
||||
impl Global for GlobalContextServerFactoryRegistry {}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ContextServerFactoryRegistry {
|
||||
context_servers: HashMap<Arc<str>, ContextServerFactory>,
|
||||
pub trait ContextServerDescriptor {
|
||||
fn command(&self, project: Entity<Project>, cx: &AsyncApp) -> Task<Result<ServerCommand>>;
|
||||
fn configuration(
|
||||
&self,
|
||||
project: Entity<Project>,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<Option<ContextServerConfiguration>>>;
|
||||
}
|
||||
|
||||
impl ContextServerFactoryRegistry {
|
||||
/// Returns the global [`ContextServerFactoryRegistry`].
|
||||
struct GlobalContextServerDescriptorRegistry(Entity<ContextServerDescriptorRegistry>);
|
||||
|
||||
impl Global for GlobalContextServerDescriptorRegistry {}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ContextServerDescriptorRegistry {
|
||||
context_servers: HashMap<Arc<str>, Arc<dyn ContextServerDescriptor>>,
|
||||
}
|
||||
|
||||
impl ContextServerDescriptorRegistry {
|
||||
/// Returns the global [`ContextServerDescriptorRegistry`].
|
||||
pub fn global(cx: &App) -> Entity<Self> {
|
||||
GlobalContextServerFactoryRegistry::global(cx).0.clone()
|
||||
GlobalContextServerDescriptorRegistry::global(cx).0.clone()
|
||||
}
|
||||
|
||||
/// Returns the global [`ContextServerFactoryRegistry`].
|
||||
/// Returns the global [`ContextServerDescriptorRegistry`].
|
||||
///
|
||||
/// Inserts a default [`ContextServerFactoryRegistry`] if one does not yet exist.
|
||||
/// Inserts a default [`ContextServerDescriptorRegistry`] if one does not yet exist.
|
||||
pub fn default_global(cx: &mut App) -> Entity<Self> {
|
||||
if !cx.has_global::<GlobalContextServerFactoryRegistry>() {
|
||||
if !cx.has_global::<GlobalContextServerDescriptorRegistry>() {
|
||||
let registry = cx.new(|_| Self::new());
|
||||
cx.set_global(GlobalContextServerFactoryRegistry(registry));
|
||||
cx.set_global(GlobalContextServerDescriptorRegistry(registry));
|
||||
}
|
||||
cx.global::<GlobalContextServerFactoryRegistry>().0.clone()
|
||||
cx.global::<GlobalContextServerDescriptorRegistry>()
|
||||
.0
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub fn new() -> Self {
|
||||
@@ -42,20 +51,28 @@ impl ContextServerFactoryRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context_server_factories(&self) -> Vec<(Arc<str>, ContextServerFactory)> {
|
||||
pub fn context_server_descriptors(&self) -> Vec<(Arc<str>, Arc<dyn ContextServerDescriptor>)> {
|
||||
self.context_servers
|
||||
.iter()
|
||||
.map(|(id, factory)| (id.clone(), factory.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Registers the provided [`ContextServerFactory`].
|
||||
pub fn register_server_factory(&mut self, id: Arc<str>, factory: ContextServerFactory) {
|
||||
self.context_servers.insert(id, factory);
|
||||
pub fn context_server_descriptor(&self, id: &str) -> Option<Arc<dyn ContextServerDescriptor>> {
|
||||
self.context_servers.get(id).cloned()
|
||||
}
|
||||
|
||||
/// Unregisters the [`ContextServerFactory`] for the server with the given ID.
|
||||
pub fn unregister_server_factory_by_id(&mut self, server_id: &str) {
|
||||
/// Registers the provided [`ContextServerDescriptor`].
|
||||
pub fn register_context_server_descriptor(
|
||||
&mut self,
|
||||
id: Arc<str>,
|
||||
descriptor: Arc<dyn ContextServerDescriptor>,
|
||||
) {
|
||||
self.context_servers.insert(id, descriptor);
|
||||
}
|
||||
|
||||
/// Unregisters the [`ContextServerDescriptor`] for the server with the given ID.
|
||||
pub fn unregister_context_server_descriptor_by_id(&mut self, server_id: &str) {
|
||||
self.context_servers.remove(server_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,6 +42,30 @@ impl RequestType {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for RequestType {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(s: &str) -> Result<Self, Self::Error> {
|
||||
match s {
|
||||
"initialize" => Ok(RequestType::Initialize),
|
||||
"tools/call" => Ok(RequestType::CallTool),
|
||||
"resources/unsubscribe" => Ok(RequestType::ResourcesUnsubscribe),
|
||||
"resources/subscribe" => Ok(RequestType::ResourcesSubscribe),
|
||||
"resources/read" => Ok(RequestType::ResourcesRead),
|
||||
"resources/list" => Ok(RequestType::ResourcesList),
|
||||
"logging/setLevel" => Ok(RequestType::LoggingSetLevel),
|
||||
"prompts/get" => Ok(RequestType::PromptsGet),
|
||||
"prompts/list" => Ok(RequestType::PromptsList),
|
||||
"completion/complete" => Ok(RequestType::CompletionComplete),
|
||||
"ping" => Ok(RequestType::Ping),
|
||||
"tools/list" => Ok(RequestType::ListTools),
|
||||
"resources/templates/list" => Ok(RequestType::ListResourceTemplates),
|
||||
"roots/list" => Ok(RequestType::ListRoots),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct ProtocolVersion(pub String);
|
||||
@@ -154,7 +178,7 @@ pub struct CompletionArgument {
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InitializeResponse {
|
||||
pub protocol_version: ProtocolVersion,
|
||||
@@ -343,7 +367,7 @@ pub struct ClientCapabilities {
|
||||
pub roots: Option<RootsCapabilities>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Default, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ServerCapabilities {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
|
||||
@@ -444,10 +444,7 @@ pub trait DebugAdapter: 'static + Send + Sync {
|
||||
log::info!("Getting latest version of debug adapter {}", self.name());
|
||||
delegate.update_status(self.name(), DapStatus::CheckingForUpdate);
|
||||
if let Some(version) = self.fetch_latest_adapter_version(delegate).await.log_err() {
|
||||
log::info!(
|
||||
"Installiing latest version of debug adapter {}",
|
||||
self.name()
|
||||
);
|
||||
log::info!("Installing latest version of debug adapter {}", self.name());
|
||||
delegate.update_status(self.name(), DapStatus::Downloading);
|
||||
match self.install_binary(version, delegate).await {
|
||||
Ok(_) => {
|
||||
|
||||
@@ -7,6 +7,7 @@ use crate::{
|
||||
};
|
||||
use crate::{new_session_modal::NewSessionModal, session::DebugSession};
|
||||
use anyhow::{Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
use dap::DebugRequest;
|
||||
use dap::{
|
||||
@@ -26,6 +27,7 @@ use project::{Project, debugger::session::ThreadStatus};
|
||||
use rpc::proto::{self};
|
||||
use settings::Settings;
|
||||
use std::any::TypeId;
|
||||
use std::path::PathBuf;
|
||||
use task::{DebugScenario, TaskContext};
|
||||
use ui::{ContextMenu, Divider, DropdownMenu, Tooltip, prelude::*};
|
||||
use workspace::SplitDirection;
|
||||
@@ -403,7 +405,6 @@ impl DebugPanel {
|
||||
pub fn resolve_scenario(
|
||||
&self,
|
||||
scenario: DebugScenario,
|
||||
|
||||
task_context: TaskContext,
|
||||
buffer: Option<Entity<Buffer>>,
|
||||
window: &Window,
|
||||
@@ -424,8 +425,60 @@ impl DebugPanel {
|
||||
stop_on_entry,
|
||||
} = scenario;
|
||||
let request = if let Some(mut request) = request {
|
||||
// Resolve task variables within the request.
|
||||
if let DebugRequest::Launch(_) = &mut request {}
|
||||
if let DebugRequest::Launch(launch_config) = &mut request {
|
||||
let mut variable_names = HashMap::default();
|
||||
let mut substituted_variables = HashSet::default();
|
||||
let task_variables = task_context
|
||||
.task_variables
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
let key_string = key.to_string();
|
||||
if !variable_names.contains_key(&key_string) {
|
||||
variable_names.insert(key_string.clone(), key.clone());
|
||||
}
|
||||
(key_string, value.as_str())
|
||||
})
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let cwd = launch_config
|
||||
.cwd
|
||||
.as_ref()
|
||||
.and_then(|cwd| cwd.to_str())
|
||||
.and_then(|cwd| {
|
||||
task::substitute_all_template_variables_in_str(
|
||||
cwd,
|
||||
&task_variables,
|
||||
&variable_names,
|
||||
&mut substituted_variables,
|
||||
)
|
||||
});
|
||||
|
||||
if let Some(cwd) = cwd {
|
||||
launch_config.cwd = Some(PathBuf::from(cwd))
|
||||
}
|
||||
|
||||
if let Some(program) = task::substitute_all_template_variables_in_str(
|
||||
&launch_config.program,
|
||||
&task_variables,
|
||||
&variable_names,
|
||||
&mut substituted_variables,
|
||||
) {
|
||||
launch_config.program = program;
|
||||
}
|
||||
|
||||
for arg in launch_config.args.iter_mut() {
|
||||
if let Some(substituted_arg) =
|
||||
task::substitute_all_template_variables_in_str(
|
||||
&arg,
|
||||
&task_variables,
|
||||
&variable_names,
|
||||
&mut substituted_variables,
|
||||
)
|
||||
{
|
||||
*arg = substituted_arg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
request
|
||||
} else if let Some(build) = build {
|
||||
@@ -944,6 +997,7 @@ impl DebugPanel {
|
||||
past_debug_definition,
|
||||
weak_panel,
|
||||
workspace,
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
|
||||
@@ -158,6 +158,7 @@ pub fn init(cx: &mut App) {
|
||||
debug_panel.read(cx).past_debug_definition.clone(),
|
||||
weak_panel,
|
||||
weak_workspace,
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
@@ -166,14 +167,22 @@ pub fn init(cx: &mut App) {
|
||||
},
|
||||
)
|
||||
.register_action(|workspace: &mut Workspace, _: &Start, window, cx| {
|
||||
tasks_ui::toggle_modal(
|
||||
workspace,
|
||||
None,
|
||||
task::TaskModal::DebugModal,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.detach();
|
||||
if let Some(debug_panel) = workspace.panel::<DebugPanel>(cx) {
|
||||
let weak_panel = debug_panel.downgrade();
|
||||
let weak_workspace = cx.weak_entity();
|
||||
let task_store = workspace.project().read(cx).task_store().clone();
|
||||
|
||||
workspace.toggle_modal(window, cx, |window, cx| {
|
||||
NewSessionModal::new(
|
||||
debug_panel.read(cx).past_debug_definition.clone(),
|
||||
weak_panel,
|
||||
weak_workspace,
|
||||
Some(task_store),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
}
|
||||
});
|
||||
})
|
||||
})
|
||||
|
||||
@@ -6,19 +6,25 @@ use std::{
|
||||
|
||||
use dap::{DapRegistry, DebugRequest, adapters::DebugTaskDefinition};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use fuzzy::{StringMatch, StringMatchCandidate};
|
||||
use gpui::{
|
||||
App, AppContext, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, TextStyle,
|
||||
WeakEntity,
|
||||
App, AppContext, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render,
|
||||
Subscription, TextStyle, WeakEntity,
|
||||
};
|
||||
use picker::{Picker, PickerDelegate, highlighted_match_with_paths::HighlightedMatch};
|
||||
use project::{TaskSourceKind, task_store::TaskStore};
|
||||
use session_modes::{AttachMode, DebugScenarioDelegate, LaunchMode};
|
||||
use settings::Settings;
|
||||
use task::{DebugScenario, LaunchRequest, TaskContext};
|
||||
use task::{DebugScenario, LaunchRequest};
|
||||
use theme::ThemeSettings;
|
||||
use ui::{
|
||||
ActiveTheme, Button, ButtonCommon, ButtonSize, CheckboxWithLabel, Clickable, Color, Context,
|
||||
ContextMenu, Disableable, DropdownMenu, FluentBuilder, InteractiveElement, IntoElement, Label,
|
||||
LabelCommon as _, ParentElement, RenderOnce, SharedString, Styled, StyledExt, ToggleButton,
|
||||
ToggleState, Toggleable, Window, div, h_flex, relative, rems, v_flex,
|
||||
ContextMenu, Disableable, DropdownMenu, FluentBuilder, Icon, IconName, InteractiveElement,
|
||||
IntoElement, Label, LabelCommon as _, ListItem, ListItemSpacing, ParentElement, RenderOnce,
|
||||
SharedString, Styled, StyledExt, ToggleButton, ToggleState, Toggleable, Window, div, h_flex,
|
||||
relative, rems, v_flex,
|
||||
};
|
||||
use util::ResultExt;
|
||||
use workspace::{ModalView, Workspace};
|
||||
|
||||
use crate::{attach_modal::AttachModal, debugger_panel::DebugPanel};
|
||||
@@ -57,6 +63,7 @@ impl NewSessionModal {
|
||||
past_debug_definition: Option<DebugTaskDefinition>,
|
||||
debug_panel: WeakEntity<DebugPanel>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
task_store: Option<Entity<TaskStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
@@ -73,6 +80,18 @@ impl NewSessionModal {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if let Some(task_store) = task_store {
|
||||
cx.defer_in(window, |this, window, cx| {
|
||||
this.mode = NewSessionMode::scenario(
|
||||
this.debug_panel.clone(),
|
||||
this.workspace.clone(),
|
||||
task_store,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
Self {
|
||||
workspace: workspace.clone(),
|
||||
debugger,
|
||||
@@ -86,10 +105,10 @@ impl NewSessionModal {
|
||||
}
|
||||
}
|
||||
|
||||
fn debug_config(&self, cx: &App, debugger: &str) -> DebugScenario {
|
||||
let request = self.mode.debug_task(cx);
|
||||
fn debug_config(&self, cx: &App, debugger: &str) -> Option<DebugScenario> {
|
||||
let request = self.mode.debug_task(cx)?;
|
||||
let label = suggested_label(&request, debugger);
|
||||
DebugScenario {
|
||||
Some(DebugScenario {
|
||||
adapter: debugger.to_owned().into(),
|
||||
label,
|
||||
request: Some(request),
|
||||
@@ -100,21 +119,42 @@ impl NewSessionModal {
|
||||
_ => None,
|
||||
},
|
||||
build: None,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn start_new_session(&self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let Some(debugger) = self.debugger.as_ref() else {
|
||||
// todo: show in UI.
|
||||
// todo(debugger): show in UI.
|
||||
log::error!("No debugger selected");
|
||||
return;
|
||||
};
|
||||
let config = self.debug_config(cx, debugger);
|
||||
|
||||
if let NewSessionMode::Scenario(picker) = &self.mode {
|
||||
picker.update(cx, |picker, cx| {
|
||||
picker.delegate.confirm(false, window, cx);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(config) = self.debug_config(cx, debugger) else {
|
||||
log::error!("debug config not found in mode: {}", self.mode);
|
||||
return;
|
||||
};
|
||||
|
||||
let debug_panel = self.debug_panel.clone();
|
||||
let workspace = self.workspace.clone();
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let task_contexts = workspace
|
||||
.update_in(cx, |workspace, window, cx| {
|
||||
tasks_ui::task_contexts(workspace, window, cx)
|
||||
})?
|
||||
.await;
|
||||
|
||||
let task_context = task_contexts.active_context().cloned().unwrap_or_default();
|
||||
|
||||
debug_panel.update_in(cx, |debug_panel, window, cx| {
|
||||
debug_panel.start_session(config, TaskContext::default(), None, window, cx)
|
||||
debug_panel.start_session(config, task_context, None, window, cx)
|
||||
})?;
|
||||
this.update(cx, |_, cx| {
|
||||
cx.emit(DismissEvent);
|
||||
@@ -256,9 +296,14 @@ impl NewSessionModal {
|
||||
.iter()
|
||||
.flat_map(|task_inventory| {
|
||||
task_inventory.read(cx).list_debug_scenarios(
|
||||
worktree.as_ref().map(|worktree| worktree.read(cx).id()),
|
||||
worktree
|
||||
.as_ref()
|
||||
.map(|worktree| worktree.read(cx).id())
|
||||
.iter()
|
||||
.copied(),
|
||||
)
|
||||
})
|
||||
.map(|(_source_kind, scenario)| scenario)
|
||||
.collect()
|
||||
})
|
||||
.ok()
|
||||
@@ -277,102 +322,22 @@ impl NewSessionModal {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct LaunchMode {
|
||||
program: Entity<Editor>,
|
||||
cwd: Entity<Editor>,
|
||||
}
|
||||
|
||||
impl LaunchMode {
|
||||
fn new(
|
||||
past_launch_config: Option<LaunchRequest>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Entity<Self> {
|
||||
let (past_program, past_cwd) = past_launch_config
|
||||
.map(|config| (Some(config.program), config.cwd))
|
||||
.unwrap_or_else(|| (None, None));
|
||||
|
||||
let program = cx.new(|cx| Editor::single_line(window, cx));
|
||||
program.update(cx, |this, cx| {
|
||||
this.set_placeholder_text("Program path", cx);
|
||||
|
||||
if let Some(past_program) = past_program {
|
||||
this.set_text(past_program, window, cx);
|
||||
};
|
||||
});
|
||||
let cwd = cx.new(|cx| Editor::single_line(window, cx));
|
||||
cwd.update(cx, |this, cx| {
|
||||
this.set_placeholder_text("Working Directory", cx);
|
||||
if let Some(past_cwd) = past_cwd {
|
||||
this.set_text(past_cwd.to_string_lossy(), window, cx);
|
||||
};
|
||||
});
|
||||
cx.new(|_| Self { program, cwd })
|
||||
}
|
||||
|
||||
fn debug_task(&self, cx: &App) -> task::LaunchRequest {
|
||||
let path = self.cwd.read(cx).text(cx);
|
||||
task::LaunchRequest {
|
||||
program: self.program.read(cx).text(cx),
|
||||
cwd: path.is_empty().not().then(|| PathBuf::from(path)),
|
||||
args: Default::default(),
|
||||
env: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AttachMode {
|
||||
definition: DebugTaskDefinition,
|
||||
attach_picker: Entity<AttachModal>,
|
||||
}
|
||||
|
||||
impl AttachMode {
|
||||
fn new(
|
||||
debugger: Option<SharedString>,
|
||||
workspace: Entity<Workspace>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<NewSessionModal>,
|
||||
) -> Entity<Self> {
|
||||
let definition = DebugTaskDefinition {
|
||||
adapter: debugger.clone().unwrap_or_default(),
|
||||
label: "Attach New Session Setup".into(),
|
||||
request: dap::DebugRequest::Attach(task::AttachRequest { process_id: None }),
|
||||
initialize_args: None,
|
||||
tcp_connection: None,
|
||||
stop_on_entry: Some(false),
|
||||
};
|
||||
let attach_picker = cx.new(|cx| {
|
||||
let modal = AttachModal::new(definition.clone(), workspace, false, window, cx);
|
||||
window.focus(&modal.focus_handle(cx));
|
||||
|
||||
modal
|
||||
});
|
||||
cx.new(|_| Self {
|
||||
definition,
|
||||
attach_picker,
|
||||
})
|
||||
}
|
||||
fn debug_task(&self) -> task::AttachRequest {
|
||||
task::AttachRequest { process_id: None }
|
||||
}
|
||||
}
|
||||
|
||||
static SELECT_DEBUGGER_LABEL: SharedString = SharedString::new_static("Select Debugger");
|
||||
static SELECT_SCENARIO_LABEL: SharedString = SharedString::new_static("Select Profile");
|
||||
|
||||
#[derive(Clone)]
|
||||
enum NewSessionMode {
|
||||
Launch(Entity<LaunchMode>),
|
||||
Scenario(Entity<Picker<DebugScenarioDelegate>>),
|
||||
Attach(Entity<AttachMode>),
|
||||
}
|
||||
|
||||
impl NewSessionMode {
|
||||
fn debug_task(&self, cx: &App) -> DebugRequest {
|
||||
fn debug_task(&self, cx: &App) -> Option<DebugRequest> {
|
||||
match self {
|
||||
NewSessionMode::Launch(entity) => entity.read(cx).debug_task(cx).into(),
|
||||
NewSessionMode::Attach(entity) => entity.read(cx).debug_task().into(),
|
||||
NewSessionMode::Launch(entity) => Some(entity.read(cx).debug_task(cx).into()),
|
||||
NewSessionMode::Attach(entity) => Some(entity.read(cx).debug_task().into()),
|
||||
NewSessionMode::Scenario(_) => None,
|
||||
}
|
||||
}
|
||||
fn as_attach(&self) -> Option<&Entity<AttachMode>> {
|
||||
@@ -382,6 +347,78 @@ impl NewSessionMode {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn scenario(
|
||||
debug_panel: WeakEntity<DebugPanel>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
task_store: Entity<TaskStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<NewSessionModal>,
|
||||
) -> NewSessionMode {
|
||||
let picker = cx.new(|cx| {
|
||||
Picker::uniform_list(
|
||||
DebugScenarioDelegate::new(debug_panel, workspace, task_store),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.modal(false)
|
||||
});
|
||||
|
||||
cx.subscribe(&picker, |_, _, _, cx| {
|
||||
cx.emit(DismissEvent);
|
||||
})
|
||||
.detach();
|
||||
|
||||
picker.focus_handle(cx).focus(window);
|
||||
NewSessionMode::Scenario(picker)
|
||||
}
|
||||
|
||||
fn attach(
|
||||
debugger: Option<SharedString>,
|
||||
workspace: Entity<Workspace>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<NewSessionModal>,
|
||||
) -> Self {
|
||||
Self::Attach(AttachMode::new(debugger, workspace, window, cx))
|
||||
}
|
||||
|
||||
fn launch(
|
||||
past_launch_config: Option<LaunchRequest>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<NewSessionModal>,
|
||||
) -> Self {
|
||||
Self::Launch(LaunchMode::new(past_launch_config, window, cx))
|
||||
}
|
||||
|
||||
fn has_match(&self, cx: &App) -> bool {
|
||||
match self {
|
||||
NewSessionMode::Scenario(picker) => picker.read(cx).delegate.match_count() > 0,
|
||||
NewSessionMode::Attach(picker) => {
|
||||
picker
|
||||
.read(cx)
|
||||
.attach_picker
|
||||
.read(cx)
|
||||
.picker
|
||||
.read(cx)
|
||||
.delegate
|
||||
.match_count()
|
||||
> 0
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for NewSessionMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let mode = match self {
|
||||
NewSessionMode::Launch(_) => "launch".to_owned(),
|
||||
NewSessionMode::Attach(_) => "attach".to_owned(),
|
||||
NewSessionMode::Scenario(_) => "scenario picker".to_owned(),
|
||||
};
|
||||
|
||||
write!(f, "{}", mode)
|
||||
}
|
||||
}
|
||||
|
||||
impl Focusable for NewSessionMode {
|
||||
@@ -389,6 +426,7 @@ impl Focusable for NewSessionMode {
|
||||
match &self {
|
||||
NewSessionMode::Launch(entity) => entity.read(cx).program.focus_handle(cx),
|
||||
NewSessionMode::Attach(entity) => entity.read(cx).attach_picker.focus_handle(cx),
|
||||
NewSessionMode::Scenario(entity) => entity.read(cx).focus_handle(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -437,27 +475,14 @@ impl RenderOnce for NewSessionMode {
|
||||
NewSessionMode::Attach(entity) => entity.update(cx, |this, cx| {
|
||||
this.clone().render(window, cx).into_any_element()
|
||||
}),
|
||||
NewSessionMode::Scenario(entity) => v_flex()
|
||||
.w(rems(34.))
|
||||
.child(entity.clone())
|
||||
.into_any_element(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl NewSessionMode {
|
||||
fn attach(
|
||||
debugger: Option<SharedString>,
|
||||
workspace: Entity<Workspace>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<NewSessionModal>,
|
||||
) -> Self {
|
||||
Self::Attach(AttachMode::new(debugger, workspace, window, cx))
|
||||
}
|
||||
fn launch(
|
||||
past_launch_config: Option<LaunchRequest>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<NewSessionModal>,
|
||||
) -> Self {
|
||||
Self::Launch(LaunchMode::new(past_launch_config, window, cx))
|
||||
}
|
||||
}
|
||||
fn render_editor(editor: &Entity<Editor>, window: &mut Window, cx: &App) -> impl IntoElement {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let theme = cx.theme();
|
||||
@@ -519,6 +544,34 @@ impl Render for NewSessionModal {
|
||||
h_flex()
|
||||
.justify_start()
|
||||
.w_full()
|
||||
.child(
|
||||
ToggleButton::new("debugger-session-ui-picker-button", "Scenarios")
|
||||
.size(ButtonSize::Default)
|
||||
.style(ui::ButtonStyle::Subtle)
|
||||
.toggle_state(matches!(self.mode, NewSessionMode::Scenario(_)))
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
let Some(task_store) = this
|
||||
.workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
workspace.project().read(cx).task_store().clone()
|
||||
})
|
||||
.ok()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
this.mode = NewSessionMode::scenario(
|
||||
this.debug_panel.clone(),
|
||||
this.workspace.clone(),
|
||||
task_store,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
cx.notify();
|
||||
}))
|
||||
.first(),
|
||||
)
|
||||
.child(
|
||||
ToggleButton::new(
|
||||
"debugger-session-ui-launch-button",
|
||||
@@ -532,7 +585,7 @@ impl Render for NewSessionModal {
|
||||
this.mode.focus_handle(cx).focus(window);
|
||||
cx.notify();
|
||||
}))
|
||||
.first(),
|
||||
.middle(),
|
||||
)
|
||||
.child(
|
||||
ToggleButton::new(
|
||||
@@ -601,10 +654,21 @@ impl Render for NewSessionModal {
|
||||
})
|
||||
.child(
|
||||
Button::new("debugger-spawn", "Start")
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
this.start_new_session(window, cx);
|
||||
.on_click(cx.listener(|this, _, window, cx| match &this.mode {
|
||||
NewSessionMode::Scenario(picker) => {
|
||||
picker.update(cx, |picker, cx| {
|
||||
picker.delegate.confirm(true, window, cx)
|
||||
})
|
||||
}
|
||||
_ => this.start_new_session(window, cx),
|
||||
}))
|
||||
.disabled(self.debugger.is_none()),
|
||||
.disabled(match self.mode {
|
||||
NewSessionMode::Scenario(_) => !self.mode.has_match(cx),
|
||||
NewSessionMode::Attach(_) => {
|
||||
self.debugger.is_none() || !self.mode.has_match(cx)
|
||||
}
|
||||
NewSessionMode::Launch(_) => self.debugger.is_none(),
|
||||
}),
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -619,3 +683,319 @@ impl Focusable for NewSessionModal {
|
||||
}
|
||||
|
||||
impl ModalView for NewSessionModal {}
|
||||
|
||||
// This module makes sure that the modes setup the correct subscriptions whenever they're created
|
||||
mod session_modes {
|
||||
use std::rc::Rc;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
#[non_exhaustive]
|
||||
pub(super) struct LaunchMode {
|
||||
pub(super) program: Entity<Editor>,
|
||||
pub(super) cwd: Entity<Editor>,
|
||||
}
|
||||
|
||||
impl LaunchMode {
|
||||
pub(super) fn new(
|
||||
past_launch_config: Option<LaunchRequest>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Entity<Self> {
|
||||
let (past_program, past_cwd) = past_launch_config
|
||||
.map(|config| (Some(config.program), config.cwd))
|
||||
.unwrap_or_else(|| (None, None));
|
||||
|
||||
let program = cx.new(|cx| Editor::single_line(window, cx));
|
||||
program.update(cx, |this, cx| {
|
||||
this.set_placeholder_text("Program path", cx);
|
||||
|
||||
if let Some(past_program) = past_program {
|
||||
this.set_text(past_program, window, cx);
|
||||
};
|
||||
});
|
||||
let cwd = cx.new(|cx| Editor::single_line(window, cx));
|
||||
cwd.update(cx, |this, cx| {
|
||||
this.set_placeholder_text("Working Directory", cx);
|
||||
if let Some(past_cwd) = past_cwd {
|
||||
this.set_text(past_cwd.to_string_lossy(), window, cx);
|
||||
};
|
||||
});
|
||||
cx.new(|_| Self { program, cwd })
|
||||
}
|
||||
|
||||
pub(super) fn debug_task(&self, cx: &App) -> task::LaunchRequest {
|
||||
let path = self.cwd.read(cx).text(cx);
|
||||
task::LaunchRequest {
|
||||
program: self.program.read(cx).text(cx),
|
||||
cwd: path.is_empty().not().then(|| PathBuf::from(path)),
|
||||
args: Default::default(),
|
||||
env: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(super) struct AttachMode {
|
||||
pub(super) definition: DebugTaskDefinition,
|
||||
pub(super) attach_picker: Entity<AttachModal>,
|
||||
_subscription: Rc<Subscription>,
|
||||
}
|
||||
|
||||
impl AttachMode {
|
||||
pub(super) fn new(
|
||||
debugger: Option<SharedString>,
|
||||
workspace: Entity<Workspace>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<NewSessionModal>,
|
||||
) -> Entity<Self> {
|
||||
let definition = DebugTaskDefinition {
|
||||
adapter: debugger.clone().unwrap_or_default(),
|
||||
label: "Attach New Session Setup".into(),
|
||||
request: dap::DebugRequest::Attach(task::AttachRequest { process_id: None }),
|
||||
initialize_args: None,
|
||||
tcp_connection: None,
|
||||
stop_on_entry: Some(false),
|
||||
};
|
||||
let attach_picker = cx.new(|cx| {
|
||||
let modal = AttachModal::new(definition.clone(), workspace, false, window, cx);
|
||||
window.focus(&modal.focus_handle(cx));
|
||||
|
||||
modal
|
||||
});
|
||||
|
||||
let subscription = cx.subscribe(&attach_picker, |_, _, _, cx| {
|
||||
cx.emit(DismissEvent);
|
||||
});
|
||||
|
||||
cx.new(|_| Self {
|
||||
definition,
|
||||
attach_picker,
|
||||
_subscription: Rc::new(subscription),
|
||||
})
|
||||
}
|
||||
pub(super) fn debug_task(&self) -> task::AttachRequest {
|
||||
task::AttachRequest { process_id: None }
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) struct DebugScenarioDelegate {
|
||||
task_store: Entity<TaskStore>,
|
||||
candidates: Option<Vec<(TaskSourceKind, DebugScenario)>>,
|
||||
selected_index: usize,
|
||||
matches: Vec<StringMatch>,
|
||||
prompt: String,
|
||||
debug_panel: WeakEntity<DebugPanel>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
}
|
||||
|
||||
impl DebugScenarioDelegate {
|
||||
pub(super) fn new(
|
||||
debug_panel: WeakEntity<DebugPanel>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
task_store: Entity<TaskStore>,
|
||||
) -> Self {
|
||||
Self {
|
||||
task_store,
|
||||
candidates: None,
|
||||
selected_index: 0,
|
||||
matches: Vec::new(),
|
||||
prompt: String::new(),
|
||||
debug_panel,
|
||||
workspace,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PickerDelegate for DebugScenarioDelegate {
|
||||
type ListItem = ui::ListItem;
|
||||
|
||||
fn match_count(&self) -> usize {
|
||||
self.matches.len()
|
||||
}
|
||||
|
||||
fn selected_index(&self) -> usize {
|
||||
self.selected_index
|
||||
}
|
||||
|
||||
fn set_selected_index(
|
||||
&mut self,
|
||||
ix: usize,
|
||||
_window: &mut Window,
|
||||
_cx: &mut Context<picker::Picker<Self>>,
|
||||
) {
|
||||
self.selected_index = ix;
|
||||
}
|
||||
|
||||
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> std::sync::Arc<str> {
|
||||
"".into()
|
||||
}
|
||||
|
||||
fn update_matches(
|
||||
&mut self,
|
||||
query: String,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<picker::Picker<Self>>,
|
||||
) -> gpui::Task<()> {
|
||||
let candidates: Vec<_> = match &self.candidates {
|
||||
Some(candidates) => candidates
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(index, (_, candidate))| {
|
||||
StringMatchCandidate::new(index, candidate.label.as_ref())
|
||||
})
|
||||
.collect(),
|
||||
None => {
|
||||
let worktree_ids: Vec<_> = self
|
||||
.workspace
|
||||
.update(cx, |this, cx| {
|
||||
this.visible_worktrees(cx)
|
||||
.map(|tree| tree.read(cx).id())
|
||||
.collect()
|
||||
})
|
||||
.ok()
|
||||
.unwrap_or_default();
|
||||
|
||||
let scenarios: Vec<_> = self
|
||||
.task_store
|
||||
.read(cx)
|
||||
.task_inventory()
|
||||
.map(|item| item.read(cx).list_debug_scenarios(worktree_ids.into_iter()))
|
||||
.unwrap_or_default();
|
||||
|
||||
self.candidates = Some(scenarios.clone());
|
||||
|
||||
scenarios
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(index, (_, candidate))| {
|
||||
StringMatchCandidate::new(index, candidate.label.as_ref())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
};
|
||||
|
||||
cx.spawn_in(window, async move |picker, cx| {
|
||||
let matches = fuzzy::match_strings(
|
||||
&candidates,
|
||||
&query,
|
||||
true,
|
||||
1000,
|
||||
&Default::default(),
|
||||
cx.background_executor().clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
picker
|
||||
.update(cx, |picker, _| {
|
||||
let delegate = &mut picker.delegate;
|
||||
|
||||
delegate.matches = matches;
|
||||
delegate.prompt = query;
|
||||
|
||||
if delegate.matches.is_empty() {
|
||||
delegate.selected_index = 0;
|
||||
} else {
|
||||
delegate.selected_index =
|
||||
delegate.selected_index.min(delegate.matches.len() - 1);
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
})
|
||||
}
|
||||
|
||||
fn confirm(
|
||||
&mut self,
|
||||
_: bool,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<picker::Picker<Self>>,
|
||||
) {
|
||||
let debug_scenario =
|
||||
self.matches
|
||||
.get(self.selected_index())
|
||||
.and_then(|match_candidate| {
|
||||
self.candidates
|
||||
.as_ref()
|
||||
.map(|candidates| candidates[match_candidate.candidate_id].clone())
|
||||
});
|
||||
|
||||
let Some((task_source_kind, debug_scenario)) = debug_scenario else {
|
||||
return;
|
||||
};
|
||||
|
||||
let task_context = if let TaskSourceKind::Worktree {
|
||||
id: worktree_id,
|
||||
directory_in_worktree: _,
|
||||
id_base: _,
|
||||
} = task_source_kind
|
||||
{
|
||||
let workspace = self.workspace.clone();
|
||||
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
workspace
|
||||
.update_in(cx, |workspace, window, cx| {
|
||||
tasks_ui::task_contexts(workspace, window, cx)
|
||||
})
|
||||
.ok()?
|
||||
.await
|
||||
.task_context_for_worktree_id(worktree_id)
|
||||
.cloned()
|
||||
})
|
||||
} else {
|
||||
gpui::Task::ready(None)
|
||||
};
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let task_context = task_context.await.unwrap_or_default();
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.delegate
|
||||
.debug_panel
|
||||
.update(cx, |panel, cx| {
|
||||
panel.start_session(debug_scenario, task_context, None, window, cx);
|
||||
})
|
||||
.ok();
|
||||
|
||||
cx.emit(DismissEvent);
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn dismissed(&mut self, _: &mut Window, cx: &mut Context<picker::Picker<Self>>) {
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
|
||||
fn render_match(
|
||||
&self,
|
||||
ix: usize,
|
||||
selected: bool,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<picker::Picker<Self>>,
|
||||
) -> Option<Self::ListItem> {
|
||||
let hit = &self.matches[ix];
|
||||
|
||||
let highlighted_location = HighlightedMatch {
|
||||
text: hit.string.clone(),
|
||||
highlight_positions: hit.positions.clone(),
|
||||
char_count: hit.string.chars().count(),
|
||||
color: Color::Default,
|
||||
};
|
||||
|
||||
let icon = Icon::new(IconName::FileTree)
|
||||
.color(Color::Muted)
|
||||
.size(ui::IconSize::Small);
|
||||
|
||||
Some(
|
||||
ListItem::new(SharedString::from(format!("debug-scenario-selection-{ix}")))
|
||||
.inset(true)
|
||||
.start_slot::<Icon>(icon)
|
||||
.spacing(ListItemSpacing::Sparse)
|
||||
.toggle_state(selected)
|
||||
.child(highlighted_location.render(window, cx)),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,7 +188,7 @@ impl Render for SubView {
|
||||
cx.notify();
|
||||
}))
|
||||
.size_full()
|
||||
// Add border uncoditionally to prevent layout shifts on focus changes.
|
||||
// Add border unconditionally to prevent layout shifts on focus changes.
|
||||
.border_1()
|
||||
.when(self.pane_focus_handle.contains_focused(window, cx), |el| {
|
||||
el.border_color(cx.theme().colors().pane_focused_border)
|
||||
|
||||
@@ -19,6 +19,7 @@ component.workspace = true
|
||||
ctor.workspace = true
|
||||
editor.workspace = true
|
||||
env_logger.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
indoc.workspace = true
|
||||
language.workspace = true
|
||||
@@ -29,6 +30,7 @@ markdown.workspace = true
|
||||
project.workspace = true
|
||||
rand.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
text.workspace = true
|
||||
theme.workspace = true
|
||||
|
||||
@@ -14,6 +14,7 @@ use editor::{
|
||||
display_map::{BlockPlacement, BlockProperties, BlockStyle, CustomBlockId},
|
||||
scroll::Autoscroll,
|
||||
};
|
||||
use futures::future::join_all;
|
||||
use gpui::{
|
||||
AnyElement, AnyView, App, AsyncApp, Context, Entity, EventEmitter, FocusHandle, Focusable,
|
||||
Global, InteractiveElement, IntoElement, ParentElement, Render, SharedString, Styled,
|
||||
@@ -23,13 +24,15 @@ use language::{
|
||||
Bias, Buffer, BufferRow, BufferSnapshot, DiagnosticEntry, Point, ToTreeSitterPoint,
|
||||
};
|
||||
use lsp::DiagnosticSeverity;
|
||||
|
||||
use project::{DiagnosticSummary, Project, ProjectPath, project_settings::ProjectSettings};
|
||||
use project::{
|
||||
DiagnosticSummary, Project, ProjectPath,
|
||||
lsp_store::rust_analyzer_ext::{cancel_flycheck, run_flycheck},
|
||||
project_settings::ProjectSettings,
|
||||
};
|
||||
use settings::Settings;
|
||||
use std::{
|
||||
any::{Any, TypeId},
|
||||
cmp,
|
||||
cmp::Ordering,
|
||||
cmp::{self, Ordering},
|
||||
ops::{Range, RangeInclusive},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
@@ -45,7 +48,10 @@ use workspace::{
|
||||
searchable::SearchableItemHandle,
|
||||
};
|
||||
|
||||
actions!(diagnostics, [Deploy, ToggleWarnings]);
|
||||
actions!(
|
||||
diagnostics,
|
||||
[Deploy, ToggleWarnings, ToggleDiagnosticsRefresh]
|
||||
);
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct IncludeWarnings(bool);
|
||||
@@ -68,9 +74,16 @@ pub(crate) struct ProjectDiagnosticsEditor {
|
||||
paths_to_update: BTreeSet<ProjectPath>,
|
||||
include_warnings: bool,
|
||||
update_excerpts_task: Option<Task<Result<()>>>,
|
||||
cargo_diagnostics_fetch: CargoDiagnosticsFetchState,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
struct CargoDiagnosticsFetchState {
|
||||
fetch_task: Option<Task<()>>,
|
||||
cancel_task: Option<Task<()>>,
|
||||
diagnostic_sources: Arc<Vec<ProjectPath>>,
|
||||
}
|
||||
|
||||
impl EventEmitter<EditorEvent> for ProjectDiagnosticsEditor {}
|
||||
|
||||
const DIAGNOSTICS_UPDATE_DELAY: Duration = Duration::from_millis(50);
|
||||
@@ -126,6 +139,7 @@ impl Render for ProjectDiagnosticsEditor {
|
||||
.track_focus(&self.focus_handle(cx))
|
||||
.size_full()
|
||||
.on_action(cx.listener(Self::toggle_warnings))
|
||||
.on_action(cx.listener(Self::toggle_diagnostics_refresh))
|
||||
.child(child)
|
||||
}
|
||||
}
|
||||
@@ -212,7 +226,11 @@ impl ProjectDiagnosticsEditor {
|
||||
cx.observe_global_in::<IncludeWarnings>(window, |this, window, cx| {
|
||||
this.include_warnings = cx.global::<IncludeWarnings>().0;
|
||||
this.diagnostics.clear();
|
||||
this.update_all_excerpts(window, cx);
|
||||
this.update_all_diagnostics(window, cx);
|
||||
})
|
||||
.detach();
|
||||
cx.observe_release(&cx.entity(), |editor, _, cx| {
|
||||
editor.stop_cargo_diagnostics_fetch(cx);
|
||||
})
|
||||
.detach();
|
||||
|
||||
@@ -229,9 +247,14 @@ impl ProjectDiagnosticsEditor {
|
||||
editor,
|
||||
paths_to_update: Default::default(),
|
||||
update_excerpts_task: None,
|
||||
cargo_diagnostics_fetch: CargoDiagnosticsFetchState {
|
||||
fetch_task: None,
|
||||
cancel_task: None,
|
||||
diagnostic_sources: Arc::new(Vec::new()),
|
||||
},
|
||||
_subscription: project_event_subscription,
|
||||
};
|
||||
this.update_all_excerpts(window, cx);
|
||||
this.update_all_diagnostics(window, cx);
|
||||
this
|
||||
}
|
||||
|
||||
@@ -239,15 +262,17 @@ impl ProjectDiagnosticsEditor {
|
||||
if self.update_excerpts_task.is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
let project_handle = self.project.clone();
|
||||
self.update_excerpts_task = Some(cx.spawn_in(window, async move |this, cx| {
|
||||
cx.background_executor()
|
||||
.timer(DIAGNOSTICS_UPDATE_DELAY)
|
||||
.await;
|
||||
loop {
|
||||
let Some(path) = this.update(cx, |this, _| {
|
||||
let Some(path) = this.update(cx, |this, cx| {
|
||||
let Some(path) = this.paths_to_update.pop_first() else {
|
||||
this.update_excerpts_task.take();
|
||||
this.update_excerpts_task = None;
|
||||
cx.notify();
|
||||
return None;
|
||||
};
|
||||
Some(path)
|
||||
@@ -307,6 +332,32 @@ impl ProjectDiagnosticsEditor {
|
||||
cx.set_global(IncludeWarnings(!self.include_warnings));
|
||||
}
|
||||
|
||||
fn toggle_diagnostics_refresh(
|
||||
&mut self,
|
||||
_: &ToggleDiagnosticsRefresh,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let fetch_cargo_diagnostics = ProjectSettings::get_global(cx)
|
||||
.diagnostics
|
||||
.fetch_cargo_diagnostics();
|
||||
|
||||
if fetch_cargo_diagnostics {
|
||||
if self.cargo_diagnostics_fetch.fetch_task.is_some() {
|
||||
self.stop_cargo_diagnostics_fetch(cx);
|
||||
} else {
|
||||
self.update_all_diagnostics(window, cx);
|
||||
}
|
||||
} else {
|
||||
if self.update_excerpts_task.is_some() {
|
||||
self.update_excerpts_task = None;
|
||||
} else {
|
||||
self.update_all_diagnostics(window, cx);
|
||||
}
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn focus_in(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self.focus_handle.is_focused(window) && !self.multibuffer.read(cx).is_empty() {
|
||||
self.editor.focus_handle(cx).focus(window)
|
||||
@@ -320,6 +371,66 @@ impl ProjectDiagnosticsEditor {
|
||||
}
|
||||
}
|
||||
|
||||
fn update_all_diagnostics(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let cargo_diagnostics_sources = self.cargo_diagnostics_sources(cx);
|
||||
if cargo_diagnostics_sources.is_empty() {
|
||||
self.update_all_excerpts(window, cx);
|
||||
} else {
|
||||
self.fetch_cargo_diagnostics(Arc::new(cargo_diagnostics_sources), cx);
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch_cargo_diagnostics(
|
||||
&mut self,
|
||||
diagnostics_sources: Arc<Vec<ProjectPath>>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let project = self.project.clone();
|
||||
self.cargo_diagnostics_fetch.cancel_task = None;
|
||||
self.cargo_diagnostics_fetch.fetch_task = None;
|
||||
self.cargo_diagnostics_fetch.diagnostic_sources = diagnostics_sources.clone();
|
||||
if self.cargo_diagnostics_fetch.diagnostic_sources.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
self.cargo_diagnostics_fetch.fetch_task = Some(cx.spawn(async move |editor, cx| {
|
||||
let mut fetch_tasks = Vec::new();
|
||||
for buffer_path in diagnostics_sources.iter().cloned() {
|
||||
if cx
|
||||
.update(|cx| {
|
||||
fetch_tasks.push(run_flycheck(project.clone(), buffer_path, cx));
|
||||
})
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let _ = join_all(fetch_tasks).await;
|
||||
editor
|
||||
.update(cx, |editor, _| {
|
||||
editor.cargo_diagnostics_fetch.fetch_task = None;
|
||||
})
|
||||
.ok();
|
||||
}));
|
||||
}
|
||||
|
||||
fn stop_cargo_diagnostics_fetch(&mut self, cx: &mut App) {
|
||||
self.cargo_diagnostics_fetch.fetch_task = None;
|
||||
let mut cancel_gasks = Vec::new();
|
||||
for buffer_path in std::mem::take(&mut self.cargo_diagnostics_fetch.diagnostic_sources)
|
||||
.iter()
|
||||
.cloned()
|
||||
{
|
||||
cancel_gasks.push(cancel_flycheck(self.project.clone(), buffer_path, cx));
|
||||
}
|
||||
|
||||
self.cargo_diagnostics_fetch.cancel_task = Some(cx.background_spawn(async move {
|
||||
let _ = join_all(cancel_gasks).await;
|
||||
log::info!("Finished fetching cargo diagnostics");
|
||||
}));
|
||||
}
|
||||
|
||||
/// Enqueue an update of all excerpts. Updates all paths that either
|
||||
/// currently have diagnostics or are currently present in this view.
|
||||
fn update_all_excerpts(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
@@ -422,20 +533,17 @@ impl ProjectDiagnosticsEditor {
|
||||
})?;
|
||||
|
||||
for item in more {
|
||||
let insert_pos = blocks
|
||||
.binary_search_by(|existing| {
|
||||
match existing.initial_range.start.cmp(&item.initial_range.start) {
|
||||
Ordering::Equal => item
|
||||
.initial_range
|
||||
.end
|
||||
.cmp(&existing.initial_range.end)
|
||||
.reverse(),
|
||||
other => other,
|
||||
}
|
||||
let i = blocks
|
||||
.binary_search_by(|probe| {
|
||||
probe
|
||||
.initial_range
|
||||
.start
|
||||
.cmp(&item.initial_range.start)
|
||||
.then(probe.initial_range.end.cmp(&item.initial_range.end))
|
||||
.then(Ordering::Greater)
|
||||
})
|
||||
.unwrap_or_else(|pos| pos);
|
||||
|
||||
blocks.insert(insert_pos, item);
|
||||
.unwrap_or_else(|i| i);
|
||||
blocks.insert(i, item);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -448,10 +556,25 @@ impl ProjectDiagnosticsEditor {
|
||||
&mut cx,
|
||||
)
|
||||
.await;
|
||||
excerpt_ranges.push(ExcerptRange {
|
||||
context: excerpt_range,
|
||||
primary: b.initial_range.clone(),
|
||||
})
|
||||
let i = excerpt_ranges
|
||||
.binary_search_by(|probe| {
|
||||
probe
|
||||
.context
|
||||
.start
|
||||
.cmp(&excerpt_range.start)
|
||||
.then(probe.context.end.cmp(&excerpt_range.end))
|
||||
.then(probe.primary.start.cmp(&b.initial_range.start))
|
||||
.then(probe.primary.end.cmp(&b.initial_range.end))
|
||||
.then(cmp::Ordering::Greater)
|
||||
})
|
||||
.unwrap_or_else(|i| i);
|
||||
excerpt_ranges.insert(
|
||||
i,
|
||||
ExcerptRange {
|
||||
context: excerpt_range,
|
||||
primary: b.initial_range.clone(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
@@ -534,6 +657,30 @@ impl ProjectDiagnosticsEditor {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn cargo_diagnostics_sources(&self, cx: &App) -> Vec<ProjectPath> {
|
||||
let fetch_cargo_diagnostics = ProjectSettings::get_global(cx)
|
||||
.diagnostics
|
||||
.fetch_cargo_diagnostics();
|
||||
if !fetch_cargo_diagnostics {
|
||||
return Vec::new();
|
||||
}
|
||||
self.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.filter_map(|worktree| {
|
||||
let _cargo_toml_entry = worktree.read(cx).entry_for_path("Cargo.toml")?;
|
||||
let rust_file_entry = worktree.read(cx).entries(false, 0).find(|entry| {
|
||||
entry
|
||||
.path
|
||||
.extension()
|
||||
.and_then(|extension| extension.to_str())
|
||||
== Some("rs")
|
||||
})?;
|
||||
self.project.read(cx).path_for_entry(rust_file_entry.id, cx)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Focusable for ProjectDiagnosticsEditor {
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use crate::ProjectDiagnosticsEditor;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{ProjectDiagnosticsEditor, ToggleDiagnosticsRefresh};
|
||||
use gpui::{Context, Entity, EventEmitter, ParentElement, Render, WeakEntity, Window};
|
||||
use ui::prelude::*;
|
||||
use ui::{IconButton, IconButtonShape, IconName, Tooltip};
|
||||
@@ -13,18 +15,26 @@ impl Render for ToolbarControls {
|
||||
let mut include_warnings = false;
|
||||
let mut has_stale_excerpts = false;
|
||||
let mut is_updating = false;
|
||||
let cargo_diagnostics_sources = Arc::new(self.diagnostics().map_or(Vec::new(), |editor| {
|
||||
editor.read(cx).cargo_diagnostics_sources(cx)
|
||||
}));
|
||||
let fetch_cargo_diagnostics = !cargo_diagnostics_sources.is_empty();
|
||||
|
||||
if let Some(editor) = self.diagnostics() {
|
||||
let diagnostics = editor.read(cx);
|
||||
include_warnings = diagnostics.include_warnings;
|
||||
has_stale_excerpts = !diagnostics.paths_to_update.is_empty();
|
||||
is_updating = diagnostics.update_excerpts_task.is_some()
|
||||
|| diagnostics
|
||||
.project
|
||||
.read(cx)
|
||||
.language_servers_running_disk_based_diagnostics(cx)
|
||||
.next()
|
||||
.is_some();
|
||||
is_updating = if fetch_cargo_diagnostics {
|
||||
diagnostics.cargo_diagnostics_fetch.fetch_task.is_some()
|
||||
} else {
|
||||
diagnostics.update_excerpts_task.is_some()
|
||||
|| diagnostics
|
||||
.project
|
||||
.read(cx)
|
||||
.language_servers_running_disk_based_diagnostics(cx)
|
||||
.next()
|
||||
.is_some()
|
||||
};
|
||||
}
|
||||
|
||||
let tooltip = if include_warnings {
|
||||
@@ -41,21 +51,56 @@ impl Render for ToolbarControls {
|
||||
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.when(has_stale_excerpts, |div| {
|
||||
div.child(
|
||||
IconButton::new("update-excerpts", IconName::Update)
|
||||
.icon_color(Color::Info)
|
||||
.shape(IconButtonShape::Square)
|
||||
.disabled(is_updating)
|
||||
.tooltip(Tooltip::text("Update excerpts"))
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
if let Some(diagnostics) = this.diagnostics() {
|
||||
diagnostics.update(cx, |diagnostics, cx| {
|
||||
diagnostics.update_all_excerpts(window, cx);
|
||||
});
|
||||
}
|
||||
})),
|
||||
)
|
||||
.map(|div| {
|
||||
if is_updating {
|
||||
div.child(
|
||||
IconButton::new("stop-updating", IconName::StopFilled)
|
||||
.icon_color(Color::Info)
|
||||
.shape(IconButtonShape::Square)
|
||||
.tooltip(Tooltip::for_action_title(
|
||||
"Stop diagnostics update",
|
||||
&ToggleDiagnosticsRefresh,
|
||||
))
|
||||
.on_click(cx.listener(move |toolbar_controls, _, _, cx| {
|
||||
if let Some(diagnostics) = toolbar_controls.diagnostics() {
|
||||
diagnostics.update(cx, |diagnostics, cx| {
|
||||
diagnostics.stop_cargo_diagnostics_fetch(cx);
|
||||
diagnostics.update_excerpts_task = None;
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
})),
|
||||
)
|
||||
} else {
|
||||
div.child(
|
||||
IconButton::new("refresh-diagnostics", IconName::Update)
|
||||
.icon_color(Color::Info)
|
||||
.shape(IconButtonShape::Square)
|
||||
.disabled(!has_stale_excerpts && !fetch_cargo_diagnostics)
|
||||
.tooltip(Tooltip::for_action_title(
|
||||
"Refresh diagnostics",
|
||||
&ToggleDiagnosticsRefresh,
|
||||
))
|
||||
.on_click(cx.listener({
|
||||
move |toolbar_controls, _, window, cx| {
|
||||
if let Some(diagnostics) = toolbar_controls.diagnostics() {
|
||||
let cargo_diagnostics_sources =
|
||||
Arc::clone(&cargo_diagnostics_sources);
|
||||
diagnostics.update(cx, move |diagnostics, cx| {
|
||||
if fetch_cargo_diagnostics {
|
||||
diagnostics.fetch_cargo_diagnostics(
|
||||
cargo_diagnostics_sources,
|
||||
cx,
|
||||
);
|
||||
} else {
|
||||
diagnostics.update_all_excerpts(window, cx);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
})),
|
||||
)
|
||||
}
|
||||
})
|
||||
.child(
|
||||
IconButton::new("toggle-warnings", IconName::Warning)
|
||||
|
||||
@@ -249,7 +249,9 @@ actions!(
|
||||
ApplyDiffHunk,
|
||||
Backspace,
|
||||
Cancel,
|
||||
CancelFlycheck,
|
||||
CancelLanguageServerWork,
|
||||
ClearFlycheck,
|
||||
ConfirmRename,
|
||||
ConfirmCompletionInsert,
|
||||
ConfirmCompletionReplace,
|
||||
@@ -308,6 +310,7 @@ actions!(
|
||||
GoToImplementation,
|
||||
GoToImplementationSplit,
|
||||
GoToNextChange,
|
||||
GoToParentModule,
|
||||
GoToPreviousChange,
|
||||
GoToPreviousDiagnostic,
|
||||
GoToTypeDefinition,
|
||||
@@ -371,6 +374,7 @@ actions!(
|
||||
RevertFile,
|
||||
ReloadFile,
|
||||
Rewrap,
|
||||
RunFlycheck,
|
||||
ScrollCursorBottom,
|
||||
ScrollCursorCenter,
|
||||
ScrollCursorCenterTopBottom,
|
||||
|
||||
@@ -5005,11 +5005,11 @@ impl Editor {
|
||||
range
|
||||
};
|
||||
|
||||
ranges.push(range);
|
||||
ranges.push(range.clone());
|
||||
|
||||
if !self.linked_edit_ranges.is_empty() {
|
||||
let start_anchor = snapshot.anchor_before(selection.head());
|
||||
let end_anchor = snapshot.anchor_after(selection.tail());
|
||||
let start_anchor = snapshot.anchor_before(range.start);
|
||||
let end_anchor = snapshot.anchor_after(range.end);
|
||||
if let Some(ranges) = self
|
||||
.linked_editing_ranges_for(start_anchor.text_anchor..end_anchor.text_anchor, cx)
|
||||
{
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
JoinLines,
|
||||
linked_editing_ranges::LinkedEditingRanges,
|
||||
scroll::scroll_amount::ScrollAmount,
|
||||
test::{
|
||||
assert_text_with_selections, build_editor,
|
||||
@@ -19559,6 +19560,146 @@ async fn test_hide_mouse_context_menu_on_modal_opened(cx: &mut TestAppContext) {
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_html_linked_edits_on_completion(cx: &mut TestAppContext) {
|
||||
init_test(cx, |_| {});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_file(path!("/file.html"), Default::default())
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs, [path!("/").as_ref()], cx).await;
|
||||
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
|
||||
let html_language = Arc::new(Language::new(
|
||||
LanguageConfig {
|
||||
name: "HTML".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["html".to_string()],
|
||||
..LanguageMatcher::default()
|
||||
},
|
||||
brackets: BracketPairConfig {
|
||||
pairs: vec![BracketPair {
|
||||
start: "<".into(),
|
||||
end: ">".into(),
|
||||
close: true,
|
||||
..Default::default()
|
||||
}],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_html::LANGUAGE.into()),
|
||||
));
|
||||
language_registry.add(html_language);
|
||||
let mut fake_servers = language_registry.register_fake_lsp(
|
||||
"HTML",
|
||||
FakeLspAdapter {
|
||||
capabilities: lsp::ServerCapabilities {
|
||||
completion_provider: Some(lsp::CompletionOptions {
|
||||
resolve_provider: Some(true),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
let workspace = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
let cx = &mut VisualTestContext::from_window(*workspace, cx);
|
||||
|
||||
let worktree_id = workspace
|
||||
.update(cx, |workspace, _window, cx| {
|
||||
workspace.project().update(cx, |project, cx| {
|
||||
project.worktrees(cx).next().unwrap().read(cx).id()
|
||||
})
|
||||
})
|
||||
.unwrap();
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer_with_lsp(path!("/file.html"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let editor = workspace
|
||||
.update(cx, |workspace, window, cx| {
|
||||
workspace.open_path((worktree_id, "file.html"), None, true, window, cx)
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap()
|
||||
.downcast::<Editor>()
|
||||
.unwrap();
|
||||
|
||||
let fake_server = fake_servers.next().await.unwrap();
|
||||
editor.update_in(cx, |editor, window, cx| {
|
||||
editor.set_text("<ad></ad>", window, cx);
|
||||
editor.change_selections(None, window, cx, |selections| {
|
||||
selections.select_ranges([Point::new(0, 3)..Point::new(0, 3)]);
|
||||
});
|
||||
let Some((buffer, _)) = editor
|
||||
.buffer
|
||||
.read(cx)
|
||||
.text_anchor_for_position(editor.selections.newest_anchor().start, cx)
|
||||
else {
|
||||
panic!("Failed to get buffer for selection position");
|
||||
};
|
||||
let buffer = buffer.read(cx);
|
||||
let buffer_id = buffer.remote_id();
|
||||
let opening_range =
|
||||
buffer.anchor_before(Point::new(0, 1))..buffer.anchor_after(Point::new(0, 3));
|
||||
let closing_range =
|
||||
buffer.anchor_before(Point::new(0, 6))..buffer.anchor_after(Point::new(0, 8));
|
||||
let mut linked_ranges = HashMap::default();
|
||||
linked_ranges.insert(
|
||||
buffer_id,
|
||||
vec![(opening_range.clone(), vec![closing_range.clone()])],
|
||||
);
|
||||
editor.linked_edit_ranges = LinkedEditingRanges(linked_ranges);
|
||||
});
|
||||
let mut completion_handle =
|
||||
fake_server.set_request_handler::<lsp::request::Completion, _, _>(move |_, _| async move {
|
||||
Ok(Some(lsp::CompletionResponse::Array(vec![
|
||||
lsp::CompletionItem {
|
||||
label: "head".to_string(),
|
||||
text_edit: Some(lsp::CompletionTextEdit::InsertAndReplace(
|
||||
lsp::InsertReplaceEdit {
|
||||
new_text: "head".to_string(),
|
||||
insert: lsp::Range::new(
|
||||
lsp::Position::new(0, 1),
|
||||
lsp::Position::new(0, 3),
|
||||
),
|
||||
replace: lsp::Range::new(
|
||||
lsp::Position::new(0, 1),
|
||||
lsp::Position::new(0, 3),
|
||||
),
|
||||
},
|
||||
)),
|
||||
..Default::default()
|
||||
},
|
||||
])))
|
||||
});
|
||||
editor.update_in(cx, |editor, window, cx| {
|
||||
editor.show_completions(&ShowCompletions { trigger: None }, window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
completion_handle.next().await.unwrap();
|
||||
editor.update(cx, |editor, _| {
|
||||
assert!(
|
||||
editor.context_menu_visible(),
|
||||
"Completion menu should be visible"
|
||||
);
|
||||
});
|
||||
editor.update_in(cx, |editor, window, cx| {
|
||||
editor.confirm_completion(&ConfirmCompletion::default(), window, cx)
|
||||
});
|
||||
cx.executor().run_until_parked();
|
||||
editor.update(cx, |editor, cx| {
|
||||
assert_eq!(editor.text(cx), "<head></head>");
|
||||
});
|
||||
}
|
||||
|
||||
fn empty_range(row: usize, column: usize) -> Range<DisplayPoint> {
|
||||
let point = DisplayPoint::new(DisplayRow(row as u32), column as u32);
|
||||
point..point
|
||||
|
||||
@@ -2276,6 +2276,9 @@ impl EditorElement {
|
||||
}
|
||||
|
||||
let display_row = multibuffer_point.to_display_point(snapshot).row();
|
||||
if !range.contains(&display_row) {
|
||||
return None;
|
||||
}
|
||||
if row_infos
|
||||
.get((display_row - range.start).0 as usize)
|
||||
.is_some_and(|row_info| row_info.expand_info.is_some())
|
||||
@@ -6871,7 +6874,12 @@ impl Element for EditorElement {
|
||||
// The max scroll position for the top of the window
|
||||
let max_scroll_top = if matches!(
|
||||
snapshot.mode,
|
||||
EditorMode::AutoHeight { .. } | EditorMode::SingleLine { .. }
|
||||
EditorMode::SingleLine { .. }
|
||||
| EditorMode::AutoHeight { .. }
|
||||
| EditorMode::Full {
|
||||
sized_by_content: true,
|
||||
..
|
||||
}
|
||||
) {
|
||||
(max_row - height_in_lines + 1.).max(0.)
|
||||
} else {
|
||||
|
||||
@@ -233,7 +233,7 @@ pub fn deploy_context_menu(
|
||||
.separator()
|
||||
.action("Cut", Box::new(Cut))
|
||||
.action("Copy", Box::new(Copy))
|
||||
.action("Copy and trim", Box::new(CopyAndTrim))
|
||||
.action("Copy and Trim", Box::new(CopyAndTrim))
|
||||
.action("Paste", Box::new(Paste))
|
||||
.separator()
|
||||
.map(|builder| {
|
||||
|
||||
@@ -4,15 +4,20 @@ use anyhow::Context as _;
|
||||
use gpui::{App, AppContext as _, Context, Entity, Window};
|
||||
use language::{Capability, Language, proto::serialize_anchor};
|
||||
use multi_buffer::MultiBuffer;
|
||||
use project::lsp_store::{
|
||||
lsp_ext_command::{DocsUrls, ExpandMacro, ExpandedMacro},
|
||||
rust_analyzer_ext::RUST_ANALYZER_NAME,
|
||||
use project::{
|
||||
ProjectItem,
|
||||
lsp_command::location_link_from_proto,
|
||||
lsp_store::{
|
||||
lsp_ext_command::{DocsUrls, ExpandMacro, ExpandedMacro},
|
||||
rust_analyzer_ext::{RUST_ANALYZER_NAME, cancel_flycheck, clear_flycheck, run_flycheck},
|
||||
},
|
||||
};
|
||||
use rpc::proto;
|
||||
use text::ToPointUtf16;
|
||||
|
||||
use crate::{
|
||||
Editor, ExpandMacroRecursively, OpenDocs, element::register_action,
|
||||
CancelFlycheck, ClearFlycheck, Editor, ExpandMacroRecursively, GoToParentModule,
|
||||
GotoDefinitionKind, OpenDocs, RunFlycheck, element::register_action, hover_links::HoverLink,
|
||||
lsp_ext::find_specific_language_server_in_selection,
|
||||
};
|
||||
|
||||
@@ -30,11 +35,97 @@ pub fn apply_related_actions(editor: &Entity<Editor>, window: &mut Window, cx: &
|
||||
.filter_map(|buffer| buffer.read(cx).language())
|
||||
.any(|language| is_rust_language(language))
|
||||
{
|
||||
register_action(&editor, window, go_to_parent_module);
|
||||
register_action(&editor, window, expand_macro_recursively);
|
||||
register_action(&editor, window, open_docs);
|
||||
register_action(&editor, window, cancel_flycheck_action);
|
||||
register_action(&editor, window, run_flycheck_action);
|
||||
register_action(&editor, window, clear_flycheck_action);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn go_to_parent_module(
|
||||
editor: &mut Editor,
|
||||
_: &GoToParentModule,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Editor>,
|
||||
) {
|
||||
if editor.selections.count() == 0 {
|
||||
return;
|
||||
}
|
||||
let Some(project) = &editor.project else {
|
||||
return;
|
||||
};
|
||||
|
||||
let server_lookup = find_specific_language_server_in_selection(
|
||||
editor,
|
||||
cx,
|
||||
is_rust_language,
|
||||
RUST_ANALYZER_NAME,
|
||||
);
|
||||
|
||||
let project = project.clone();
|
||||
let lsp_store = project.read(cx).lsp_store();
|
||||
let upstream_client = lsp_store.read(cx).upstream_client();
|
||||
cx.spawn_in(window, async move |editor, cx| {
|
||||
let Some((trigger_anchor, _, server_to_query, buffer)) = server_lookup.await else {
|
||||
return anyhow::Ok(());
|
||||
};
|
||||
|
||||
let location_links = if let Some((client, project_id)) = upstream_client {
|
||||
let buffer_id = buffer.update(cx, |buffer, _| buffer.remote_id())?;
|
||||
|
||||
let request = proto::LspExtGoToParentModule {
|
||||
project_id,
|
||||
buffer_id: buffer_id.to_proto(),
|
||||
position: Some(serialize_anchor(&trigger_anchor.text_anchor)),
|
||||
};
|
||||
let response = client
|
||||
.request(request)
|
||||
.await
|
||||
.context("lsp ext go to parent module proto request")?;
|
||||
futures::future::join_all(
|
||||
response
|
||||
.links
|
||||
.into_iter()
|
||||
.map(|link| location_link_from_proto(link, lsp_store.clone(), cx)),
|
||||
)
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<anyhow::Result<_>>()
|
||||
.context("go to parent module via collab")?
|
||||
} else {
|
||||
let buffer_snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
|
||||
let position = trigger_anchor.text_anchor.to_point_utf16(&buffer_snapshot);
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
project.request_lsp(
|
||||
buffer,
|
||||
project::LanguageServerToQuery::Other(server_to_query),
|
||||
project::lsp_store::lsp_ext_command::GoToParentModule { position },
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await
|
||||
.context("go to parent module")?
|
||||
};
|
||||
|
||||
editor
|
||||
.update_in(cx, |editor, window, cx| {
|
||||
editor.navigate_to_hover_links(
|
||||
Some(GotoDefinitionKind::Declaration),
|
||||
location_links.into_iter().map(HoverLink::Text).collect(),
|
||||
false,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
pub fn expand_macro_recursively(
|
||||
editor: &mut Editor,
|
||||
_: &ExpandMacroRecursively,
|
||||
@@ -213,3 +304,87 @@ pub fn open_docs(editor: &mut Editor, _: &OpenDocs, window: &mut Window, cx: &mu
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn cancel_flycheck_action(
|
||||
editor: &mut Editor,
|
||||
_: &CancelFlycheck,
|
||||
_: &mut Window,
|
||||
cx: &mut Context<Editor>,
|
||||
) {
|
||||
let Some(project) = &editor.project else {
|
||||
return;
|
||||
};
|
||||
let Some(buffer_id) = editor
|
||||
.selections
|
||||
.disjoint_anchors()
|
||||
.iter()
|
||||
.find_map(|selection| {
|
||||
let buffer_id = selection.start.buffer_id.or(selection.end.buffer_id)?;
|
||||
let project = project.read(cx);
|
||||
let entry_id = project
|
||||
.buffer_for_id(buffer_id, cx)?
|
||||
.read(cx)
|
||||
.entry_id(cx)?;
|
||||
project.path_for_entry(entry_id, cx)
|
||||
})
|
||||
else {
|
||||
return;
|
||||
};
|
||||
cancel_flycheck(project.clone(), buffer_id, cx).detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn run_flycheck_action(
|
||||
editor: &mut Editor,
|
||||
_: &RunFlycheck,
|
||||
_: &mut Window,
|
||||
cx: &mut Context<Editor>,
|
||||
) {
|
||||
let Some(project) = &editor.project else {
|
||||
return;
|
||||
};
|
||||
let Some(buffer_id) = editor
|
||||
.selections
|
||||
.disjoint_anchors()
|
||||
.iter()
|
||||
.find_map(|selection| {
|
||||
let buffer_id = selection.start.buffer_id.or(selection.end.buffer_id)?;
|
||||
let project = project.read(cx);
|
||||
let entry_id = project
|
||||
.buffer_for_id(buffer_id, cx)?
|
||||
.read(cx)
|
||||
.entry_id(cx)?;
|
||||
project.path_for_entry(entry_id, cx)
|
||||
})
|
||||
else {
|
||||
return;
|
||||
};
|
||||
run_flycheck(project.clone(), buffer_id, cx).detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn clear_flycheck_action(
|
||||
editor: &mut Editor,
|
||||
_: &ClearFlycheck,
|
||||
_: &mut Window,
|
||||
cx: &mut Context<Editor>,
|
||||
) {
|
||||
let Some(project) = &editor.project else {
|
||||
return;
|
||||
};
|
||||
let Some(buffer_id) = editor
|
||||
.selections
|
||||
.disjoint_anchors()
|
||||
.iter()
|
||||
.find_map(|selection| {
|
||||
let buffer_id = selection.start.buffer_id.or(selection.end.buffer_id)?;
|
||||
let project = project.read(cx);
|
||||
let entry_id = project
|
||||
.buffer_for_id(buffer_id, cx)?
|
||||
.read(cx)
|
||||
.entry_id(cx)?;
|
||||
project.path_for_entry(entry_id, cx)
|
||||
})
|
||||
else {
|
||||
return;
|
||||
};
|
||||
clear_flycheck(project.clone(), buffer_id, cx).detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ markdown.workspace = true
|
||||
node_runtime.workspace = true
|
||||
pathdiff.workspace = true
|
||||
paths.workspace = true
|
||||
pretty_assertions.workspace = true
|
||||
project.workspace = true
|
||||
prompt_store.workspace = true
|
||||
regex.workspace = true
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{
|
||||
"assistant": {
|
||||
"always_allow_tool_actions": true,
|
||||
"stream_edits": true,
|
||||
"version": "2"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,10 +52,10 @@ struct Args {
|
||||
#[arg(long, value_delimiter = ',', default_value = "rs,ts")]
|
||||
languages: Vec<String>,
|
||||
/// How many times to run each example.
|
||||
#[arg(long, default_value = "1")]
|
||||
#[arg(long, default_value = "8")]
|
||||
repetitions: usize,
|
||||
/// Maximum number of examples to run concurrently.
|
||||
#[arg(long, default_value = "10")]
|
||||
#[arg(long, default_value = "4")]
|
||||
concurrency: usize,
|
||||
}
|
||||
|
||||
@@ -420,12 +420,18 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
||||
languages::init(languages.clone(), node_runtime.clone(), cx);
|
||||
assistant_tools::init(client.http_client(), cx);
|
||||
context_server::init(cx);
|
||||
prompt_store::init(cx);
|
||||
let stdout_is_a_pty = false;
|
||||
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
|
||||
agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
|
||||
agent::init(
|
||||
fs.clone(),
|
||||
client.clone(),
|
||||
prompt_builder.clone(),
|
||||
languages.clone(),
|
||||
cx,
|
||||
);
|
||||
assistant_tools::init(client.http_client(), cx);
|
||||
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.set_user_settings(include_str!("../runner_settings.json"), cx)
|
||||
|
||||
@@ -160,7 +160,11 @@ impl ExampleContext {
|
||||
if left == right {
|
||||
Ok(())
|
||||
} else {
|
||||
println!("{}{:#?} != {:#?}", self.log_prefix, left, right);
|
||||
println!(
|
||||
"{}{}",
|
||||
self.log_prefix,
|
||||
pretty_assertions::Comparison::new(&left, &right)
|
||||
);
|
||||
Err(anyhow::Error::from(FailedAssertion(message.clone())))
|
||||
},
|
||||
message,
|
||||
@@ -334,8 +338,8 @@ impl ExampleContext {
|
||||
}
|
||||
|
||||
pub fn edits(&self) -> HashMap<Arc<Path>, FileEdits> {
|
||||
self.app
|
||||
.read_entity(&self.agent_thread, |thread, cx| {
|
||||
self.agent_thread
|
||||
.read_with(&self.app, |thread, cx| {
|
||||
let action_log = thread.action_log().read(cx);
|
||||
HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
|
||||
|(buffer, diff)| {
|
||||
@@ -503,16 +507,16 @@ impl ToolUse {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
pub struct FileEdits {
|
||||
hunks: Vec<FileEditHunk>,
|
||||
pub hunks: Vec<FileEditHunk>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct FileEditHunk {
|
||||
base_text: String,
|
||||
text: String,
|
||||
status: DiffHunkStatus,
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
pub struct FileEditHunk {
|
||||
pub base_text: String,
|
||||
pub text: String,
|
||||
pub status: DiffHunkStatus,
|
||||
}
|
||||
|
||||
impl FileEdits {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user