Compare commits
24 Commits
from-str-f
...
follow-age
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ab62666c3 | ||
|
|
35539847a4 | ||
|
|
f619d5f02a | ||
|
|
ba59305510 | ||
|
|
672a1dd553 | ||
|
|
93cc4946d8 | ||
|
|
0c0a4ed866 | ||
|
|
51f1998107 | ||
|
|
1ffedf4a08 | ||
|
|
d25da9728b | ||
|
|
e1e3f2e423 | ||
|
|
92b9ecd7d2 | ||
|
|
758d260cec | ||
|
|
8d4d3badf3 | ||
|
|
7c23d13773 | ||
|
|
ad87c545c7 | ||
|
|
23fbab15ee | ||
|
|
d7e181576e | ||
|
|
9788aff4b1 | ||
|
|
2a319efade | ||
|
|
50ec26c163 | ||
|
|
39dd133b1c | ||
|
|
24eb039752 | ||
|
|
bffa53d706 |
156
Cargo.lock
generated
156
Cargo.lock
generated
@@ -68,6 +68,7 @@ dependencies = [
|
||||
"convert_case 0.8.0",
|
||||
"db",
|
||||
"editor",
|
||||
"extension",
|
||||
"feature_flags",
|
||||
"file_icons",
|
||||
"fs",
|
||||
@@ -81,6 +82,7 @@ dependencies = [
|
||||
"indexmap",
|
||||
"indoc",
|
||||
"itertools 0.14.0",
|
||||
"jsonschema",
|
||||
"language",
|
||||
"language_model",
|
||||
"language_model_selector",
|
||||
@@ -90,6 +92,7 @@ dependencies = [
|
||||
"markdown",
|
||||
"menu",
|
||||
"multi_buffer",
|
||||
"notifications",
|
||||
"ordered-float 2.10.1",
|
||||
"parking_lot",
|
||||
"paths",
|
||||
@@ -106,6 +109,7 @@ dependencies = [
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_json_lenient",
|
||||
"settings",
|
||||
"smallvec",
|
||||
"smol",
|
||||
@@ -148,7 +152,9 @@ checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"const-random",
|
||||
"getrandom 0.2.15",
|
||||
"once_cell",
|
||||
"serde",
|
||||
"version_check",
|
||||
"zerocopy 0.7.35",
|
||||
]
|
||||
@@ -2186,6 +2192,12 @@ dependencies = [
|
||||
"piper",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "borrow-or-share"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3eeab4423108c5d7c744f4d234de88d18d636100093ae04caf4825134b9c3a32"
|
||||
|
||||
[[package]]
|
||||
name = "borsh"
|
||||
version = "1.5.7"
|
||||
@@ -2301,6 +2313,12 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bytecount"
|
||||
version = "0.6.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce"
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
version = "1.22.0"
|
||||
@@ -3221,7 +3239,9 @@ dependencies = [
|
||||
name = "component_preview"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"agent",
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"client",
|
||||
"collections",
|
||||
"component",
|
||||
@@ -3231,6 +3251,7 @@ dependencies = [
|
||||
"log",
|
||||
"notifications",
|
||||
"project",
|
||||
"prompt_store",
|
||||
"serde",
|
||||
"ui",
|
||||
"ui_input",
|
||||
@@ -4378,7 +4399,6 @@ name = "diagnostics"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"cargo_metadata",
|
||||
"client",
|
||||
"collections",
|
||||
"component",
|
||||
@@ -4388,7 +4408,6 @@ dependencies = [
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"indoc",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
"linkme",
|
||||
"log",
|
||||
@@ -4400,7 +4419,6 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"text",
|
||||
"theme",
|
||||
"ui",
|
||||
@@ -4783,6 +4801,15 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "email_address"
|
||||
version = "0.2.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "embed-resource"
|
||||
version = "3.0.2"
|
||||
@@ -5198,6 +5225,7 @@ dependencies = [
|
||||
"collections",
|
||||
"db",
|
||||
"editor",
|
||||
"extension",
|
||||
"extension_host",
|
||||
"fs",
|
||||
"fuzzy",
|
||||
@@ -5430,6 +5458,17 @@ version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8bf7cc16383c4b8d58b9905a8509f02926ce3058053c056376248d958c9df1e8"
|
||||
|
||||
[[package]]
|
||||
name = "fluent-uri"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1918b65d96df47d3591bed19c5cca17e3fa5d0707318e4b5ef2eae01764df7e5"
|
||||
dependencies = [
|
||||
"borrow-or-share",
|
||||
"ref-cast",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "flume"
|
||||
version = "0.11.1"
|
||||
@@ -5584,6 +5623,16 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fraction"
|
||||
version = "0.15.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0f158e3ff0a1b334408dc9fb811cd99b446986f4d8b741bb08f9df1604085ae7"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
"num",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "freetype-sys"
|
||||
version = "0.20.1"
|
||||
@@ -7587,6 +7636,33 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jsonschema"
|
||||
version = "0.30.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1b46a0365a611fbf1d2143104dcf910aada96fafd295bab16c60b802bf6fa1d"
|
||||
dependencies = [
|
||||
"ahash 0.8.11",
|
||||
"base64 0.22.1",
|
||||
"bytecount",
|
||||
"email_address",
|
||||
"fancy-regex 0.14.0",
|
||||
"fraction",
|
||||
"idna",
|
||||
"itoa",
|
||||
"num-cmp",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"referencing",
|
||||
"regex",
|
||||
"regex-syntax 0.8.5",
|
||||
"reqwest 0.12.15 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"uuid-simd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jsonwebtoken"
|
||||
version = "9.3.1"
|
||||
@@ -8271,7 +8347,7 @@ dependencies = [
|
||||
"prost 0.9.0",
|
||||
"prost-build 0.9.0",
|
||||
"prost-types 0.9.0",
|
||||
"reqwest 0.12.15",
|
||||
"reqwest 0.12.15 (git+https://github.com/zed-industries/reqwest.git?rev=951c770a32f1998d6e999cef3e59e0013e6c4415)",
|
||||
"serde",
|
||||
"workspace-hack",
|
||||
]
|
||||
@@ -9181,6 +9257,12 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-cmp"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "63335b2e2c34fae2fb0aa2cecfd9f0832a1e24b3b32ecec612c3426d46dc8aaa"
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.4.6"
|
||||
@@ -11774,6 +11856,20 @@ dependencies = [
|
||||
"syn 2.0.100",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "referencing"
|
||||
version = "0.30.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c8eff4fa778b5c2a57e85c5f2fe3a709c52f0e60d23146e2151cbef5893f420e"
|
||||
dependencies = [
|
||||
"ahash 0.8.11",
|
||||
"fluent-uri",
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"percent-encoding",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "refineable"
|
||||
version = "0.1.0"
|
||||
@@ -12043,6 +12139,43 @@ dependencies = [
|
||||
"winreg 0.50.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.12.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"bytes 1.10.1",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"http 1.3.1",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"hyper 1.6.0",
|
||||
"hyper-util",
|
||||
"ipnet",
|
||||
"js-sys",
|
||||
"log",
|
||||
"mime",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper 1.0.2",
|
||||
"tokio",
|
||||
"tower 0.5.2",
|
||||
"tower-service",
|
||||
"url",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"web-sys",
|
||||
"windows-registry 0.4.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.12.15"
|
||||
@@ -12103,7 +12236,7 @@ dependencies = [
|
||||
"http_client_tls",
|
||||
"log",
|
||||
"regex",
|
||||
"reqwest 0.12.15",
|
||||
"reqwest 0.12.15 (git+https://github.com/zed-industries/reqwest.git?rev=951c770a32f1998d6e999cef3e59e0013e6c4415)",
|
||||
"serde",
|
||||
"smol",
|
||||
"tokio",
|
||||
@@ -15954,6 +16087,17 @@ dependencies = [
|
||||
"sha1_smol",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uuid-simd"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "23b082222b4f6619906941c17eb2297fff4c2fb96cb60164170522942a200bd8"
|
||||
dependencies = [
|
||||
"outref",
|
||||
"uuid",
|
||||
"vsimd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "v_frame"
|
||||
version = "0.3.8"
|
||||
@@ -18054,6 +18198,7 @@ dependencies = [
|
||||
"hmac",
|
||||
"hyper 0.14.32",
|
||||
"hyper-rustls 0.27.5",
|
||||
"idna",
|
||||
"indexmap",
|
||||
"inout",
|
||||
"itertools 0.12.1",
|
||||
@@ -18077,6 +18222,7 @@ dependencies = [
|
||||
"num-bigint-dig",
|
||||
"num-integer",
|
||||
"num-iter",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"object",
|
||||
"once_cell",
|
||||
|
||||
@@ -462,6 +462,7 @@ indexmap = { version = "2.7.0", features = ["serde"] }
|
||||
indoc = "2"
|
||||
inventory = "0.3.19"
|
||||
itertools = "0.14.0"
|
||||
jsonschema = "0.30.0"
|
||||
jsonwebtoken = "9.3"
|
||||
jupyter-protocol = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
|
||||
jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed" ,rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
|
||||
|
||||
1
assets/icons/hammer.svg
Normal file
1
assets/icons/hammer.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-hammer-icon lucide-hammer"><path d="m15 12-8.373 8.373a1 1 0 1 1-3-3L12 9"/><path d="m18 15 4-4"/><path d="m21.5 11.5-1.914-1.914A2 2 0 0 1 19 8.172V7l-2.26-2.26a6 6 0 0 0-4.202-1.756L9 2.96l.92.82A6.18 6.18 0 0 1 12 8.4V10l2 2h1.172a2 2 0 0 1 1.414.586L18.5 14.5"/></svg>
|
||||
|
After Width: | Height: | Size: 475 B |
@@ -248,7 +248,6 @@
|
||||
"ctrl-shift-o": "agent::ToggleNavigationMenu",
|
||||
"ctrl-shift-i": "agent::ToggleOptionsMenu",
|
||||
"shift-escape": "agent::ExpandMessageEditor",
|
||||
"ctrl-e": "agent::ChatMode",
|
||||
"ctrl-alt-e": "agent::RemoveAllContext"
|
||||
}
|
||||
},
|
||||
@@ -963,6 +962,14 @@
|
||||
"escape": "menu::Cancel"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "ConfigureContextServerModal > Editor",
|
||||
"bindings": {
|
||||
"escape": "menu::Cancel",
|
||||
"enter": "editor::Newline",
|
||||
"ctrl-enter": "menu::Confirm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Diagnostics",
|
||||
"use_key_equivalents": true,
|
||||
|
||||
@@ -293,7 +293,6 @@
|
||||
"cmd-shift-o": "agent::ToggleNavigationMenu",
|
||||
"cmd-shift-i": "agent::ToggleOptionsMenu",
|
||||
"shift-escape": "agent::ExpandMessageEditor",
|
||||
"cmd-e": "agent::ChatMode",
|
||||
"cmd-alt-e": "agent::RemoveAllContext"
|
||||
}
|
||||
},
|
||||
@@ -1069,6 +1068,15 @@
|
||||
"escape": "menu::Cancel"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "ConfigureContextServerModal > Editor",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"escape": "menu::Cancel",
|
||||
"enter": "editor::Newline",
|
||||
"cmd-enter": "menu::Confirm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Diagnostics",
|
||||
"use_key_equivalents": true,
|
||||
|
||||
@@ -3,11 +3,12 @@ You are a highly skilled software engineer with extensive knowledge in many prog
|
||||
## Communication
|
||||
|
||||
1. Be conversational but professional.
|
||||
2. Refer to the USER in the second person and yourself in the first person.
|
||||
2. Refer to the user in the second person and yourself in the first person.
|
||||
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names.
|
||||
4. NEVER lie or make things up.
|
||||
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
|
||||
|
||||
{{#if has_tools}}
|
||||
## Tool Use
|
||||
|
||||
1. Make sure to adhere to the tools schema.
|
||||
@@ -22,6 +23,7 @@ You are a highly skilled software engineer with extensive knowledge in many prog
|
||||
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
|
||||
|
||||
{{! TODO: If there are files, we should mention it but otherwise omit that fact }}
|
||||
{{#if has_tools}}
|
||||
If appropriate, use tool calls to explore the current project, which contains the following root directories:
|
||||
|
||||
{{#each worktrees}}
|
||||
@@ -36,6 +38,14 @@ If appropriate, use tool calls to explore the current project, which contains th
|
||||
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
|
||||
- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file.
|
||||
{{/if}}
|
||||
{{/if}}
|
||||
{{else}}
|
||||
You are being tasked with providing a response, but you have no ability to use tools or to read or write any aspect of the user's system (other than any context the user might have provided to you).
|
||||
|
||||
As such, if you need the user to perform any actions for you, you must request them explicitly. Bias towards giving a response to the best of your ability, and then making requests for the user to take action (e.g. to give you more context) only optionally.
|
||||
|
||||
The one exception to this is if the user references something you don't know about - for example, the name of a source code file, function, type, or other piece of code that you have no awareness of. In this case, you MUST NOT MAKE SOMETHING UP, or assume you know what that thing is or how it works. Instead, you must ask the user for clarification rather than giving a response.
|
||||
{{/if}}
|
||||
|
||||
## Code Block Formatting
|
||||
|
||||
@@ -111,6 +121,8 @@ In Markdown, hash marks signify headings. For example:
|
||||
```
|
||||
</bad_example_do_not_do_this>
|
||||
This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks.
|
||||
|
||||
{{#if has_tools}}
|
||||
## Fixing Diagnostics
|
||||
|
||||
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.
|
||||
@@ -124,10 +136,11 @@ Otherwise, follow debugging best practices:
|
||||
2. Add descriptive logging statements and error messages to track variable and code state.
|
||||
3. Add test functions and statements to isolate the problem.
|
||||
|
||||
{{/if}}
|
||||
## Calling External APIs
|
||||
|
||||
1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission.
|
||||
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data.
|
||||
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file(s). If no such file exists or if the package is not present, use the latest version that is in your training data.
|
||||
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
|
||||
|
||||
## System Information
|
||||
@@ -135,10 +148,10 @@ Otherwise, follow debugging best practices:
|
||||
Operating System: {{os}}
|
||||
Default Shell: {{shell}}
|
||||
|
||||
{{#if (or has_rules has_default_user_rules)}}
|
||||
{{#if (or has_rules has_user_rules)}}
|
||||
## User's Custom Instructions
|
||||
|
||||
The following additional instructions are provided by the user, and should be followed to the best of your ability without interfering with the tool use guidelines.
|
||||
The following additional instructions are provided by the user, and should be followed to the best of your ability{{#if has_tools}} without interfering with the tool use guidelines{{/if}}.
|
||||
|
||||
{{#if has_rules}}
|
||||
There are project rules that apply to these root directories:
|
||||
|
||||
@@ -935,22 +935,9 @@
|
||||
"max_severity": null
|
||||
},
|
||||
"rust": {
|
||||
// When enabled, Zed runs `cargo check --message-format=json`-based commands and
|
||||
// collect cargo diagnostics instead of rust-analyzer.
|
||||
"fetch_cargo_diagnostics": false,
|
||||
// A command override for fetching the cargo diagnostics.
|
||||
// First argument is the command, followed by the arguments.
|
||||
"diagnostics_fetch_command": [
|
||||
"cargo",
|
||||
"check",
|
||||
"--quiet",
|
||||
"--workspace",
|
||||
"--message-format=json",
|
||||
"--all-targets",
|
||||
"--keep-going"
|
||||
],
|
||||
// Extra environment variables to pass to the diagnostics fetch command.
|
||||
"env": {}
|
||||
// When enabled, Zed disables rust-analyzer's check on save and starts to query
|
||||
// Cargo diagnostics separately.
|
||||
"fetch_cargo_diagnostics": false
|
||||
}
|
||||
},
|
||||
// Files or globs of files that will be excluded by Zed entirely. They will be skipped during file
|
||||
|
||||
@@ -35,6 +35,7 @@ context_server.workspace = true
|
||||
convert_case.workspace = true
|
||||
db.workspace = true
|
||||
editor.workspace = true
|
||||
extension.workspace = true
|
||||
feature_flags.workspace = true
|
||||
file_icons.workspace = true
|
||||
fs.workspace = true
|
||||
@@ -47,6 +48,7 @@ html_to_markdown.workspace = true
|
||||
http_client.workspace = true
|
||||
indexmap.workspace = true
|
||||
itertools.workspace = true
|
||||
jsonschema.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
language_model_selector.workspace = true
|
||||
@@ -56,6 +58,7 @@ lsp.workspace = true
|
||||
markdown.workspace = true
|
||||
menu.workspace = true
|
||||
multi_buffer.workspace = true
|
||||
notifications.workspace = true
|
||||
ordered-float.workspace = true
|
||||
parking_lot.workspace = true
|
||||
paths.workspace = true
|
||||
@@ -71,6 +74,7 @@ rope.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
serde_json_lenient.workspace = true
|
||||
settings.workspace = true
|
||||
smallvec.workspace = true
|
||||
smol.workspace = true
|
||||
|
||||
@@ -6,6 +6,7 @@ mod assistant_panel;
|
||||
mod buffer_codegen;
|
||||
mod context;
|
||||
mod context_picker;
|
||||
mod context_server_configuration;
|
||||
mod context_store;
|
||||
mod context_strip;
|
||||
mod history_store;
|
||||
@@ -30,6 +31,7 @@ use command_palette_hooks::CommandPaletteFilter;
|
||||
use feature_flags::{Assistant2FeatureFlag, FeatureFlagAppExt};
|
||||
use fs::Fs;
|
||||
use gpui::{App, actions, impl_actions};
|
||||
use language::LanguageRegistry;
|
||||
use prompt_store::PromptBuilder;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
@@ -44,6 +46,8 @@ pub use crate::inline_assistant::InlineAssistant;
|
||||
pub use crate::thread::{Message, MessageSegment, Thread, ThreadEvent};
|
||||
pub use crate::thread_store::ThreadStore;
|
||||
pub use agent_diff::{AgentDiff, AgentDiffToolbar};
|
||||
pub use context_store::ContextStore;
|
||||
pub use ui::{all_agent_previews, get_agent_preview};
|
||||
|
||||
actions!(
|
||||
agent,
|
||||
@@ -60,7 +64,6 @@ actions!(
|
||||
AddContextServer,
|
||||
RemoveSelectedThread,
|
||||
Chat,
|
||||
ChatMode,
|
||||
CycleNextInlineAssist,
|
||||
CyclePreviousInlineAssist,
|
||||
FocusUp,
|
||||
@@ -107,11 +110,13 @@ pub fn init(
|
||||
fs: Arc<dyn Fs>,
|
||||
client: Arc<Client>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut App,
|
||||
) {
|
||||
AssistantSettings::register(cx);
|
||||
thread_store::init(cx);
|
||||
assistant_panel::init(cx);
|
||||
context_server_configuration::init(language_registry, cx);
|
||||
|
||||
inline_assistant::init(
|
||||
fs.clone(),
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
mod add_context_server_modal;
|
||||
mod configure_context_server_modal;
|
||||
mod manage_profiles_modal;
|
||||
mod tool_picker;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_tool::{ToolSource, ToolWorkingSet};
|
||||
use collections::HashMap;
|
||||
use context_server::manager::ContextServerManager;
|
||||
use context_server::manager::{ContextServer, ContextServerManager, ContextServerStatus};
|
||||
use fs::Fs;
|
||||
use gpui::{
|
||||
Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, ScrollHandle, Subscription,
|
||||
Action, Animation, AnimationExt as _, AnyView, App, Entity, EventEmitter, FocusHandle,
|
||||
Focusable, ScrollHandle, Subscription, pulsating_between,
|
||||
};
|
||||
use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
|
||||
use settings::{Settings, update_settings_file};
|
||||
@@ -22,6 +24,7 @@ use util::ResultExt as _;
|
||||
use zed_actions::ExtensionCategoryFilter;
|
||||
|
||||
pub(crate) use add_context_server_modal::AddContextServerModal;
|
||||
pub(crate) use configure_context_server_modal::ConfigureContextServerModal;
|
||||
pub(crate) use manage_profiles_modal::ManageProfilesModal;
|
||||
|
||||
use crate::AddContextServer;
|
||||
@@ -254,10 +257,12 @@ impl AssistantConfiguration {
|
||||
)
|
||||
}
|
||||
|
||||
fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
fn render_context_servers_section(
|
||||
&mut self,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let context_servers = self.context_server_manager.read(cx).all_servers().clone();
|
||||
let tools_by_source = self.tools.read(cx).tools_by_source(cx);
|
||||
let empty = Vec::new();
|
||||
|
||||
const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
|
||||
|
||||
@@ -272,136 +277,11 @@ impl AssistantConfiguration {
|
||||
.child(Headline::new("Model Context Protocol (MCP) Servers"))
|
||||
.child(Label::new(SUBHEADING).color(Color::Muted)),
|
||||
)
|
||||
.children(context_servers.into_iter().map(|context_server| {
|
||||
let is_running = context_server.client().is_some();
|
||||
let are_tools_expanded = self
|
||||
.expanded_context_server_tools
|
||||
.get(&context_server.id())
|
||||
.copied()
|
||||
.unwrap_or_default();
|
||||
|
||||
let tools = tools_by_source
|
||||
.get(&ToolSource::ContextServer {
|
||||
id: context_server.id().into(),
|
||||
})
|
||||
.unwrap_or_else(|| &empty);
|
||||
let tool_count = tools.len();
|
||||
|
||||
v_flex()
|
||||
.id(SharedString::from(context_server.id()))
|
||||
.border_1()
|
||||
.rounded_md()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.bg(cx.theme().colors().background.opacity(0.25))
|
||||
.child(
|
||||
h_flex()
|
||||
.p_1()
|
||||
.justify_between()
|
||||
.when(are_tools_expanded && tool_count > 1, |element| {
|
||||
element
|
||||
.border_b_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
})
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Disclosure::new("tool-list-disclosure", are_tools_expanded)
|
||||
.disabled(tool_count == 0)
|
||||
.on_click(cx.listener({
|
||||
let context_server_id = context_server.id();
|
||||
move |this, _event, _window, _cx| {
|
||||
let is_open = this
|
||||
.expanded_context_server_tools
|
||||
.entry(context_server_id.clone())
|
||||
.or_insert(false);
|
||||
|
||||
*is_open = !*is_open;
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(Indicator::dot().color(if is_running {
|
||||
Color::Success
|
||||
} else {
|
||||
Color::Error
|
||||
}))
|
||||
.child(Label::new(context_server.id()))
|
||||
.child(
|
||||
Label::new(format!("{tool_count} tools"))
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
Switch::new("context-server-switch", is_running.into())
|
||||
.color(SwitchColor::Accent)
|
||||
.on_click({
|
||||
let context_server_manager =
|
||||
self.context_server_manager.clone();
|
||||
let context_server = context_server.clone();
|
||||
move |state, _window, cx| match state {
|
||||
ToggleState::Unselected
|
||||
| ToggleState::Indeterminate => {
|
||||
context_server_manager.update(cx, |this, cx| {
|
||||
this.stop_server(context_server.clone(), cx)
|
||||
.log_err();
|
||||
});
|
||||
}
|
||||
ToggleState::Selected => {
|
||||
cx.spawn({
|
||||
let context_server_manager =
|
||||
context_server_manager.clone();
|
||||
let context_server = context_server.clone();
|
||||
async move |cx| {
|
||||
if let Some(start_server_task) =
|
||||
context_server_manager
|
||||
.update(cx, |this, cx| {
|
||||
this.start_server(
|
||||
context_server,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.log_err()
|
||||
{
|
||||
start_server_task.await.log_err();
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
}),
|
||||
),
|
||||
)
|
||||
.map(|parent| {
|
||||
if !are_tools_expanded {
|
||||
return parent;
|
||||
}
|
||||
|
||||
parent.child(v_flex().py_1p5().px_1().gap_1().children(
|
||||
tools.into_iter().enumerate().map(|(ix, tool)| {
|
||||
h_flex()
|
||||
.id(("tool-item", ix))
|
||||
.px_1()
|
||||
.gap_2()
|
||||
.justify_between()
|
||||
.hover(|style| style.bg(cx.theme().colors().element_hover))
|
||||
.rounded_sm()
|
||||
.child(
|
||||
Label::new(tool.name())
|
||||
.buffer_font(cx)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
Icon::new(IconName::Info)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Ignored),
|
||||
)
|
||||
.tooltip(Tooltip::text(tool.description()))
|
||||
}),
|
||||
))
|
||||
})
|
||||
}))
|
||||
.children(
|
||||
context_servers
|
||||
.into_iter()
|
||||
.map(|context_server| self.render_context_server(context_server, window, cx)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_between()
|
||||
@@ -429,7 +309,7 @@ impl AssistantConfiguration {
|
||||
.style(ButtonStyle::Filled)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.full_width()
|
||||
.icon(IconName::DatabaseZap)
|
||||
.icon(IconName::Hammer)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_position(IconPosition::Start)
|
||||
.on_click(|_event, window, cx| {
|
||||
@@ -447,10 +327,214 @@ impl AssistantConfiguration {
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_context_server(
|
||||
&self,
|
||||
context_server: Arc<ContextServer>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl use<> + IntoElement {
|
||||
let tools_by_source = self.tools.read(cx).tools_by_source(cx);
|
||||
let server_status = self
|
||||
.context_server_manager
|
||||
.read(cx)
|
||||
.status_for_server(&context_server.id());
|
||||
|
||||
let is_running = matches!(server_status, Some(ContextServerStatus::Running));
|
||||
|
||||
let error = if let Some(ContextServerStatus::Error(error)) = server_status.clone() {
|
||||
Some(error)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let are_tools_expanded = self
|
||||
.expanded_context_server_tools
|
||||
.get(&context_server.id())
|
||||
.copied()
|
||||
.unwrap_or_default();
|
||||
|
||||
let tools = tools_by_source
|
||||
.get(&ToolSource::ContextServer {
|
||||
id: context_server.id().into(),
|
||||
})
|
||||
.map_or([].as_slice(), |tools| tools.as_slice());
|
||||
let tool_count = tools.len();
|
||||
|
||||
let border_color = cx.theme().colors().border.opacity(0.6);
|
||||
|
||||
v_flex()
|
||||
.id(SharedString::from(context_server.id()))
|
||||
.border_1()
|
||||
.rounded_md()
|
||||
.border_color(border_color)
|
||||
.bg(cx.theme().colors().background.opacity(0.2))
|
||||
.overflow_hidden()
|
||||
.child(
|
||||
h_flex()
|
||||
.p_1()
|
||||
.justify_between()
|
||||
.when(
|
||||
error.is_some() || are_tools_expanded && tool_count > 1,
|
||||
|element| element.border_b_1().border_color(border_color),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.child(
|
||||
Disclosure::new(
|
||||
"tool-list-disclosure",
|
||||
are_tools_expanded || error.is_some(),
|
||||
)
|
||||
.disabled(tool_count == 0)
|
||||
.on_click(cx.listener({
|
||||
let context_server_id = context_server.id();
|
||||
move |this, _event, _window, _cx| {
|
||||
let is_open = this
|
||||
.expanded_context_server_tools
|
||||
.entry(context_server_id.clone())
|
||||
.or_insert(false);
|
||||
|
||||
*is_open = !*is_open;
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(match server_status {
|
||||
Some(ContextServerStatus::Starting) => {
|
||||
let color = Color::Success.color(cx);
|
||||
Indicator::dot()
|
||||
.color(Color::Success)
|
||||
.with_animation(
|
||||
SharedString::from(format!(
|
||||
"{}-starting",
|
||||
context_server.id(),
|
||||
)),
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
.with_easing(pulsating_between(0.4, 1.)),
|
||||
move |this, delta| {
|
||||
this.color(color.alpha(delta).into())
|
||||
},
|
||||
)
|
||||
.into_any_element()
|
||||
}
|
||||
Some(ContextServerStatus::Running) => {
|
||||
Indicator::dot().color(Color::Success).into_any_element()
|
||||
}
|
||||
Some(ContextServerStatus::Error(_)) => {
|
||||
Indicator::dot().color(Color::Error).into_any_element()
|
||||
}
|
||||
None => Indicator::dot().color(Color::Muted).into_any_element(),
|
||||
})
|
||||
.child(Label::new(context_server.id()).ml_0p5())
|
||||
.when(is_running, |this| {
|
||||
this.child(
|
||||
Label::new(if tool_count == 1 {
|
||||
SharedString::from("1 tool")
|
||||
} else {
|
||||
SharedString::from(format!("{} tools", tool_count))
|
||||
})
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
Switch::new("context-server-switch", is_running.into())
|
||||
.color(SwitchColor::Accent)
|
||||
.on_click({
|
||||
let context_server_manager = self.context_server_manager.clone();
|
||||
let context_server = context_server.clone();
|
||||
move |state, _window, cx| match state {
|
||||
ToggleState::Unselected | ToggleState::Indeterminate => {
|
||||
context_server_manager.update(cx, |this, cx| {
|
||||
this.stop_server(context_server.clone(), cx).log_err();
|
||||
});
|
||||
}
|
||||
ToggleState::Selected => {
|
||||
cx.spawn({
|
||||
let context_server_manager =
|
||||
context_server_manager.clone();
|
||||
let context_server = context_server.clone();
|
||||
async move |cx| {
|
||||
if let Some(start_server_task) =
|
||||
context_server_manager
|
||||
.update(cx, |this, cx| {
|
||||
this.start_server(context_server, cx)
|
||||
})
|
||||
.log_err()
|
||||
{
|
||||
start_server_task.await.log_err();
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
}),
|
||||
),
|
||||
)
|
||||
.map(|parent| {
|
||||
if let Some(error) = error {
|
||||
return parent.child(
|
||||
h_flex()
|
||||
.p_2()
|
||||
.gap_2()
|
||||
.items_start()
|
||||
.child(
|
||||
h_flex()
|
||||
.flex_none()
|
||||
.h(window.line_height() / 1.6_f32)
|
||||
.justify_center()
|
||||
.child(
|
||||
Icon::new(IconName::XCircle)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Error),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
div().w_full().child(
|
||||
Label::new(error)
|
||||
.buffer_font(cx)
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
if !are_tools_expanded || tools.is_empty() {
|
||||
return parent;
|
||||
}
|
||||
|
||||
parent.child(v_flex().py_1p5().px_1().gap_1().children(
|
||||
tools.into_iter().enumerate().map(|(ix, tool)| {
|
||||
h_flex()
|
||||
.id(("tool-item", ix))
|
||||
.px_1()
|
||||
.gap_2()
|
||||
.justify_between()
|
||||
.hover(|style| style.bg(cx.theme().colors().element_hover))
|
||||
.rounded_sm()
|
||||
.child(
|
||||
Label::new(tool.name())
|
||||
.buffer_font(cx)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
Icon::new(IconName::Info)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Ignored),
|
||||
)
|
||||
.tooltip(Tooltip::text(tool.description()))
|
||||
}),
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AssistantConfiguration {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
v_flex()
|
||||
.id("assistant-configuration")
|
||||
.key_context("AgentConfiguration")
|
||||
@@ -467,7 +551,7 @@ impl Render for AssistantConfiguration {
|
||||
.overflow_y_scroll()
|
||||
.child(self.render_command_permission(cx))
|
||||
.child(Divider::horizontal().color(DividerColor::Border))
|
||||
.child(self.render_context_servers_section(cx))
|
||||
.child(self.render_context_servers_section(window, cx))
|
||||
.child(Divider::horizontal().color(DividerColor::Border))
|
||||
.child(self.render_provider_configuration_section(cx)),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,443 @@
|
||||
use std::{
|
||||
sync::{Arc, Mutex},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use context_server::manager::{ContextServerManager, ContextServerStatus};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use extension::ContextServerConfiguration;
|
||||
use gpui::{
|
||||
Animation, AnimationExt, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Task,
|
||||
TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, percentage,
|
||||
};
|
||||
use language::{Language, LanguageRegistry};
|
||||
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
|
||||
use notifications::status_toast::{StatusToast, ToastIcon};
|
||||
use settings::{Settings as _, update_settings_file};
|
||||
use theme::ThemeSettings;
|
||||
use ui::{KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*};
|
||||
use util::ResultExt;
|
||||
use workspace::{ModalView, Workspace};
|
||||
|
||||
pub(crate) struct ConfigureContextServerModal {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_servers_to_setup: Vec<ConfigureContextServer>,
|
||||
context_server_manager: Entity<ContextServerManager>,
|
||||
}
|
||||
|
||||
struct ConfigureContextServer {
|
||||
id: Arc<str>,
|
||||
installation_instructions: Entity<markdown::Markdown>,
|
||||
settings_validator: Option<jsonschema::Validator>,
|
||||
settings_editor: Entity<Editor>,
|
||||
last_error: Option<SharedString>,
|
||||
waiting_for_context_server: bool,
|
||||
}
|
||||
|
||||
impl ConfigureContextServerModal {
|
||||
pub fn new(
|
||||
configurations: impl Iterator<Item = (Arc<str>, ContextServerConfiguration)>,
|
||||
jsonc_language: Option<Arc<Language>>,
|
||||
context_server_manager: Entity<ContextServerManager>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Option<Self> {
|
||||
let context_servers_to_setup = configurations
|
||||
.map(|(id, manifest)| {
|
||||
let jsonc_language = jsonc_language.clone();
|
||||
let settings_validator = jsonschema::validator_for(&manifest.settings_schema)
|
||||
.context("Failed to load JSON schema for context server settings")
|
||||
.log_err();
|
||||
ConfigureContextServer {
|
||||
id: id.clone(),
|
||||
installation_instructions: cx.new(|cx| {
|
||||
Markdown::new(
|
||||
manifest.installation_instructions.clone().into(),
|
||||
Some(language_registry.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
}),
|
||||
settings_validator,
|
||||
settings_editor: cx.new(|cx| {
|
||||
let mut editor = Editor::auto_height(16, window, cx);
|
||||
editor.set_text(manifest.default_settings.trim(), window, cx);
|
||||
if let Some(buffer) = editor.buffer().read(cx).as_singleton() {
|
||||
buffer.update(cx, |buffer, cx| buffer.set_language(jsonc_language, cx))
|
||||
}
|
||||
editor
|
||||
}),
|
||||
waiting_for_context_server: false,
|
||||
last_error: None,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if context_servers_to_setup.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Self {
|
||||
workspace,
|
||||
context_servers_to_setup,
|
||||
context_server_manager,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ConfigureContextServerModal {
|
||||
pub fn confirm(&mut self, cx: &mut Context<Self>) {
|
||||
if self.context_servers_to_setup.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(workspace) = self.workspace.upgrade() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let configuration = &mut self.context_servers_to_setup[0];
|
||||
if configuration.waiting_for_context_server {
|
||||
return;
|
||||
}
|
||||
|
||||
let settings_value = match serde_json_lenient::from_str::<serde_json::Value>(
|
||||
&configuration.settings_editor.read(cx).text(cx),
|
||||
) {
|
||||
Ok(value) => value,
|
||||
Err(error) => {
|
||||
configuration.last_error = Some(error.to_string().into());
|
||||
cx.notify();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(validator) = configuration.settings_validator.as_ref() {
|
||||
if let Err(error) = validator.validate(&settings_value) {
|
||||
configuration.last_error = Some(error.to_string().into());
|
||||
cx.notify();
|
||||
return;
|
||||
}
|
||||
}
|
||||
let id = configuration.id.clone();
|
||||
|
||||
let settings_changed = context_server::ContextServerSettings::get_global(cx)
|
||||
.context_servers
|
||||
.get(&id)
|
||||
.map_or(true, |config| {
|
||||
config.settings.as_ref() != Some(&settings_value)
|
||||
});
|
||||
|
||||
let is_running = self.context_server_manager.read(cx).status_for_server(&id)
|
||||
== Some(ContextServerStatus::Running);
|
||||
|
||||
if !settings_changed && is_running {
|
||||
self.complete_setup(id, cx);
|
||||
return;
|
||||
}
|
||||
|
||||
configuration.waiting_for_context_server = true;
|
||||
|
||||
let task = wait_for_context_server(&self.context_server_manager, id.clone(), cx);
|
||||
cx.spawn({
|
||||
let id = id.clone();
|
||||
async move |this, cx| {
|
||||
let result = task.await;
|
||||
this.update(cx, |this, cx| match result {
|
||||
Ok(_) => {
|
||||
this.complete_setup(id, cx);
|
||||
}
|
||||
Err(err) => {
|
||||
if let Some(configuration) = this.context_servers_to_setup.get_mut(0) {
|
||||
configuration.last_error = Some(err.into());
|
||||
configuration.waiting_for_context_server = false;
|
||||
} else {
|
||||
this.dismiss(cx);
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
// When we write the settings to the file, the context server will be restarted.
|
||||
update_settings_file::<context_server::ContextServerSettings>(
|
||||
workspace.read(cx).app_state().fs.clone(),
|
||||
cx,
|
||||
{
|
||||
let id = id.clone();
|
||||
|settings, _| {
|
||||
if let Some(server_config) = settings.context_servers.get_mut(&id) {
|
||||
server_config.settings = Some(settings_value);
|
||||
} else {
|
||||
settings.context_servers.insert(
|
||||
id,
|
||||
context_server::ServerConfig {
|
||||
settings: Some(settings_value),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn complete_setup(&mut self, id: Arc<str>, cx: &mut Context<Self>) {
|
||||
self.context_servers_to_setup.remove(0);
|
||||
cx.notify();
|
||||
|
||||
if !self.context_servers_to_setup.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
self.workspace
|
||||
.update(cx, {
|
||||
|workspace, cx| {
|
||||
let status_toast = StatusToast::new(
|
||||
format!("{} configured successfully.", id),
|
||||
cx,
|
||||
|this, _cx| {
|
||||
this.icon(ToastIcon::new(IconName::Hammer).color(Color::Muted))
|
||||
.action("Dismiss", |_, _| {})
|
||||
},
|
||||
);
|
||||
|
||||
workspace.toggle_status_toast(status_toast, cx);
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
|
||||
self.dismiss(cx);
|
||||
}
|
||||
|
||||
fn dismiss(&self, cx: &mut Context<Self>) {
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
}
|
||||
|
||||
fn wait_for_context_server(
|
||||
context_server_manager: &Entity<ContextServerManager>,
|
||||
context_server_id: Arc<str>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<(), Arc<str>>> {
|
||||
let (tx, rx) = futures::channel::oneshot::channel();
|
||||
let tx = Arc::new(Mutex::new(Some(tx)));
|
||||
|
||||
let subscription = cx.subscribe(context_server_manager, move |_, event, _cx| match event {
|
||||
context_server::manager::Event::ServerStatusChanged { server_id, status } => match status {
|
||||
Some(ContextServerStatus::Running) => {
|
||||
if server_id == &context_server_id {
|
||||
if let Some(tx) = tx.lock().unwrap().take() {
|
||||
let _ = tx.send(Ok(()));
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(ContextServerStatus::Error(error)) => {
|
||||
if server_id == &context_server_id {
|
||||
if let Some(tx) = tx.lock().unwrap().take() {
|
||||
let _ = tx.send(Err(error.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
});
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
let result = rx.await.unwrap();
|
||||
drop(subscription);
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
impl Render for ConfigureContextServerModal {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let Some(configuration) = self.context_servers_to_setup.first() else {
|
||||
return div().child("No context servers to setup");
|
||||
};
|
||||
|
||||
let focus_handle = self.focus_handle(cx);
|
||||
|
||||
div()
|
||||
.elevation_3(cx)
|
||||
.w(rems(34.))
|
||||
.key_context("ConfigureContextServerModal")
|
||||
.on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| this.confirm(cx)))
|
||||
.on_action(cx.listener(|this, _: &menu::Cancel, _window, cx| this.dismiss(cx)))
|
||||
.capture_any_mouse_down(cx.listener(|this, _, window, cx| {
|
||||
this.focus_handle(cx).focus(window);
|
||||
}))
|
||||
.child(
|
||||
Modal::new("configure-context-server", None)
|
||||
.header(ModalHeader::new().headline(format!("Configure {}", configuration.id)))
|
||||
.section(
|
||||
Section::new()
|
||||
.child(div().pb_2().text_sm().child(MarkdownElement::new(
|
||||
configuration.installation_instructions.clone(),
|
||||
default_markdown_style(window, cx),
|
||||
)))
|
||||
.child(
|
||||
div()
|
||||
.p_2()
|
||||
.rounded_md()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.gap_1()
|
||||
.child({
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.buffer_font.family.clone(),
|
||||
font_fallbacks: settings.buffer_font.fallbacks.clone(),
|
||||
font_size: settings.buffer_font_size(cx).into(),
|
||||
font_weight: settings.buffer_font.weight,
|
||||
line_height: relative(
|
||||
settings.buffer_line_height.value(),
|
||||
),
|
||||
..Default::default()
|
||||
};
|
||||
EditorElement::new(
|
||||
&configuration.settings_editor,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
syntax: cx.theme().syntax().clone(),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
})
|
||||
.when_some(configuration.last_error.clone(), |this, error| {
|
||||
this.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.child(
|
||||
Icon::new(IconName::Warning)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Warning),
|
||||
)
|
||||
.child(
|
||||
div().w_full().child(
|
||||
Label::new(error)
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
),
|
||||
),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.when(configuration.waiting_for_context_server, |this| {
|
||||
this.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.child(
|
||||
Icon::new(IconName::ArrowCircle)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Info)
|
||||
.with_animation(
|
||||
"arrow-circle",
|
||||
Animation::new(Duration::from_secs(2)).repeat(),
|
||||
|icon, delta| {
|
||||
icon.transform(Transformation::rotate(
|
||||
percentage(delta),
|
||||
))
|
||||
},
|
||||
)
|
||||
.into_any_element(),
|
||||
)
|
||||
.child(
|
||||
Label::new("Waiting for Context Server")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.footer(
|
||||
ModalFooter::new().end_slot(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
Button::new("cancel", "Cancel")
|
||||
.key_binding(
|
||||
KeyBinding::for_action_in(
|
||||
&menu::Cancel,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.map(|kb| kb.size(rems_from_px(12.))),
|
||||
)
|
||||
.on_click(cx.listener(|this, _event, _window, cx| {
|
||||
this.dismiss(cx)
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new("configure-server", "Configure MCP")
|
||||
.disabled(configuration.waiting_for_context_server)
|
||||
.key_binding(
|
||||
KeyBinding::for_action_in(
|
||||
&menu::Confirm,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.map(|kb| kb.size(rems_from_px(12.))),
|
||||
)
|
||||
.on_click(cx.listener(|this, _event, _window, cx| {
|
||||
this.confirm(cx)
|
||||
})),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
|
||||
let theme_settings = ThemeSettings::get_global(cx);
|
||||
let colors = cx.theme().colors();
|
||||
let mut text_style = window.text_style();
|
||||
text_style.refine(&TextStyleRefinement {
|
||||
font_family: Some(theme_settings.ui_font.family.clone()),
|
||||
font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
|
||||
font_features: Some(theme_settings.ui_font.features.clone()),
|
||||
font_size: Some(TextSize::XSmall.rems(cx).into()),
|
||||
color: Some(colors.text_muted),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
MarkdownStyle {
|
||||
base_text_style: text_style.clone(),
|
||||
selection_background_color: cx.theme().players().local().selection,
|
||||
link: TextStyleRefinement {
|
||||
background_color: Some(colors.editor_foreground.opacity(0.025)),
|
||||
underline: Some(UnderlineStyle {
|
||||
color: Some(colors.text_accent.opacity(0.5)),
|
||||
thickness: px(1.),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
impl ModalView for ConfigureContextServerModal {}
|
||||
impl EventEmitter<DismissEvent> for ConfigureContextServerModal {}
|
||||
impl Focusable for ConfigureContextServerModal {
|
||||
fn focus_handle(&self, cx: &App) -> FocusHandle {
|
||||
if let Some(current) = self.context_servers_to_setup.first() {
|
||||
current.settings_editor.read(cx).focus_handle(cx)
|
||||
} else {
|
||||
cx.focus_handle()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -843,7 +843,7 @@ pub fn load_context(
|
||||
text.push_str(
|
||||
"\n<context>\n\
|
||||
The following items were attached by the user. \
|
||||
You don't need to use other tools to read them.\n\n",
|
||||
They are up-to-date and don't need to be re-read.\n\n",
|
||||
);
|
||||
|
||||
if !file_context.is_empty() {
|
||||
|
||||
120
crates/agent/src/context_server_configuration.rs
Normal file
120
crates/agent/src/context_server_configuration.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use context_server::ContextServerDescriptorRegistry;
|
||||
use extension::ExtensionManifest;
|
||||
use language::LanguageRegistry;
|
||||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::{AssistantPanel, assistant_configuration::ConfigureContextServerModal};
|
||||
|
||||
pub(crate) fn init(language_registry: Arc<LanguageRegistry>, cx: &mut App) {
|
||||
cx.observe_new(move |_: &mut Workspace, window, cx| {
|
||||
let Some(window) = window else {
|
||||
return;
|
||||
};
|
||||
|
||||
if let Some(extension_events) = extension::ExtensionEvents::try_global(cx).as_ref() {
|
||||
cx.subscribe_in(extension_events, window, {
|
||||
let language_registry = language_registry.clone();
|
||||
move |workspace, _, event, window, cx| match event {
|
||||
extension::Event::ExtensionInstalled(manifest) => {
|
||||
show_configure_mcp_modal(
|
||||
language_registry.clone(),
|
||||
manifest,
|
||||
workspace,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
extension::Event::ConfigureExtensionRequested(manifest) => {
|
||||
if !manifest.context_servers.is_empty() {
|
||||
show_configure_mcp_modal(
|
||||
language_registry.clone(),
|
||||
manifest,
|
||||
workspace,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
} else {
|
||||
log::info!(
|
||||
"No extension events global found. Skipping context server configuration wizard"
|
||||
);
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn show_configure_mcp_modal(
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
manifest: &Arc<ExtensionManifest>,
|
||||
workspace: &mut Workspace,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<'_, Workspace>,
|
||||
) {
|
||||
let Some(context_server_manager) = workspace.panel::<AssistantPanel>(cx).map(|panel| {
|
||||
panel
|
||||
.read(cx)
|
||||
.thread_store()
|
||||
.read(cx)
|
||||
.context_server_manager()
|
||||
}) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let registry = ContextServerDescriptorRegistry::global(cx).read(cx);
|
||||
let project = workspace.project().clone();
|
||||
let configuration_tasks = manifest
|
||||
.context_servers
|
||||
.keys()
|
||||
.cloned()
|
||||
.filter_map({
|
||||
|key| {
|
||||
let descriptor = registry.context_server_descriptor(&key)?;
|
||||
Some(cx.spawn({
|
||||
let project = project.clone();
|
||||
async move |_, cx| {
|
||||
descriptor
|
||||
.configuration(project, &cx)
|
||||
.await
|
||||
.context("Failed to resolve context server configuration")
|
||||
.log_err()
|
||||
.flatten()
|
||||
.map(|config| (key, config))
|
||||
}
|
||||
}))
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let jsonc_language = language_registry.language_for_name("jsonc");
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let descriptors = futures::future::join_all(configuration_tasks).await;
|
||||
let jsonc_language = jsonc_language.await.ok();
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
let modal = ConfigureContextServerModal::new(
|
||||
descriptors.into_iter().flatten(),
|
||||
jsonc_language,
|
||||
context_server_manager,
|
||||
language_registry,
|
||||
cx.entity().downgrade(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
if let Some(modal) = modal {
|
||||
this.toggle_modal(window, cx, |_, _| modal);
|
||||
}
|
||||
})
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
@@ -4,7 +4,7 @@ use std::sync::Arc;
|
||||
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
|
||||
use crate::context::{ContextLoadResult, load_context};
|
||||
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
|
||||
use crate::ui::AnimatedLabel;
|
||||
use crate::ui::{AgentPreview, AnimatedLabel};
|
||||
use buffer_diff::BufferDiff;
|
||||
use collections::HashSet;
|
||||
use editor::actions::{MoveUp, Paste};
|
||||
@@ -42,10 +42,11 @@ use crate::profile_selector::ProfileSelector;
|
||||
use crate::thread::{Thread, TokenUsageRatio};
|
||||
use crate::thread_store::ThreadStore;
|
||||
use crate::{
|
||||
AgentDiff, Chat, ChatMode, ExpandMessageEditor, NewThread, OpenAgentDiff, RemoveAllContext,
|
||||
ToggleContextPicker, ToggleProfileSelector,
|
||||
ActiveThread, AgentDiff, Chat, ExpandMessageEditor, NewThread, OpenAgentDiff, RemoveAllContext,
|
||||
ToggleContextPicker, ToggleProfileSelector, register_agent_preview,
|
||||
};
|
||||
|
||||
#[derive(RegisterComponent)]
|
||||
pub struct MessageEditor {
|
||||
thread: Entity<Thread>,
|
||||
incompatible_tools_state: Entity<IncompatibleToolsState>,
|
||||
@@ -206,10 +207,6 @@ impl MessageEditor {
|
||||
&self.context_store
|
||||
}
|
||||
|
||||
fn toggle_chat_mode(&mut self, _: &ChatMode, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn expand_message_editor(
|
||||
&mut self,
|
||||
_: &ExpandMessageEditor,
|
||||
@@ -432,12 +429,13 @@ impl MessageEditor {
|
||||
Some(
|
||||
IconButton::new("max-mode", IconName::ZedMaxMode)
|
||||
.icon_size(IconSize::Small)
|
||||
.toggle_state(active_completion_mode == Some(CompletionMode::Max))
|
||||
.icon_color(Color::Muted)
|
||||
.toggle_state(active_completion_mode == CompletionMode::Max)
|
||||
.on_click(cx.listener(move |this, _event, _window, cx| {
|
||||
this.thread.update(cx, |thread, _cx| {
|
||||
thread.set_completion_mode(match active_completion_mode {
|
||||
Some(CompletionMode::Max) => Some(CompletionMode::Normal),
|
||||
Some(CompletionMode::Normal) | None => Some(CompletionMode::Max),
|
||||
CompletionMode::Max => CompletionMode::Normal,
|
||||
CompletionMode::Normal => CompletionMode::Max,
|
||||
});
|
||||
});
|
||||
}))
|
||||
@@ -499,7 +497,6 @@ impl MessageEditor {
|
||||
.on_action(cx.listener(Self::toggle_context_picker))
|
||||
.on_action(cx.listener(Self::remove_all_context))
|
||||
.on_action(cx.listener(Self::move_up))
|
||||
.on_action(cx.listener(Self::toggle_chat_mode))
|
||||
.on_action(cx.listener(Self::expand_message_editor))
|
||||
.capture_action(cx.listener(Self::paste))
|
||||
.gap_2()
|
||||
@@ -1206,3 +1203,53 @@ impl Render for MessageEditor {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Component for MessageEditor {
|
||||
fn scope() -> ComponentScope {
|
||||
ComponentScope::Agent
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentPreview for MessageEditor {
|
||||
fn create_preview(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
active_thread: Entity<ActiveThread>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Option<AnyElement> {
|
||||
if let Some(workspace_entity) = workspace.upgrade() {
|
||||
let fs = workspace_entity.read(cx).app_state().fs.clone();
|
||||
let weak_project = workspace_entity.read(cx).project().clone().downgrade();
|
||||
let context_store = cx.new(|_cx| ContextStore::new(weak_project, None));
|
||||
let thread = active_thread.read(cx).thread().clone();
|
||||
|
||||
let example_message_editor = cx.new(|cx| {
|
||||
MessageEditor::new(
|
||||
fs,
|
||||
workspace,
|
||||
context_store,
|
||||
None,
|
||||
thread_store,
|
||||
thread,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
Some(
|
||||
v_flex()
|
||||
.gap_4()
|
||||
.children(vec![single_example(
|
||||
"Default",
|
||||
example_message_editor.clone().into_any_element(),
|
||||
)])
|
||||
.into_any_element(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
register_agent_preview!(MessageEditor);
|
||||
|
||||
@@ -301,6 +301,14 @@ pub enum TokenUsageRatio {
|
||||
Exceeded,
|
||||
}
|
||||
|
||||
fn default_completion_mode(cx: &App) -> CompletionMode {
|
||||
if cx.is_staff() {
|
||||
CompletionMode::Max
|
||||
} else {
|
||||
CompletionMode::Normal
|
||||
}
|
||||
}
|
||||
|
||||
/// A thread of conversation with the LLM.
|
||||
pub struct Thread {
|
||||
id: ThreadId,
|
||||
@@ -310,7 +318,7 @@ pub struct Thread {
|
||||
detailed_summary_task: Task<Option<()>>,
|
||||
detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
|
||||
detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
|
||||
completion_mode: Option<CompletionMode>,
|
||||
completion_mode: CompletionMode,
|
||||
messages: Vec<Message>,
|
||||
next_message_id: MessageId,
|
||||
last_prompt_id: PromptId,
|
||||
@@ -366,7 +374,7 @@ impl Thread {
|
||||
detailed_summary_task: Task::ready(None),
|
||||
detailed_summary_tx,
|
||||
detailed_summary_rx,
|
||||
completion_mode: None,
|
||||
completion_mode: default_completion_mode(cx),
|
||||
messages: Vec::new(),
|
||||
next_message_id: MessageId(0),
|
||||
last_prompt_id: PromptId::new(),
|
||||
@@ -440,7 +448,7 @@ impl Thread {
|
||||
detailed_summary_task: Task::ready(None),
|
||||
detailed_summary_tx,
|
||||
detailed_summary_rx,
|
||||
completion_mode: None,
|
||||
completion_mode: default_completion_mode(cx),
|
||||
messages: serialized
|
||||
.messages
|
||||
.into_iter()
|
||||
@@ -569,11 +577,11 @@ impl Thread {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn completion_mode(&self) -> Option<CompletionMode> {
|
||||
pub fn completion_mode(&self) -> CompletionMode {
|
||||
self.completion_mode
|
||||
}
|
||||
|
||||
pub fn set_completion_mode(&mut self, mode: Option<CompletionMode>) {
|
||||
pub fn set_completion_mode(&mut self, mode: CompletionMode) {
|
||||
self.completion_mode = mode;
|
||||
}
|
||||
|
||||
@@ -1152,9 +1160,9 @@ impl Thread {
|
||||
|
||||
request.tools = available_tools;
|
||||
request.mode = if model.supports_max_mode() {
|
||||
self.completion_mode
|
||||
Some(self.completion_mode)
|
||||
} else {
|
||||
None
|
||||
Some(CompletionMode::Normal)
|
||||
};
|
||||
|
||||
request
|
||||
@@ -2110,7 +2118,7 @@ impl Thread {
|
||||
.map(|repo| {
|
||||
repo.update(cx, |repo, _| {
|
||||
let current_branch =
|
||||
repo.branch.as_ref().map(|branch| branch.name.to_string());
|
||||
repo.branch.as_ref().map(|branch| branch.name().to_owned());
|
||||
repo.send_job(None, |state, _| async move {
|
||||
let RepositoryState::Local { backend, .. } = state else {
|
||||
return GitState {
|
||||
@@ -2509,7 +2517,7 @@ mod tests {
|
||||
let expected_context = format!(
|
||||
r#"
|
||||
<context>
|
||||
The following items were attached by the user. You don't need to use other tools to read them.
|
||||
The following items were attached by the user. They are up-to-date and don't need to be re-read.
|
||||
|
||||
<files>
|
||||
```rs {path_part}
|
||||
|
||||
@@ -9,8 +9,8 @@ use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
|
||||
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::HashMap;
|
||||
use context_server::manager::ContextServerManager;
|
||||
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
|
||||
use context_server::manager::{ContextServerManager, ContextServerStatus};
|
||||
use context_server::{ContextServerDescriptorRegistry, ContextServerTool};
|
||||
use futures::channel::{mpsc, oneshot};
|
||||
use futures::future::{self, BoxFuture, Shared};
|
||||
use futures::{FutureExt as _, StreamExt as _};
|
||||
@@ -108,7 +108,7 @@ impl ThreadStore {
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> (Self, oneshot::Receiver<()>) {
|
||||
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
|
||||
let context_server_factory_registry = ContextServerDescriptorRegistry::default_global(cx);
|
||||
let context_server_manager = cx.new(|cx| {
|
||||
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
|
||||
});
|
||||
@@ -555,62 +555,68 @@ impl ThreadStore {
|
||||
) {
|
||||
let tool_working_set = self.tools.clone();
|
||||
match event {
|
||||
context_server::manager::Event::ServerStarted { server_id } => {
|
||||
if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
|
||||
let context_server_manager = context_server_manager.clone();
|
||||
cx.spawn({
|
||||
let server = server.clone();
|
||||
let server_id = server_id.clone();
|
||||
async move |this, cx| {
|
||||
let Some(protocol) = server.client() else {
|
||||
return;
|
||||
};
|
||||
context_server::manager::Event::ServerStatusChanged { server_id, status } => {
|
||||
match status {
|
||||
Some(ContextServerStatus::Running) => {
|
||||
if let Some(server) = context_server_manager.read(cx).get_server(server_id)
|
||||
{
|
||||
let context_server_manager = context_server_manager.clone();
|
||||
cx.spawn({
|
||||
let server = server.clone();
|
||||
let server_id = server_id.clone();
|
||||
async move |this, cx| {
|
||||
let Some(protocol) = server.client() else {
|
||||
return;
|
||||
};
|
||||
|
||||
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
|
||||
if let Some(tools) = protocol.list_tools().await.log_err() {
|
||||
let tool_ids = tool_working_set
|
||||
.update(cx, |tool_working_set, _| {
|
||||
tools
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| {
|
||||
log::info!(
|
||||
"registering context server tool: {:?}",
|
||||
tool.name
|
||||
);
|
||||
tool_working_set.insert(Arc::new(
|
||||
ContextServerTool::new(
|
||||
context_server_manager.clone(),
|
||||
server.id(),
|
||||
tool,
|
||||
),
|
||||
))
|
||||
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
|
||||
if let Some(tools) = protocol.list_tools().await.log_err() {
|
||||
let tool_ids = tool_working_set
|
||||
.update(cx, |tool_working_set, _| {
|
||||
tools
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| {
|
||||
log::info!(
|
||||
"registering context server tool: {:?}",
|
||||
tool.name
|
||||
);
|
||||
tool_working_set.insert(Arc::new(
|
||||
ContextServerTool::new(
|
||||
context_server_manager.clone(),
|
||||
server.id(),
|
||||
tool,
|
||||
),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.log_err();
|
||||
.log_err();
|
||||
|
||||
if let Some(tool_ids) = tool_ids {
|
||||
this.update(cx, |this, cx| {
|
||||
this.context_server_tool_ids
|
||||
.insert(server_id, tool_ids);
|
||||
this.load_default_profile(cx);
|
||||
})
|
||||
.log_err();
|
||||
if let Some(tool_ids) = tool_ids {
|
||||
this.update(cx, |this, cx| {
|
||||
this.context_server_tool_ids
|
||||
.insert(server_id, tool_ids);
|
||||
this.load_default_profile(cx);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
context_server::manager::Event::ServerStopped { server_id } => {
|
||||
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
|
||||
tool_working_set.update(cx, |tool_working_set, _| {
|
||||
tool_working_set.remove(&tool_ids);
|
||||
});
|
||||
self.load_default_profile(cx);
|
||||
}
|
||||
None => {
|
||||
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
|
||||
tool_working_set.update(cx, |tool_working_set, _| {
|
||||
tool_working_set.remove(&tool_ids);
|
||||
});
|
||||
self.load_default_profile(cx);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
mod agent_notification;
|
||||
pub mod agent_preview;
|
||||
mod animated_label;
|
||||
mod context_pill;
|
||||
mod upsell;
|
||||
mod usage_banner;
|
||||
|
||||
pub use agent_notification::*;
|
||||
pub use agent_preview::*;
|
||||
pub use animated_label::*;
|
||||
pub use context_pill::*;
|
||||
pub use usage_banner::*;
|
||||
|
||||
99
crates/agent/src/ui/agent_preview.rs
Normal file
99
crates/agent/src/ui/agent_preview.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
use collections::HashMap;
|
||||
use component::ComponentId;
|
||||
use gpui::{App, Entity, WeakEntity};
|
||||
use linkme::distributed_slice;
|
||||
use std::sync::OnceLock;
|
||||
use ui::{AnyElement, Component, Window};
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::{ActiveThread, ThreadStore};
|
||||
|
||||
/// Function type for creating agent component previews
|
||||
pub type PreviewFn = fn(
|
||||
WeakEntity<Workspace>,
|
||||
Entity<ActiveThread>,
|
||||
WeakEntity<ThreadStore>,
|
||||
&mut Window,
|
||||
&mut App,
|
||||
) -> Option<AnyElement>;
|
||||
|
||||
/// Distributed slice for preview registration functions
|
||||
#[distributed_slice]
|
||||
pub static __ALL_AGENT_PREVIEWS: [fn() -> (ComponentId, PreviewFn)] = [..];
|
||||
|
||||
/// Trait that must be implemented by components that provide agent previews.
|
||||
pub trait AgentPreview: Component {
|
||||
/// Get the ID for this component
|
||||
///
|
||||
/// Eventually this will move to the component trait.
|
||||
fn id() -> ComponentId
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
ComponentId(Self::name())
|
||||
}
|
||||
|
||||
/// Static method to create a preview for this component type
|
||||
fn create_preview(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
active_thread: Entity<ActiveThread>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Option<AnyElement>
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
/// Register an agent preview for the given component type
|
||||
#[macro_export]
|
||||
macro_rules! register_agent_preview {
|
||||
($type:ty) => {
|
||||
#[linkme::distributed_slice($crate::ui::agent_preview::__ALL_AGENT_PREVIEWS)]
|
||||
static __REGISTER_AGENT_PREVIEW: fn() -> (
|
||||
component::ComponentId,
|
||||
$crate::ui::agent_preview::PreviewFn,
|
||||
) = || {
|
||||
(
|
||||
<$type as $crate::ui::agent_preview::AgentPreview>::id(),
|
||||
<$type as $crate::ui::agent_preview::AgentPreview>::create_preview,
|
||||
)
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
/// Lazy initialized registry of preview functions
|
||||
static AGENT_PREVIEW_REGISTRY: OnceLock<HashMap<ComponentId, PreviewFn>> = OnceLock::new();
|
||||
|
||||
/// Initialize the agent preview registry if needed
|
||||
fn get_or_init_registry() -> &'static HashMap<ComponentId, PreviewFn> {
|
||||
AGENT_PREVIEW_REGISTRY.get_or_init(|| {
|
||||
let mut map = HashMap::default();
|
||||
for register_fn in __ALL_AGENT_PREVIEWS.iter() {
|
||||
let (id, preview_fn) = register_fn();
|
||||
map.insert(id, preview_fn);
|
||||
}
|
||||
map
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a specific agent preview by component ID.
|
||||
pub fn get_agent_preview(
|
||||
id: &ComponentId,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
active_thread: Entity<ActiveThread>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Option<AnyElement> {
|
||||
let registry = get_or_init_registry();
|
||||
registry
|
||||
.get(id)
|
||||
.and_then(|preview_fn| preview_fn(workspace, active_thread, thread_store, window, cx))
|
||||
}
|
||||
|
||||
/// Get all registered agent previews.
|
||||
pub fn all_agent_previews() -> Vec<ComponentId> {
|
||||
let registry = get_or_init_registry();
|
||||
registry.keys().cloned().collect()
|
||||
}
|
||||
@@ -216,9 +216,10 @@ impl RenderOnce for ContextPill {
|
||||
})
|
||||
.when_some(on_click.as_ref(), |element, on_click| {
|
||||
let on_click = on_click.clone();
|
||||
element
|
||||
.cursor_pointer()
|
||||
.on_click(move |event, window, cx| on_click(event, window, cx))
|
||||
element.cursor_pointer().on_click(move |event, window, cx| {
|
||||
on_click(event, window, cx);
|
||||
cx.stop_propagation();
|
||||
})
|
||||
})
|
||||
.into_any_element()
|
||||
}
|
||||
@@ -254,7 +255,10 @@ impl RenderOnce for ContextPill {
|
||||
})
|
||||
.when_some(on_click.as_ref(), |element, on_click| {
|
||||
let on_click = on_click.clone();
|
||||
element.on_click(move |event, window, cx| on_click(event, window, cx))
|
||||
element.on_click(move |event, window, cx| {
|
||||
on_click(event, window, cx);
|
||||
cx.stop_propagation();
|
||||
})
|
||||
})
|
||||
.into_any(),
|
||||
}
|
||||
|
||||
163
crates/agent/src/ui/upsell.rs
Normal file
163
crates/agent/src/ui/upsell.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
use component::{Component, ComponentScope, single_example};
|
||||
use gpui::{
|
||||
AnyElement, App, ClickEvent, IntoElement, ParentElement, RenderOnce, SharedString, Styled,
|
||||
Window,
|
||||
};
|
||||
use theme::ActiveTheme;
|
||||
use ui::{
|
||||
Button, ButtonCommon, ButtonStyle, Checkbox, Clickable, Color, Label, LabelCommon,
|
||||
RegisterComponent, ToggleState, h_flex, v_flex,
|
||||
};
|
||||
|
||||
/// A component that displays an upsell message with a call-to-action button
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// let upsell = Upsell::new(
|
||||
/// "Upgrade to Zed Pro",
|
||||
/// "Get unlimited access to AI features and more",
|
||||
/// "Upgrade Now",
|
||||
/// Box::new(|_, _window, cx| {
|
||||
/// cx.open_url("https://zed.dev/pricing");
|
||||
/// }),
|
||||
/// Box::new(|_, _window, cx| {
|
||||
/// // Handle dismiss
|
||||
/// }),
|
||||
/// Box::new(|checked, window, cx| {
|
||||
/// // Handle don't show again
|
||||
/// }),
|
||||
/// );
|
||||
/// ```
|
||||
#[derive(IntoElement, RegisterComponent)]
|
||||
pub struct Upsell {
|
||||
title: SharedString,
|
||||
message: SharedString,
|
||||
cta_text: SharedString,
|
||||
on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App)>,
|
||||
on_dismiss: Box<dyn Fn(&ClickEvent, &mut Window, &mut App)>,
|
||||
on_dont_show_again: Box<dyn Fn(bool, &mut Window, &mut App)>,
|
||||
}
|
||||
|
||||
impl Upsell {
|
||||
/// Create a new upsell component
|
||||
pub fn new(
|
||||
title: impl Into<SharedString>,
|
||||
message: impl Into<SharedString>,
|
||||
cta_text: impl Into<SharedString>,
|
||||
on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App)>,
|
||||
on_dismiss: Box<dyn Fn(&ClickEvent, &mut Window, &mut App)>,
|
||||
on_dont_show_again: Box<dyn Fn(bool, &mut Window, &mut App)>,
|
||||
) -> Self {
|
||||
Self {
|
||||
title: title.into(),
|
||||
message: message.into(),
|
||||
cta_text: cta_text.into(),
|
||||
on_click,
|
||||
on_dismiss,
|
||||
on_dont_show_again,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for Upsell {
|
||||
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
v_flex()
|
||||
.w_full()
|
||||
.p_4()
|
||||
.gap_3()
|
||||
.bg(cx.theme().colors().surface_background)
|
||||
.rounded_md()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
Label::new(self.title)
|
||||
.size(ui::LabelSize::Large)
|
||||
.weight(gpui::FontWeight::BOLD),
|
||||
)
|
||||
.child(Label::new(self.message).color(Color::Muted)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.justify_between()
|
||||
.items_center()
|
||||
.child(
|
||||
h_flex()
|
||||
.items_center()
|
||||
.gap_1()
|
||||
.child(
|
||||
Checkbox::new("dont-show-again", ToggleState::Unselected).on_click(
|
||||
move |_, window, cx| {
|
||||
(self.on_dont_show_again)(true, window, cx);
|
||||
},
|
||||
),
|
||||
)
|
||||
.child(
|
||||
Label::new("Don't show again")
|
||||
.color(Color::Muted)
|
||||
.size(ui::LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Button::new("dismiss-button", "Dismiss")
|
||||
.style(ButtonStyle::Subtle)
|
||||
.on_click(self.on_dismiss),
|
||||
)
|
||||
.child(
|
||||
Button::new("cta-button", self.cta_text)
|
||||
.style(ButtonStyle::Filled)
|
||||
.on_click(self.on_click),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Component for Upsell {
|
||||
fn scope() -> ComponentScope {
|
||||
ComponentScope::Agent
|
||||
}
|
||||
|
||||
fn name() -> &'static str {
|
||||
"Upsell"
|
||||
}
|
||||
|
||||
fn description() -> Option<&'static str> {
|
||||
Some("A promotional component that displays a message with a call-to-action.")
|
||||
}
|
||||
|
||||
fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
|
||||
let examples = vec![
|
||||
single_example(
|
||||
"Default",
|
||||
Upsell::new(
|
||||
"Upgrade to Zed Pro",
|
||||
"Get unlimited access to AI features and more with Zed Pro. Unlock advanced AI capabilities and other premium features.",
|
||||
"Upgrade Now",
|
||||
Box::new(|_, _, _| {}),
|
||||
Box::new(|_, _, _| {}),
|
||||
Box::new(|_, _, _| {}),
|
||||
).render(window, cx).into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Short Message",
|
||||
Upsell::new(
|
||||
"Try Zed Pro for free",
|
||||
"Start your 7-day trial today.",
|
||||
"Start Trial",
|
||||
Box::new(|_, _, _| {}),
|
||||
Box::new(|_, _, _| {}),
|
||||
Box::new(|_, _, _| {}),
|
||||
).render(window, cx).into_any_element(),
|
||||
),
|
||||
];
|
||||
|
||||
Some(v_flex().gap_4().children(examples).into_any_element())
|
||||
}
|
||||
}
|
||||
@@ -98,6 +98,10 @@ impl RenderOnce for UsageBanner {
|
||||
}
|
||||
|
||||
impl Component for UsageBanner {
|
||||
fn scope() -> ComponentScope {
|
||||
ComponentScope::Agent
|
||||
}
|
||||
|
||||
fn sort_name() -> &'static str {
|
||||
"AgentUsageBanner"
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@ use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
|
||||
use client::{Client, TypedEnvelope, proto, telemetry::Telemetry};
|
||||
use clock::ReplicaId;
|
||||
use collections::HashMap;
|
||||
use context_server::ContextServerFactoryRegistry;
|
||||
use context_server::manager::ContextServerManager;
|
||||
use context_server::ContextServerDescriptorRegistry;
|
||||
use context_server::manager::{ContextServerManager, ContextServerStatus};
|
||||
use fs::{Fs, RemoveOptions};
|
||||
use futures::StreamExt;
|
||||
use fuzzy::StringMatchCandidate;
|
||||
@@ -99,7 +99,7 @@ impl ContextStore {
|
||||
|
||||
let this = cx.new(|cx: &mut Context<Self>| {
|
||||
let context_server_factory_registry =
|
||||
ContextServerFactoryRegistry::default_global(cx);
|
||||
ContextServerDescriptorRegistry::default_global(cx);
|
||||
let context_server_manager = cx.new(|cx| {
|
||||
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
|
||||
});
|
||||
@@ -831,54 +831,60 @@ impl ContextStore {
|
||||
) {
|
||||
let slash_command_working_set = self.slash_commands.clone();
|
||||
match event {
|
||||
context_server::manager::Event::ServerStarted { server_id } => {
|
||||
if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
|
||||
let context_server_manager = context_server_manager.clone();
|
||||
cx.spawn({
|
||||
let server = server.clone();
|
||||
let server_id = server_id.clone();
|
||||
async move |this, cx| {
|
||||
let Some(protocol) = server.client() else {
|
||||
return;
|
||||
};
|
||||
context_server::manager::Event::ServerStatusChanged { server_id, status } => {
|
||||
match status {
|
||||
Some(ContextServerStatus::Running) => {
|
||||
if let Some(server) = context_server_manager.read(cx).get_server(server_id)
|
||||
{
|
||||
let context_server_manager = context_server_manager.clone();
|
||||
cx.spawn({
|
||||
let server = server.clone();
|
||||
let server_id = server_id.clone();
|
||||
async move |this, cx| {
|
||||
let Some(protocol) = server.client() else {
|
||||
return;
|
||||
};
|
||||
|
||||
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
|
||||
if let Some(prompts) = protocol.list_prompts().await.log_err() {
|
||||
let slash_command_ids = prompts
|
||||
.into_iter()
|
||||
.filter(assistant_slash_commands::acceptable_prompt)
|
||||
.map(|prompt| {
|
||||
log::info!(
|
||||
"registering context server command: {:?}",
|
||||
prompt.name
|
||||
);
|
||||
slash_command_working_set.insert(Arc::new(
|
||||
assistant_slash_commands::ContextServerSlashCommand::new(
|
||||
context_server_manager.clone(),
|
||||
&server,
|
||||
prompt,
|
||||
),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
|
||||
if let Some(prompts) = protocol.list_prompts().await.log_err() {
|
||||
let slash_command_ids = prompts
|
||||
.into_iter()
|
||||
.filter(assistant_slash_commands::acceptable_prompt)
|
||||
.map(|prompt| {
|
||||
log::info!(
|
||||
"registering context server command: {:?}",
|
||||
prompt.name
|
||||
);
|
||||
slash_command_working_set.insert(Arc::new(
|
||||
assistant_slash_commands::ContextServerSlashCommand::new(
|
||||
context_server_manager.clone(),
|
||||
&server,
|
||||
prompt,
|
||||
),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
this.update( cx, |this, _cx| {
|
||||
this.context_server_slash_command_ids
|
||||
.insert(server_id.clone(), slash_command_ids);
|
||||
})
|
||||
.log_err();
|
||||
this.update( cx, |this, _cx| {
|
||||
this.context_server_slash_command_ids
|
||||
.insert(server_id.clone(), slash_command_ids);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
context_server::manager::Event::ServerStopped { server_id } => {
|
||||
if let Some(slash_command_ids) =
|
||||
self.context_server_slash_command_ids.remove(server_id)
|
||||
{
|
||||
slash_command_working_set.remove(&slash_command_ids);
|
||||
}
|
||||
None => {
|
||||
if let Some(slash_command_ids) =
|
||||
self.context_server_slash_command_ids.remove(server_id)
|
||||
{
|
||||
slash_command_working_set.remove(&slash_command_ids);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +77,6 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
registry.register_tool(TerminalTool);
|
||||
registry.register_tool(BatchTool);
|
||||
registry.register_tool(CreateDirectoryTool);
|
||||
registry.register_tool(CreateFileTool);
|
||||
registry.register_tool(CopyPathTool);
|
||||
registry.register_tool(DeletePathTool);
|
||||
registry.register_tool(SymbolInfoTool);
|
||||
@@ -125,12 +124,14 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
fn register_edit_file_tool(cx: &mut App) {
|
||||
let registry = ToolRegistry::global(cx);
|
||||
|
||||
registry.unregister_tool(CreateFileTool);
|
||||
registry.unregister_tool(EditFileTool);
|
||||
registry.unregister_tool(StreamingEditFileTool);
|
||||
|
||||
if AssistantSettings::get_global(cx).stream_edits(cx) {
|
||||
registry.register_tool(StreamingEditFileTool);
|
||||
} else {
|
||||
registry.register_tool(CreateFileTool);
|
||||
registry.register_tool(EditFileTool);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,10 +10,11 @@ use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
|
||||
use futures::{
|
||||
Stream, StreamExt,
|
||||
channel::mpsc::{self, UnboundedReceiver},
|
||||
pin_mut,
|
||||
stream::BoxStream,
|
||||
};
|
||||
use gpui::{AppContext, AsyncApp, Entity, SharedString, Task};
|
||||
use language::{Bias, Buffer, BufferSnapshot, LineIndent, Point};
|
||||
use language::{Anchor, Bias, Buffer, BufferSnapshot, LineIndent, Point};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
MessageContent, Role,
|
||||
@@ -23,19 +24,29 @@ use std::{cmp, iter, mem, ops::Range, path::PathBuf, sync::Arc, task::Poll};
|
||||
use streaming_diff::{CharOperation, StreamingDiff};
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct EditAgentTemplate {
|
||||
struct CreateFilePromptTemplate {
|
||||
path: Option<PathBuf>,
|
||||
edit_description: String,
|
||||
}
|
||||
|
||||
impl Template for EditAgentTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "edit_agent.hbs";
|
||||
impl Template for CreateFilePromptTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "create_file_prompt.hbs";
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct EditFilePromptTemplate {
|
||||
path: Option<PathBuf>,
|
||||
edit_description: String,
|
||||
}
|
||||
|
||||
impl Template for EditFilePromptTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "edit_file_prompt.hbs";
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum EditAgentOutputEvent {
|
||||
Edited,
|
||||
HallucinatedOldText(SharedString),
|
||||
Edited { position: Anchor },
|
||||
OldTextNotFound(SharedString),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -64,6 +75,84 @@ impl EditAgent {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn overwrite(
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
edit_description: String,
|
||||
previous_messages: Vec<LanguageModelRequestMessage>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> (
|
||||
Task<Result<EditAgentOutput>>,
|
||||
mpsc::UnboundedReceiver<EditAgentOutputEvent>,
|
||||
) {
|
||||
let this = self.clone();
|
||||
let (events_tx, events_rx) = mpsc::unbounded();
|
||||
let output = cx.spawn(async move |cx| {
|
||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
|
||||
let prompt = CreateFilePromptTemplate {
|
||||
path,
|
||||
edit_description,
|
||||
}
|
||||
.render(&this.templates)?;
|
||||
let new_chunks = this.request(previous_messages, prompt, cx).await?;
|
||||
|
||||
let (output, mut inner_events) = this.replace_text_with_chunks(buffer, new_chunks, cx);
|
||||
while let Some(event) = inner_events.next().await {
|
||||
events_tx.unbounded_send(event).ok();
|
||||
}
|
||||
output.await
|
||||
});
|
||||
(output, events_rx)
|
||||
}
|
||||
|
||||
fn replace_text_with_chunks(
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> (
|
||||
Task<Result<EditAgentOutput>>,
|
||||
mpsc::UnboundedReceiver<EditAgentOutputEvent>,
|
||||
) {
|
||||
let (output_events_tx, output_events_rx) = mpsc::unbounded();
|
||||
let this = self.clone();
|
||||
let task = cx.spawn(async move |cx| {
|
||||
// Ensure the buffer is tracked by the action log.
|
||||
this.action_log
|
||||
.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
|
||||
|
||||
cx.update(|cx| {
|
||||
buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
|
||||
this.action_log
|
||||
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
|
||||
})?;
|
||||
|
||||
let mut raw_edits = String::new();
|
||||
pin_mut!(edit_chunks);
|
||||
while let Some(chunk) = edit_chunks.next().await {
|
||||
let chunk = chunk?;
|
||||
raw_edits.push_str(&chunk);
|
||||
cx.update(|cx| {
|
||||
buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
|
||||
this.action_log
|
||||
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
|
||||
})?;
|
||||
output_events_tx
|
||||
.unbounded_send(EditAgentOutputEvent::Edited {
|
||||
position: Anchor::MAX,
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
Ok(EditAgentOutput {
|
||||
_raw_edits: raw_edits,
|
||||
_parser_metrics: EditParserMetrics::default(),
|
||||
})
|
||||
});
|
||||
(task, output_events_rx)
|
||||
}
|
||||
|
||||
pub fn edit(
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
@@ -78,10 +167,15 @@ impl EditAgent {
|
||||
let (events_tx, events_rx) = mpsc::unbounded();
|
||||
let output = cx.spawn(async move |cx| {
|
||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||
let edit_chunks = this
|
||||
.request_edits(snapshot, edit_description, previous_messages, cx)
|
||||
.await?;
|
||||
let (output, mut inner_events) = this.apply_edits(buffer, edit_chunks, cx);
|
||||
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
|
||||
let prompt = EditFilePromptTemplate {
|
||||
path,
|
||||
edit_description,
|
||||
}
|
||||
.render(&this.templates)?;
|
||||
let edit_chunks = this.request(previous_messages, prompt, cx).await?;
|
||||
|
||||
let (output, mut inner_events) = this.apply_edit_chunks(buffer, edit_chunks, cx);
|
||||
while let Some(event) = inner_events.next().await {
|
||||
events_tx.unbounded_send(event).ok();
|
||||
}
|
||||
@@ -90,7 +184,7 @@ impl EditAgent {
|
||||
(output, events_rx)
|
||||
}
|
||||
|
||||
fn apply_edits(
|
||||
fn apply_edit_chunks(
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
|
||||
@@ -138,7 +232,7 @@ impl EditAgent {
|
||||
let Some(old_range) = old_range else {
|
||||
// We couldn't find the old text in the buffer. Report the error.
|
||||
output_events
|
||||
.unbounded_send(EditAgentOutputEvent::HallucinatedOldText(old_text_query))
|
||||
.unbounded_send(EditAgentOutputEvent::OldTextNotFound(old_text_query))
|
||||
.ok();
|
||||
continue;
|
||||
};
|
||||
@@ -183,14 +277,15 @@ impl EditAgent {
|
||||
match op {
|
||||
CharOperation::Insert { text } => {
|
||||
let edit_start = snapshot.anchor_after(edit_start);
|
||||
edits_tx.unbounded_send((edit_start..edit_start, text))?;
|
||||
edits_tx
|
||||
.unbounded_send((edit_start..edit_start, Arc::from(text)))?;
|
||||
}
|
||||
CharOperation::Delete { bytes } => {
|
||||
let edit_end = edit_start + bytes;
|
||||
let edit_range = snapshot.anchor_after(edit_start)
|
||||
..snapshot.anchor_before(edit_end);
|
||||
edit_start = edit_end;
|
||||
edits_tx.unbounded_send((edit_range, String::new()))?;
|
||||
edits_tx.unbounded_send((edit_range, Arc::from("")))?;
|
||||
}
|
||||
CharOperation::Keep { bytes } => edit_start += bytes,
|
||||
}
|
||||
@@ -204,16 +299,32 @@ impl EditAgent {
|
||||
// TODO: group all edits into one transaction
|
||||
let mut edits_rx = edits_rx.ready_chunks(32);
|
||||
while let Some(edits) = edits_rx.next().await {
|
||||
if edits.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Edit the buffer and report edits to the action log as part of the
|
||||
// same effect cycle, otherwise the edit will be reported as if the
|
||||
// user made it.
|
||||
cx.update(|cx| {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit(edits, None, cx));
|
||||
let max_edit_end = cx.update(|cx| {
|
||||
let max_edit_end = buffer.update(cx, |buffer, cx| {
|
||||
buffer.edit(edits.iter().cloned(), None, cx);
|
||||
let max_edit_end = buffer
|
||||
.summaries_for_anchors::<Point, _>(
|
||||
edits.iter().map(|(range, _)| &range.end),
|
||||
)
|
||||
.max()
|
||||
.unwrap();
|
||||
buffer.anchor_before(max_edit_end)
|
||||
});
|
||||
self.action_log
|
||||
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx))
|
||||
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
|
||||
max_edit_end
|
||||
})?;
|
||||
output_events
|
||||
.unbounded_send(EditAgentOutputEvent::Edited)
|
||||
.unbounded_send(EditAgentOutputEvent::Edited {
|
||||
position: max_edit_end,
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
@@ -232,7 +343,7 @@ impl EditAgent {
|
||||
) {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
let output = cx.background_spawn(async move {
|
||||
futures::pin_mut!(chunks);
|
||||
pin_mut!(chunks);
|
||||
|
||||
let mut parser = EditParser::new();
|
||||
let mut raw_edits = String::new();
|
||||
@@ -336,20 +447,12 @@ impl EditAgent {
|
||||
})
|
||||
}
|
||||
|
||||
async fn request_edits(
|
||||
async fn request(
|
||||
&self,
|
||||
snapshot: BufferSnapshot,
|
||||
edit_description: String,
|
||||
mut messages: Vec<LanguageModelRequestMessage>,
|
||||
prompt: String,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
|
||||
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
|
||||
let prompt = EditAgentTemplate {
|
||||
path,
|
||||
edit_description,
|
||||
}
|
||||
.render(&self.templates)?;
|
||||
|
||||
let mut message_content = Vec::new();
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
if last_message.role == Role::Assistant {
|
||||
@@ -611,7 +714,8 @@ mod tests {
|
||||
&mut rng,
|
||||
cx,
|
||||
);
|
||||
let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
|
||||
let (apply, _events) =
|
||||
agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
|
||||
apply.await.unwrap();
|
||||
pretty_assertions::assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
@@ -648,7 +752,8 @@ mod tests {
|
||||
&mut rng,
|
||||
cx,
|
||||
);
|
||||
let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
|
||||
let (apply, _events) =
|
||||
agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
|
||||
apply.await.unwrap();
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
@@ -679,7 +784,8 @@ mod tests {
|
||||
&mut rng,
|
||||
cx,
|
||||
);
|
||||
let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
|
||||
let (apply, _events) =
|
||||
agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
|
||||
apply.await.unwrap();
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
@@ -692,7 +798,7 @@ mod tests {
|
||||
let agent = init_test(cx).await;
|
||||
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
let (apply, mut events) = agent.apply_edits(
|
||||
let (apply, mut events) = agent.apply_edit_chunks(
|
||||
buffer.clone(),
|
||||
chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
|
||||
&mut cx.to_async(),
|
||||
@@ -716,7 +822,12 @@ mod tests {
|
||||
|
||||
chunks_tx.unbounded_send("<new_text>abX").unwrap();
|
||||
cx.run_until_parked();
|
||||
assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
[EditAgentOutputEvent::Edited {
|
||||
position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
|
||||
}]
|
||||
);
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
"abXc\ndef\nghi"
|
||||
@@ -724,7 +835,12 @@ mod tests {
|
||||
|
||||
chunks_tx.unbounded_send("cY").unwrap();
|
||||
cx.run_until_parked();
|
||||
assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
[EditAgentOutputEvent::Edited {
|
||||
position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
|
||||
}]
|
||||
);
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
"abXcY\ndef\nghi"
|
||||
@@ -744,7 +860,7 @@ mod tests {
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
vec![EditAgentOutputEvent::HallucinatedOldText(
|
||||
vec![EditAgentOutputEvent::OldTextNotFound(
|
||||
"hallucinated old".into()
|
||||
)]
|
||||
);
|
||||
@@ -776,7 +892,9 @@ mod tests {
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
vec![EditAgentOutputEvent::Edited]
|
||||
vec![EditAgentOutputEvent::Edited {
|
||||
position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
|
||||
}]
|
||||
);
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
|
||||
@@ -4,10 +4,11 @@ use crate::{
|
||||
streaming_edit_file_tool::StreamingEditFileToolInput,
|
||||
};
|
||||
use Role::*;
|
||||
use anyhow::{Context, anyhow};
|
||||
use anyhow::anyhow;
|
||||
use client::{Client, UserStore};
|
||||
use collections::HashMap;
|
||||
use fs::FakeFs;
|
||||
use futures::{FutureExt, future::LocalBoxFuture};
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use language_model::{
|
||||
@@ -71,14 +72,15 @@ fn eval_extract_handle_command_output() {
|
||||
StreamingEditFileToolInput {
|
||||
display_description: edit_description.into(),
|
||||
path: input_file_path.into(),
|
||||
create_or_overwrite: false,
|
||||
},
|
||||
)],
|
||||
),
|
||||
],
|
||||
input_path: input_file_path.into(),
|
||||
input_content: input_file_content.into(),
|
||||
input_content: Some(input_file_content.into()),
|
||||
edit_description: edit_description.into(),
|
||||
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
|
||||
assertion: EvalAssertion::assert_eq(output_file_content),
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -126,14 +128,15 @@ fn eval_delete_run_git_blame() {
|
||||
StreamingEditFileToolInput {
|
||||
display_description: edit_description.into(),
|
||||
path: input_file_path.into(),
|
||||
create_or_overwrite: false,
|
||||
},
|
||||
)],
|
||||
),
|
||||
],
|
||||
input_path: input_file_path.into(),
|
||||
input_content: input_file_content.into(),
|
||||
input_content: Some(input_file_content.into()),
|
||||
edit_description: edit_description.into(),
|
||||
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
|
||||
assertion: EvalAssertion::assert_eq(output_file_content),
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -240,14 +243,15 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
|
||||
StreamingEditFileToolInput {
|
||||
display_description: edit_description.into(),
|
||||
path: input_file_path.into(),
|
||||
create_or_overwrite: false,
|
||||
},
|
||||
)],
|
||||
),
|
||||
],
|
||||
input_path: input_file_path.into(),
|
||||
input_content: input_file_content.into(),
|
||||
input_content: Some(input_file_content.into()),
|
||||
edit_description: edit_description.into(),
|
||||
assertion: EvalAssertion::JudgeDiff(indoc! {"
|
||||
assertion: EvalAssertion::judge_diff(indoc! {"
|
||||
- The compile_parser_to_wasm method has been changed to use wasi-sdk
|
||||
- ureq is used to download the SDK for current platform and architecture
|
||||
"}),
|
||||
@@ -315,14 +319,15 @@ fn eval_disable_cursor_blinking() {
|
||||
StreamingEditFileToolInput {
|
||||
display_description: edit_description.into(),
|
||||
path: input_file_path.into(),
|
||||
create_or_overwrite: false,
|
||||
},
|
||||
)],
|
||||
),
|
||||
],
|
||||
input_path: input_file_path.into(),
|
||||
input_content: input_file_content.into(),
|
||||
input_content: Some(input_file_content.into()),
|
||||
edit_description: edit_description.into(),
|
||||
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
|
||||
assertion: EvalAssertion::assert_eq(output_file_content),
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -504,14 +509,15 @@ fn eval_from_pixels_constructor() {
|
||||
StreamingEditFileToolInput {
|
||||
display_description: edit_description.into(),
|
||||
path: input_file_path.into(),
|
||||
create_or_overwrite: false,
|
||||
},
|
||||
)],
|
||||
),
|
||||
],
|
||||
input_path: input_file_path.into(),
|
||||
input_content: input_file_content.into(),
|
||||
input_content: Some(input_file_content.into()),
|
||||
edit_description: edit_description.into(),
|
||||
assertion: EvalAssertion::JudgeDiff(indoc! {"
|
||||
assertion: EvalAssertion::assert_eq(indoc! {"
|
||||
- The diff contains a new `from_pixels` constructor
|
||||
- The diff contains new tests for the `from_pixels` constructor
|
||||
"}),
|
||||
@@ -519,6 +525,104 @@ fn eval_from_pixels_constructor() {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "eval"), ignore)]
|
||||
fn eval_zode() {
|
||||
let input_file_path = "root/zode.py";
|
||||
let edit_description = "Create the main Zode CLI script";
|
||||
eval(
|
||||
200,
|
||||
1.,
|
||||
EvalInput {
|
||||
conversation: vec![
|
||||
message(User, [text(include_str!("evals/fixtures/zode/prompt.md"))]),
|
||||
message(
|
||||
Assistant,
|
||||
[
|
||||
tool_use(
|
||||
"tool_1",
|
||||
"read_file",
|
||||
ReadFileToolInput {
|
||||
path: "root/eval/react.py".into(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
},
|
||||
),
|
||||
tool_use(
|
||||
"tool_2",
|
||||
"read_file",
|
||||
ReadFileToolInput {
|
||||
path: "root/eval/react_test.py".into(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
},
|
||||
),
|
||||
],
|
||||
),
|
||||
message(
|
||||
User,
|
||||
[
|
||||
tool_result(
|
||||
"tool_1",
|
||||
"read_file",
|
||||
include_str!("evals/fixtures/zode/react.py"),
|
||||
),
|
||||
tool_result(
|
||||
"tool_2",
|
||||
"read_file",
|
||||
include_str!("evals/fixtures/zode/react_test.py"),
|
||||
),
|
||||
],
|
||||
),
|
||||
message(
|
||||
Assistant,
|
||||
[
|
||||
text(
|
||||
"Now that I understand what we need to build, I'll create the main Python script:",
|
||||
),
|
||||
tool_use(
|
||||
"tool_3",
|
||||
"edit_file",
|
||||
StreamingEditFileToolInput {
|
||||
display_description: edit_description.into(),
|
||||
path: input_file_path.into(),
|
||||
create_or_overwrite: true,
|
||||
},
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
input_path: input_file_path.into(),
|
||||
input_content: None,
|
||||
edit_description: edit_description.into(),
|
||||
assertion: EvalAssertion::new(async move |sample, _, _cx| {
|
||||
let invalid_starts = [' ', '`', '\n'];
|
||||
let mut message = String::new();
|
||||
for start in invalid_starts {
|
||||
if sample.text.starts_with(start) {
|
||||
message.push_str(&format!("The sample starts with a {:?}\n", start));
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Remove trailing newline.
|
||||
message.pop();
|
||||
|
||||
if message.is_empty() {
|
||||
Ok(EvalAssertionOutcome {
|
||||
score: 100,
|
||||
message: None,
|
||||
})
|
||||
} else {
|
||||
Ok(EvalAssertionOutcome {
|
||||
score: 0,
|
||||
message: Some(message),
|
||||
})
|
||||
}
|
||||
}),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn message(
|
||||
role: Role,
|
||||
contents: impl IntoIterator<Item = MessageContent>,
|
||||
@@ -574,11 +678,135 @@ fn tool_result(
|
||||
struct EvalInput {
|
||||
conversation: Vec<LanguageModelRequestMessage>,
|
||||
input_path: PathBuf,
|
||||
input_content: String,
|
||||
input_content: Option<String>,
|
||||
edit_description: String,
|
||||
assertion: EvalAssertion,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct EvalSample {
|
||||
text: String,
|
||||
edit_output: EditAgentOutput,
|
||||
diff: String,
|
||||
}
|
||||
|
||||
trait AssertionFn: 'static + Send + Sync {
|
||||
fn assert<'a>(
|
||||
&'a self,
|
||||
sample: &'a EvalSample,
|
||||
judge_model: Arc<dyn LanguageModel>,
|
||||
cx: &'a mut TestAppContext,
|
||||
) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>>;
|
||||
}
|
||||
|
||||
impl<F> AssertionFn for F
|
||||
where
|
||||
F: 'static
|
||||
+ Send
|
||||
+ Sync
|
||||
+ AsyncFn(
|
||||
&EvalSample,
|
||||
Arc<dyn LanguageModel>,
|
||||
&mut TestAppContext,
|
||||
) -> Result<EvalAssertionOutcome>,
|
||||
{
|
||||
fn assert<'a>(
|
||||
&'a self,
|
||||
sample: &'a EvalSample,
|
||||
judge_model: Arc<dyn LanguageModel>,
|
||||
cx: &'a mut TestAppContext,
|
||||
) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>> {
|
||||
(self)(sample, judge_model, cx).boxed_local()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct EvalAssertion(Arc<dyn AssertionFn>);
|
||||
|
||||
impl EvalAssertion {
|
||||
fn new<F>(f: F) -> Self
|
||||
where
|
||||
F: 'static
|
||||
+ Send
|
||||
+ Sync
|
||||
+ AsyncFn(
|
||||
&EvalSample,
|
||||
Arc<dyn LanguageModel>,
|
||||
&mut TestAppContext,
|
||||
) -> Result<EvalAssertionOutcome>,
|
||||
{
|
||||
EvalAssertion(Arc::new(f))
|
||||
}
|
||||
|
||||
fn assert_eq(expected: impl Into<String>) -> Self {
|
||||
let expected = expected.into();
|
||||
Self::new(async move |sample, _judge, _cx| {
|
||||
Ok(EvalAssertionOutcome {
|
||||
score: if strip_empty_lines(&sample.text) == strip_empty_lines(&expected) {
|
||||
100
|
||||
} else {
|
||||
0
|
||||
},
|
||||
message: None,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn judge_diff(assertions: &'static str) -> Self {
|
||||
Self::new(async move |sample, judge, cx| {
|
||||
let prompt = DiffJudgeTemplate {
|
||||
diff: sample.diff.clone(),
|
||||
assertions,
|
||||
}
|
||||
.render(&Templates::new())
|
||||
.unwrap();
|
||||
|
||||
let request = LanguageModelRequest {
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![prompt.into()],
|
||||
cache: false,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
let mut response = judge
|
||||
.stream_completion_text(request, &cx.to_async())
|
||||
.await?;
|
||||
let mut output = String::new();
|
||||
while let Some(chunk) = response.stream.next().await {
|
||||
let chunk = chunk?;
|
||||
output.push_str(&chunk);
|
||||
}
|
||||
|
||||
// Parse the score from the response
|
||||
let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
|
||||
if let Some(captures) = re.captures(&output) {
|
||||
if let Some(score_match) = captures.get(1) {
|
||||
let score = score_match.as_str().parse().unwrap_or(0);
|
||||
return Ok(EvalAssertionOutcome {
|
||||
score,
|
||||
message: Some(output),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Err(anyhow!(
|
||||
"No score found in response. Raw output: {}",
|
||||
output
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
input: &EvalSample,
|
||||
judge_model: Arc<dyn LanguageModel>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> Result<EvalAssertionOutcome> {
|
||||
self.0.assert(input, judge_model, cx).await
|
||||
}
|
||||
}
|
||||
|
||||
fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
|
||||
let mut evaluated_count = 0;
|
||||
report_progress(evaluated_count, iterations);
|
||||
@@ -606,12 +834,12 @@ fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
|
||||
while let Ok(output) = rx.recv() {
|
||||
match output {
|
||||
Ok(output) => {
|
||||
cumulative_parser_metrics += output.edit_output._parser_metrics.clone();
|
||||
cumulative_parser_metrics += output.sample.edit_output._parser_metrics.clone();
|
||||
eval_outputs.push(output.clone());
|
||||
if output.assertion.score < 80 {
|
||||
failed_count += 1;
|
||||
failed_evals
|
||||
.entry(output.buffer_text.clone())
|
||||
.entry(output.sample.text.clone())
|
||||
.or_insert(Vec::new())
|
||||
.push(output);
|
||||
}
|
||||
@@ -671,10 +899,8 @@ fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
|
||||
|
||||
#[derive(Clone)]
|
||||
struct EvalOutput {
|
||||
assertion: EvalAssertionResult,
|
||||
buffer_text: String,
|
||||
edit_output: EditAgentOutput,
|
||||
diff: String,
|
||||
sample: EvalSample,
|
||||
assertion: EvalAssertionOutcome,
|
||||
}
|
||||
|
||||
impl Display for EvalOutput {
|
||||
@@ -684,14 +910,14 @@ impl Display for EvalOutput {
|
||||
writeln!(f, "Message: {}", message)?;
|
||||
}
|
||||
|
||||
writeln!(f, "Diff:\n{}", self.diff)?;
|
||||
writeln!(f, "Diff:\n{}", self.sample.diff)?;
|
||||
|
||||
writeln!(
|
||||
f,
|
||||
"Parser Metrics:\n{:#?}",
|
||||
self.edit_output._parser_metrics
|
||||
self.sample.edit_output._parser_metrics
|
||||
)?;
|
||||
writeln!(f, "Raw Edits:\n{}", self.edit_output._raw_edits)?;
|
||||
writeln!(f, "Raw Edits:\n{}", self.sample.edit_output._raw_edits)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -777,96 +1003,45 @@ impl EditAgentTest {
|
||||
.update(cx, |project, cx| project.open_buffer(path, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.set_text(eval.input_content.clone(), cx)
|
||||
});
|
||||
let (edit_output, _events) = self.agent.edit(
|
||||
buffer.clone(),
|
||||
eval.edit_description,
|
||||
eval.conversation,
|
||||
&mut cx.to_async(),
|
||||
);
|
||||
let edit_output = edit_output.await?;
|
||||
let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
|
||||
buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
|
||||
let (edit_output, _) = self.agent.edit(
|
||||
buffer.clone(),
|
||||
eval.edit_description,
|
||||
eval.conversation,
|
||||
&mut cx.to_async(),
|
||||
);
|
||||
edit_output.await?
|
||||
} else {
|
||||
let (edit_output, _) = self.agent.overwrite(
|
||||
buffer.clone(),
|
||||
eval.edit_description,
|
||||
eval.conversation,
|
||||
&mut cx.to_async(),
|
||||
);
|
||||
edit_output.await?
|
||||
};
|
||||
|
||||
let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
|
||||
let actual_diff = language::unified_diff(&eval.input_content, &buffer_text);
|
||||
let assertion = match eval.assertion {
|
||||
EvalAssertion::AssertEqual(expected_output) => EvalAssertionResult {
|
||||
score: if strip_empty_lines(&buffer_text) == strip_empty_lines(&expected_output) {
|
||||
100
|
||||
} else {
|
||||
0
|
||||
},
|
||||
message: None,
|
||||
},
|
||||
EvalAssertion::JudgeDiff(assertions) => self
|
||||
.judge_diff(&actual_diff, assertions, &cx.to_async())
|
||||
.await
|
||||
.context("failed comparing diffs")?,
|
||||
};
|
||||
|
||||
Ok(EvalOutput {
|
||||
assertion,
|
||||
diff: actual_diff,
|
||||
buffer_text,
|
||||
let sample = EvalSample {
|
||||
edit_output,
|
||||
})
|
||||
}
|
||||
|
||||
async fn judge_diff(
|
||||
&self,
|
||||
diff: &str,
|
||||
assertions: &'static str,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<EvalAssertionResult> {
|
||||
let prompt = DiffJudgeTemplate {
|
||||
diff: diff.to_string(),
|
||||
assertions,
|
||||
}
|
||||
.render(&self.agent.templates)
|
||||
.unwrap();
|
||||
|
||||
let request = LanguageModelRequest {
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![prompt.into()],
|
||||
cache: false,
|
||||
}],
|
||||
..Default::default()
|
||||
diff: language::unified_diff(
|
||||
eval.input_content.as_deref().unwrap_or_default(),
|
||||
&buffer_text,
|
||||
),
|
||||
text: buffer_text,
|
||||
};
|
||||
let mut response = self.judge_model.stream_completion_text(request, cx).await?;
|
||||
let mut output = String::new();
|
||||
while let Some(chunk) = response.stream.next().await {
|
||||
let chunk = chunk?;
|
||||
output.push_str(&chunk);
|
||||
}
|
||||
let assertion = eval
|
||||
.assertion
|
||||
.run(&sample, self.judge_model.clone(), cx)
|
||||
.await?;
|
||||
|
||||
// Parse the score from the response
|
||||
let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
|
||||
if let Some(captures) = re.captures(&output) {
|
||||
if let Some(score_match) = captures.get(1) {
|
||||
let score = score_match.as_str().parse().unwrap_or(0);
|
||||
return Ok(EvalAssertionResult {
|
||||
score,
|
||||
message: Some(output),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Err(anyhow!(
|
||||
"No score found in response. Raw output: {}",
|
||||
output
|
||||
))
|
||||
Ok(EvalOutput { assertion, sample })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
|
||||
enum EvalAssertion {
|
||||
AssertEqual(String),
|
||||
JudgeDiff(&'static str),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
|
||||
struct EvalAssertionResult {
|
||||
struct EvalAssertionOutcome {
|
||||
score: usize,
|
||||
message: Option<String>,
|
||||
}
|
||||
|
||||
2193
crates/assistant_tools/src/edit_agent/evals/fixtures/zode/prompt.md
Normal file
2193
crates/assistant_tools/src/edit_agent/evals/fixtures/zode/prompt.md
Normal file
File diff suppressed because one or more lines are too long
@@ -0,0 +1,14 @@
|
||||
class InputCell:
|
||||
def __init__(self, initial_value):
|
||||
self.value = None
|
||||
|
||||
|
||||
class ComputeCell:
|
||||
def __init__(self, inputs, compute_function):
|
||||
self.value = None
|
||||
|
||||
def add_callback(self, callback):
|
||||
pass
|
||||
|
||||
def remove_callback(self, callback):
|
||||
pass
|
||||
@@ -0,0 +1,271 @@
|
||||
# These tests are auto-generated with test data from:
|
||||
# https://github.com/exercism/problem-specifications/tree/main/exercises/react/canonical-data.json
|
||||
# File last updated on 2023-07-19
|
||||
|
||||
from functools import partial
|
||||
import unittest
|
||||
|
||||
from react import (
|
||||
InputCell,
|
||||
ComputeCell,
|
||||
)
|
||||
|
||||
|
||||
class ReactTest(unittest.TestCase):
|
||||
def test_input_cells_have_a_value(self):
|
||||
input = InputCell(10)
|
||||
self.assertEqual(input.value, 10)
|
||||
|
||||
def test_an_input_cell_s_value_can_be_set(self):
|
||||
input = InputCell(4)
|
||||
input.value = 20
|
||||
self.assertEqual(input.value, 20)
|
||||
|
||||
def test_compute_cells_calculate_initial_value(self):
|
||||
input = InputCell(1)
|
||||
output = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] + 1,
|
||||
)
|
||||
self.assertEqual(output.value, 2)
|
||||
|
||||
def test_compute_cells_take_inputs_in_the_right_order(self):
|
||||
one = InputCell(1)
|
||||
two = InputCell(2)
|
||||
output = ComputeCell(
|
||||
[
|
||||
one,
|
||||
two,
|
||||
],
|
||||
lambda inputs: inputs[0] + inputs[1] * 10,
|
||||
)
|
||||
self.assertEqual(output.value, 21)
|
||||
|
||||
def test_compute_cells_update_value_when_dependencies_are_changed(self):
|
||||
input = InputCell(1)
|
||||
output = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] + 1,
|
||||
)
|
||||
input.value = 3
|
||||
self.assertEqual(output.value, 4)
|
||||
|
||||
def test_compute_cells_can_depend_on_other_compute_cells(self):
|
||||
input = InputCell(1)
|
||||
times_two = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] * 2,
|
||||
)
|
||||
times_thirty = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] * 30,
|
||||
)
|
||||
output = ComputeCell(
|
||||
[
|
||||
times_two,
|
||||
times_thirty,
|
||||
],
|
||||
lambda inputs: inputs[0] + inputs[1],
|
||||
)
|
||||
self.assertEqual(output.value, 32)
|
||||
input.value = 3
|
||||
self.assertEqual(output.value, 96)
|
||||
|
||||
def test_compute_cells_fire_callbacks(self):
|
||||
input = InputCell(1)
|
||||
output = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] + 1,
|
||||
)
|
||||
cb1_observer = []
|
||||
callback1 = self.callback_factory(cb1_observer)
|
||||
output.add_callback(callback1)
|
||||
input.value = 3
|
||||
self.assertEqual(cb1_observer[-1], 4)
|
||||
|
||||
def test_callback_cells_only_fire_on_change(self):
|
||||
input = InputCell(1)
|
||||
output = ComputeCell([input], lambda inputs: 111 if inputs[0] < 3 else 222)
|
||||
cb1_observer = []
|
||||
callback1 = self.callback_factory(cb1_observer)
|
||||
output.add_callback(callback1)
|
||||
input.value = 2
|
||||
self.assertEqual(cb1_observer, [])
|
||||
input.value = 4
|
||||
self.assertEqual(cb1_observer[-1], 222)
|
||||
|
||||
def test_callbacks_do_not_report_already_reported_values(self):
|
||||
input = InputCell(1)
|
||||
output = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] + 1,
|
||||
)
|
||||
cb1_observer = []
|
||||
callback1 = self.callback_factory(cb1_observer)
|
||||
output.add_callback(callback1)
|
||||
input.value = 2
|
||||
self.assertEqual(cb1_observer[-1], 3)
|
||||
input.value = 3
|
||||
self.assertEqual(cb1_observer[-1], 4)
|
||||
|
||||
def test_callbacks_can_fire_from_multiple_cells(self):
|
||||
input = InputCell(1)
|
||||
plus_one = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] + 1,
|
||||
)
|
||||
minus_one = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] - 1,
|
||||
)
|
||||
cb1_observer = []
|
||||
cb2_observer = []
|
||||
callback1 = self.callback_factory(cb1_observer)
|
||||
callback2 = self.callback_factory(cb2_observer)
|
||||
plus_one.add_callback(callback1)
|
||||
minus_one.add_callback(callback2)
|
||||
input.value = 10
|
||||
self.assertEqual(cb1_observer[-1], 11)
|
||||
self.assertEqual(cb2_observer[-1], 9)
|
||||
|
||||
def test_callbacks_can_be_added_and_removed(self):
|
||||
input = InputCell(11)
|
||||
output = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] + 1,
|
||||
)
|
||||
cb1_observer = []
|
||||
cb2_observer = []
|
||||
cb3_observer = []
|
||||
callback1 = self.callback_factory(cb1_observer)
|
||||
callback2 = self.callback_factory(cb2_observer)
|
||||
callback3 = self.callback_factory(cb3_observer)
|
||||
output.add_callback(callback1)
|
||||
output.add_callback(callback2)
|
||||
input.value = 31
|
||||
self.assertEqual(cb1_observer[-1], 32)
|
||||
self.assertEqual(cb2_observer[-1], 32)
|
||||
output.remove_callback(callback1)
|
||||
output.add_callback(callback3)
|
||||
input.value = 41
|
||||
self.assertEqual(len(cb1_observer), 1)
|
||||
self.assertEqual(cb2_observer[-1], 42)
|
||||
self.assertEqual(cb3_observer[-1], 42)
|
||||
|
||||
def test_removing_a_callback_multiple_times_doesn_t_interfere_with_other_callbacks(
|
||||
self,
|
||||
):
|
||||
input = InputCell(1)
|
||||
output = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] + 1,
|
||||
)
|
||||
cb1_observer = []
|
||||
cb2_observer = []
|
||||
callback1 = self.callback_factory(cb1_observer)
|
||||
callback2 = self.callback_factory(cb2_observer)
|
||||
output.add_callback(callback1)
|
||||
output.add_callback(callback2)
|
||||
output.remove_callback(callback1)
|
||||
output.remove_callback(callback1)
|
||||
output.remove_callback(callback1)
|
||||
input.value = 2
|
||||
self.assertEqual(cb1_observer, [])
|
||||
self.assertEqual(cb2_observer[-1], 3)
|
||||
|
||||
def test_callbacks_should_only_be_called_once_even_if_multiple_dependencies_change(
|
||||
self,
|
||||
):
|
||||
input = InputCell(1)
|
||||
plus_one = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] + 1,
|
||||
)
|
||||
minus_one1 = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] - 1,
|
||||
)
|
||||
minus_one2 = ComputeCell(
|
||||
[
|
||||
minus_one1,
|
||||
],
|
||||
lambda inputs: inputs[0] - 1,
|
||||
)
|
||||
output = ComputeCell(
|
||||
[
|
||||
plus_one,
|
||||
minus_one2,
|
||||
],
|
||||
lambda inputs: inputs[0] * inputs[1],
|
||||
)
|
||||
cb1_observer = []
|
||||
callback1 = self.callback_factory(cb1_observer)
|
||||
output.add_callback(callback1)
|
||||
input.value = 4
|
||||
self.assertEqual(cb1_observer[-1], 10)
|
||||
|
||||
def test_callbacks_should_not_be_called_if_dependencies_change_but_output_value_doesn_t_change(
|
||||
self,
|
||||
):
|
||||
input = InputCell(1)
|
||||
plus_one = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] + 1,
|
||||
)
|
||||
minus_one = ComputeCell(
|
||||
[
|
||||
input,
|
||||
],
|
||||
lambda inputs: inputs[0] - 1,
|
||||
)
|
||||
always_two = ComputeCell(
|
||||
[
|
||||
plus_one,
|
||||
minus_one,
|
||||
],
|
||||
lambda inputs: inputs[0] - inputs[1],
|
||||
)
|
||||
cb1_observer = []
|
||||
callback1 = self.callback_factory(cb1_observer)
|
||||
always_two.add_callback(callback1)
|
||||
input.value = 2
|
||||
self.assertEqual(cb1_observer, [])
|
||||
input.value = 3
|
||||
self.assertEqual(cb1_observer, [])
|
||||
input.value = 4
|
||||
self.assertEqual(cb1_observer, [])
|
||||
input.value = 5
|
||||
self.assertEqual(cb1_observer, [])
|
||||
|
||||
# Utility functions.
|
||||
def callback_factory(self, observer):
|
||||
def callback(observer, value):
|
||||
observer.append(value)
|
||||
|
||||
return partial(callback, observer)
|
||||
@@ -577,13 +577,10 @@ impl ToolCard for EditFileToolCard {
|
||||
card.child(
|
||||
v_flex()
|
||||
.relative()
|
||||
.map(|editor_container| {
|
||||
if self.full_height_expanded {
|
||||
editor_container.h_full()
|
||||
} else {
|
||||
editor_container
|
||||
.h(DEFAULT_COLLAPSED_LINES as f32 * editor_line_height)
|
||||
}
|
||||
.h_full()
|
||||
.when(!self.full_height_expanded, |editor_container| {
|
||||
editor_container
|
||||
.max_h(DEFAULT_COLLAPSED_LINES as f32 * editor_line_height)
|
||||
})
|
||||
.overflow_hidden()
|
||||
.border_t_1()
|
||||
|
||||
@@ -38,7 +38,7 @@ pub struct StreamingEditFileToolInput {
|
||||
/// so that we can display it immediately.
|
||||
pub display_description: String,
|
||||
|
||||
/// The full path of the file to modify in the project.
|
||||
/// The full path of the file to create or modify in the project.
|
||||
///
|
||||
/// WARNING: When specifying which file path need changing, you MUST
|
||||
/// start each path with one of the project's root directories.
|
||||
@@ -58,6 +58,10 @@ pub struct StreamingEditFileToolInput {
|
||||
/// `frontend/db.js`
|
||||
/// </example>
|
||||
pub path: PathBuf,
|
||||
|
||||
/// If true, this tool will recreate the file from scratch.
|
||||
/// If false, this tool will produce granular edits to an existing file.
|
||||
pub create_or_overwrite: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
@@ -158,7 +162,7 @@ impl Tool for StreamingEditFileTool {
|
||||
let card_clone = card.clone();
|
||||
let messages = messages.to_vec();
|
||||
let task = cx.spawn(async move |cx: &mut AsyncApp| {
|
||||
if !exists.await? {
|
||||
if !input.create_or_overwrite && !exists.await? {
|
||||
return Err(anyhow!("{} not found", input.path.display()));
|
||||
}
|
||||
|
||||
@@ -182,17 +186,26 @@ impl Tool for StreamingEditFileTool {
|
||||
})
|
||||
.await;
|
||||
|
||||
let (output, mut events) = edit_agent.edit(
|
||||
buffer.clone(),
|
||||
input.display_description.clone(),
|
||||
messages,
|
||||
cx,
|
||||
);
|
||||
let (output, mut events) = if input.create_or_overwrite {
|
||||
edit_agent.overwrite(
|
||||
buffer.clone(),
|
||||
input.display_description.clone(),
|
||||
messages,
|
||||
cx,
|
||||
)
|
||||
} else {
|
||||
edit_agent.edit(
|
||||
buffer.clone(),
|
||||
input.display_description.clone(),
|
||||
messages,
|
||||
cx,
|
||||
)
|
||||
};
|
||||
|
||||
let mut hallucinated_old_text = false;
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
EditAgentOutputEvent::Edited => {
|
||||
EditAgentOutputEvent::Edited { position } => {
|
||||
if let Some(card) = card_clone.as_ref() {
|
||||
let new_snapshot =
|
||||
buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
@@ -213,7 +226,7 @@ impl Tool for StreamingEditFileTool {
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
EditAgentOutputEvent::HallucinatedOldText(_) => hallucinated_old_text = true,
|
||||
EditAgentOutputEvent::OldTextNotFound(_) => hallucinated_old_text = true,
|
||||
}
|
||||
}
|
||||
output.await?;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
This is a tool for editing files. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead. For larger edits, use the `create_file` tool to overwrite files.
|
||||
This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
|
||||
|
||||
Before using this tool:
|
||||
|
||||
|
||||
12
crates/assistant_tools/src/templates/create_file_prompt.hbs
Normal file
12
crates/assistant_tools/src/templates/create_file_prompt.hbs
Normal file
@@ -0,0 +1,12 @@
|
||||
You are an expert engineer and your task is to write a new file from scratch.
|
||||
|
||||
<file_to_edit>
|
||||
{{path}}
|
||||
</file_to_edit>
|
||||
|
||||
<edit_description>
|
||||
{{edit_description}}
|
||||
</edit_description>
|
||||
|
||||
You MUST respond directly with the file's content, without explanations, additional text or triple backticks.
|
||||
The text you output will be saved verbatim as the content of the file.
|
||||
@@ -27,7 +27,9 @@ use crate::db::billing_subscription::{
|
||||
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
|
||||
};
|
||||
use crate::llm::db::subscription_usage_meter::CompletionMode;
|
||||
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
|
||||
use crate::llm::{
|
||||
AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
};
|
||||
use crate::rpc::{ResultExt as _, Server};
|
||||
use crate::{AppState, Cents, Error, Result};
|
||||
use crate::{db::UserId, llm::db::LlmDatabase};
|
||||
@@ -54,6 +56,10 @@ pub fn router() -> Router {
|
||||
"/billing/subscriptions/manage",
|
||||
post(manage_billing_subscription),
|
||||
)
|
||||
.route(
|
||||
"/billing/subscriptions/migrate",
|
||||
post(migrate_to_new_billing),
|
||||
)
|
||||
.route("/billing/monthly_spend", get(get_monthly_spend))
|
||||
.route("/billing/usage", get(get_current_usage))
|
||||
}
|
||||
@@ -256,6 +262,7 @@ async fn list_billing_subscriptions(
|
||||
enum ProductCode {
|
||||
ZedPro,
|
||||
ZedProTrial,
|
||||
ZedFree,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -386,6 +393,11 @@ async fn create_billing_subscription(
|
||||
)
|
||||
.await?
|
||||
}
|
||||
Some(ProductCode::ZedFree) => {
|
||||
stripe_billing
|
||||
.checkout_with_zed_free(customer_id, &user.github_login, &success_url)
|
||||
.await?
|
||||
}
|
||||
None => {
|
||||
let default_model = llm_db.model(
|
||||
zed_llm_client::LanguageModelProvider::Anthropic,
|
||||
@@ -604,6 +616,85 @@ async fn manage_billing_subscription(
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MigrateToNewBillingBody {
|
||||
github_user_id: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct MigrateToNewBillingResponse {
|
||||
/// The ID of the subscription that was canceled.
|
||||
canceled_subscription_id: String,
|
||||
}
|
||||
|
||||
async fn migrate_to_new_billing(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
extract::Json(body): extract::Json<MigrateToNewBillingBody>,
|
||||
) -> Result<Json<MigrateToNewBillingResponse>> {
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::error!("failed to retrieve Stripe client");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
|
||||
let user = app
|
||||
.db
|
||||
.get_user_by_github_user_id(body.github_user_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let old_billing_subscriptions_by_user = app
|
||||
.db
|
||||
.get_active_billing_subscriptions(HashSet::from_iter([user.id]))
|
||||
.await?;
|
||||
|
||||
let Some((_billing_customer, billing_subscription)) =
|
||||
old_billing_subscriptions_by_user.get(&user.id)
|
||||
else {
|
||||
return Err(Error::http(
|
||||
StatusCode::NOT_FOUND,
|
||||
"No active billing subscriptions to migrate".into(),
|
||||
));
|
||||
};
|
||||
|
||||
let stripe_subscription_id = billing_subscription
|
||||
.stripe_subscription_id
|
||||
.parse::<stripe::SubscriptionId>()
|
||||
.context("failed to parse Stripe subscription ID from database")?;
|
||||
|
||||
Subscription::cancel(
|
||||
&stripe_client,
|
||||
&stripe_subscription_id,
|
||||
stripe::CancelSubscription {
|
||||
invoice_now: Some(true),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let feature_flags = app.db.list_feature_flags().await?;
|
||||
|
||||
for feature_flag in ["new-billing", "assistant2"] {
|
||||
let already_in_feature_flag = feature_flags.iter().any(|flag| flag.flag == feature_flag);
|
||||
if already_in_feature_flag {
|
||||
continue;
|
||||
}
|
||||
|
||||
let feature_flag = feature_flags
|
||||
.iter()
|
||||
.find(|flag| flag.flag == feature_flag)
|
||||
.context("failed to find feature flag: {feature_flag:?}")?;
|
||||
|
||||
app.db.add_user_flag(user.id, feature_flag.id).await?;
|
||||
}
|
||||
|
||||
Ok(Json(MigrateToNewBillingResponse {
|
||||
canceled_subscription_id: stripe_subscription_id.to_string(),
|
||||
}))
|
||||
}
|
||||
|
||||
/// The amount of time we wait in between each poll of Stripe events.
|
||||
///
|
||||
/// This value should strike a balance between:
|
||||
@@ -1168,8 +1259,21 @@ async fn get_current_usage(
|
||||
SubscriptionKind::ZedFree => zed_llm_client::Plan::Free,
|
||||
};
|
||||
|
||||
let feature_flags = app.db.get_user_flags(user.id).await?;
|
||||
let has_extended_trial = feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
|
||||
|
||||
let model_requests_limit = match plan.model_requests_limit() {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
|
||||
1_000
|
||||
} else {
|
||||
limit
|
||||
};
|
||||
|
||||
Some(limit)
|
||||
}
|
||||
zed_llm_client::UsageLimit::Unlimited => None,
|
||||
};
|
||||
let edit_prediction_limit = match plan.edit_predictions_limit() {
|
||||
|
||||
@@ -327,6 +327,10 @@ impl Server {
|
||||
.add_request_handler(
|
||||
forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
|
||||
)
|
||||
.add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
|
||||
.add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
|
||||
.add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
|
||||
.add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
|
||||
.add_request_handler(
|
||||
forward_read_only_project_request::<proto::LanguageServerIdForName>,
|
||||
)
|
||||
|
||||
@@ -565,6 +565,29 @@ impl StripeBilling {
|
||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
||||
Ok(session.url.context("no checkout session URL")?)
|
||||
}
|
||||
|
||||
pub async fn checkout_with_zed_free(
|
||||
&self,
|
||||
customer_id: stripe::CustomerId,
|
||||
github_login: &str,
|
||||
success_url: &str,
|
||||
) -> Result<String> {
|
||||
let zed_free_price_id = self.zed_free_price_id().await?;
|
||||
|
||||
let mut params = stripe::CreateCheckoutSession::new();
|
||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
||||
params.customer = Some(customer_id);
|
||||
params.client_reference_id = Some(github_login);
|
||||
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
|
||||
price: Some(zed_free_price_id.to_string()),
|
||||
quantity: Some(1),
|
||||
..Default::default()
|
||||
}]);
|
||||
params.success_url = Some(success_url);
|
||||
|
||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
||||
Ok(session.url.context("no checkout session URL")?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
||||
@@ -25,7 +25,7 @@ use language::{
|
||||
use project::{
|
||||
ProjectPath, SERVER_PROGRESS_THROTTLE_TIMEOUT,
|
||||
lsp_store::{
|
||||
lsp_ext_command::{ExpandedMacro, LspExpandMacro},
|
||||
lsp_ext_command::{ExpandedMacro, LspExtExpandMacro},
|
||||
rust_analyzer_ext::RUST_ANALYZER_NAME,
|
||||
},
|
||||
project_settings::{InlineBlameSettings, ProjectSettings},
|
||||
@@ -2704,8 +2704,8 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes
|
||||
let fake_language_server = fake_language_servers.next().await.unwrap();
|
||||
|
||||
// host
|
||||
let mut expand_request_a =
|
||||
fake_language_server.set_request_handler::<LspExpandMacro, _, _>(|params, _| async move {
|
||||
let mut expand_request_a = fake_language_server.set_request_handler::<LspExtExpandMacro, _, _>(
|
||||
|params, _| async move {
|
||||
assert_eq!(
|
||||
params.text_document.uri,
|
||||
lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(),
|
||||
@@ -2715,7 +2715,8 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes
|
||||
name: "test_macro_name".to_string(),
|
||||
expansion: "test_macro_expansion on the host".to_string(),
|
||||
}))
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
editor_a.update_in(cx_a, |editor, window, cx| {
|
||||
expand_macro_recursively(editor, &ExpandMacroRecursively, window, cx)
|
||||
@@ -2738,8 +2739,8 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes
|
||||
});
|
||||
|
||||
// client
|
||||
let mut expand_request_b =
|
||||
fake_language_server.set_request_handler::<LspExpandMacro, _, _>(|params, _| async move {
|
||||
let mut expand_request_b = fake_language_server.set_request_handler::<LspExtExpandMacro, _, _>(
|
||||
|params, _| async move {
|
||||
assert_eq!(
|
||||
params.text_document.uri,
|
||||
lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(),
|
||||
@@ -2749,7 +2750,8 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes
|
||||
name: "test_macro_name".to_string(),
|
||||
expansion: "test_macro_expansion on the client".to_string(),
|
||||
}))
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
editor_b.update_in(cx_b, |editor, window, cx| {
|
||||
expand_macro_recursively(editor, &ExpandMacroRecursively, window, cx)
|
||||
|
||||
@@ -2902,7 +2902,7 @@ async fn test_git_branch_name(
|
||||
.read(cx)
|
||||
.branch
|
||||
.as_ref()
|
||||
.map(|branch| branch.name.to_string()),
|
||||
.map(|branch| branch.name().to_owned()),
|
||||
branch_name
|
||||
)
|
||||
}
|
||||
@@ -6864,7 +6864,7 @@ async fn test_remote_git_branches(
|
||||
|
||||
let branches_b = branches_b
|
||||
.into_iter()
|
||||
.map(|branch| branch.name.to_string())
|
||||
.map(|branch| branch.name().to_string())
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
assert_eq!(branches_b, branches_set);
|
||||
@@ -6895,7 +6895,7 @@ async fn test_remote_git_branches(
|
||||
})
|
||||
});
|
||||
|
||||
assert_eq!(host_branch.name, branches[2]);
|
||||
assert_eq!(host_branch.name(), branches[2]);
|
||||
|
||||
// Also try creating a new branch
|
||||
cx_b.update(|cx| {
|
||||
@@ -6933,5 +6933,5 @@ async fn test_remote_git_branches(
|
||||
})
|
||||
});
|
||||
|
||||
assert_eq!(host_branch.name, "totally-new-branch");
|
||||
assert_eq!(host_branch.name(), "totally-new-branch");
|
||||
}
|
||||
|
||||
@@ -293,7 +293,7 @@ async fn test_ssh_collaboration_git_branches(
|
||||
|
||||
let branches_b = branches_b
|
||||
.into_iter()
|
||||
.map(|branch| branch.name.to_string())
|
||||
.map(|branch| branch.name().to_string())
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
assert_eq!(&branches_b, &branches_set);
|
||||
@@ -326,7 +326,7 @@ async fn test_ssh_collaboration_git_branches(
|
||||
})
|
||||
});
|
||||
|
||||
assert_eq!(server_branch.name, branches[2]);
|
||||
assert_eq!(server_branch.name(), branches[2]);
|
||||
|
||||
// Also try creating a new branch
|
||||
cx_b.update(|cx| {
|
||||
@@ -366,7 +366,7 @@ async fn test_ssh_collaboration_git_branches(
|
||||
})
|
||||
});
|
||||
|
||||
assert_eq!(server_branch.name, "totally-new-branch");
|
||||
assert_eq!(server_branch.name(), "totally-new-branch");
|
||||
|
||||
// Remove the git repository and check that all participants get the update.
|
||||
remote_fs
|
||||
|
||||
@@ -15,18 +15,21 @@ path = "src/component_preview.rs"
|
||||
default = []
|
||||
|
||||
[dependencies]
|
||||
agent.workspace = true
|
||||
anyhow.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
component.workspace = true
|
||||
db.workspace = true
|
||||
gpui.workspace = true
|
||||
languages.workspace = true
|
||||
notifications.workspace = true
|
||||
log.workspace = true
|
||||
notifications.workspace = true
|
||||
project.workspace = true
|
||||
prompt_store.workspace = true
|
||||
serde.workspace = true
|
||||
ui.workspace = true
|
||||
ui_input.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
workspace.workspace = true
|
||||
db.workspace = true
|
||||
anyhow.workspace = true
|
||||
serde.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
|
||||
@@ -3,10 +3,12 @@
|
||||
//! A view for exploring Zed components.
|
||||
|
||||
mod persistence;
|
||||
mod preview_support;
|
||||
|
||||
use std::iter::Iterator;
|
||||
use std::sync::Arc;
|
||||
|
||||
use agent::{ActiveThread, ThreadStore};
|
||||
use client::UserStore;
|
||||
use component::{ComponentId, ComponentMetadata, components};
|
||||
use gpui::{
|
||||
@@ -19,6 +21,7 @@ use gpui::{ListState, ScrollHandle, ScrollStrategy, UniformListScrollHandle};
|
||||
use languages::LanguageRegistry;
|
||||
use notifications::status_toast::{StatusToast, ToastIcon};
|
||||
use persistence::COMPONENT_PREVIEW_DB;
|
||||
use preview_support::active_thread::{load_preview_thread_store, static_active_thread};
|
||||
use project::Project;
|
||||
use ui::{Divider, HighlightedLabel, ListItem, ListSubHeader, prelude::*};
|
||||
|
||||
@@ -33,6 +36,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
|
||||
|
||||
cx.observe_new(move |workspace: &mut Workspace, _window, cx| {
|
||||
let app_state = app_state.clone();
|
||||
let project = workspace.project().clone();
|
||||
let weak_workspace = cx.entity().downgrade();
|
||||
|
||||
workspace.register_action(
|
||||
@@ -45,6 +49,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
|
||||
let component_preview = cx.new(|cx| {
|
||||
ComponentPreview::new(
|
||||
weak_workspace.clone(),
|
||||
project.clone(),
|
||||
language_registry,
|
||||
user_store,
|
||||
None,
|
||||
@@ -52,6 +57,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.expect("Failed to create component preview")
|
||||
});
|
||||
|
||||
workspace.add_item_to_active_pane(
|
||||
@@ -69,6 +75,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
|
||||
|
||||
enum PreviewEntry {
|
||||
AllComponents,
|
||||
ActiveThread,
|
||||
Separator,
|
||||
Component(ComponentMetadata, Option<Vec<usize>>),
|
||||
SectionHeader(SharedString),
|
||||
@@ -91,6 +98,7 @@ enum PreviewPage {
|
||||
#[default]
|
||||
AllComponents,
|
||||
Component(ComponentId),
|
||||
ActiveThread,
|
||||
}
|
||||
|
||||
struct ComponentPreview {
|
||||
@@ -102,24 +110,63 @@ struct ComponentPreview {
|
||||
active_page: PreviewPage,
|
||||
components: Vec<ComponentMetadata>,
|
||||
component_list: ListState,
|
||||
agent_previews: Vec<
|
||||
Box<
|
||||
dyn Fn(
|
||||
&Self,
|
||||
WeakEntity<Workspace>,
|
||||
Entity<ActiveThread>,
|
||||
WeakEntity<ThreadStore>,
|
||||
&mut Window,
|
||||
&mut App,
|
||||
) -> Option<AnyElement>,
|
||||
>,
|
||||
>,
|
||||
cursor_index: usize,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
user_store: Entity<UserStore>,
|
||||
filter_editor: Entity<SingleLineInput>,
|
||||
filter_text: String,
|
||||
|
||||
// preview support
|
||||
thread_store: Option<Entity<ThreadStore>>,
|
||||
active_thread: Option<Entity<ActiveThread>>,
|
||||
}
|
||||
|
||||
impl ComponentPreview {
|
||||
pub fn new(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
user_store: Entity<UserStore>,
|
||||
selected_index: impl Into<Option<usize>>,
|
||||
active_page: Option<PreviewPage>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
) -> anyhow::Result<Self> {
|
||||
let workspace_clone = workspace.clone();
|
||||
let project_clone = project.clone();
|
||||
|
||||
let entity = cx.weak_entity();
|
||||
window
|
||||
.spawn(cx, async move |cx| {
|
||||
let thread_store_task =
|
||||
load_preview_thread_store(workspace_clone.clone(), project_clone.clone(), cx)
|
||||
.await;
|
||||
|
||||
if let Ok(thread_store) = thread_store_task.await {
|
||||
entity
|
||||
.update_in(cx, |this, window, cx| {
|
||||
this.thread_store = Some(thread_store.clone());
|
||||
this.create_active_thread(window, cx);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
let sorted_components = components().all_sorted();
|
||||
let selected_index = selected_index.into().unwrap_or(0);
|
||||
let active_page = active_page.unwrap_or(PreviewPage::AllComponents);
|
||||
@@ -143,6 +190,40 @@ impl ComponentPreview {
|
||||
},
|
||||
);
|
||||
|
||||
// Initialize agent previews
|
||||
let agent_previews = agent::all_agent_previews()
|
||||
.into_iter()
|
||||
.map(|id| {
|
||||
Box::new(
|
||||
move |_self: &ComponentPreview,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
active_thread: Entity<ActiveThread>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut App| {
|
||||
agent::get_agent_preview(
|
||||
&id,
|
||||
workspace,
|
||||
active_thread,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
},
|
||||
)
|
||||
as Box<
|
||||
dyn Fn(
|
||||
&ComponentPreview,
|
||||
WeakEntity<Workspace>,
|
||||
Entity<ActiveThread>,
|
||||
WeakEntity<ThreadStore>,
|
||||
&mut Window,
|
||||
&mut App,
|
||||
) -> Option<AnyElement>,
|
||||
>
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut component_preview = Self {
|
||||
workspace_id: None,
|
||||
focus_handle: cx.focus_handle(),
|
||||
@@ -151,13 +232,17 @@ impl ComponentPreview {
|
||||
language_registry,
|
||||
user_store,
|
||||
workspace,
|
||||
project,
|
||||
active_page,
|
||||
component_map: components().0,
|
||||
components: sorted_components,
|
||||
component_list,
|
||||
agent_previews,
|
||||
cursor_index: selected_index,
|
||||
filter_editor,
|
||||
filter_text: String::new(),
|
||||
thread_store: None,
|
||||
active_thread: None,
|
||||
};
|
||||
|
||||
if component_preview.cursor_index > 0 {
|
||||
@@ -169,13 +254,41 @@ impl ComponentPreview {
|
||||
let focus_handle = component_preview.filter_editor.read(cx).focus_handle(cx);
|
||||
window.focus(&focus_handle);
|
||||
|
||||
component_preview
|
||||
Ok(component_preview)
|
||||
}
|
||||
|
||||
pub fn create_active_thread(
|
||||
&mut self,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> &mut Self {
|
||||
let workspace = self.workspace.clone();
|
||||
let language_registry = self.language_registry.clone();
|
||||
let weak_handle = self.workspace.clone();
|
||||
if let Some(workspace) = workspace.upgrade() {
|
||||
let project = workspace.read(cx).project().clone();
|
||||
if let Some(thread_store) = self.thread_store.clone() {
|
||||
let active_thread = static_active_thread(
|
||||
weak_handle,
|
||||
project,
|
||||
language_registry,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
self.active_thread = Some(active_thread);
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn active_page_id(&self, _cx: &App) -> ActivePageId {
|
||||
match &self.active_page {
|
||||
PreviewPage::AllComponents => ActivePageId::default(),
|
||||
PreviewPage::Component(component_id) => ActivePageId(component_id.0.to_string()),
|
||||
PreviewPage::ActiveThread => ActivePageId("active_thread".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -289,6 +402,7 @@ impl ComponentPreview {
|
||||
|
||||
// Always show all components first
|
||||
entries.push(PreviewEntry::AllComponents);
|
||||
entries.push(PreviewEntry::ActiveThread);
|
||||
entries.push(PreviewEntry::Separator);
|
||||
|
||||
let mut scopes: Vec<_> = scope_groups
|
||||
@@ -389,6 +503,19 @@ impl ComponentPreview {
|
||||
}))
|
||||
.into_any_element()
|
||||
}
|
||||
PreviewEntry::ActiveThread => {
|
||||
let selected = self.active_page == PreviewPage::ActiveThread;
|
||||
|
||||
ListItem::new(ix)
|
||||
.child(Label::new("Active Thread").color(Color::Default))
|
||||
.selectable(true)
|
||||
.toggle_state(selected)
|
||||
.inset(true)
|
||||
.on_click(cx.listener(move |this, _, _, cx| {
|
||||
this.set_active_page(PreviewPage::ActiveThread, cx);
|
||||
}))
|
||||
.into_any_element()
|
||||
}
|
||||
PreviewEntry::Separator => ListItem::new(ix)
|
||||
.child(
|
||||
h_flex()
|
||||
@@ -471,6 +598,7 @@ impl ComponentPreview {
|
||||
.render_scope_header(ix, shared_string.clone(), window, cx)
|
||||
.into_any_element(),
|
||||
PreviewEntry::AllComponents => div().w_full().h_0().into_any_element(),
|
||||
PreviewEntry::ActiveThread => div().w_full().h_0().into_any_element(),
|
||||
PreviewEntry::Separator => div().w_full().h_0().into_any_element(),
|
||||
})
|
||||
.unwrap()
|
||||
@@ -595,6 +723,41 @@ impl ComponentPreview {
|
||||
}
|
||||
}
|
||||
|
||||
fn render_active_thread(
|
||||
&self,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
v_flex()
|
||||
.id("render-active-thread")
|
||||
.size_full()
|
||||
.child(
|
||||
v_flex().children(self.agent_previews.iter().filter_map(|preview_fn| {
|
||||
if let (Some(thread_store), Some(active_thread)) = (
|
||||
self.thread_store.as_ref().map(|ts| ts.downgrade()),
|
||||
self.active_thread.clone(),
|
||||
) {
|
||||
preview_fn(
|
||||
self,
|
||||
self.workspace.clone(),
|
||||
active_thread,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.map(|element| div().child(element))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})),
|
||||
)
|
||||
.children(self.active_thread.clone().map(|thread| thread.clone()))
|
||||
.when_none(&self.active_thread.clone(), |this| {
|
||||
this.child("No active thread")
|
||||
})
|
||||
.into_any_element()
|
||||
}
|
||||
|
||||
fn test_status_toast(&self, cx: &mut Context<Self>) {
|
||||
if let Some(workspace) = self.workspace.upgrade() {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
@@ -704,6 +867,9 @@ impl Render for ComponentPreview {
|
||||
PreviewPage::Component(id) => self
|
||||
.render_component_page(&id, window, cx)
|
||||
.into_any_element(),
|
||||
PreviewPage::ActiveThread => {
|
||||
self.render_active_thread(window, cx).into_any_element()
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
@@ -759,20 +925,28 @@ impl Item for ComponentPreview {
|
||||
let language_registry = self.language_registry.clone();
|
||||
let user_store = self.user_store.clone();
|
||||
let weak_workspace = self.workspace.clone();
|
||||
let project = self.project.clone();
|
||||
let selected_index = self.cursor_index;
|
||||
let active_page = self.active_page.clone();
|
||||
|
||||
Some(cx.new(|cx| {
|
||||
Self::new(
|
||||
weak_workspace,
|
||||
language_registry,
|
||||
user_store,
|
||||
selected_index,
|
||||
Some(active_page),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}))
|
||||
let self_result = Self::new(
|
||||
weak_workspace,
|
||||
project,
|
||||
language_registry,
|
||||
user_store,
|
||||
selected_index,
|
||||
Some(active_page),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
match self_result {
|
||||
Ok(preview) => Some(cx.new(|_cx| preview)),
|
||||
Err(e) => {
|
||||
log::error!("Failed to clone component preview: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn to_item_events(event: &Self::Event, mut f: impl FnMut(workspace::item::ItemEvent)) {
|
||||
@@ -838,10 +1012,12 @@ impl SerializableItem for ComponentPreview {
|
||||
let user_store = user_store.clone();
|
||||
let language_registry = language_registry.clone();
|
||||
let weak_workspace = workspace.clone();
|
||||
let project = project.clone();
|
||||
cx.update(move |window, cx| {
|
||||
Ok(cx.new(|cx| {
|
||||
ComponentPreview::new(
|
||||
weak_workspace,
|
||||
project,
|
||||
language_registry,
|
||||
user_store,
|
||||
None,
|
||||
@@ -849,6 +1025,7 @@ impl SerializableItem for ComponentPreview {
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.expect("Failed to create component preview")
|
||||
}))
|
||||
})?
|
||||
})
|
||||
|
||||
1
crates/component_preview/src/preview_support.rs
Normal file
1
crates/component_preview/src/preview_support.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod active_thread;
|
||||
@@ -0,0 +1,69 @@
|
||||
use languages::LanguageRegistry;
|
||||
use project::Project;
|
||||
use std::sync::Arc;
|
||||
|
||||
use agent::{ActiveThread, ContextStore, MessageSegment, ThreadStore};
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
use gpui::{AppContext, AsyncApp, Entity, Task, WeakEntity};
|
||||
use prompt_store::PromptBuilder;
|
||||
use ui::{App, Window};
|
||||
use workspace::Workspace;
|
||||
|
||||
pub async fn load_preview_thread_store(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<anyhow::Result<Entity<ThreadStore>>> {
|
||||
cx.spawn(async move |cx| {
|
||||
workspace
|
||||
.update(cx, |_, cx| {
|
||||
ThreadStore::load(
|
||||
project.clone(),
|
||||
cx.new(|_| ToolWorkingSet::default()),
|
||||
None,
|
||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
pub fn static_active_thread(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Entity<ActiveThread> {
|
||||
let context_store =
|
||||
cx.new(|_| ContextStore::new(project.downgrade(), Some(thread_store.downgrade())));
|
||||
|
||||
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.insert_assistant_message(vec![
|
||||
MessageSegment::Text("I'll help you fix the lifetime error in your `cx.spawn` call. When working with async operations in GPUI, there are specific patterns to follow for proper lifetime management.".to_string()),
|
||||
MessageSegment::Text("\n\nLet's look at what's happening in your code:".to_string()),
|
||||
MessageSegment::Text("\n\n---\n\nLet's check the current state of the active_thread.rs file to understand what might have changed:".to_string()),
|
||||
MessageSegment::Text("\n\n---\n\nLooking at the implementation of `load_preview_thread_store` and understanding GPUI's async patterns, here's the issue:".to_string()),
|
||||
MessageSegment::Text("\n\n1. `load_preview_thread_store` returns a `Task<anyhow::Result<Entity<ThreadStore>>>`, which means it's already a task".to_string()),
|
||||
MessageSegment::Text("\n2. When you call this function inside another `spawn` call, you're nesting tasks incorrectly".to_string()),
|
||||
MessageSegment::Text("\n3. The `this` parameter you're trying to use in your closure has the wrong context".to_string()),
|
||||
MessageSegment::Text("\n\nHere's the correct way to implement this:".to_string()),
|
||||
MessageSegment::Text("\n\n---\n\nThe problem is in how you're setting up the async closure and trying to reference variables like `window` and `language_registry` that aren't accessible in that scope.".to_string()),
|
||||
MessageSegment::Text("\n\nHere's how to fix it:".to_string()),
|
||||
], cx);
|
||||
});
|
||||
cx.new(|cx| {
|
||||
ActiveThread::new(
|
||||
thread,
|
||||
thread_store,
|
||||
context_store,
|
||||
language_registry,
|
||||
workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
}
|
||||
@@ -34,3 +34,7 @@ smol.workspace = true
|
||||
url = { workspace = true, features = ["serde"] }
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -140,7 +140,7 @@ impl Client {
|
||||
/// This function initializes a new Client by spawning a child process for the context server,
|
||||
/// setting up communication channels, and initializing handlers for input/output operations.
|
||||
/// It takes a server ID, binary information, and an async app context as input.
|
||||
pub fn new(
|
||||
pub fn stdio(
|
||||
server_id: ContextServerId,
|
||||
binary: ModelContextServerBinary,
|
||||
cx: AsyncApp,
|
||||
@@ -158,7 +158,16 @@ impl Client {
|
||||
.unwrap_or_else(String::new);
|
||||
|
||||
let transport = Arc::new(StdioTransport::new(binary, &cx)?);
|
||||
Self::new(server_id, server_name.into(), transport, cx)
|
||||
}
|
||||
|
||||
/// Creates a new Client instance for a context server.
|
||||
pub fn new(
|
||||
server_id: ContextServerId,
|
||||
server_name: Arc<str>,
|
||||
transport: Arc<dyn Transport>,
|
||||
cx: AsyncApp,
|
||||
) -> Result<Self> {
|
||||
let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
|
||||
let (output_done_tx, output_done_rx) = barrier::channel();
|
||||
|
||||
@@ -167,7 +176,7 @@ impl Client {
|
||||
let response_handlers =
|
||||
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
|
||||
|
||||
let stdout_input_task = cx.spawn({
|
||||
let receive_input_task = cx.spawn({
|
||||
let notification_handlers = notification_handlers.clone();
|
||||
let response_handlers = response_handlers.clone();
|
||||
let transport = transport.clone();
|
||||
@@ -177,13 +186,13 @@ impl Client {
|
||||
.await
|
||||
}
|
||||
});
|
||||
let stderr_input_task = cx.spawn({
|
||||
let receive_err_task = cx.spawn({
|
||||
let transport = transport.clone();
|
||||
async move |_| Self::handle_stderr(transport).log_err().await
|
||||
async move |_| Self::handle_err(transport).log_err().await
|
||||
});
|
||||
let input_task = cx.spawn(async move |_| {
|
||||
let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
|
||||
stdout.or(stderr)
|
||||
let (input, err) = futures::join!(receive_input_task, receive_err_task);
|
||||
input.or(err)
|
||||
});
|
||||
|
||||
let output_task = cx.background_spawn({
|
||||
@@ -201,7 +210,7 @@ impl Client {
|
||||
server_id,
|
||||
notification_handlers,
|
||||
response_handlers,
|
||||
name: server_name.into(),
|
||||
name: server_name,
|
||||
next_id: Default::default(),
|
||||
outbound_tx,
|
||||
executor: cx.background_executor().clone(),
|
||||
@@ -247,7 +256,7 @@ impl Client {
|
||||
|
||||
/// Handles the stderr output from the context server.
|
||||
/// Continuously reads and logs any error messages from the server.
|
||||
async fn handle_stderr(transport: Arc<dyn Transport>) -> anyhow::Result<()> {
|
||||
async fn handle_err(transport: Arc<dyn Transport>) -> anyhow::Result<()> {
|
||||
while let Some(err) = transport.receive_err().next().await {
|
||||
log::warn!("context server stderr: {}", err.trim());
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ pub use context_server_settings::{ContextServerSettings, ServerCommand, ServerCo
|
||||
use gpui::{App, actions};
|
||||
|
||||
pub use crate::context_server_tool::ContextServerTool;
|
||||
pub use crate::registry::ContextServerFactoryRegistry;
|
||||
pub use crate::registry::ContextServerDescriptorRegistry;
|
||||
|
||||
actions!(context_servers, [Restart]);
|
||||
|
||||
@@ -21,7 +21,7 @@ pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers";
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
context_server_settings::init(cx);
|
||||
ContextServerFactoryRegistry::default_global(cx);
|
||||
ContextServerDescriptorRegistry::default_global(cx);
|
||||
extension_context_server::init(cx);
|
||||
|
||||
CommandPaletteFilter::update_global(cx, |filter, _cx| {
|
||||
|
||||
@@ -1,9 +1,21 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use extension::{Extension, ExtensionContextServerProxy, ExtensionHostProxy, ProjectDelegate};
|
||||
use gpui::{App, Entity};
|
||||
use anyhow::Result;
|
||||
use extension::{
|
||||
ContextServerConfiguration, Extension, ExtensionContextServerProxy, ExtensionHostProxy,
|
||||
ProjectDelegate,
|
||||
};
|
||||
use gpui::{App, AsyncApp, Entity, Task};
|
||||
use project::Project;
|
||||
|
||||
use crate::{ContextServerFactoryRegistry, ServerCommand};
|
||||
use crate::{ContextServerDescriptorRegistry, ServerCommand, registry};
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
let proxy = ExtensionHostProxy::default_global(cx);
|
||||
proxy.register_context_server_proxy(ContextServerDescriptorRegistryProxy {
|
||||
context_server_factory_registry: ContextServerDescriptorRegistry::global(cx),
|
||||
});
|
||||
}
|
||||
|
||||
struct ExtensionProject {
|
||||
worktree_ids: Vec<u64>,
|
||||
@@ -15,60 +27,78 @@ impl ProjectDelegate for ExtensionProject {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
let proxy = ExtensionHostProxy::default_global(cx);
|
||||
proxy.register_context_server_proxy(ContextServerFactoryRegistryProxy {
|
||||
context_server_factory_registry: ContextServerFactoryRegistry::global(cx),
|
||||
});
|
||||
struct ContextServerDescriptor {
|
||||
id: Arc<str>,
|
||||
extension: Arc<dyn Extension>,
|
||||
}
|
||||
|
||||
struct ContextServerFactoryRegistryProxy {
|
||||
context_server_factory_registry: Entity<ContextServerFactoryRegistry>,
|
||||
fn extension_project(project: Entity<Project>, cx: &mut AsyncApp) -> Result<Arc<ExtensionProject>> {
|
||||
project.update(cx, |project, cx| {
|
||||
Arc::new(ExtensionProject {
|
||||
worktree_ids: project
|
||||
.visible_worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).id().to_proto())
|
||||
.collect(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
impl ExtensionContextServerProxy for ContextServerFactoryRegistryProxy {
|
||||
impl registry::ContextServerDescriptor for ContextServerDescriptor {
|
||||
fn command(&self, project: Entity<Project>, cx: &AsyncApp) -> Task<Result<ServerCommand>> {
|
||||
let id = self.id.clone();
|
||||
let extension = self.extension.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let extension_project = extension_project(project, cx)?;
|
||||
let mut command = extension
|
||||
.context_server_command(id.clone(), extension_project.clone())
|
||||
.await?;
|
||||
command.command = extension
|
||||
.path_from_extension(command.command.as_ref())
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
log::info!("loaded command for context server {id}: {command:?}");
|
||||
|
||||
Ok(ServerCommand {
|
||||
path: command.command,
|
||||
args: command.args,
|
||||
env: Some(command.env.into_iter().collect()),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn configuration(
|
||||
&self,
|
||||
project: Entity<Project>,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<Option<ContextServerConfiguration>>> {
|
||||
let id = self.id.clone();
|
||||
let extension = self.extension.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let extension_project = extension_project(project, cx)?;
|
||||
let configuration = extension
|
||||
.context_server_configuration(id.clone(), extension_project)
|
||||
.await?;
|
||||
|
||||
log::debug!("loaded configuration for context server {id}: {configuration:?}");
|
||||
|
||||
Ok(configuration)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct ContextServerDescriptorRegistryProxy {
|
||||
context_server_factory_registry: Entity<ContextServerDescriptorRegistry>,
|
||||
}
|
||||
|
||||
impl ExtensionContextServerProxy for ContextServerDescriptorRegistryProxy {
|
||||
fn register_context_server(&self, extension: Arc<dyn Extension>, id: Arc<str>, cx: &mut App) {
|
||||
self.context_server_factory_registry
|
||||
.update(cx, |registry, _| {
|
||||
registry.register_server_factory(
|
||||
registry.register_context_server_descriptor(
|
||||
id.clone(),
|
||||
Arc::new({
|
||||
move |project, cx| {
|
||||
log::info!(
|
||||
"loading command for context server {id} from extension {}",
|
||||
extension.manifest().id
|
||||
);
|
||||
|
||||
let id = id.clone();
|
||||
let extension = extension.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let extension_project = project.update(cx, |project, cx| {
|
||||
Arc::new(ExtensionProject {
|
||||
worktree_ids: project
|
||||
.visible_worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).id().to_proto())
|
||||
.collect(),
|
||||
})
|
||||
})?;
|
||||
|
||||
let mut command = extension
|
||||
.context_server_command(id.clone(), extension_project)
|
||||
.await?;
|
||||
command.command = extension
|
||||
.path_from_extension(command.command.as_ref())
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
log::info!("loaded command for context server {id}: {command:?}");
|
||||
|
||||
Ok(ServerCommand {
|
||||
path: command.command,
|
||||
args: command.args,
|
||||
env: Some(command.env.into_iter().collect()),
|
||||
})
|
||||
})
|
||||
}
|
||||
}),
|
||||
Arc::new(ContextServerDescriptor { id, extension })
|
||||
as Arc<dyn registry::ContextServerDescriptor>,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
@@ -27,18 +27,27 @@ use project::Project;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::transport::Transport;
|
||||
use crate::{ContextServerSettings, ServerConfig};
|
||||
|
||||
use crate::{
|
||||
CONTEXT_SERVERS_NAMESPACE, ContextServerFactoryRegistry,
|
||||
CONTEXT_SERVERS_NAMESPACE, ContextServerDescriptorRegistry,
|
||||
client::{self, Client},
|
||||
types,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum ContextServerStatus {
|
||||
Starting,
|
||||
Running,
|
||||
Error(Arc<str>),
|
||||
}
|
||||
|
||||
pub struct ContextServer {
|
||||
pub id: Arc<str>,
|
||||
pub config: Arc<ServerConfig>,
|
||||
pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
|
||||
transport: Option<Arc<dyn Transport>>,
|
||||
}
|
||||
|
||||
impl ContextServer {
|
||||
@@ -47,9 +56,20 @@ impl ContextServer {
|
||||
id,
|
||||
config,
|
||||
client: RwLock::new(None),
|
||||
transport: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn test(id: Arc<str>, transport: Arc<dyn crate::transport::Transport>) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
id,
|
||||
client: RwLock::new(None),
|
||||
config: Arc::new(ServerConfig::default()),
|
||||
transport: Some(transport),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn id(&self) -> Arc<str> {
|
||||
self.id.clone()
|
||||
}
|
||||
@@ -63,20 +83,32 @@ impl ContextServer {
|
||||
}
|
||||
|
||||
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
|
||||
log::info!("starting context server {}", self.id);
|
||||
let Some(command) = &self.config.command else {
|
||||
bail!("no command specified for server {}", self.id);
|
||||
let client = if let Some(transport) = self.transport.clone() {
|
||||
Client::new(
|
||||
client::ContextServerId(self.id.clone()),
|
||||
self.id(),
|
||||
transport,
|
||||
cx.clone(),
|
||||
)?
|
||||
} else {
|
||||
let Some(command) = &self.config.command else {
|
||||
bail!("no command specified for server {}", self.id);
|
||||
};
|
||||
Client::stdio(
|
||||
client::ContextServerId(self.id.clone()),
|
||||
client::ModelContextServerBinary {
|
||||
executable: Path::new(&command.path).to_path_buf(),
|
||||
args: command.args.clone(),
|
||||
env: command.env.clone(),
|
||||
},
|
||||
cx.clone(),
|
||||
)?
|
||||
};
|
||||
let client = Client::new(
|
||||
client::ContextServerId(self.id.clone()),
|
||||
client::ModelContextServerBinary {
|
||||
executable: Path::new(&command.path).to_path_buf(),
|
||||
args: command.args.clone(),
|
||||
env: command.env.clone(),
|
||||
},
|
||||
cx.clone(),
|
||||
)?;
|
||||
self.initialize(client).await
|
||||
}
|
||||
|
||||
async fn initialize(&self, client: Client) -> Result<()> {
|
||||
log::info!("starting context server {}", self.id);
|
||||
let protocol = crate::protocol::ModelContextProtocol::new(client);
|
||||
let client_info = types::Implementation {
|
||||
name: "Zed".to_string(),
|
||||
@@ -105,23 +137,26 @@ impl ContextServer {
|
||||
|
||||
pub struct ContextServerManager {
|
||||
servers: HashMap<Arc<str>, Arc<ContextServer>>,
|
||||
server_status: HashMap<Arc<str>, ContextServerStatus>,
|
||||
project: Entity<Project>,
|
||||
registry: Entity<ContextServerFactoryRegistry>,
|
||||
registry: Entity<ContextServerDescriptorRegistry>,
|
||||
update_servers_task: Option<Task<Result<()>>>,
|
||||
needs_server_update: bool,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
pub enum Event {
|
||||
ServerStarted { server_id: Arc<str> },
|
||||
ServerStopped { server_id: Arc<str> },
|
||||
ServerStatusChanged {
|
||||
server_id: Arc<str>,
|
||||
status: Option<ContextServerStatus>,
|
||||
},
|
||||
}
|
||||
|
||||
impl EventEmitter<Event> for ContextServerManager {}
|
||||
|
||||
impl ContextServerManager {
|
||||
pub fn new(
|
||||
registry: Entity<ContextServerFactoryRegistry>,
|
||||
registry: Entity<ContextServerDescriptorRegistry>,
|
||||
project: Entity<Project>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
@@ -138,6 +173,7 @@ impl ContextServerManager {
|
||||
registry,
|
||||
needs_server_update: false,
|
||||
servers: HashMap::default(),
|
||||
server_status: HashMap::default(),
|
||||
update_servers_task: None,
|
||||
};
|
||||
this.available_context_servers_changed(cx);
|
||||
@@ -153,7 +189,9 @@ impl ContextServerManager {
|
||||
this.needs_server_update = false;
|
||||
})?;
|
||||
|
||||
Self::maintain_servers(this.clone(), cx).await?;
|
||||
if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
|
||||
log::error!("Error maintaining context servers: {}", err);
|
||||
}
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
let has_any_context_servers = !this.running_servers().is_empty();
|
||||
@@ -181,52 +219,37 @@ impl ContextServerManager {
|
||||
.cloned()
|
||||
}
|
||||
|
||||
pub fn status_for_server(&self, id: &str) -> Option<ContextServerStatus> {
|
||||
self.server_status.get(id).cloned()
|
||||
}
|
||||
|
||||
pub fn start_server(
|
||||
&self,
|
||||
server: Arc<ContextServer>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<anyhow::Result<()>> {
|
||||
cx.spawn(async move |this, cx| {
|
||||
let id = server.id.clone();
|
||||
server.start(&cx).await?;
|
||||
this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?;
|
||||
Ok(())
|
||||
})
|
||||
) -> Task<Result<()>> {
|
||||
cx.spawn(async move |this, cx| Self::run_server(this, server, cx).await)
|
||||
}
|
||||
|
||||
pub fn stop_server(
|
||||
&self,
|
||||
&mut self,
|
||||
server: Arc<ContextServer>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> anyhow::Result<()> {
|
||||
server.stop()?;
|
||||
cx.emit(Event::ServerStopped {
|
||||
server_id: server.id(),
|
||||
});
|
||||
) -> Result<()> {
|
||||
server.stop().log_err();
|
||||
self.update_server_status(server.id().clone(), None, cx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn restart_server(
|
||||
&mut self,
|
||||
id: &Arc<str>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<anyhow::Result<()>> {
|
||||
pub fn restart_server(&mut self, id: &Arc<str>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let id = id.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
if let Some(server) = this.update(cx, |this, _cx| this.servers.remove(&id))? {
|
||||
server.stop()?;
|
||||
let config = server.config();
|
||||
|
||||
this.update(cx, |this, cx| this.stop_server(server, cx))??;
|
||||
let new_server = Arc::new(ContextServer::new(id.clone(), config));
|
||||
new_server.clone().start(&cx).await?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.servers.insert(id.clone(), new_server);
|
||||
cx.emit(Event::ServerStopped {
|
||||
server_id: id.clone(),
|
||||
});
|
||||
cx.emit(Event::ServerStarted {
|
||||
server_id: id.clone(),
|
||||
});
|
||||
})?;
|
||||
Self::run_server(this, new_server, cx).await?;
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
@@ -263,12 +286,14 @@ impl ContextServerManager {
|
||||
(this.registry.clone(), this.project.clone())
|
||||
})?;
|
||||
|
||||
for (id, factory) in
|
||||
registry.read_with(cx, |registry, _| registry.context_server_factories())?
|
||||
for (id, descriptor) in
|
||||
registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
|
||||
{
|
||||
let config = desired_servers.entry(id).or_default();
|
||||
if config.command.is_none() {
|
||||
if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
|
||||
if let Some(extension_command) =
|
||||
descriptor.command(project.clone(), &cx).await.log_err()
|
||||
{
|
||||
config.command = Some(extension_command);
|
||||
}
|
||||
}
|
||||
@@ -290,28 +315,270 @@ impl ContextServerManager {
|
||||
for (id, config) in desired_servers {
|
||||
let existing_config = this.servers.get(&id).map(|server| server.config());
|
||||
if existing_config.as_deref() != Some(&config) {
|
||||
let config = Arc::new(config);
|
||||
let server = Arc::new(ContextServer::new(id.clone(), config));
|
||||
let server = Arc::new(ContextServer::new(id.clone(), Arc::new(config)));
|
||||
servers_to_start.insert(id.clone(), server.clone());
|
||||
let old_server = this.servers.insert(id.clone(), server);
|
||||
if let Some(old_server) = old_server {
|
||||
if let Some(old_server) = this.servers.remove(&id) {
|
||||
servers_to_stop.insert(id, old_server);
|
||||
}
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
for (id, server) in servers_to_stop {
|
||||
server.stop().log_err();
|
||||
this.update(cx, |_, cx| cx.emit(Event::ServerStopped { server_id: id }))?;
|
||||
for (_, server) in servers_to_stop {
|
||||
this.update(cx, |this, cx| this.stop_server(server, cx).ok())?;
|
||||
}
|
||||
|
||||
for (id, server) in servers_to_start {
|
||||
if server.start(&cx).await.log_err().is_some() {
|
||||
this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?;
|
||||
}
|
||||
for (_, server) in servers_to_start {
|
||||
Self::run_server(this.clone(), server, cx).await.ok();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_server(
|
||||
this: WeakEntity<Self>,
|
||||
server: Arc<ContextServer>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<()> {
|
||||
let id = server.id();
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_server_status(id.clone(), Some(ContextServerStatus::Starting), cx);
|
||||
this.servers.insert(id.clone(), server.clone());
|
||||
})?;
|
||||
|
||||
match server.start(&cx).await {
|
||||
Ok(_) => {
|
||||
log::debug!("`{}` context server started", id);
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_server_status(id.clone(), Some(ContextServerStatus::Running), cx)
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("`{}` context server failed to start\n{}", id, err);
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_server_status(
|
||||
id.clone(),
|
||||
Some(ContextServerStatus::Error(err.to_string().into())),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_server_status(
|
||||
&mut self,
|
||||
id: Arc<str>,
|
||||
status: Option<ContextServerStatus>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(status) = status.clone() {
|
||||
self.server_status.insert(id.clone(), status);
|
||||
} else {
|
||||
self.server_status.remove(&id);
|
||||
}
|
||||
|
||||
cx.emit(Event::ServerStatusChanged {
|
||||
server_id: id,
|
||||
status,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::pin::Pin;
|
||||
|
||||
use crate::types::{
|
||||
Implementation, InitializeResponse, ProtocolVersion, RequestType, ServerCapabilities,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use futures::{Stream, StreamExt as _, lock::Mutex};
|
||||
use gpui::{AppContext as _, TestAppContext};
|
||||
use project::FakeFs;
|
||||
use serde_json::json;
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_context_server_status(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
let project = create_test_project(cx, json!({"code.rs": ""})).await;
|
||||
|
||||
let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
|
||||
let manager = cx.new(|cx| ContextServerManager::new(registry.clone(), project, cx));
|
||||
|
||||
let server_1_id: Arc<str> = "mcp-1".into();
|
||||
let server_2_id: Arc<str> = "mcp-2".into();
|
||||
|
||||
let transport_1 = Arc::new(FakeTransport::new(
|
||||
|_, request_type, _| match request_type {
|
||||
Some(RequestType::Initialize) => {
|
||||
Some(create_initialize_response("mcp-1".to_string()))
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
));
|
||||
|
||||
let transport_2 = Arc::new(FakeTransport::new(
|
||||
|_, request_type, _| match request_type {
|
||||
Some(RequestType::Initialize) => {
|
||||
Some(create_initialize_response("mcp-2".to_string()))
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
));
|
||||
|
||||
let server_1 = ContextServer::test(server_1_id.clone(), transport_1.clone());
|
||||
let server_2 = ContextServer::test(server_2_id.clone(), transport_2.clone());
|
||||
|
||||
manager
|
||||
.update(cx, |manager, cx| manager.start_server(server_1, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
manager.read(cx).status_for_server(&server_1_id),
|
||||
Some(ContextServerStatus::Running)
|
||||
);
|
||||
assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
|
||||
});
|
||||
|
||||
manager
|
||||
.update(cx, |manager, cx| manager.start_server(server_2.clone(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
manager.read(cx).status_for_server(&server_1_id),
|
||||
Some(ContextServerStatus::Running)
|
||||
);
|
||||
assert_eq!(
|
||||
manager.read(cx).status_for_server(&server_2_id),
|
||||
Some(ContextServerStatus::Running)
|
||||
);
|
||||
});
|
||||
|
||||
manager
|
||||
.update(cx, |manager, cx| manager.stop_server(server_2, cx))
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
manager.read(cx).status_for_server(&server_1_id),
|
||||
Some(ContextServerStatus::Running)
|
||||
);
|
||||
assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
|
||||
});
|
||||
}
|
||||
|
||||
async fn create_test_project(
|
||||
cx: &mut TestAppContext,
|
||||
files: serde_json::Value,
|
||||
) -> Entity<Project> {
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/test"), files).await;
|
||||
Project::test(fs, [path!("/test").as_ref()], cx).await
|
||||
}
|
||||
|
||||
fn init_test_settings(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
ContextServerSettings::register(cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn create_initialize_response(server_name: String) -> serde_json::Value {
|
||||
serde_json::to_value(&InitializeResponse {
|
||||
protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
|
||||
server_info: Implementation {
|
||||
name: server_name,
|
||||
version: "1.0.0".to_string(),
|
||||
},
|
||||
capabilities: ServerCapabilities::default(),
|
||||
meta: None,
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
struct FakeTransport {
|
||||
on_request: Arc<
|
||||
dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
|
||||
+ Send
|
||||
+ Sync,
|
||||
>,
|
||||
tx: futures::channel::mpsc::UnboundedSender<String>,
|
||||
rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
|
||||
}
|
||||
|
||||
impl FakeTransport {
|
||||
fn new(
|
||||
on_request: impl Fn(
|
||||
u64,
|
||||
Option<RequestType>,
|
||||
serde_json::Value,
|
||||
) -> Option<serde_json::Value>
|
||||
+ 'static
|
||||
+ Send
|
||||
+ Sync,
|
||||
) -> Self {
|
||||
let (tx, rx) = futures::channel::mpsc::unbounded();
|
||||
Self {
|
||||
on_request: Arc::new(on_request),
|
||||
tx,
|
||||
rx: Arc::new(Mutex::new(rx)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for FakeTransport {
|
||||
async fn send(&self, message: String) -> Result<()> {
|
||||
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
|
||||
let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
|
||||
|
||||
if let Some(method) = msg.get("method") {
|
||||
let request_type = method
|
||||
.as_str()
|
||||
.and_then(|method| types::RequestType::try_from(method).ok());
|
||||
if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
|
||||
let response = serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": payload
|
||||
});
|
||||
|
||||
self.tx
|
||||
.unbounded_send(response.to_string())
|
||||
.map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
|
||||
let rx = self.rx.clone();
|
||||
Box::pin(futures::stream::unfold(rx, |rx| async move {
|
||||
let mut rx_guard = rx.lock().await;
|
||||
if let Some(message) = rx_guard.next().await {
|
||||
drop(rx_guard);
|
||||
Some((message, rx))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
|
||||
Box::pin(futures::stream::empty())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,38 +2,47 @@ use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use extension::ContextServerConfiguration;
|
||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Global, ReadGlobal, Task};
|
||||
use project::Project;
|
||||
|
||||
use crate::ServerCommand;
|
||||
|
||||
pub type ContextServerFactory =
|
||||
Arc<dyn Fn(Entity<Project>, &AsyncApp) -> Task<Result<ServerCommand>> + Send + Sync + 'static>;
|
||||
|
||||
struct GlobalContextServerFactoryRegistry(Entity<ContextServerFactoryRegistry>);
|
||||
|
||||
impl Global for GlobalContextServerFactoryRegistry {}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ContextServerFactoryRegistry {
|
||||
context_servers: HashMap<Arc<str>, ContextServerFactory>,
|
||||
pub trait ContextServerDescriptor {
|
||||
fn command(&self, project: Entity<Project>, cx: &AsyncApp) -> Task<Result<ServerCommand>>;
|
||||
fn configuration(
|
||||
&self,
|
||||
project: Entity<Project>,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<Option<ContextServerConfiguration>>>;
|
||||
}
|
||||
|
||||
impl ContextServerFactoryRegistry {
|
||||
/// Returns the global [`ContextServerFactoryRegistry`].
|
||||
struct GlobalContextServerDescriptorRegistry(Entity<ContextServerDescriptorRegistry>);
|
||||
|
||||
impl Global for GlobalContextServerDescriptorRegistry {}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ContextServerDescriptorRegistry {
|
||||
context_servers: HashMap<Arc<str>, Arc<dyn ContextServerDescriptor>>,
|
||||
}
|
||||
|
||||
impl ContextServerDescriptorRegistry {
|
||||
/// Returns the global [`ContextServerDescriptorRegistry`].
|
||||
pub fn global(cx: &App) -> Entity<Self> {
|
||||
GlobalContextServerFactoryRegistry::global(cx).0.clone()
|
||||
GlobalContextServerDescriptorRegistry::global(cx).0.clone()
|
||||
}
|
||||
|
||||
/// Returns the global [`ContextServerFactoryRegistry`].
|
||||
/// Returns the global [`ContextServerDescriptorRegistry`].
|
||||
///
|
||||
/// Inserts a default [`ContextServerFactoryRegistry`] if one does not yet exist.
|
||||
/// Inserts a default [`ContextServerDescriptorRegistry`] if one does not yet exist.
|
||||
pub fn default_global(cx: &mut App) -> Entity<Self> {
|
||||
if !cx.has_global::<GlobalContextServerFactoryRegistry>() {
|
||||
if !cx.has_global::<GlobalContextServerDescriptorRegistry>() {
|
||||
let registry = cx.new(|_| Self::new());
|
||||
cx.set_global(GlobalContextServerFactoryRegistry(registry));
|
||||
cx.set_global(GlobalContextServerDescriptorRegistry(registry));
|
||||
}
|
||||
cx.global::<GlobalContextServerFactoryRegistry>().0.clone()
|
||||
cx.global::<GlobalContextServerDescriptorRegistry>()
|
||||
.0
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub fn new() -> Self {
|
||||
@@ -42,20 +51,28 @@ impl ContextServerFactoryRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context_server_factories(&self) -> Vec<(Arc<str>, ContextServerFactory)> {
|
||||
pub fn context_server_descriptors(&self) -> Vec<(Arc<str>, Arc<dyn ContextServerDescriptor>)> {
|
||||
self.context_servers
|
||||
.iter()
|
||||
.map(|(id, factory)| (id.clone(), factory.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Registers the provided [`ContextServerFactory`].
|
||||
pub fn register_server_factory(&mut self, id: Arc<str>, factory: ContextServerFactory) {
|
||||
self.context_servers.insert(id, factory);
|
||||
pub fn context_server_descriptor(&self, id: &str) -> Option<Arc<dyn ContextServerDescriptor>> {
|
||||
self.context_servers.get(id).cloned()
|
||||
}
|
||||
|
||||
/// Unregisters the [`ContextServerFactory`] for the server with the given ID.
|
||||
pub fn unregister_server_factory_by_id(&mut self, server_id: &str) {
|
||||
/// Registers the provided [`ContextServerDescriptor`].
|
||||
pub fn register_context_server_descriptor(
|
||||
&mut self,
|
||||
id: Arc<str>,
|
||||
descriptor: Arc<dyn ContextServerDescriptor>,
|
||||
) {
|
||||
self.context_servers.insert(id, descriptor);
|
||||
}
|
||||
|
||||
/// Unregisters the [`ContextServerDescriptor`] for the server with the given ID.
|
||||
pub fn unregister_context_server_descriptor_by_id(&mut self, server_id: &str) {
|
||||
self.context_servers.remove(server_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,6 +42,30 @@ impl RequestType {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for RequestType {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(s: &str) -> Result<Self, Self::Error> {
|
||||
match s {
|
||||
"initialize" => Ok(RequestType::Initialize),
|
||||
"tools/call" => Ok(RequestType::CallTool),
|
||||
"resources/unsubscribe" => Ok(RequestType::ResourcesUnsubscribe),
|
||||
"resources/subscribe" => Ok(RequestType::ResourcesSubscribe),
|
||||
"resources/read" => Ok(RequestType::ResourcesRead),
|
||||
"resources/list" => Ok(RequestType::ResourcesList),
|
||||
"logging/setLevel" => Ok(RequestType::LoggingSetLevel),
|
||||
"prompts/get" => Ok(RequestType::PromptsGet),
|
||||
"prompts/list" => Ok(RequestType::PromptsList),
|
||||
"completion/complete" => Ok(RequestType::CompletionComplete),
|
||||
"ping" => Ok(RequestType::Ping),
|
||||
"tools/list" => Ok(RequestType::ListTools),
|
||||
"resources/templates/list" => Ok(RequestType::ListResourceTemplates),
|
||||
"roots/list" => Ok(RequestType::ListRoots),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct ProtocolVersion(pub String);
|
||||
@@ -154,7 +178,7 @@ pub struct CompletionArgument {
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InitializeResponse {
|
||||
pub protocol_version: ProtocolVersion,
|
||||
@@ -343,7 +367,7 @@ pub struct ClientCapabilities {
|
||||
pub roots: Option<RootsCapabilities>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Default, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ServerCapabilities {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
|
||||
@@ -444,10 +444,7 @@ pub trait DebugAdapter: 'static + Send + Sync {
|
||||
log::info!("Getting latest version of debug adapter {}", self.name());
|
||||
delegate.update_status(self.name(), DapStatus::CheckingForUpdate);
|
||||
if let Some(version) = self.fetch_latest_adapter_version(delegate).await.log_err() {
|
||||
log::info!(
|
||||
"Installiing latest version of debug adapter {}",
|
||||
self.name()
|
||||
);
|
||||
log::info!("Installing latest version of debug adapter {}", self.name());
|
||||
delegate.update_status(self.name(), DapStatus::Downloading);
|
||||
match self.install_binary(version, delegate).await {
|
||||
Ok(_) => {
|
||||
|
||||
@@ -7,6 +7,7 @@ use crate::{
|
||||
};
|
||||
use crate::{new_session_modal::NewSessionModal, session::DebugSession};
|
||||
use anyhow::{Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
use dap::DebugRequest;
|
||||
use dap::{
|
||||
@@ -26,6 +27,7 @@ use project::{Project, debugger::session::ThreadStatus};
|
||||
use rpc::proto::{self};
|
||||
use settings::Settings;
|
||||
use std::any::TypeId;
|
||||
use std::path::PathBuf;
|
||||
use task::{DebugScenario, TaskContext};
|
||||
use ui::{ContextMenu, Divider, DropdownMenu, Tooltip, prelude::*};
|
||||
use workspace::SplitDirection;
|
||||
@@ -403,7 +405,6 @@ impl DebugPanel {
|
||||
pub fn resolve_scenario(
|
||||
&self,
|
||||
scenario: DebugScenario,
|
||||
|
||||
task_context: TaskContext,
|
||||
buffer: Option<Entity<Buffer>>,
|
||||
window: &Window,
|
||||
@@ -424,8 +425,60 @@ impl DebugPanel {
|
||||
stop_on_entry,
|
||||
} = scenario;
|
||||
let request = if let Some(mut request) = request {
|
||||
// Resolve task variables within the request.
|
||||
if let DebugRequest::Launch(_) = &mut request {}
|
||||
if let DebugRequest::Launch(launch_config) = &mut request {
|
||||
let mut variable_names = HashMap::default();
|
||||
let mut substituted_variables = HashSet::default();
|
||||
let task_variables = task_context
|
||||
.task_variables
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
let key_string = key.to_string();
|
||||
if !variable_names.contains_key(&key_string) {
|
||||
variable_names.insert(key_string.clone(), key.clone());
|
||||
}
|
||||
(key_string, value.as_str())
|
||||
})
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let cwd = launch_config
|
||||
.cwd
|
||||
.as_ref()
|
||||
.and_then(|cwd| cwd.to_str())
|
||||
.and_then(|cwd| {
|
||||
task::substitute_all_template_variables_in_str(
|
||||
cwd,
|
||||
&task_variables,
|
||||
&variable_names,
|
||||
&mut substituted_variables,
|
||||
)
|
||||
});
|
||||
|
||||
if let Some(cwd) = cwd {
|
||||
launch_config.cwd = Some(PathBuf::from(cwd))
|
||||
}
|
||||
|
||||
if let Some(program) = task::substitute_all_template_variables_in_str(
|
||||
&launch_config.program,
|
||||
&task_variables,
|
||||
&variable_names,
|
||||
&mut substituted_variables,
|
||||
) {
|
||||
launch_config.program = program;
|
||||
}
|
||||
|
||||
for arg in launch_config.args.iter_mut() {
|
||||
if let Some(substituted_arg) =
|
||||
task::substitute_all_template_variables_in_str(
|
||||
&arg,
|
||||
&task_variables,
|
||||
&variable_names,
|
||||
&mut substituted_variables,
|
||||
)
|
||||
{
|
||||
*arg = substituted_arg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
request
|
||||
} else if let Some(build) = build {
|
||||
@@ -944,6 +997,7 @@ impl DebugPanel {
|
||||
past_debug_definition,
|
||||
weak_panel,
|
||||
workspace,
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
|
||||
@@ -158,6 +158,7 @@ pub fn init(cx: &mut App) {
|
||||
debug_panel.read(cx).past_debug_definition.clone(),
|
||||
weak_panel,
|
||||
weak_workspace,
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
@@ -166,14 +167,22 @@ pub fn init(cx: &mut App) {
|
||||
},
|
||||
)
|
||||
.register_action(|workspace: &mut Workspace, _: &Start, window, cx| {
|
||||
tasks_ui::toggle_modal(
|
||||
workspace,
|
||||
None,
|
||||
task::TaskModal::DebugModal,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.detach();
|
||||
if let Some(debug_panel) = workspace.panel::<DebugPanel>(cx) {
|
||||
let weak_panel = debug_panel.downgrade();
|
||||
let weak_workspace = cx.weak_entity();
|
||||
let task_store = workspace.project().read(cx).task_store().clone();
|
||||
|
||||
workspace.toggle_modal(window, cx, |window, cx| {
|
||||
NewSessionModal::new(
|
||||
debug_panel.read(cx).past_debug_definition.clone(),
|
||||
weak_panel,
|
||||
weak_workspace,
|
||||
Some(task_store),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
}
|
||||
});
|
||||
})
|
||||
})
|
||||
|
||||
@@ -6,19 +6,25 @@ use std::{
|
||||
|
||||
use dap::{DapRegistry, DebugRequest, adapters::DebugTaskDefinition};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use fuzzy::{StringMatch, StringMatchCandidate};
|
||||
use gpui::{
|
||||
App, AppContext, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, TextStyle,
|
||||
WeakEntity,
|
||||
App, AppContext, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render,
|
||||
Subscription, TextStyle, WeakEntity,
|
||||
};
|
||||
use picker::{Picker, PickerDelegate, highlighted_match_with_paths::HighlightedMatch};
|
||||
use project::{TaskSourceKind, task_store::TaskStore};
|
||||
use session_modes::{AttachMode, DebugScenarioDelegate, LaunchMode};
|
||||
use settings::Settings;
|
||||
use task::{DebugScenario, LaunchRequest, TaskContext};
|
||||
use task::{DebugScenario, LaunchRequest};
|
||||
use theme::ThemeSettings;
|
||||
use ui::{
|
||||
ActiveTheme, Button, ButtonCommon, ButtonSize, CheckboxWithLabel, Clickable, Color, Context,
|
||||
ContextMenu, Disableable, DropdownMenu, FluentBuilder, InteractiveElement, IntoElement, Label,
|
||||
LabelCommon as _, ParentElement, RenderOnce, SharedString, Styled, StyledExt, ToggleButton,
|
||||
ToggleState, Toggleable, Window, div, h_flex, relative, rems, v_flex,
|
||||
ContextMenu, Disableable, DropdownMenu, FluentBuilder, Icon, IconName, InteractiveElement,
|
||||
IntoElement, Label, LabelCommon as _, ListItem, ListItemSpacing, ParentElement, RenderOnce,
|
||||
SharedString, Styled, StyledExt, ToggleButton, ToggleState, Toggleable, Window, div, h_flex,
|
||||
relative, rems, v_flex,
|
||||
};
|
||||
use util::ResultExt;
|
||||
use workspace::{ModalView, Workspace};
|
||||
|
||||
use crate::{attach_modal::AttachModal, debugger_panel::DebugPanel};
|
||||
@@ -57,6 +63,7 @@ impl NewSessionModal {
|
||||
past_debug_definition: Option<DebugTaskDefinition>,
|
||||
debug_panel: WeakEntity<DebugPanel>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
task_store: Option<Entity<TaskStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
@@ -73,6 +80,18 @@ impl NewSessionModal {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if let Some(task_store) = task_store {
|
||||
cx.defer_in(window, |this, window, cx| {
|
||||
this.mode = NewSessionMode::scenario(
|
||||
this.debug_panel.clone(),
|
||||
this.workspace.clone(),
|
||||
task_store,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
Self {
|
||||
workspace: workspace.clone(),
|
||||
debugger,
|
||||
@@ -86,10 +105,10 @@ impl NewSessionModal {
|
||||
}
|
||||
}
|
||||
|
||||
fn debug_config(&self, cx: &App, debugger: &str) -> DebugScenario {
|
||||
let request = self.mode.debug_task(cx);
|
||||
fn debug_config(&self, cx: &App, debugger: &str) -> Option<DebugScenario> {
|
||||
let request = self.mode.debug_task(cx)?;
|
||||
let label = suggested_label(&request, debugger);
|
||||
DebugScenario {
|
||||
Some(DebugScenario {
|
||||
adapter: debugger.to_owned().into(),
|
||||
label,
|
||||
request: Some(request),
|
||||
@@ -100,21 +119,42 @@ impl NewSessionModal {
|
||||
_ => None,
|
||||
},
|
||||
build: None,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn start_new_session(&self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let Some(debugger) = self.debugger.as_ref() else {
|
||||
// todo: show in UI.
|
||||
// todo(debugger): show in UI.
|
||||
log::error!("No debugger selected");
|
||||
return;
|
||||
};
|
||||
let config = self.debug_config(cx, debugger);
|
||||
|
||||
if let NewSessionMode::Scenario(picker) = &self.mode {
|
||||
picker.update(cx, |picker, cx| {
|
||||
picker.delegate.confirm(false, window, cx);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(config) = self.debug_config(cx, debugger) else {
|
||||
log::error!("debug config not found in mode: {}", self.mode);
|
||||
return;
|
||||
};
|
||||
|
||||
let debug_panel = self.debug_panel.clone();
|
||||
let workspace = self.workspace.clone();
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let task_contexts = workspace
|
||||
.update_in(cx, |workspace, window, cx| {
|
||||
tasks_ui::task_contexts(workspace, window, cx)
|
||||
})?
|
||||
.await;
|
||||
|
||||
let task_context = task_contexts.active_context().cloned().unwrap_or_default();
|
||||
|
||||
debug_panel.update_in(cx, |debug_panel, window, cx| {
|
||||
debug_panel.start_session(config, TaskContext::default(), None, window, cx)
|
||||
debug_panel.start_session(config, task_context, None, window, cx)
|
||||
})?;
|
||||
this.update(cx, |_, cx| {
|
||||
cx.emit(DismissEvent);
|
||||
@@ -256,9 +296,14 @@ impl NewSessionModal {
|
||||
.iter()
|
||||
.flat_map(|task_inventory| {
|
||||
task_inventory.read(cx).list_debug_scenarios(
|
||||
worktree.as_ref().map(|worktree| worktree.read(cx).id()),
|
||||
worktree
|
||||
.as_ref()
|
||||
.map(|worktree| worktree.read(cx).id())
|
||||
.iter()
|
||||
.copied(),
|
||||
)
|
||||
})
|
||||
.map(|(_source_kind, scenario)| scenario)
|
||||
.collect()
|
||||
})
|
||||
.ok()
|
||||
@@ -277,102 +322,22 @@ impl NewSessionModal {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct LaunchMode {
|
||||
program: Entity<Editor>,
|
||||
cwd: Entity<Editor>,
|
||||
}
|
||||
|
||||
impl LaunchMode {
|
||||
fn new(
|
||||
past_launch_config: Option<LaunchRequest>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Entity<Self> {
|
||||
let (past_program, past_cwd) = past_launch_config
|
||||
.map(|config| (Some(config.program), config.cwd))
|
||||
.unwrap_or_else(|| (None, None));
|
||||
|
||||
let program = cx.new(|cx| Editor::single_line(window, cx));
|
||||
program.update(cx, |this, cx| {
|
||||
this.set_placeholder_text("Program path", cx);
|
||||
|
||||
if let Some(past_program) = past_program {
|
||||
this.set_text(past_program, window, cx);
|
||||
};
|
||||
});
|
||||
let cwd = cx.new(|cx| Editor::single_line(window, cx));
|
||||
cwd.update(cx, |this, cx| {
|
||||
this.set_placeholder_text("Working Directory", cx);
|
||||
if let Some(past_cwd) = past_cwd {
|
||||
this.set_text(past_cwd.to_string_lossy(), window, cx);
|
||||
};
|
||||
});
|
||||
cx.new(|_| Self { program, cwd })
|
||||
}
|
||||
|
||||
fn debug_task(&self, cx: &App) -> task::LaunchRequest {
|
||||
let path = self.cwd.read(cx).text(cx);
|
||||
task::LaunchRequest {
|
||||
program: self.program.read(cx).text(cx),
|
||||
cwd: path.is_empty().not().then(|| PathBuf::from(path)),
|
||||
args: Default::default(),
|
||||
env: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AttachMode {
|
||||
definition: DebugTaskDefinition,
|
||||
attach_picker: Entity<AttachModal>,
|
||||
}
|
||||
|
||||
impl AttachMode {
|
||||
fn new(
|
||||
debugger: Option<SharedString>,
|
||||
workspace: Entity<Workspace>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<NewSessionModal>,
|
||||
) -> Entity<Self> {
|
||||
let definition = DebugTaskDefinition {
|
||||
adapter: debugger.clone().unwrap_or_default(),
|
||||
label: "Attach New Session Setup".into(),
|
||||
request: dap::DebugRequest::Attach(task::AttachRequest { process_id: None }),
|
||||
initialize_args: None,
|
||||
tcp_connection: None,
|
||||
stop_on_entry: Some(false),
|
||||
};
|
||||
let attach_picker = cx.new(|cx| {
|
||||
let modal = AttachModal::new(definition.clone(), workspace, false, window, cx);
|
||||
window.focus(&modal.focus_handle(cx));
|
||||
|
||||
modal
|
||||
});
|
||||
cx.new(|_| Self {
|
||||
definition,
|
||||
attach_picker,
|
||||
})
|
||||
}
|
||||
fn debug_task(&self) -> task::AttachRequest {
|
||||
task::AttachRequest { process_id: None }
|
||||
}
|
||||
}
|
||||
|
||||
static SELECT_DEBUGGER_LABEL: SharedString = SharedString::new_static("Select Debugger");
|
||||
static SELECT_SCENARIO_LABEL: SharedString = SharedString::new_static("Select Profile");
|
||||
|
||||
#[derive(Clone)]
|
||||
enum NewSessionMode {
|
||||
Launch(Entity<LaunchMode>),
|
||||
Scenario(Entity<Picker<DebugScenarioDelegate>>),
|
||||
Attach(Entity<AttachMode>),
|
||||
}
|
||||
|
||||
impl NewSessionMode {
|
||||
fn debug_task(&self, cx: &App) -> DebugRequest {
|
||||
fn debug_task(&self, cx: &App) -> Option<DebugRequest> {
|
||||
match self {
|
||||
NewSessionMode::Launch(entity) => entity.read(cx).debug_task(cx).into(),
|
||||
NewSessionMode::Attach(entity) => entity.read(cx).debug_task().into(),
|
||||
NewSessionMode::Launch(entity) => Some(entity.read(cx).debug_task(cx).into()),
|
||||
NewSessionMode::Attach(entity) => Some(entity.read(cx).debug_task().into()),
|
||||
NewSessionMode::Scenario(_) => None,
|
||||
}
|
||||
}
|
||||
fn as_attach(&self) -> Option<&Entity<AttachMode>> {
|
||||
@@ -382,6 +347,78 @@ impl NewSessionMode {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn scenario(
|
||||
debug_panel: WeakEntity<DebugPanel>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
task_store: Entity<TaskStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<NewSessionModal>,
|
||||
) -> NewSessionMode {
|
||||
let picker = cx.new(|cx| {
|
||||
Picker::uniform_list(
|
||||
DebugScenarioDelegate::new(debug_panel, workspace, task_store),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.modal(false)
|
||||
});
|
||||
|
||||
cx.subscribe(&picker, |_, _, _, cx| {
|
||||
cx.emit(DismissEvent);
|
||||
})
|
||||
.detach();
|
||||
|
||||
picker.focus_handle(cx).focus(window);
|
||||
NewSessionMode::Scenario(picker)
|
||||
}
|
||||
|
||||
fn attach(
|
||||
debugger: Option<SharedString>,
|
||||
workspace: Entity<Workspace>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<NewSessionModal>,
|
||||
) -> Self {
|
||||
Self::Attach(AttachMode::new(debugger, workspace, window, cx))
|
||||
}
|
||||
|
||||
fn launch(
|
||||
past_launch_config: Option<LaunchRequest>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<NewSessionModal>,
|
||||
) -> Self {
|
||||
Self::Launch(LaunchMode::new(past_launch_config, window, cx))
|
||||
}
|
||||
|
||||
fn has_match(&self, cx: &App) -> bool {
|
||||
match self {
|
||||
NewSessionMode::Scenario(picker) => picker.read(cx).delegate.match_count() > 0,
|
||||
NewSessionMode::Attach(picker) => {
|
||||
picker
|
||||
.read(cx)
|
||||
.attach_picker
|
||||
.read(cx)
|
||||
.picker
|
||||
.read(cx)
|
||||
.delegate
|
||||
.match_count()
|
||||
> 0
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for NewSessionMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let mode = match self {
|
||||
NewSessionMode::Launch(_) => "launch".to_owned(),
|
||||
NewSessionMode::Attach(_) => "attach".to_owned(),
|
||||
NewSessionMode::Scenario(_) => "scenario picker".to_owned(),
|
||||
};
|
||||
|
||||
write!(f, "{}", mode)
|
||||
}
|
||||
}
|
||||
|
||||
impl Focusable for NewSessionMode {
|
||||
@@ -389,6 +426,7 @@ impl Focusable for NewSessionMode {
|
||||
match &self {
|
||||
NewSessionMode::Launch(entity) => entity.read(cx).program.focus_handle(cx),
|
||||
NewSessionMode::Attach(entity) => entity.read(cx).attach_picker.focus_handle(cx),
|
||||
NewSessionMode::Scenario(entity) => entity.read(cx).focus_handle(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -437,27 +475,14 @@ impl RenderOnce for NewSessionMode {
|
||||
NewSessionMode::Attach(entity) => entity.update(cx, |this, cx| {
|
||||
this.clone().render(window, cx).into_any_element()
|
||||
}),
|
||||
NewSessionMode::Scenario(entity) => v_flex()
|
||||
.w(rems(34.))
|
||||
.child(entity.clone())
|
||||
.into_any_element(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl NewSessionMode {
|
||||
fn attach(
|
||||
debugger: Option<SharedString>,
|
||||
workspace: Entity<Workspace>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<NewSessionModal>,
|
||||
) -> Self {
|
||||
Self::Attach(AttachMode::new(debugger, workspace, window, cx))
|
||||
}
|
||||
fn launch(
|
||||
past_launch_config: Option<LaunchRequest>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<NewSessionModal>,
|
||||
) -> Self {
|
||||
Self::Launch(LaunchMode::new(past_launch_config, window, cx))
|
||||
}
|
||||
}
|
||||
fn render_editor(editor: &Entity<Editor>, window: &mut Window, cx: &App) -> impl IntoElement {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let theme = cx.theme();
|
||||
@@ -519,6 +544,34 @@ impl Render for NewSessionModal {
|
||||
h_flex()
|
||||
.justify_start()
|
||||
.w_full()
|
||||
.child(
|
||||
ToggleButton::new("debugger-session-ui-picker-button", "Scenarios")
|
||||
.size(ButtonSize::Default)
|
||||
.style(ui::ButtonStyle::Subtle)
|
||||
.toggle_state(matches!(self.mode, NewSessionMode::Scenario(_)))
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
let Some(task_store) = this
|
||||
.workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
workspace.project().read(cx).task_store().clone()
|
||||
})
|
||||
.ok()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
this.mode = NewSessionMode::scenario(
|
||||
this.debug_panel.clone(),
|
||||
this.workspace.clone(),
|
||||
task_store,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
cx.notify();
|
||||
}))
|
||||
.first(),
|
||||
)
|
||||
.child(
|
||||
ToggleButton::new(
|
||||
"debugger-session-ui-launch-button",
|
||||
@@ -532,7 +585,7 @@ impl Render for NewSessionModal {
|
||||
this.mode.focus_handle(cx).focus(window);
|
||||
cx.notify();
|
||||
}))
|
||||
.first(),
|
||||
.middle(),
|
||||
)
|
||||
.child(
|
||||
ToggleButton::new(
|
||||
@@ -601,10 +654,21 @@ impl Render for NewSessionModal {
|
||||
})
|
||||
.child(
|
||||
Button::new("debugger-spawn", "Start")
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
this.start_new_session(window, cx);
|
||||
.on_click(cx.listener(|this, _, window, cx| match &this.mode {
|
||||
NewSessionMode::Scenario(picker) => {
|
||||
picker.update(cx, |picker, cx| {
|
||||
picker.delegate.confirm(true, window, cx)
|
||||
})
|
||||
}
|
||||
_ => this.start_new_session(window, cx),
|
||||
}))
|
||||
.disabled(self.debugger.is_none()),
|
||||
.disabled(match self.mode {
|
||||
NewSessionMode::Scenario(_) => !self.mode.has_match(cx),
|
||||
NewSessionMode::Attach(_) => {
|
||||
self.debugger.is_none() || !self.mode.has_match(cx)
|
||||
}
|
||||
NewSessionMode::Launch(_) => self.debugger.is_none(),
|
||||
}),
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -619,3 +683,319 @@ impl Focusable for NewSessionModal {
|
||||
}
|
||||
|
||||
impl ModalView for NewSessionModal {}
|
||||
|
||||
// This module makes sure that the modes setup the correct subscriptions whenever they're created
|
||||
mod session_modes {
|
||||
use std::rc::Rc;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
#[non_exhaustive]
|
||||
pub(super) struct LaunchMode {
|
||||
pub(super) program: Entity<Editor>,
|
||||
pub(super) cwd: Entity<Editor>,
|
||||
}
|
||||
|
||||
impl LaunchMode {
|
||||
pub(super) fn new(
|
||||
past_launch_config: Option<LaunchRequest>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Entity<Self> {
|
||||
let (past_program, past_cwd) = past_launch_config
|
||||
.map(|config| (Some(config.program), config.cwd))
|
||||
.unwrap_or_else(|| (None, None));
|
||||
|
||||
let program = cx.new(|cx| Editor::single_line(window, cx));
|
||||
program.update(cx, |this, cx| {
|
||||
this.set_placeholder_text("Program path", cx);
|
||||
|
||||
if let Some(past_program) = past_program {
|
||||
this.set_text(past_program, window, cx);
|
||||
};
|
||||
});
|
||||
let cwd = cx.new(|cx| Editor::single_line(window, cx));
|
||||
cwd.update(cx, |this, cx| {
|
||||
this.set_placeholder_text("Working Directory", cx);
|
||||
if let Some(past_cwd) = past_cwd {
|
||||
this.set_text(past_cwd.to_string_lossy(), window, cx);
|
||||
};
|
||||
});
|
||||
cx.new(|_| Self { program, cwd })
|
||||
}
|
||||
|
||||
pub(super) fn debug_task(&self, cx: &App) -> task::LaunchRequest {
|
||||
let path = self.cwd.read(cx).text(cx);
|
||||
task::LaunchRequest {
|
||||
program: self.program.read(cx).text(cx),
|
||||
cwd: path.is_empty().not().then(|| PathBuf::from(path)),
|
||||
args: Default::default(),
|
||||
env: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(super) struct AttachMode {
|
||||
pub(super) definition: DebugTaskDefinition,
|
||||
pub(super) attach_picker: Entity<AttachModal>,
|
||||
_subscription: Rc<Subscription>,
|
||||
}
|
||||
|
||||
impl AttachMode {
|
||||
pub(super) fn new(
|
||||
debugger: Option<SharedString>,
|
||||
workspace: Entity<Workspace>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<NewSessionModal>,
|
||||
) -> Entity<Self> {
|
||||
let definition = DebugTaskDefinition {
|
||||
adapter: debugger.clone().unwrap_or_default(),
|
||||
label: "Attach New Session Setup".into(),
|
||||
request: dap::DebugRequest::Attach(task::AttachRequest { process_id: None }),
|
||||
initialize_args: None,
|
||||
tcp_connection: None,
|
||||
stop_on_entry: Some(false),
|
||||
};
|
||||
let attach_picker = cx.new(|cx| {
|
||||
let modal = AttachModal::new(definition.clone(), workspace, false, window, cx);
|
||||
window.focus(&modal.focus_handle(cx));
|
||||
|
||||
modal
|
||||
});
|
||||
|
||||
let subscription = cx.subscribe(&attach_picker, |_, _, _, cx| {
|
||||
cx.emit(DismissEvent);
|
||||
});
|
||||
|
||||
cx.new(|_| Self {
|
||||
definition,
|
||||
attach_picker,
|
||||
_subscription: Rc::new(subscription),
|
||||
})
|
||||
}
|
||||
pub(super) fn debug_task(&self) -> task::AttachRequest {
|
||||
task::AttachRequest { process_id: None }
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) struct DebugScenarioDelegate {
|
||||
task_store: Entity<TaskStore>,
|
||||
candidates: Option<Vec<(TaskSourceKind, DebugScenario)>>,
|
||||
selected_index: usize,
|
||||
matches: Vec<StringMatch>,
|
||||
prompt: String,
|
||||
debug_panel: WeakEntity<DebugPanel>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
}
|
||||
|
||||
impl DebugScenarioDelegate {
|
||||
pub(super) fn new(
|
||||
debug_panel: WeakEntity<DebugPanel>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
task_store: Entity<TaskStore>,
|
||||
) -> Self {
|
||||
Self {
|
||||
task_store,
|
||||
candidates: None,
|
||||
selected_index: 0,
|
||||
matches: Vec::new(),
|
||||
prompt: String::new(),
|
||||
debug_panel,
|
||||
workspace,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PickerDelegate for DebugScenarioDelegate {
|
||||
type ListItem = ui::ListItem;
|
||||
|
||||
fn match_count(&self) -> usize {
|
||||
self.matches.len()
|
||||
}
|
||||
|
||||
fn selected_index(&self) -> usize {
|
||||
self.selected_index
|
||||
}
|
||||
|
||||
fn set_selected_index(
|
||||
&mut self,
|
||||
ix: usize,
|
||||
_window: &mut Window,
|
||||
_cx: &mut Context<picker::Picker<Self>>,
|
||||
) {
|
||||
self.selected_index = ix;
|
||||
}
|
||||
|
||||
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> std::sync::Arc<str> {
|
||||
"".into()
|
||||
}
|
||||
|
||||
fn update_matches(
|
||||
&mut self,
|
||||
query: String,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<picker::Picker<Self>>,
|
||||
) -> gpui::Task<()> {
|
||||
let candidates: Vec<_> = match &self.candidates {
|
||||
Some(candidates) => candidates
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(index, (_, candidate))| {
|
||||
StringMatchCandidate::new(index, candidate.label.as_ref())
|
||||
})
|
||||
.collect(),
|
||||
None => {
|
||||
let worktree_ids: Vec<_> = self
|
||||
.workspace
|
||||
.update(cx, |this, cx| {
|
||||
this.visible_worktrees(cx)
|
||||
.map(|tree| tree.read(cx).id())
|
||||
.collect()
|
||||
})
|
||||
.ok()
|
||||
.unwrap_or_default();
|
||||
|
||||
let scenarios: Vec<_> = self
|
||||
.task_store
|
||||
.read(cx)
|
||||
.task_inventory()
|
||||
.map(|item| item.read(cx).list_debug_scenarios(worktree_ids.into_iter()))
|
||||
.unwrap_or_default();
|
||||
|
||||
self.candidates = Some(scenarios.clone());
|
||||
|
||||
scenarios
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(index, (_, candidate))| {
|
||||
StringMatchCandidate::new(index, candidate.label.as_ref())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
};
|
||||
|
||||
cx.spawn_in(window, async move |picker, cx| {
|
||||
let matches = fuzzy::match_strings(
|
||||
&candidates,
|
||||
&query,
|
||||
true,
|
||||
1000,
|
||||
&Default::default(),
|
||||
cx.background_executor().clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
picker
|
||||
.update(cx, |picker, _| {
|
||||
let delegate = &mut picker.delegate;
|
||||
|
||||
delegate.matches = matches;
|
||||
delegate.prompt = query;
|
||||
|
||||
if delegate.matches.is_empty() {
|
||||
delegate.selected_index = 0;
|
||||
} else {
|
||||
delegate.selected_index =
|
||||
delegate.selected_index.min(delegate.matches.len() - 1);
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
})
|
||||
}
|
||||
|
||||
fn confirm(
|
||||
&mut self,
|
||||
_: bool,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<picker::Picker<Self>>,
|
||||
) {
|
||||
let debug_scenario =
|
||||
self.matches
|
||||
.get(self.selected_index())
|
||||
.and_then(|match_candidate| {
|
||||
self.candidates
|
||||
.as_ref()
|
||||
.map(|candidates| candidates[match_candidate.candidate_id].clone())
|
||||
});
|
||||
|
||||
let Some((task_source_kind, debug_scenario)) = debug_scenario else {
|
||||
return;
|
||||
};
|
||||
|
||||
let task_context = if let TaskSourceKind::Worktree {
|
||||
id: worktree_id,
|
||||
directory_in_worktree: _,
|
||||
id_base: _,
|
||||
} = task_source_kind
|
||||
{
|
||||
let workspace = self.workspace.clone();
|
||||
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
workspace
|
||||
.update_in(cx, |workspace, window, cx| {
|
||||
tasks_ui::task_contexts(workspace, window, cx)
|
||||
})
|
||||
.ok()?
|
||||
.await
|
||||
.task_context_for_worktree_id(worktree_id)
|
||||
.cloned()
|
||||
})
|
||||
} else {
|
||||
gpui::Task::ready(None)
|
||||
};
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let task_context = task_context.await.unwrap_or_default();
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.delegate
|
||||
.debug_panel
|
||||
.update(cx, |panel, cx| {
|
||||
panel.start_session(debug_scenario, task_context, None, window, cx);
|
||||
})
|
||||
.ok();
|
||||
|
||||
cx.emit(DismissEvent);
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn dismissed(&mut self, _: &mut Window, cx: &mut Context<picker::Picker<Self>>) {
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
|
||||
fn render_match(
|
||||
&self,
|
||||
ix: usize,
|
||||
selected: bool,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<picker::Picker<Self>>,
|
||||
) -> Option<Self::ListItem> {
|
||||
let hit = &self.matches[ix];
|
||||
|
||||
let highlighted_location = HighlightedMatch {
|
||||
text: hit.string.clone(),
|
||||
highlight_positions: hit.positions.clone(),
|
||||
char_count: hit.string.chars().count(),
|
||||
color: Color::Default,
|
||||
};
|
||||
|
||||
let icon = Icon::new(IconName::FileTree)
|
||||
.color(Color::Muted)
|
||||
.size(ui::IconSize::Small);
|
||||
|
||||
Some(
|
||||
ListItem::new(SharedString::from(format!("debug-scenario-selection-{ix}")))
|
||||
.inset(true)
|
||||
.start_slot::<Icon>(icon)
|
||||
.spacing(ListItemSpacing::Sparse)
|
||||
.toggle_state(selected)
|
||||
.child(highlighted_location.render(window, cx)),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,7 +188,7 @@ impl Render for SubView {
|
||||
cx.notify();
|
||||
}))
|
||||
.size_full()
|
||||
// Add border uncoditionally to prevent layout shifts on focus changes.
|
||||
// Add border unconditionally to prevent layout shifts on focus changes.
|
||||
.border_1()
|
||||
.when(self.pane_focus_handle.contains_focused(window, cx), |el| {
|
||||
el.border_color(cx.theme().colors().pane_focused_border)
|
||||
|
||||
@@ -14,7 +14,6 @@ doctest = false
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
cargo_metadata.workspace = true
|
||||
collections.workspace = true
|
||||
component.workspace = true
|
||||
ctor.workspace = true
|
||||
@@ -23,7 +22,6 @@ env_logger.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
indoc.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
linkme.workspace = true
|
||||
log.workspace = true
|
||||
@@ -34,7 +32,6 @@ rand.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
text.workspace = true
|
||||
theme.workspace = true
|
||||
ui.workspace = true
|
||||
|
||||
@@ -1,603 +0,0 @@
|
||||
use std::{
|
||||
path::{Component, Path, Prefix},
|
||||
process::Stdio,
|
||||
sync::atomic::{self, AtomicUsize},
|
||||
};
|
||||
|
||||
use cargo_metadata::{
|
||||
Message,
|
||||
diagnostic::{Applicability, Diagnostic as CargoDiagnostic, DiagnosticLevel, DiagnosticSpan},
|
||||
};
|
||||
use collections::HashMap;
|
||||
use gpui::{AppContext, Entity, Task};
|
||||
use itertools::Itertools as _;
|
||||
use language::Diagnostic;
|
||||
use project::{
|
||||
Worktree, lsp_store::rust_analyzer_ext::CARGO_DIAGNOSTICS_SOURCE_NAME,
|
||||
project_settings::ProjectSettings,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
use smol::{
|
||||
channel::Receiver,
|
||||
io::{AsyncBufReadExt, BufReader},
|
||||
process::Command,
|
||||
};
|
||||
use ui::App;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::ProjectDiagnosticsEditor;
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum CargoMessage {
|
||||
Cargo(Message),
|
||||
Rustc(CargoDiagnostic),
|
||||
}
|
||||
|
||||
/// Appends formatted string to a `String`.
|
||||
macro_rules! format_to {
|
||||
($buf:expr) => ();
|
||||
($buf:expr, $lit:literal $($arg:tt)*) => {
|
||||
{
|
||||
use ::std::fmt::Write as _;
|
||||
// We can't do ::std::fmt::Write::write_fmt($buf, format_args!($lit $($arg)*))
|
||||
// unfortunately, as that loses out on autoref behavior.
|
||||
_ = $buf.write_fmt(format_args!($lit $($arg)*))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn cargo_diagnostics_sources(
|
||||
editor: &ProjectDiagnosticsEditor,
|
||||
cx: &App,
|
||||
) -> Vec<Entity<Worktree>> {
|
||||
let fetch_cargo_diagnostics = ProjectSettings::get_global(cx)
|
||||
.diagnostics
|
||||
.fetch_cargo_diagnostics();
|
||||
if !fetch_cargo_diagnostics {
|
||||
return Vec::new();
|
||||
}
|
||||
editor
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.filter(|worktree| worktree.read(cx).entry_for_path("Cargo.toml").is_some())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum FetchUpdate {
|
||||
Diagnostic(CargoDiagnostic),
|
||||
Progress(String),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum FetchStatus {
|
||||
Started,
|
||||
Progress { message: String },
|
||||
Finished,
|
||||
}
|
||||
|
||||
pub fn fetch_worktree_diagnostics(
|
||||
worktree_root: &Path,
|
||||
cx: &App,
|
||||
) -> Option<(Task<()>, Receiver<FetchUpdate>)> {
|
||||
let diagnostics_settings = ProjectSettings::get_global(cx)
|
||||
.diagnostics
|
||||
.cargo
|
||||
.as_ref()
|
||||
.filter(|cargo_diagnostics| cargo_diagnostics.fetch_cargo_diagnostics)?;
|
||||
let command_string = diagnostics_settings
|
||||
.diagnostics_fetch_command
|
||||
.iter()
|
||||
.join(" ");
|
||||
let mut command_parts = diagnostics_settings.diagnostics_fetch_command.iter();
|
||||
let mut command = Command::new(command_parts.next()?)
|
||||
.args(command_parts)
|
||||
.envs(diagnostics_settings.env.clone())
|
||||
.current_dir(worktree_root)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::null())
|
||||
.kill_on_drop(true)
|
||||
.spawn()
|
||||
.log_err()?;
|
||||
|
||||
let stdout = command.stdout.take()?;
|
||||
let mut reader = BufReader::new(stdout);
|
||||
let (tx, rx) = smol::channel::unbounded();
|
||||
let error_threshold = 10;
|
||||
|
||||
let cargo_diagnostics_fetch_task = cx.background_spawn(async move {
|
||||
let _command = command;
|
||||
let mut errors = 0;
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
match reader.read_line(&mut line).await {
|
||||
Ok(0) => {
|
||||
return;
|
||||
},
|
||||
Ok(_) => {
|
||||
errors = 0;
|
||||
let mut deserializer = serde_json::Deserializer::from_str(&line);
|
||||
deserializer.disable_recursion_limit();
|
||||
let send_result = match CargoMessage::deserialize(&mut deserializer) {
|
||||
Ok(CargoMessage::Cargo(Message::CompilerMessage(message))) => tx.send(FetchUpdate::Diagnostic(message.message)).await,
|
||||
Ok(CargoMessage::Cargo(Message::CompilerArtifact(artifact))) => tx.send(FetchUpdate::Progress(format!("Compiled {:?}", artifact.manifest_path.parent().unwrap_or(&artifact.manifest_path)))).await,
|
||||
Ok(CargoMessage::Cargo(_)) => Ok(()),
|
||||
Ok(CargoMessage::Rustc(rustc_message)) => tx.send(FetchUpdate::Diagnostic(rustc_message)).await,
|
||||
Err(_) => {
|
||||
log::debug!("Failed to parse cargo diagnostics from line '{line}'");
|
||||
Ok(())
|
||||
},
|
||||
};
|
||||
if send_result.is_err() {
|
||||
return;
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
log::error!("Failed to read line from {command_string} command output when fetching cargo diagnostics: {e}");
|
||||
errors += 1;
|
||||
if errors >= error_threshold {
|
||||
log::error!("Failed {error_threshold} times, aborting the diagnostics fetch");
|
||||
return;
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Some((cargo_diagnostics_fetch_task, rx))
|
||||
}
|
||||
|
||||
static CARGO_DIAGNOSTICS_FETCH_GENERATION: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
struct CargoFetchDiagnosticData {
|
||||
generation: usize,
|
||||
}
|
||||
|
||||
pub fn next_cargo_fetch_generation() {
|
||||
CARGO_DIAGNOSTICS_FETCH_GENERATION.fetch_add(1, atomic::Ordering::Release);
|
||||
}
|
||||
|
||||
pub fn is_outdated_cargo_fetch_diagnostic(diagnostic: &Diagnostic) -> bool {
|
||||
if let Some(data) = diagnostic
|
||||
.data
|
||||
.clone()
|
||||
.and_then(|data| serde_json::from_value::<CargoFetchDiagnosticData>(data).ok())
|
||||
{
|
||||
let current_generation = CARGO_DIAGNOSTICS_FETCH_GENERATION.load(atomic::Ordering::Acquire);
|
||||
data.generation < current_generation
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a Rust root diagnostic to LSP form
|
||||
///
|
||||
/// This flattens the Rust diagnostic by:
|
||||
///
|
||||
/// 1. Creating a LSP diagnostic with the root message and primary span.
|
||||
/// 2. Adding any labelled secondary spans to `relatedInformation`
|
||||
/// 3. Categorising child diagnostics as either `SuggestedFix`es,
|
||||
/// `relatedInformation` or additional message lines.
|
||||
///
|
||||
/// If the diagnostic has no primary span this will return `None`
|
||||
///
|
||||
/// Taken from https://github.com/rust-lang/rust-analyzer/blob/fe7b4f2ad96f7c13cc571f45edc2c578b35dddb4/crates/rust-analyzer/src/diagnostics/to_proto.rs#L275-L285
|
||||
pub(crate) fn map_rust_diagnostic_to_lsp(
|
||||
worktree_root: &Path,
|
||||
cargo_diagnostic: &CargoDiagnostic,
|
||||
) -> Vec<(lsp::Url, lsp::Diagnostic)> {
|
||||
let primary_spans: Vec<&DiagnosticSpan> = cargo_diagnostic
|
||||
.spans
|
||||
.iter()
|
||||
.filter(|s| s.is_primary)
|
||||
.collect();
|
||||
if primary_spans.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let severity = diagnostic_severity(cargo_diagnostic.level);
|
||||
|
||||
let mut source = String::from(CARGO_DIAGNOSTICS_SOURCE_NAME);
|
||||
let mut code = cargo_diagnostic.code.as_ref().map(|c| c.code.clone());
|
||||
|
||||
if let Some(code_val) = &code {
|
||||
// See if this is an RFC #2103 scoped lint (e.g. from Clippy)
|
||||
let scoped_code: Vec<&str> = code_val.split("::").collect();
|
||||
if scoped_code.len() == 2 {
|
||||
source = String::from(scoped_code[0]);
|
||||
code = Some(String::from(scoped_code[1]));
|
||||
}
|
||||
}
|
||||
|
||||
let mut needs_primary_span_label = true;
|
||||
let mut subdiagnostics = Vec::new();
|
||||
let mut tags = Vec::new();
|
||||
|
||||
for secondary_span in cargo_diagnostic.spans.iter().filter(|s| !s.is_primary) {
|
||||
if let Some(label) = secondary_span.label.clone() {
|
||||
subdiagnostics.push(lsp::DiagnosticRelatedInformation {
|
||||
location: location(worktree_root, secondary_span),
|
||||
message: label,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let mut message = cargo_diagnostic.message.clone();
|
||||
for child in &cargo_diagnostic.children {
|
||||
let child = map_rust_child_diagnostic(worktree_root, child);
|
||||
match child {
|
||||
MappedRustChildDiagnostic::SubDiagnostic(sub) => {
|
||||
subdiagnostics.push(sub);
|
||||
}
|
||||
MappedRustChildDiagnostic::MessageLine(message_line) => {
|
||||
format_to!(message, "\n{message_line}");
|
||||
|
||||
// These secondary messages usually duplicate the content of the
|
||||
// primary span label.
|
||||
needs_primary_span_label = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(code) = &cargo_diagnostic.code {
|
||||
let code = code.code.as_str();
|
||||
if matches!(
|
||||
code,
|
||||
"dead_code"
|
||||
| "unknown_lints"
|
||||
| "unreachable_code"
|
||||
| "unused_attributes"
|
||||
| "unused_imports"
|
||||
| "unused_macros"
|
||||
| "unused_variables"
|
||||
) {
|
||||
tags.push(lsp::DiagnosticTag::UNNECESSARY);
|
||||
}
|
||||
|
||||
if matches!(code, "deprecated") {
|
||||
tags.push(lsp::DiagnosticTag::DEPRECATED);
|
||||
}
|
||||
}
|
||||
|
||||
let code_description = match source.as_str() {
|
||||
"rustc" => rustc_code_description(code.as_deref()),
|
||||
"clippy" => clippy_code_description(code.as_deref()),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let generation = CARGO_DIAGNOSTICS_FETCH_GENERATION.load(atomic::Ordering::Acquire);
|
||||
let data = Some(
|
||||
serde_json::to_value(CargoFetchDiagnosticData { generation })
|
||||
.expect("Serializing a regular Rust struct"),
|
||||
);
|
||||
|
||||
primary_spans
|
||||
.iter()
|
||||
.flat_map(|primary_span| {
|
||||
let primary_location = primary_location(worktree_root, primary_span);
|
||||
let message = {
|
||||
let mut message = message.clone();
|
||||
if needs_primary_span_label {
|
||||
if let Some(primary_span_label) = &primary_span.label {
|
||||
format_to!(message, "\n{primary_span_label}");
|
||||
}
|
||||
}
|
||||
message
|
||||
};
|
||||
// Each primary diagnostic span may result in multiple LSP diagnostics.
|
||||
let mut diagnostics = Vec::new();
|
||||
|
||||
let mut related_info_macro_calls = vec![];
|
||||
|
||||
// If error occurs from macro expansion, add related info pointing to
|
||||
// where the error originated
|
||||
// Also, we would generate an additional diagnostic, so that exact place of macro
|
||||
// will be highlighted in the error origin place.
|
||||
let span_stack = std::iter::successors(Some(*primary_span), |span| {
|
||||
Some(&span.expansion.as_ref()?.span)
|
||||
});
|
||||
for (i, span) in span_stack.enumerate() {
|
||||
if is_dummy_macro_file(&span.file_name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// First span is the original diagnostic, others are macro call locations that
|
||||
// generated that code.
|
||||
let is_in_macro_call = i != 0;
|
||||
|
||||
let secondary_location = location(worktree_root, span);
|
||||
if secondary_location == primary_location {
|
||||
continue;
|
||||
}
|
||||
related_info_macro_calls.push(lsp::DiagnosticRelatedInformation {
|
||||
location: secondary_location.clone(),
|
||||
message: if is_in_macro_call {
|
||||
"Error originated from macro call here".to_owned()
|
||||
} else {
|
||||
"Actual error occurred here".to_owned()
|
||||
},
|
||||
});
|
||||
// For the additional in-macro diagnostic we add the inverse message pointing to the error location in code.
|
||||
let information_for_additional_diagnostic =
|
||||
vec![lsp::DiagnosticRelatedInformation {
|
||||
location: primary_location.clone(),
|
||||
message: "Exact error occurred here".to_owned(),
|
||||
}];
|
||||
|
||||
let diagnostic = lsp::Diagnostic {
|
||||
range: secondary_location.range,
|
||||
// downgrade to hint if we're pointing at the macro
|
||||
severity: Some(lsp::DiagnosticSeverity::HINT),
|
||||
code: code.clone().map(lsp::NumberOrString::String),
|
||||
code_description: code_description.clone(),
|
||||
source: Some(source.clone()),
|
||||
message: message.clone(),
|
||||
related_information: Some(information_for_additional_diagnostic),
|
||||
tags: if tags.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(tags.clone())
|
||||
},
|
||||
data: data.clone(),
|
||||
};
|
||||
diagnostics.push((secondary_location.uri, diagnostic));
|
||||
}
|
||||
|
||||
// Emit the primary diagnostic.
|
||||
diagnostics.push((
|
||||
primary_location.uri.clone(),
|
||||
lsp::Diagnostic {
|
||||
range: primary_location.range,
|
||||
severity,
|
||||
code: code.clone().map(lsp::NumberOrString::String),
|
||||
code_description: code_description.clone(),
|
||||
source: Some(source.clone()),
|
||||
message,
|
||||
related_information: {
|
||||
let info = related_info_macro_calls
|
||||
.iter()
|
||||
.cloned()
|
||||
.chain(subdiagnostics.iter().cloned())
|
||||
.collect::<Vec<_>>();
|
||||
if info.is_empty() { None } else { Some(info) }
|
||||
},
|
||||
tags: if tags.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(tags.clone())
|
||||
},
|
||||
data: data.clone(),
|
||||
},
|
||||
));
|
||||
|
||||
// Emit hint-level diagnostics for all `related_information` entries such as "help"s.
|
||||
// This is useful because they will show up in the user's editor, unlike
|
||||
// `related_information`, which just produces hard-to-read links, at least in VS Code.
|
||||
let back_ref = lsp::DiagnosticRelatedInformation {
|
||||
location: primary_location,
|
||||
message: "original diagnostic".to_owned(),
|
||||
};
|
||||
for sub in &subdiagnostics {
|
||||
diagnostics.push((
|
||||
sub.location.uri.clone(),
|
||||
lsp::Diagnostic {
|
||||
range: sub.location.range,
|
||||
severity: Some(lsp::DiagnosticSeverity::HINT),
|
||||
code: code.clone().map(lsp::NumberOrString::String),
|
||||
code_description: code_description.clone(),
|
||||
source: Some(source.clone()),
|
||||
message: sub.message.clone(),
|
||||
related_information: Some(vec![back_ref.clone()]),
|
||||
tags: None, // don't apply modifiers again
|
||||
data: data.clone(),
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
diagnostics
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn rustc_code_description(code: Option<&str>) -> Option<lsp::CodeDescription> {
|
||||
code.filter(|code| {
|
||||
let mut chars = code.chars();
|
||||
chars.next() == Some('E')
|
||||
&& chars.by_ref().take(4).all(|c| c.is_ascii_digit())
|
||||
&& chars.next().is_none()
|
||||
})
|
||||
.and_then(|code| {
|
||||
lsp::Url::parse(&format!(
|
||||
"https://doc.rust-lang.org/error-index.html#{code}"
|
||||
))
|
||||
.ok()
|
||||
.map(|href| lsp::CodeDescription { href })
|
||||
})
|
||||
}
|
||||
|
||||
fn clippy_code_description(code: Option<&str>) -> Option<lsp::CodeDescription> {
|
||||
code.and_then(|code| {
|
||||
lsp::Url::parse(&format!(
|
||||
"https://rust-lang.github.io/rust-clippy/master/index.html#{code}"
|
||||
))
|
||||
.ok()
|
||||
.map(|href| lsp::CodeDescription { href })
|
||||
})
|
||||
}
|
||||
|
||||
/// Determines the LSP severity from a diagnostic
|
||||
fn diagnostic_severity(level: DiagnosticLevel) -> Option<lsp::DiagnosticSeverity> {
|
||||
let res = match level {
|
||||
DiagnosticLevel::Ice => lsp::DiagnosticSeverity::ERROR,
|
||||
DiagnosticLevel::Error => lsp::DiagnosticSeverity::ERROR,
|
||||
DiagnosticLevel::Warning => lsp::DiagnosticSeverity::WARNING,
|
||||
DiagnosticLevel::Note => lsp::DiagnosticSeverity::INFORMATION,
|
||||
DiagnosticLevel::Help => lsp::DiagnosticSeverity::HINT,
|
||||
_ => return None,
|
||||
};
|
||||
Some(res)
|
||||
}
|
||||
|
||||
enum MappedRustChildDiagnostic {
|
||||
SubDiagnostic(lsp::DiagnosticRelatedInformation),
|
||||
MessageLine(String),
|
||||
}
|
||||
|
||||
fn map_rust_child_diagnostic(
|
||||
worktree_root: &Path,
|
||||
cargo_diagnostic: &CargoDiagnostic,
|
||||
) -> MappedRustChildDiagnostic {
|
||||
let spans: Vec<&DiagnosticSpan> = cargo_diagnostic
|
||||
.spans
|
||||
.iter()
|
||||
.filter(|s| s.is_primary)
|
||||
.collect();
|
||||
if spans.is_empty() {
|
||||
// `rustc` uses these spanless children as a way to print multi-line
|
||||
// messages
|
||||
return MappedRustChildDiagnostic::MessageLine(cargo_diagnostic.message.clone());
|
||||
}
|
||||
|
||||
let mut edit_map: HashMap<lsp::Url, Vec<lsp::TextEdit>> = HashMap::default();
|
||||
let mut suggested_replacements = Vec::new();
|
||||
for &span in &spans {
|
||||
if let Some(suggested_replacement) = &span.suggested_replacement {
|
||||
if !suggested_replacement.is_empty() {
|
||||
suggested_replacements.push(suggested_replacement);
|
||||
}
|
||||
let location = location(worktree_root, span);
|
||||
let edit = lsp::TextEdit::new(location.range, suggested_replacement.clone());
|
||||
|
||||
// Only actually emit a quickfix if the suggestion is "valid enough".
|
||||
// We accept both "MaybeIncorrect" and "MachineApplicable". "MaybeIncorrect" means that
|
||||
// the suggestion is *complete* (contains no placeholders where code needs to be
|
||||
// inserted), but might not be what the user wants, or might need minor adjustments.
|
||||
if matches!(
|
||||
span.suggestion_applicability,
|
||||
None | Some(Applicability::MaybeIncorrect | Applicability::MachineApplicable)
|
||||
) {
|
||||
edit_map.entry(location.uri).or_default().push(edit);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// rustc renders suggestion diagnostics by appending the suggested replacement, so do the same
|
||||
// here, otherwise the diagnostic text is missing useful information.
|
||||
let mut message = cargo_diagnostic.message.clone();
|
||||
if !suggested_replacements.is_empty() {
|
||||
message.push_str(": ");
|
||||
let suggestions = suggested_replacements
|
||||
.iter()
|
||||
.map(|suggestion| format!("`{suggestion}`"))
|
||||
.join(", ");
|
||||
message.push_str(&suggestions);
|
||||
}
|
||||
|
||||
MappedRustChildDiagnostic::SubDiagnostic(lsp::DiagnosticRelatedInformation {
|
||||
location: location(worktree_root, spans[0]),
|
||||
message,
|
||||
})
|
||||
}
|
||||
|
||||
/// Converts a Rust span to a LSP location
|
||||
fn location(worktree_root: &Path, span: &DiagnosticSpan) -> lsp::Location {
|
||||
let file_name = worktree_root.join(&span.file_name);
|
||||
let uri = url_from_abs_path(&file_name);
|
||||
|
||||
let range = {
|
||||
lsp::Range::new(
|
||||
position(span, span.line_start, span.column_start.saturating_sub(1)),
|
||||
position(span, span.line_end, span.column_end.saturating_sub(1)),
|
||||
)
|
||||
};
|
||||
lsp::Location::new(uri, range)
|
||||
}
|
||||
|
||||
/// Returns a `Url` object from a given path, will lowercase drive letters if present.
|
||||
/// This will only happen when processing windows paths.
|
||||
///
|
||||
/// When processing non-windows path, this is essentially the same as `Url::from_file_path`.
|
||||
pub(crate) fn url_from_abs_path(path: &Path) -> lsp::Url {
|
||||
let url = lsp::Url::from_file_path(path).unwrap();
|
||||
match path.components().next() {
|
||||
Some(Component::Prefix(prefix))
|
||||
if matches!(prefix.kind(), Prefix::Disk(_) | Prefix::VerbatimDisk(_)) =>
|
||||
{
|
||||
// Need to lowercase driver letter
|
||||
}
|
||||
_ => return url,
|
||||
}
|
||||
|
||||
let driver_letter_range = {
|
||||
let (scheme, drive_letter, _rest) = match url.as_str().splitn(3, ':').collect_tuple() {
|
||||
Some(it) => it,
|
||||
None => return url,
|
||||
};
|
||||
let start = scheme.len() + ':'.len_utf8();
|
||||
start..(start + drive_letter.len())
|
||||
};
|
||||
|
||||
// Note: lowercasing the `path` itself doesn't help, the `Url::parse`
|
||||
// machinery *also* canonicalizes the drive letter. So, just massage the
|
||||
// string in place.
|
||||
let mut url: String = url.into();
|
||||
url[driver_letter_range].make_ascii_lowercase();
|
||||
lsp::Url::parse(&url).unwrap()
|
||||
}
|
||||
|
||||
fn position(
|
||||
span: &DiagnosticSpan,
|
||||
line_number: usize,
|
||||
column_offset_utf32: usize,
|
||||
) -> lsp::Position {
|
||||
let line_index = line_number - span.line_start;
|
||||
|
||||
let column_offset_encoded = match span.text.get(line_index) {
|
||||
// Fast path.
|
||||
Some(line) if line.text.is_ascii() => column_offset_utf32,
|
||||
Some(line) => {
|
||||
let line_prefix_len = line
|
||||
.text
|
||||
.char_indices()
|
||||
.take(column_offset_utf32)
|
||||
.last()
|
||||
.map(|(pos, c)| pos + c.len_utf8())
|
||||
.unwrap_or(0);
|
||||
let line_prefix = &line.text[..line_prefix_len];
|
||||
line_prefix.len()
|
||||
}
|
||||
None => column_offset_utf32,
|
||||
};
|
||||
|
||||
lsp::Position {
|
||||
line: (line_number as u32).saturating_sub(1),
|
||||
character: column_offset_encoded as u32,
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks whether a file name is from macro invocation and does not refer to an actual file.
|
||||
fn is_dummy_macro_file(file_name: &str) -> bool {
|
||||
file_name.starts_with('<') && file_name.ends_with('>')
|
||||
}
|
||||
|
||||
/// Extracts a suitable "primary" location from a rustc diagnostic.
|
||||
///
|
||||
/// This takes locations pointing into the standard library, or generally outside the current
|
||||
/// workspace into account and tries to avoid those, in case macros are involved.
|
||||
fn primary_location(worktree_root: &Path, span: &DiagnosticSpan) -> lsp::Location {
|
||||
let span_stack = std::iter::successors(Some(span), |span| Some(&span.expansion.as_ref()?.span));
|
||||
for span in span_stack.clone() {
|
||||
let abs_path = worktree_root.join(&span.file_name);
|
||||
if !is_dummy_macro_file(&span.file_name) && abs_path.starts_with(worktree_root) {
|
||||
return location(worktree_root, span);
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to the outermost macro invocation if no suitable span comes up.
|
||||
let last_span = span_stack.last().unwrap();
|
||||
location(worktree_root, last_span)
|
||||
}
|
||||
@@ -1,4 +1,3 @@
|
||||
mod cargo;
|
||||
pub mod items;
|
||||
mod toolbar_controls;
|
||||
|
||||
@@ -8,18 +7,14 @@ mod diagnostic_renderer;
|
||||
mod diagnostics_tests;
|
||||
|
||||
use anyhow::Result;
|
||||
use cargo::{
|
||||
FetchStatus, FetchUpdate, cargo_diagnostics_sources, fetch_worktree_diagnostics,
|
||||
is_outdated_cargo_fetch_diagnostic, map_rust_diagnostic_to_lsp, next_cargo_fetch_generation,
|
||||
url_from_abs_path,
|
||||
};
|
||||
use collections::{BTreeSet, HashMap, HashSet};
|
||||
use collections::{BTreeSet, HashMap};
|
||||
use diagnostic_renderer::DiagnosticBlock;
|
||||
use editor::{
|
||||
DEFAULT_MULTIBUFFER_CONTEXT, Editor, EditorEvent, ExcerptRange, MultiBuffer, PathKey,
|
||||
display_map::{BlockPlacement, BlockProperties, BlockStyle, CustomBlockId},
|
||||
scroll::Autoscroll,
|
||||
};
|
||||
use futures::future::join_all;
|
||||
use gpui::{
|
||||
AnyElement, AnyView, App, AsyncApp, Context, Entity, EventEmitter, FocusHandle, Focusable,
|
||||
Global, InteractiveElement, IntoElement, ParentElement, Render, SharedString, Styled,
|
||||
@@ -28,10 +23,10 @@ use gpui::{
|
||||
use language::{
|
||||
Bias, Buffer, BufferRow, BufferSnapshot, DiagnosticEntry, Point, ToTreeSitterPoint,
|
||||
};
|
||||
use lsp::{DiagnosticSeverity, LanguageServerId};
|
||||
use lsp::DiagnosticSeverity;
|
||||
use project::{
|
||||
DiagnosticSummary, Project, ProjectPath, Worktree,
|
||||
lsp_store::rust_analyzer_ext::{CARGO_DIAGNOSTICS_SOURCE_NAME, RUST_ANALYZER_NAME},
|
||||
DiagnosticSummary, Project, ProjectPath,
|
||||
lsp_store::rust_analyzer_ext::{cancel_flycheck, run_flycheck},
|
||||
project_settings::ProjectSettings,
|
||||
};
|
||||
use settings::Settings;
|
||||
@@ -84,8 +79,9 @@ pub(crate) struct ProjectDiagnosticsEditor {
|
||||
}
|
||||
|
||||
struct CargoDiagnosticsFetchState {
|
||||
task: Option<Task<()>>,
|
||||
rust_analyzer: Option<LanguageServerId>,
|
||||
fetch_task: Option<Task<()>>,
|
||||
cancel_task: Option<Task<()>>,
|
||||
diagnostic_sources: Arc<Vec<ProjectPath>>,
|
||||
}
|
||||
|
||||
impl EventEmitter<EditorEvent> for ProjectDiagnosticsEditor {}
|
||||
@@ -252,8 +248,9 @@ impl ProjectDiagnosticsEditor {
|
||||
paths_to_update: Default::default(),
|
||||
update_excerpts_task: None,
|
||||
cargo_diagnostics_fetch: CargoDiagnosticsFetchState {
|
||||
task: None,
|
||||
rust_analyzer: None,
|
||||
fetch_task: None,
|
||||
cancel_task: None,
|
||||
diagnostic_sources: Arc::new(Vec::new()),
|
||||
},
|
||||
_subscription: project_event_subscription,
|
||||
};
|
||||
@@ -346,7 +343,7 @@ impl ProjectDiagnosticsEditor {
|
||||
.fetch_cargo_diagnostics();
|
||||
|
||||
if fetch_cargo_diagnostics {
|
||||
if self.cargo_diagnostics_fetch.task.is_some() {
|
||||
if self.cargo_diagnostics_fetch.fetch_task.is_some() {
|
||||
self.stop_cargo_diagnostics_fetch(cx);
|
||||
} else {
|
||||
self.update_all_diagnostics(window, cx);
|
||||
@@ -375,300 +372,63 @@ impl ProjectDiagnosticsEditor {
|
||||
}
|
||||
|
||||
fn update_all_diagnostics(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let cargo_diagnostics_sources = cargo_diagnostics_sources(self, cx);
|
||||
let cargo_diagnostics_sources = self.cargo_diagnostics_sources(cx);
|
||||
if cargo_diagnostics_sources.is_empty() {
|
||||
self.update_all_excerpts(window, cx);
|
||||
} else {
|
||||
self.fetch_cargo_diagnostics(Arc::new(cargo_diagnostics_sources), window, cx);
|
||||
self.fetch_cargo_diagnostics(Arc::new(cargo_diagnostics_sources), cx);
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch_cargo_diagnostics(
|
||||
&mut self,
|
||||
diagnostics_sources: Arc<Vec<Entity<Worktree>>>,
|
||||
window: &mut Window,
|
||||
diagnostics_sources: Arc<Vec<ProjectPath>>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.cargo_diagnostics_fetch.task = Some(cx.spawn_in(window, async move |editor, cx| {
|
||||
let rust_analyzer_server = editor
|
||||
.update(cx, |editor, cx| {
|
||||
editor
|
||||
.project
|
||||
.read(cx)
|
||||
.language_server_with_name(RUST_ANALYZER_NAME, cx)
|
||||
})
|
||||
.ok();
|
||||
let rust_analyzer_server = match rust_analyzer_server {
|
||||
Some(rust_analyzer_server) => rust_analyzer_server.await,
|
||||
None => None,
|
||||
};
|
||||
let project = self.project.clone();
|
||||
self.cargo_diagnostics_fetch.cancel_task = None;
|
||||
self.cargo_diagnostics_fetch.fetch_task = None;
|
||||
self.cargo_diagnostics_fetch.diagnostic_sources = diagnostics_sources.clone();
|
||||
if self.cargo_diagnostics_fetch.diagnostic_sources.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut worktree_diagnostics_tasks = Vec::new();
|
||||
let mut paths_with_reported_cargo_diagnostics = HashSet::default();
|
||||
if let Some(rust_analyzer_server) = rust_analyzer_server {
|
||||
let can_continue = editor
|
||||
.update(cx, |editor, cx| {
|
||||
editor.cargo_diagnostics_fetch.rust_analyzer = Some(rust_analyzer_server);
|
||||
let status_inserted =
|
||||
editor
|
||||
.project
|
||||
.read(cx)
|
||||
.lsp_store()
|
||||
.update(cx, |lsp_store, cx| {
|
||||
if let Some(rust_analyzer_status) = lsp_store
|
||||
.language_server_statuses
|
||||
.get_mut(&rust_analyzer_server)
|
||||
{
|
||||
rust_analyzer_status
|
||||
.progress_tokens
|
||||
.insert(fetch_cargo_diagnostics_token());
|
||||
paths_with_reported_cargo_diagnostics.extend(editor.diagnostics.iter().filter_map(|(buffer_id, diagnostics)| {
|
||||
if diagnostics.iter().any(|d| d.diagnostic.source.as_deref() == Some(CARGO_DIAGNOSTICS_SOURCE_NAME)) {
|
||||
Some(*buffer_id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}).filter_map(|buffer_id| {
|
||||
let buffer = lsp_store.buffer_store().read(cx).get(buffer_id)?;
|
||||
let path = buffer.read(cx).file()?.as_local()?.abs_path(cx);
|
||||
Some(url_from_abs_path(&path))
|
||||
}));
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
if status_inserted {
|
||||
editor.update_cargo_fetch_status(FetchStatus::Started, cx);
|
||||
next_cargo_fetch_generation();
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
self.cargo_diagnostics_fetch.fetch_task = Some(cx.spawn(async move |editor, cx| {
|
||||
let mut fetch_tasks = Vec::new();
|
||||
for buffer_path in diagnostics_sources.iter().cloned() {
|
||||
if cx
|
||||
.update(|cx| {
|
||||
fetch_tasks.push(run_flycheck(project.clone(), buffer_path, cx));
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
if can_continue {
|
||||
for worktree in diagnostics_sources.iter() {
|
||||
if let Some(((_task, worktree_diagnostics), worktree_root)) = cx
|
||||
.update(|_, cx| {
|
||||
let worktree_root = worktree.read(cx).abs_path();
|
||||
log::info!("Fetching cargo diagnostics for {worktree_root:?}");
|
||||
fetch_worktree_diagnostics(&worktree_root, cx)
|
||||
.zip(Some(worktree_root))
|
||||
})
|
||||
.ok()
|
||||
.flatten()
|
||||
{
|
||||
let editor = editor.clone();
|
||||
worktree_diagnostics_tasks.push(cx.spawn(async move |cx| {
|
||||
let _task = _task;
|
||||
let mut file_diagnostics = HashMap::default();
|
||||
let mut diagnostics_total = 0;
|
||||
let mut updated_urls = HashSet::default();
|
||||
while let Ok(fetch_update) = worktree_diagnostics.recv().await {
|
||||
match fetch_update {
|
||||
FetchUpdate::Diagnostic(diagnostic) => {
|
||||
for (url, diagnostic) in map_rust_diagnostic_to_lsp(
|
||||
&worktree_root,
|
||||
&diagnostic,
|
||||
) {
|
||||
let file_diagnostics = file_diagnostics
|
||||
.entry(url)
|
||||
.or_insert_with(Vec::<lsp::Diagnostic>::new);
|
||||
let i = file_diagnostics
|
||||
.binary_search_by(|probe| {
|
||||
probe.range.start.cmp(&diagnostic.range.start)
|
||||
.then(probe.range.end.cmp(&diagnostic.range.end))
|
||||
.then(Ordering::Greater)
|
||||
})
|
||||
.unwrap_or_else(|i| i);
|
||||
file_diagnostics.insert(i, diagnostic);
|
||||
}
|
||||
|
||||
let file_changed = file_diagnostics.len() > 1;
|
||||
if file_changed {
|
||||
if editor
|
||||
.update_in(cx, |editor, window, cx| {
|
||||
editor
|
||||
.project
|
||||
.read(cx)
|
||||
.lsp_store()
|
||||
.update(cx, |lsp_store, cx| {
|
||||
for (uri, mut diagnostics) in
|
||||
file_diagnostics.drain()
|
||||
{
|
||||
diagnostics.dedup();
|
||||
diagnostics_total += diagnostics.len();
|
||||
updated_urls.insert(uri.clone());
|
||||
|
||||
lsp_store.merge_diagnostics(
|
||||
rust_analyzer_server,
|
||||
lsp::PublishDiagnosticsParams {
|
||||
uri,
|
||||
diagnostics,
|
||||
version: None,
|
||||
},
|
||||
&[],
|
||||
|diagnostic, _| {
|
||||
!is_outdated_cargo_fetch_diagnostic(diagnostic)
|
||||
},
|
||||
cx,
|
||||
)?;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})?;
|
||||
editor.update_all_excerpts(window, cx);
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.ok()
|
||||
.transpose()
|
||||
.ok()
|
||||
.flatten()
|
||||
.is_none()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
FetchUpdate::Progress(message) => {
|
||||
if editor
|
||||
.update(cx, |editor, cx| {
|
||||
editor.update_cargo_fetch_status(
|
||||
FetchStatus::Progress { message },
|
||||
cx,
|
||||
);
|
||||
})
|
||||
.is_err()
|
||||
{
|
||||
return updated_urls;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
editor
|
||||
.update_in(cx, |editor, window, cx| {
|
||||
editor
|
||||
.project
|
||||
.read(cx)
|
||||
.lsp_store()
|
||||
.update(cx, |lsp_store, cx| {
|
||||
for (uri, mut diagnostics) in
|
||||
file_diagnostics.drain()
|
||||
{
|
||||
diagnostics.dedup();
|
||||
diagnostics_total += diagnostics.len();
|
||||
updated_urls.insert(uri.clone());
|
||||
|
||||
lsp_store.merge_diagnostics(
|
||||
rust_analyzer_server,
|
||||
lsp::PublishDiagnosticsParams {
|
||||
uri,
|
||||
diagnostics,
|
||||
version: None,
|
||||
},
|
||||
&[],
|
||||
|diagnostic, _| {
|
||||
!is_outdated_cargo_fetch_diagnostic(diagnostic)
|
||||
},
|
||||
cx,
|
||||
)?;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})?;
|
||||
editor.update_all_excerpts(window, cx);
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.ok();
|
||||
log::info!("Fetched {diagnostics_total} cargo diagnostics for worktree {worktree_root:?}");
|
||||
updated_urls
|
||||
}));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log::info!(
|
||||
"No rust-analyzer language server found, skipping diagnostics fetch"
|
||||
);
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let updated_urls = futures::future::join_all(worktree_diagnostics_tasks).await.into_iter().flatten().collect();
|
||||
if let Some(rust_analyzer_server) = rust_analyzer_server {
|
||||
let _ = join_all(fetch_tasks).await;
|
||||
editor
|
||||
.update_in(cx, |editor, window, cx| {
|
||||
editor
|
||||
.project
|
||||
.read(cx)
|
||||
.lsp_store()
|
||||
.update(cx, |lsp_store, cx| {
|
||||
for uri_to_cleanup in paths_with_reported_cargo_diagnostics.difference(&updated_urls).cloned() {
|
||||
lsp_store.merge_diagnostics(
|
||||
rust_analyzer_server,
|
||||
lsp::PublishDiagnosticsParams {
|
||||
uri: uri_to_cleanup,
|
||||
diagnostics: Vec::new(),
|
||||
version: None,
|
||||
},
|
||||
&[],
|
||||
|diagnostic, _| {
|
||||
!is_outdated_cargo_fetch_diagnostic(diagnostic)
|
||||
},
|
||||
cx,
|
||||
).ok();
|
||||
}
|
||||
});
|
||||
editor.update_all_excerpts(window, cx);
|
||||
|
||||
editor.stop_cargo_diagnostics_fetch(cx);
|
||||
cx.notify();
|
||||
.update(cx, |editor, _| {
|
||||
editor.cargo_diagnostics_fetch.fetch_task = None;
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
fn update_cargo_fetch_status(&self, status: FetchStatus, cx: &mut App) {
|
||||
let Some(rust_analyzer) = self.cargo_diagnostics_fetch.rust_analyzer else {
|
||||
return;
|
||||
};
|
||||
|
||||
let work_done = match status {
|
||||
FetchStatus::Started => lsp::WorkDoneProgress::Begin(lsp::WorkDoneProgressBegin {
|
||||
title: "cargo".to_string(),
|
||||
cancellable: None,
|
||||
message: Some("Fetching cargo diagnostics".to_string()),
|
||||
percentage: None,
|
||||
}),
|
||||
FetchStatus::Progress { message } => {
|
||||
lsp::WorkDoneProgress::Report(lsp::WorkDoneProgressReport {
|
||||
message: Some(message),
|
||||
cancellable: None,
|
||||
percentage: None,
|
||||
})
|
||||
}
|
||||
FetchStatus::Finished => {
|
||||
lsp::WorkDoneProgress::End(lsp::WorkDoneProgressEnd { message: None })
|
||||
}
|
||||
};
|
||||
let progress = lsp::ProgressParams {
|
||||
token: lsp::NumberOrString::String(fetch_cargo_diagnostics_token()),
|
||||
value: lsp::ProgressParamsValue::WorkDone(work_done),
|
||||
};
|
||||
|
||||
self.project
|
||||
.read(cx)
|
||||
.lsp_store()
|
||||
.update(cx, |lsp_store, cx| {
|
||||
lsp_store.on_lsp_progress(progress, rust_analyzer, None, cx)
|
||||
});
|
||||
}
|
||||
|
||||
fn stop_cargo_diagnostics_fetch(&mut self, cx: &mut App) {
|
||||
self.update_cargo_fetch_status(FetchStatus::Finished, cx);
|
||||
self.cargo_diagnostics_fetch.task = None;
|
||||
log::info!("Finished fetching cargo diagnostics");
|
||||
self.cargo_diagnostics_fetch.fetch_task = None;
|
||||
let mut cancel_gasks = Vec::new();
|
||||
for buffer_path in std::mem::take(&mut self.cargo_diagnostics_fetch.diagnostic_sources)
|
||||
.iter()
|
||||
.cloned()
|
||||
{
|
||||
cancel_gasks.push(cancel_flycheck(self.project.clone(), buffer_path, cx));
|
||||
}
|
||||
|
||||
self.cargo_diagnostics_fetch.cancel_task = Some(cx.background_spawn(async move {
|
||||
let _ = join_all(cancel_gasks).await;
|
||||
log::info!("Finished fetching cargo diagnostics");
|
||||
}));
|
||||
}
|
||||
|
||||
/// Enqueue an update of all excerpts. Updates all paths that either
|
||||
@@ -897,6 +657,30 @@ impl ProjectDiagnosticsEditor {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn cargo_diagnostics_sources(&self, cx: &App) -> Vec<ProjectPath> {
|
||||
let fetch_cargo_diagnostics = ProjectSettings::get_global(cx)
|
||||
.diagnostics
|
||||
.fetch_cargo_diagnostics();
|
||||
if !fetch_cargo_diagnostics {
|
||||
return Vec::new();
|
||||
}
|
||||
self.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.filter_map(|worktree| {
|
||||
let _cargo_toml_entry = worktree.read(cx).entry_for_path("Cargo.toml")?;
|
||||
let rust_file_entry = worktree.read(cx).entries(false, 0).find(|entry| {
|
||||
entry
|
||||
.path
|
||||
.extension()
|
||||
.and_then(|extension| extension.to_str())
|
||||
== Some("rs")
|
||||
})?;
|
||||
self.project.read(cx).path_for_entry(rust_file_entry.id, cx)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Focusable for ProjectDiagnosticsEditor {
|
||||
@@ -1286,7 +1070,3 @@ fn is_line_blank_or_indented_less(
|
||||
let line_indent = snapshot.line_indent_for_row(row);
|
||||
line_indent.is_line_blank() || line_indent.len(tab_size) < indent_level
|
||||
}
|
||||
|
||||
fn fetch_cargo_diagnostics_token() -> String {
|
||||
"fetch_cargo_diagnostics".to_string()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::cargo::cargo_diagnostics_sources;
|
||||
use crate::{ProjectDiagnosticsEditor, ToggleDiagnosticsRefresh};
|
||||
use gpui::{Context, Entity, EventEmitter, ParentElement, Render, WeakEntity, Window};
|
||||
use ui::prelude::*;
|
||||
@@ -16,11 +15,9 @@ impl Render for ToolbarControls {
|
||||
let mut include_warnings = false;
|
||||
let mut has_stale_excerpts = false;
|
||||
let mut is_updating = false;
|
||||
let cargo_diagnostics_sources = Arc::new(
|
||||
self.diagnostics()
|
||||
.map(|editor| cargo_diagnostics_sources(editor.read(cx), cx))
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
let cargo_diagnostics_sources = Arc::new(self.diagnostics().map_or(Vec::new(), |editor| {
|
||||
editor.read(cx).cargo_diagnostics_sources(cx)
|
||||
}));
|
||||
let fetch_cargo_diagnostics = !cargo_diagnostics_sources.is_empty();
|
||||
|
||||
if let Some(editor) = self.diagnostics() {
|
||||
@@ -28,7 +25,7 @@ impl Render for ToolbarControls {
|
||||
include_warnings = diagnostics.include_warnings;
|
||||
has_stale_excerpts = !diagnostics.paths_to_update.is_empty();
|
||||
is_updating = if fetch_cargo_diagnostics {
|
||||
diagnostics.cargo_diagnostics_fetch.task.is_some()
|
||||
diagnostics.cargo_diagnostics_fetch.fetch_task.is_some()
|
||||
} else {
|
||||
diagnostics.update_excerpts_task.is_some()
|
||||
|| diagnostics
|
||||
@@ -93,7 +90,6 @@ impl Render for ToolbarControls {
|
||||
if fetch_cargo_diagnostics {
|
||||
diagnostics.fetch_cargo_diagnostics(
|
||||
cargo_diagnostics_sources,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
} else {
|
||||
|
||||
@@ -249,7 +249,9 @@ actions!(
|
||||
ApplyDiffHunk,
|
||||
Backspace,
|
||||
Cancel,
|
||||
CancelFlycheck,
|
||||
CancelLanguageServerWork,
|
||||
ClearFlycheck,
|
||||
ConfirmRename,
|
||||
ConfirmCompletionInsert,
|
||||
ConfirmCompletionReplace,
|
||||
@@ -308,6 +310,7 @@ actions!(
|
||||
GoToImplementation,
|
||||
GoToImplementationSplit,
|
||||
GoToNextChange,
|
||||
GoToParentModule,
|
||||
GoToPreviousChange,
|
||||
GoToPreviousDiagnostic,
|
||||
GoToTypeDefinition,
|
||||
@@ -371,6 +374,7 @@ actions!(
|
||||
RevertFile,
|
||||
ReloadFile,
|
||||
Rewrap,
|
||||
RunFlycheck,
|
||||
ScrollCursorBottom,
|
||||
ScrollCursorCenter,
|
||||
ScrollCursorCenterTopBottom,
|
||||
|
||||
@@ -6874,7 +6874,12 @@ impl Element for EditorElement {
|
||||
// The max scroll position for the top of the window
|
||||
let max_scroll_top = if matches!(
|
||||
snapshot.mode,
|
||||
EditorMode::AutoHeight { .. } | EditorMode::SingleLine { .. }
|
||||
EditorMode::SingleLine { .. }
|
||||
| EditorMode::AutoHeight { .. }
|
||||
| EditorMode::Full {
|
||||
sized_by_content: true,
|
||||
..
|
||||
}
|
||||
) {
|
||||
(max_row - height_in_lines + 1.).max(0.)
|
||||
} else {
|
||||
|
||||
@@ -4,15 +4,20 @@ use anyhow::Context as _;
|
||||
use gpui::{App, AppContext as _, Context, Entity, Window};
|
||||
use language::{Capability, Language, proto::serialize_anchor};
|
||||
use multi_buffer::MultiBuffer;
|
||||
use project::lsp_store::{
|
||||
lsp_ext_command::{DocsUrls, ExpandMacro, ExpandedMacro},
|
||||
rust_analyzer_ext::RUST_ANALYZER_NAME,
|
||||
use project::{
|
||||
ProjectItem,
|
||||
lsp_command::location_link_from_proto,
|
||||
lsp_store::{
|
||||
lsp_ext_command::{DocsUrls, ExpandMacro, ExpandedMacro},
|
||||
rust_analyzer_ext::{RUST_ANALYZER_NAME, cancel_flycheck, clear_flycheck, run_flycheck},
|
||||
},
|
||||
};
|
||||
use rpc::proto;
|
||||
use text::ToPointUtf16;
|
||||
|
||||
use crate::{
|
||||
Editor, ExpandMacroRecursively, OpenDocs, element::register_action,
|
||||
CancelFlycheck, ClearFlycheck, Editor, ExpandMacroRecursively, GoToParentModule,
|
||||
GotoDefinitionKind, OpenDocs, RunFlycheck, element::register_action, hover_links::HoverLink,
|
||||
lsp_ext::find_specific_language_server_in_selection,
|
||||
};
|
||||
|
||||
@@ -30,11 +35,97 @@ pub fn apply_related_actions(editor: &Entity<Editor>, window: &mut Window, cx: &
|
||||
.filter_map(|buffer| buffer.read(cx).language())
|
||||
.any(|language| is_rust_language(language))
|
||||
{
|
||||
register_action(&editor, window, go_to_parent_module);
|
||||
register_action(&editor, window, expand_macro_recursively);
|
||||
register_action(&editor, window, open_docs);
|
||||
register_action(&editor, window, cancel_flycheck_action);
|
||||
register_action(&editor, window, run_flycheck_action);
|
||||
register_action(&editor, window, clear_flycheck_action);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn go_to_parent_module(
|
||||
editor: &mut Editor,
|
||||
_: &GoToParentModule,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Editor>,
|
||||
) {
|
||||
if editor.selections.count() == 0 {
|
||||
return;
|
||||
}
|
||||
let Some(project) = &editor.project else {
|
||||
return;
|
||||
};
|
||||
|
||||
let server_lookup = find_specific_language_server_in_selection(
|
||||
editor,
|
||||
cx,
|
||||
is_rust_language,
|
||||
RUST_ANALYZER_NAME,
|
||||
);
|
||||
|
||||
let project = project.clone();
|
||||
let lsp_store = project.read(cx).lsp_store();
|
||||
let upstream_client = lsp_store.read(cx).upstream_client();
|
||||
cx.spawn_in(window, async move |editor, cx| {
|
||||
let Some((trigger_anchor, _, server_to_query, buffer)) = server_lookup.await else {
|
||||
return anyhow::Ok(());
|
||||
};
|
||||
|
||||
let location_links = if let Some((client, project_id)) = upstream_client {
|
||||
let buffer_id = buffer.update(cx, |buffer, _| buffer.remote_id())?;
|
||||
|
||||
let request = proto::LspExtGoToParentModule {
|
||||
project_id,
|
||||
buffer_id: buffer_id.to_proto(),
|
||||
position: Some(serialize_anchor(&trigger_anchor.text_anchor)),
|
||||
};
|
||||
let response = client
|
||||
.request(request)
|
||||
.await
|
||||
.context("lsp ext go to parent module proto request")?;
|
||||
futures::future::join_all(
|
||||
response
|
||||
.links
|
||||
.into_iter()
|
||||
.map(|link| location_link_from_proto(link, lsp_store.clone(), cx)),
|
||||
)
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<anyhow::Result<_>>()
|
||||
.context("go to parent module via collab")?
|
||||
} else {
|
||||
let buffer_snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
|
||||
let position = trigger_anchor.text_anchor.to_point_utf16(&buffer_snapshot);
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
project.request_lsp(
|
||||
buffer,
|
||||
project::LanguageServerToQuery::Other(server_to_query),
|
||||
project::lsp_store::lsp_ext_command::GoToParentModule { position },
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await
|
||||
.context("go to parent module")?
|
||||
};
|
||||
|
||||
editor
|
||||
.update_in(cx, |editor, window, cx| {
|
||||
editor.navigate_to_hover_links(
|
||||
Some(GotoDefinitionKind::Declaration),
|
||||
location_links.into_iter().map(HoverLink::Text).collect(),
|
||||
false,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
pub fn expand_macro_recursively(
|
||||
editor: &mut Editor,
|
||||
_: &ExpandMacroRecursively,
|
||||
@@ -213,3 +304,87 @@ pub fn open_docs(editor: &mut Editor, _: &OpenDocs, window: &mut Window, cx: &mu
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn cancel_flycheck_action(
|
||||
editor: &mut Editor,
|
||||
_: &CancelFlycheck,
|
||||
_: &mut Window,
|
||||
cx: &mut Context<Editor>,
|
||||
) {
|
||||
let Some(project) = &editor.project else {
|
||||
return;
|
||||
};
|
||||
let Some(buffer_id) = editor
|
||||
.selections
|
||||
.disjoint_anchors()
|
||||
.iter()
|
||||
.find_map(|selection| {
|
||||
let buffer_id = selection.start.buffer_id.or(selection.end.buffer_id)?;
|
||||
let project = project.read(cx);
|
||||
let entry_id = project
|
||||
.buffer_for_id(buffer_id, cx)?
|
||||
.read(cx)
|
||||
.entry_id(cx)?;
|
||||
project.path_for_entry(entry_id, cx)
|
||||
})
|
||||
else {
|
||||
return;
|
||||
};
|
||||
cancel_flycheck(project.clone(), buffer_id, cx).detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn run_flycheck_action(
|
||||
editor: &mut Editor,
|
||||
_: &RunFlycheck,
|
||||
_: &mut Window,
|
||||
cx: &mut Context<Editor>,
|
||||
) {
|
||||
let Some(project) = &editor.project else {
|
||||
return;
|
||||
};
|
||||
let Some(buffer_id) = editor
|
||||
.selections
|
||||
.disjoint_anchors()
|
||||
.iter()
|
||||
.find_map(|selection| {
|
||||
let buffer_id = selection.start.buffer_id.or(selection.end.buffer_id)?;
|
||||
let project = project.read(cx);
|
||||
let entry_id = project
|
||||
.buffer_for_id(buffer_id, cx)?
|
||||
.read(cx)
|
||||
.entry_id(cx)?;
|
||||
project.path_for_entry(entry_id, cx)
|
||||
})
|
||||
else {
|
||||
return;
|
||||
};
|
||||
run_flycheck(project.clone(), buffer_id, cx).detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn clear_flycheck_action(
|
||||
editor: &mut Editor,
|
||||
_: &ClearFlycheck,
|
||||
_: &mut Window,
|
||||
cx: &mut Context<Editor>,
|
||||
) {
|
||||
let Some(project) = &editor.project else {
|
||||
return;
|
||||
};
|
||||
let Some(buffer_id) = editor
|
||||
.selections
|
||||
.disjoint_anchors()
|
||||
.iter()
|
||||
.find_map(|selection| {
|
||||
let buffer_id = selection.start.buffer_id.or(selection.end.buffer_id)?;
|
||||
let project = project.read(cx);
|
||||
let entry_id = project
|
||||
.buffer_for_id(buffer_id, cx)?
|
||||
.read(cx)
|
||||
.entry_id(cx)?;
|
||||
project.path_for_entry(entry_id, cx)
|
||||
})
|
||||
else {
|
||||
return;
|
||||
};
|
||||
clear_flycheck(project.clone(), buffer_id, cx).detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
@@ -424,7 +424,13 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
||||
prompt_store::init(cx);
|
||||
let stdout_is_a_pty = false;
|
||||
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
|
||||
agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
|
||||
agent::init(
|
||||
fs.clone(),
|
||||
client.clone(),
|
||||
prompt_builder.clone(),
|
||||
languages.clone(),
|
||||
cx,
|
||||
);
|
||||
assistant_tools::init(client.http_client(), cx);
|
||||
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
|
||||
@@ -121,6 +121,12 @@ pub trait Extension: Send + Sync + 'static {
|
||||
project: Arc<dyn ProjectDelegate>,
|
||||
) -> Result<Command>;
|
||||
|
||||
async fn context_server_configuration(
|
||||
&self,
|
||||
context_server_id: Arc<str>,
|
||||
project: Arc<dyn ProjectDelegate>,
|
||||
) -> Result<Option<ContextServerConfiguration>>;
|
||||
|
||||
async fn suggest_docs_packages(&self, provider: Arc<str>) -> Result<Vec<String>>;
|
||||
|
||||
async fn index_docs(
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global};
|
||||
|
||||
use crate::ExtensionManifest;
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
let extension_events = cx.new(ExtensionEvents::new);
|
||||
cx.set_global(GlobalExtensionEvents(extension_events));
|
||||
@@ -31,7 +35,9 @@ impl ExtensionEvents {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum Event {
|
||||
ExtensionInstalled(Arc<ExtensionManifest>),
|
||||
ExtensionsInstalledChanged,
|
||||
ConfigureExtensionRequested(Arc<ExtensionManifest>),
|
||||
}
|
||||
|
||||
impl EventEmitter<Event> for ExtensionEvents {}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
mod context_server;
|
||||
mod lsp;
|
||||
mod slash_command;
|
||||
|
||||
use std::ops::Range;
|
||||
|
||||
pub use context_server::*;
|
||||
pub use lsp::*;
|
||||
pub use slash_command::*;
|
||||
|
||||
|
||||
10
crates/extension/src/types/context_server.rs
Normal file
10
crates/extension/src/types/context_server.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
/// Configuration for a context server.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ContextServerConfiguration {
|
||||
/// Installation instructions for the user.
|
||||
pub installation_instructions: String,
|
||||
/// Default settings for the context server.
|
||||
pub default_settings: String,
|
||||
/// JSON schema describing server settings.
|
||||
pub settings_schema: serde_json::Value,
|
||||
}
|
||||
@@ -18,6 +18,7 @@ pub use wit::{
|
||||
CodeLabel, CodeLabelSpan, CodeLabelSpanLiteral, Command, DownloadedFileType, EnvVars,
|
||||
KeyValueStore, LanguageServerInstallationStatus, Project, Range, Worktree, download_file,
|
||||
make_file_executable,
|
||||
zed::extension::context_server::ContextServerConfiguration,
|
||||
zed::extension::github::{
|
||||
GithubRelease, GithubReleaseAsset, GithubReleaseOptions, github_release_by_tag_name,
|
||||
latest_github_release,
|
||||
@@ -159,6 +160,15 @@ pub trait Extension: Send + Sync {
|
||||
Err("`context_server_command` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Returns the configuration options for the specified context server.
|
||||
fn context_server_configuration(
|
||||
&mut self,
|
||||
_context_server_id: &ContextServerId,
|
||||
_project: &Project,
|
||||
) -> Result<Option<ContextServerConfiguration>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Returns a list of package names as suggestions to be included in the
|
||||
/// search results of the `/docs` slash command.
|
||||
///
|
||||
@@ -342,6 +352,14 @@ impl wit::Guest for Component {
|
||||
extension().context_server_command(&context_server_id, project)
|
||||
}
|
||||
|
||||
fn context_server_configuration(
|
||||
context_server_id: String,
|
||||
project: &Project,
|
||||
) -> Result<Option<ContextServerConfiguration>, String> {
|
||||
let context_server_id = ContextServerId(context_server_id);
|
||||
extension().context_server_configuration(&context_server_id, project)
|
||||
}
|
||||
|
||||
fn suggest_docs_packages(provider: String) -> Result<Vec<String>, String> {
|
||||
extension().suggest_docs_packages(provider)
|
||||
}
|
||||
|
||||
11
crates/extension_api/wit/since_v0.5.0/context-server.wit
Normal file
11
crates/extension_api/wit/since_v0.5.0/context-server.wit
Normal file
@@ -0,0 +1,11 @@
|
||||
interface context-server {
|
||||
///
|
||||
record context-server-configuration {
|
||||
///
|
||||
installation-instructions: string,
|
||||
///
|
||||
settings-schema: string,
|
||||
///
|
||||
default-settings: string,
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package zed:extension;
|
||||
|
||||
world extension {
|
||||
import context-server;
|
||||
import github;
|
||||
import http-client;
|
||||
import platform;
|
||||
@@ -8,6 +9,7 @@ world extension {
|
||||
import nodejs;
|
||||
|
||||
use common.{env-vars, range};
|
||||
use context-server.{context-server-configuration};
|
||||
use lsp.{completion, symbol};
|
||||
use process.{command};
|
||||
use slash-command.{slash-command, slash-command-argument-completion, slash-command-output};
|
||||
@@ -139,6 +141,9 @@ world extension {
|
||||
/// Returns the command used to start up a context server.
|
||||
export context-server-command: func(context-server-id: string, project: borrow<project>) -> result<command, string>;
|
||||
|
||||
/// Returns the configuration for a context server.
|
||||
export context-server-configuration: func(context-server-id: string, project: borrow<project>) -> result<option<context-server-configuration>, string>;
|
||||
|
||||
/// Returns a list of packages as suggestions to be included in the `/docs`
|
||||
/// search results.
|
||||
///
|
||||
|
||||
@@ -431,6 +431,13 @@ impl ExtensionStore {
|
||||
.filter_map(|extension| extension.dev.then_some(&extension.manifest))
|
||||
}
|
||||
|
||||
pub fn extension_manifest_for_id(&self, extension_id: &str) -> Option<&Arc<ExtensionManifest>> {
|
||||
self.extension_index
|
||||
.extensions
|
||||
.get(extension_id)
|
||||
.map(|extension| &extension.manifest)
|
||||
}
|
||||
|
||||
/// Returns the names of themes provided by extensions.
|
||||
pub fn extension_themes<'a>(
|
||||
&'a self,
|
||||
@@ -744,8 +751,18 @@ impl ExtensionStore {
|
||||
.await;
|
||||
|
||||
if let ExtensionOperation::Install = operation {
|
||||
this.update( cx, |_, cx| {
|
||||
cx.emit(Event::ExtensionInstalled(extension_id));
|
||||
this.update( cx, |this, cx| {
|
||||
cx.emit(Event::ExtensionInstalled(extension_id.clone()));
|
||||
if let Some(events) = ExtensionEvents::try_global(cx) {
|
||||
if let Some(manifest) = this.extension_manifest_for_id(&extension_id) {
|
||||
events.update(cx, |this, cx| {
|
||||
this.emit(
|
||||
extension::Event::ExtensionInstalled(manifest.clone()),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
@@ -935,6 +952,17 @@ impl ExtensionStore {
|
||||
.await?;
|
||||
|
||||
this.update(cx, |this, cx| this.reload(None, cx))?.await;
|
||||
this.update(cx, |this, cx| {
|
||||
cx.emit(Event::ExtensionInstalled(extension_id.clone()));
|
||||
if let Some(events) = ExtensionEvents::try_global(cx) {
|
||||
if let Some(manifest) = this.extension_manifest_for_id(&extension_id) {
|
||||
events.update(cx, |this, cx| {
|
||||
this.emit(extension::Event::ExtensionInstalled(manifest.clone()), cx)
|
||||
});
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,8 +4,9 @@ use crate::ExtensionManifest;
|
||||
use anyhow::{Context as _, Result, anyhow, bail};
|
||||
use async_trait::async_trait;
|
||||
use extension::{
|
||||
CodeLabel, Command, Completion, ExtensionHostProxy, KeyValueStoreDelegate, ProjectDelegate,
|
||||
SlashCommand, SlashCommandArgumentCompletion, SlashCommandOutput, Symbol, WorktreeDelegate,
|
||||
CodeLabel, Command, Completion, ContextServerConfiguration, ExtensionHostProxy,
|
||||
KeyValueStoreDelegate, ProjectDelegate, SlashCommand, SlashCommandArgumentCompletion,
|
||||
SlashCommandOutput, Symbol, WorktreeDelegate,
|
||||
};
|
||||
use fs::{Fs, normalize_path};
|
||||
use futures::future::LocalBoxFuture;
|
||||
@@ -306,6 +307,33 @@ impl extension::Extension for WasmExtension {
|
||||
.await
|
||||
}
|
||||
|
||||
async fn context_server_configuration(
|
||||
&self,
|
||||
context_server_id: Arc<str>,
|
||||
project: Arc<dyn ProjectDelegate>,
|
||||
) -> Result<Option<ContextServerConfiguration>> {
|
||||
self.call(|extension, store| {
|
||||
async move {
|
||||
let project_resource = store.data_mut().table().push(project)?;
|
||||
let Some(configuration) = extension
|
||||
.call_context_server_configuration(
|
||||
store,
|
||||
context_server_id.clone(),
|
||||
project_resource,
|
||||
)
|
||||
.await?
|
||||
.map_err(|err| anyhow!("{err}"))?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
Ok(Some(configuration.try_into()?))
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn suggest_docs_packages(&self, provider: Arc<str>) -> Result<Vec<String>> {
|
||||
self.call(|extension, store| {
|
||||
async move {
|
||||
|
||||
@@ -25,6 +25,7 @@ use wasmtime::{
|
||||
pub use latest::CodeLabelSpanLiteral;
|
||||
pub use latest::{
|
||||
CodeLabel, CodeLabelSpan, Command, ExtensionProject, Range, SlashCommand,
|
||||
zed::extension::context_server::ContextServerConfiguration,
|
||||
zed::extension::lsp::{
|
||||
Completion, CompletionKind, CompletionLabelDetails, InsertTextFormat, Symbol, SymbolKind,
|
||||
},
|
||||
@@ -726,6 +727,29 @@ impl Extension {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_context_server_configuration(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
context_server_id: Arc<str>,
|
||||
project: Resource<ExtensionProject>,
|
||||
) -> Result<Result<Option<ContextServerConfiguration>, String>> {
|
||||
match self {
|
||||
Extension::V0_5_0(ext) => {
|
||||
ext.call_context_server_configuration(store, &context_server_id, project)
|
||||
.await
|
||||
}
|
||||
Extension::V0_0_1(_)
|
||||
| Extension::V0_0_4(_)
|
||||
| Extension::V0_0_6(_)
|
||||
| Extension::V0_1_0(_)
|
||||
| Extension::V0_2_0(_)
|
||||
| Extension::V0_3_0(_)
|
||||
| Extension::V0_4_0(_) => Err(anyhow!(
|
||||
"`context_server_configuration` not available prior to v0.5.0"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_suggest_docs_packages(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
|
||||
@@ -247,6 +247,21 @@ impl From<SlashCommandArgumentCompletion> for extension::SlashCommandArgumentCom
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ContextServerConfiguration> for extension::ContextServerConfiguration {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: ContextServerConfiguration) -> Result<Self, Self::Error> {
|
||||
let settings_schema: serde_json::Value = serde_json::from_str(&value.settings_schema)
|
||||
.context("Failed to parse settings_schema")?;
|
||||
|
||||
Ok(Self {
|
||||
installation_instructions: value.installation_instructions,
|
||||
default_settings: value.default_settings,
|
||||
settings_schema,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl HostKeyValueStore for WasmState {
|
||||
async fn insert(
|
||||
&mut self,
|
||||
@@ -610,6 +625,9 @@ impl process::Host for WasmState {
|
||||
#[async_trait]
|
||||
impl slash_command::Host for WasmState {}
|
||||
|
||||
#[async_trait]
|
||||
impl context_server::Host for WasmState {}
|
||||
|
||||
impl ExtensionImports for WasmState {
|
||||
async fn get_settings(
|
||||
&mut self,
|
||||
|
||||
@@ -17,6 +17,7 @@ client.workspace = true
|
||||
collections.workspace = true
|
||||
db.workspace = true
|
||||
editor.workspace = true
|
||||
extension.workspace = true
|
||||
extension_host.workspace = true
|
||||
fs.workspace = true
|
||||
fuzzy.workspace = true
|
||||
|
||||
@@ -246,6 +246,12 @@ fn keywords_by_feature() -> &'static BTreeMap<Feature, Vec<&'static str>> {
|
||||
})
|
||||
}
|
||||
|
||||
struct ExtensionCardButtons {
|
||||
install_or_uninstall: Button,
|
||||
upgrade: Option<Button>,
|
||||
configure: Option<Button>,
|
||||
}
|
||||
|
||||
pub struct ExtensionsPage {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
list: UniformListScrollHandle,
|
||||
@@ -522,6 +528,8 @@ impl ExtensionsPage {
|
||||
|
||||
let repository_url = extension.repository.clone();
|
||||
|
||||
let can_configure = !extension.context_servers.is_empty();
|
||||
|
||||
ExtensionCard::new()
|
||||
.child(
|
||||
h_flex()
|
||||
@@ -568,7 +576,36 @@ impl ExtensionsPage {
|
||||
})
|
||||
.color(Color::Accent)
|
||||
.disabled(matches!(status, ExtensionStatus::Removing)),
|
||||
),
|
||||
)
|
||||
.when(can_configure, |this| {
|
||||
this.child(
|
||||
Button::new(
|
||||
SharedString::from(format!("configure-{}", extension.id)),
|
||||
"Configure",
|
||||
)
|
||||
|
||||
|
||||
.on_click({
|
||||
let manifest = Arc::new(extension.clone());
|
||||
move |_, _, cx| {
|
||||
if let Some(events) =
|
||||
extension::ExtensionEvents::try_global(cx)
|
||||
{
|
||||
events.update(cx, |this, cx| {
|
||||
this.emit(
|
||||
extension::Event::ConfigureExtensionRequested(
|
||||
manifest.clone(),
|
||||
),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
})
|
||||
.color(Color::Accent)
|
||||
.disabled(matches!(status, ExtensionStatus::Installing)),
|
||||
)
|
||||
}),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
@@ -629,8 +666,7 @@ impl ExtensionsPage {
|
||||
let has_dev_extension = Self::dev_extension_exists(&extension.id, cx);
|
||||
|
||||
let extension_id = extension.id.clone();
|
||||
let (install_or_uninstall_button, upgrade_button) =
|
||||
self.buttons_for_entry(extension, &status, has_dev_extension, cx);
|
||||
let buttons = self.buttons_for_entry(extension, &status, has_dev_extension, cx);
|
||||
let version = extension.manifest.version.clone();
|
||||
let repository_url = extension.manifest.repository.clone();
|
||||
let authors = extension.manifest.authors.clone();
|
||||
@@ -695,8 +731,9 @@ impl ExtensionsPage {
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.justify_between()
|
||||
.children(upgrade_button)
|
||||
.child(install_or_uninstall_button),
|
||||
.children(buttons.upgrade)
|
||||
.children(buttons.configure)
|
||||
.child(buttons.install_or_uninstall),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
@@ -861,22 +898,35 @@ impl ExtensionsPage {
|
||||
status: &ExtensionStatus,
|
||||
has_dev_extension: bool,
|
||||
cx: &mut Context<Self>,
|
||||
) -> (Button, Option<Button>) {
|
||||
) -> ExtensionCardButtons {
|
||||
let is_compatible =
|
||||
extension_host::is_version_compatible(ReleaseChannel::global(cx), extension);
|
||||
|
||||
if has_dev_extension {
|
||||
// If we have a dev extension for the given extension, just treat it as uninstalled.
|
||||
// The button here is a placeholder, as it won't be interactable anyways.
|
||||
return (
|
||||
Button::new(SharedString::from(extension.id.clone()), "Install"),
|
||||
None,
|
||||
);
|
||||
return ExtensionCardButtons {
|
||||
install_or_uninstall: Button::new(
|
||||
SharedString::from(extension.id.clone()),
|
||||
"Install",
|
||||
),
|
||||
configure: None,
|
||||
upgrade: None,
|
||||
};
|
||||
}
|
||||
|
||||
let is_configurable = extension
|
||||
.manifest
|
||||
.provides
|
||||
.contains(&ExtensionProvides::ContextServers);
|
||||
|
||||
match status.clone() {
|
||||
ExtensionStatus::NotInstalled => (
|
||||
Button::new(SharedString::from(extension.id.clone()), "Install").on_click({
|
||||
ExtensionStatus::NotInstalled => ExtensionCardButtons {
|
||||
install_or_uninstall: Button::new(
|
||||
SharedString::from(extension.id.clone()),
|
||||
"Install",
|
||||
)
|
||||
.on_click({
|
||||
let extension_id = extension.id.clone();
|
||||
move |_, _, cx| {
|
||||
telemetry::event!("Extension Installed");
|
||||
@@ -885,20 +935,41 @@ impl ExtensionsPage {
|
||||
});
|
||||
}
|
||||
}),
|
||||
None,
|
||||
),
|
||||
ExtensionStatus::Installing => (
|
||||
Button::new(SharedString::from(extension.id.clone()), "Install").disabled(true),
|
||||
None,
|
||||
),
|
||||
ExtensionStatus::Upgrading => (
|
||||
Button::new(SharedString::from(extension.id.clone()), "Uninstall").disabled(true),
|
||||
Some(
|
||||
configure: None,
|
||||
upgrade: None,
|
||||
},
|
||||
ExtensionStatus::Installing => ExtensionCardButtons {
|
||||
install_or_uninstall: Button::new(
|
||||
SharedString::from(extension.id.clone()),
|
||||
"Install",
|
||||
)
|
||||
.disabled(true),
|
||||
configure: None,
|
||||
upgrade: None,
|
||||
},
|
||||
ExtensionStatus::Upgrading => ExtensionCardButtons {
|
||||
install_or_uninstall: Button::new(
|
||||
SharedString::from(extension.id.clone()),
|
||||
"Uninstall",
|
||||
)
|
||||
.disabled(true),
|
||||
configure: is_configurable.then(|| {
|
||||
Button::new(
|
||||
SharedString::from(format!("configure-{}", extension.id.clone())),
|
||||
"Configure",
|
||||
)
|
||||
.disabled(true)
|
||||
}),
|
||||
upgrade: Some(
|
||||
Button::new(SharedString::from(extension.id.clone()), "Upgrade").disabled(true),
|
||||
),
|
||||
),
|
||||
ExtensionStatus::Installed(installed_version) => (
|
||||
Button::new(SharedString::from(extension.id.clone()), "Uninstall").on_click({
|
||||
},
|
||||
ExtensionStatus::Installed(installed_version) => ExtensionCardButtons {
|
||||
install_or_uninstall: Button::new(
|
||||
SharedString::from(extension.id.clone()),
|
||||
"Uninstall",
|
||||
)
|
||||
.on_click({
|
||||
let extension_id = extension.id.clone();
|
||||
move |_, _, cx| {
|
||||
telemetry::event!("Extension Uninstalled", extension_id);
|
||||
@@ -907,7 +978,32 @@ impl ExtensionsPage {
|
||||
});
|
||||
}
|
||||
}),
|
||||
if installed_version == extension.manifest.version {
|
||||
configure: is_configurable.then(|| {
|
||||
Button::new(
|
||||
SharedString::from(format!("configure-{}", extension.id.clone())),
|
||||
"Configure",
|
||||
)
|
||||
.on_click({
|
||||
let extension_id = extension.id.clone();
|
||||
move |_, _, cx| {
|
||||
if let Some(manifest) = ExtensionStore::global(cx)
|
||||
.read(cx)
|
||||
.extension_manifest_for_id(&extension_id)
|
||||
.cloned()
|
||||
{
|
||||
if let Some(events) = extension::ExtensionEvents::try_global(cx) {
|
||||
events.update(cx, |this, cx| {
|
||||
this.emit(
|
||||
extension::Event::ConfigureExtensionRequested(manifest),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}),
|
||||
upgrade: if installed_version == extension.manifest.version {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
@@ -944,11 +1040,22 @@ impl ExtensionsPage {
|
||||
}),
|
||||
)
|
||||
},
|
||||
),
|
||||
ExtensionStatus::Removing => (
|
||||
Button::new(SharedString::from(extension.id.clone()), "Uninstall").disabled(true),
|
||||
None,
|
||||
),
|
||||
},
|
||||
ExtensionStatus::Removing => ExtensionCardButtons {
|
||||
install_or_uninstall: Button::new(
|
||||
SharedString::from(extension.id.clone()),
|
||||
"Uninstall",
|
||||
)
|
||||
.disabled(true),
|
||||
configure: is_configurable.then(|| {
|
||||
Button::new(
|
||||
SharedString::from(format!("configure-{}", extension.id.clone())),
|
||||
"Configure",
|
||||
)
|
||||
.disabled(true)
|
||||
}),
|
||||
upgrade: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -322,7 +322,7 @@ impl GitRepository for FakeGitRepository {
|
||||
.iter()
|
||||
.map(|branch_name| Branch {
|
||||
is_head: Some(branch_name) == current_branch.as_ref(),
|
||||
name: branch_name.into(),
|
||||
ref_name: branch_name.into(),
|
||||
most_recent_commit: None,
|
||||
upstream: None,
|
||||
})
|
||||
|
||||
@@ -37,12 +37,24 @@ pub const REMOTE_CANCELLED_BY_USER: &str = "Operation cancelled by user";
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
|
||||
pub struct Branch {
|
||||
pub is_head: bool,
|
||||
pub name: SharedString,
|
||||
pub ref_name: SharedString,
|
||||
pub upstream: Option<Upstream>,
|
||||
pub most_recent_commit: Option<CommitSummary>,
|
||||
}
|
||||
|
||||
impl Branch {
|
||||
pub fn name(&self) -> &str {
|
||||
self.ref_name
|
||||
.as_ref()
|
||||
.strip_prefix("refs/heads/")
|
||||
.or_else(|| self.ref_name.as_ref().strip_prefix("refs/remotes/"))
|
||||
.unwrap_or(self.ref_name.as_ref())
|
||||
}
|
||||
|
||||
pub fn is_remote(&self) -> bool {
|
||||
self.ref_name.starts_with("refs/remotes/")
|
||||
}
|
||||
|
||||
pub fn tracking_status(&self) -> Option<UpstreamTrackingStatus> {
|
||||
self.upstream
|
||||
.as_ref()
|
||||
@@ -71,6 +83,10 @@ impl Upstream {
|
||||
.strip_prefix("refs/remotes/")
|
||||
.and_then(|stripped| stripped.split("/").next())
|
||||
}
|
||||
|
||||
pub fn stripped_ref_name(&self) -> Option<&str> {
|
||||
self.ref_name.strip_prefix("refs/remotes/")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Default)]
|
||||
@@ -803,68 +819,69 @@ impl GitRepository for RealGitRepository {
|
||||
fn branches(&self) -> BoxFuture<Result<Vec<Branch>>> {
|
||||
let working_directory = self.working_directory();
|
||||
let git_binary_path = self.git_binary_path.clone();
|
||||
async move {
|
||||
let fields = [
|
||||
"%(HEAD)",
|
||||
"%(objectname)",
|
||||
"%(parent)",
|
||||
"%(refname)",
|
||||
"%(upstream)",
|
||||
"%(upstream:track)",
|
||||
"%(committerdate:unix)",
|
||||
"%(contents:subject)",
|
||||
]
|
||||
.join("%00");
|
||||
let args = vec![
|
||||
"for-each-ref",
|
||||
"refs/heads/**/*",
|
||||
"refs/remotes/**/*",
|
||||
"--format",
|
||||
&fields,
|
||||
];
|
||||
let working_directory = working_directory?;
|
||||
let output = new_smol_command(&git_binary_path)
|
||||
.current_dir(&working_directory)
|
||||
.args(args)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Err(anyhow!(
|
||||
"Failed to git git branches:\n{}",
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
));
|
||||
}
|
||||
|
||||
let input = String::from_utf8_lossy(&output.stdout);
|
||||
|
||||
let mut branches = parse_branch_input(&input)?;
|
||||
if branches.is_empty() {
|
||||
let args = vec!["symbolic-ref", "--quiet", "--short", "HEAD"];
|
||||
|
||||
self.executor
|
||||
.spawn(async move {
|
||||
let fields = [
|
||||
"%(HEAD)",
|
||||
"%(objectname)",
|
||||
"%(parent)",
|
||||
"%(refname)",
|
||||
"%(upstream)",
|
||||
"%(upstream:track)",
|
||||
"%(committerdate:unix)",
|
||||
"%(contents:subject)",
|
||||
]
|
||||
.join("%00");
|
||||
let args = vec![
|
||||
"for-each-ref",
|
||||
"refs/heads/**/*",
|
||||
"refs/remotes/**/*",
|
||||
"--format",
|
||||
&fields,
|
||||
];
|
||||
let working_directory = working_directory?;
|
||||
let output = new_smol_command(&git_binary_path)
|
||||
.current_dir(&working_directory)
|
||||
.args(args)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
// git symbolic-ref returns a non-0 exit code if HEAD points
|
||||
// to something other than a branch
|
||||
if output.status.success() {
|
||||
let name = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
|
||||
branches.push(Branch {
|
||||
name: name.into(),
|
||||
is_head: true,
|
||||
upstream: None,
|
||||
most_recent_commit: None,
|
||||
});
|
||||
if !output.status.success() {
|
||||
return Err(anyhow!(
|
||||
"Failed to git git branches:\n{}",
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(branches)
|
||||
}
|
||||
.boxed()
|
||||
let input = String::from_utf8_lossy(&output.stdout);
|
||||
|
||||
let mut branches = parse_branch_input(&input)?;
|
||||
if branches.is_empty() {
|
||||
let args = vec!["symbolic-ref", "--quiet", "HEAD"];
|
||||
|
||||
let output = new_smol_command(&git_binary_path)
|
||||
.current_dir(&working_directory)
|
||||
.args(args)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
// git symbolic-ref returns a non-0 exit code if HEAD points
|
||||
// to something other than a branch
|
||||
if output.status.success() {
|
||||
let name = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
|
||||
branches.push(Branch {
|
||||
ref_name: name.into(),
|
||||
is_head: true,
|
||||
upstream: None,
|
||||
most_recent_commit: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(branches)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn change_branch(&self, name: String) -> BoxFuture<Result<()>> {
|
||||
@@ -1691,15 +1708,7 @@ fn parse_branch_input(input: &str) -> Result<Vec<Branch>> {
|
||||
let is_current_branch = fields.next().context("no HEAD")? == "*";
|
||||
let head_sha: SharedString = fields.next().context("no objectname")?.to_string().into();
|
||||
let parent_sha: SharedString = fields.next().context("no parent")?.to_string().into();
|
||||
let raw_ref_name = fields.next().context("no refname")?;
|
||||
let ref_name: SharedString =
|
||||
if let Some(ref_name) = raw_ref_name.strip_prefix("refs/heads/") {
|
||||
ref_name.to_string().into()
|
||||
} else if let Some(ref_name) = raw_ref_name.strip_prefix("refs/remotes/") {
|
||||
ref_name.to_string().into()
|
||||
} else {
|
||||
return Err(anyhow!("unexpected format for refname"));
|
||||
};
|
||||
let ref_name = fields.next().context("no refname")?.to_string().into();
|
||||
let upstream_name = fields.next().context("no upstream")?.to_string();
|
||||
let upstream_tracking = parse_upstream_track(fields.next().context("no upstream:track")?)?;
|
||||
let commiterdate = fields.next().context("no committerdate")?.parse::<i64>()?;
|
||||
@@ -1711,7 +1720,7 @@ fn parse_branch_input(input: &str) -> Result<Vec<Branch>> {
|
||||
|
||||
branches.push(Branch {
|
||||
is_head: is_current_branch,
|
||||
name: ref_name,
|
||||
ref_name: ref_name,
|
||||
most_recent_commit: Some(CommitSummary {
|
||||
sha: head_sha,
|
||||
subject,
|
||||
@@ -1974,7 +1983,7 @@ mod tests {
|
||||
parse_branch_input(&input).unwrap(),
|
||||
vec![Branch {
|
||||
is_head: true,
|
||||
name: "zed-patches".into(),
|
||||
ref_name: "refs/heads/zed-patches".into(),
|
||||
upstream: Some(Upstream {
|
||||
ref_name: "refs/remotes/origin/zed-patches".into(),
|
||||
tracking: UpstreamTracking::Tracked(UpstreamTrackingStatus {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use fuzzy::StringMatchCandidate;
|
||||
|
||||
use collections::HashSet;
|
||||
use git::repository::Branch;
|
||||
use gpui::{
|
||||
App, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement,
|
||||
@@ -95,12 +96,28 @@ impl BranchList {
|
||||
.context("No active repository")?
|
||||
.await??;
|
||||
|
||||
all_branches.sort_by_key(|branch| {
|
||||
branch
|
||||
.most_recent_commit
|
||||
.as_ref()
|
||||
.map(|commit| 0 - commit.commit_timestamp)
|
||||
});
|
||||
let all_branches = cx
|
||||
.background_spawn(async move {
|
||||
let upstreams: HashSet<_> = all_branches
|
||||
.iter()
|
||||
.filter_map(|branch| {
|
||||
let upstream = branch.upstream.as_ref()?;
|
||||
Some(upstream.ref_name.clone())
|
||||
})
|
||||
.collect();
|
||||
|
||||
all_branches.retain(|branch| !upstreams.contains(&branch.ref_name));
|
||||
|
||||
all_branches.sort_by_key(|branch| {
|
||||
branch
|
||||
.most_recent_commit
|
||||
.as_ref()
|
||||
.map(|commit| 0 - commit.commit_timestamp)
|
||||
});
|
||||
|
||||
all_branches
|
||||
})
|
||||
.await;
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.picker.update(cx, |picker, cx| {
|
||||
@@ -266,6 +283,7 @@ impl PickerDelegate for BranchListDelegate {
|
||||
let mut matches: Vec<BranchEntry> = if query.is_empty() {
|
||||
all_branches
|
||||
.into_iter()
|
||||
.filter(|branch| !branch.is_remote())
|
||||
.take(RECENT_BRANCHES_COUNT)
|
||||
.map(|branch| BranchEntry {
|
||||
branch,
|
||||
@@ -277,7 +295,7 @@ impl PickerDelegate for BranchListDelegate {
|
||||
let candidates = all_branches
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(ix, command)| StringMatchCandidate::new(ix, &command.name.clone()))
|
||||
.map(|(ix, branch)| StringMatchCandidate::new(ix, branch.name()))
|
||||
.collect::<Vec<StringMatchCandidate>>();
|
||||
fuzzy::match_strings(
|
||||
&candidates,
|
||||
@@ -303,11 +321,11 @@ impl PickerDelegate for BranchListDelegate {
|
||||
if !query.is_empty()
|
||||
&& !matches
|
||||
.first()
|
||||
.is_some_and(|entry| entry.branch.name == query)
|
||||
.is_some_and(|entry| entry.branch.name() == query)
|
||||
{
|
||||
matches.push(BranchEntry {
|
||||
branch: Branch {
|
||||
name: query.clone().into(),
|
||||
ref_name: format!("refs/heads/{query}").into(),
|
||||
is_head: false,
|
||||
upstream: None,
|
||||
most_recent_commit: None,
|
||||
@@ -335,19 +353,19 @@ impl PickerDelegate for BranchListDelegate {
|
||||
return;
|
||||
};
|
||||
if entry.is_new {
|
||||
self.create_branch(entry.branch.name.clone(), window, cx);
|
||||
self.create_branch(entry.branch.name().to_owned().into(), window, cx);
|
||||
return;
|
||||
}
|
||||
|
||||
let current_branch = self.repo.as_ref().map(|repo| {
|
||||
repo.update(cx, |repo, _| {
|
||||
repo.branch.as_ref().map(|branch| branch.name.clone())
|
||||
repo.branch.as_ref().map(|branch| branch.ref_name.clone())
|
||||
})
|
||||
});
|
||||
|
||||
if current_branch
|
||||
.flatten()
|
||||
.is_some_and(|current_branch| current_branch == entry.branch.name)
|
||||
.is_some_and(|current_branch| current_branch == entry.branch.ref_name)
|
||||
{
|
||||
cx.emit(DismissEvent);
|
||||
return;
|
||||
@@ -368,7 +386,7 @@ impl PickerDelegate for BranchListDelegate {
|
||||
|
||||
anyhow::Ok(async move {
|
||||
repo.update(&mut cx, |repo, _| {
|
||||
repo.change_branch(branch.name.to_string())
|
||||
repo.change_branch(branch.name().to_string())
|
||||
})?
|
||||
.await?
|
||||
})
|
||||
@@ -443,13 +461,13 @@ impl PickerDelegate for BranchListDelegate {
|
||||
if entry.is_new {
|
||||
Label::new(format!(
|
||||
"Create branch \"{}\"…",
|
||||
entry.branch.name
|
||||
entry.branch.name()
|
||||
))
|
||||
.single_line()
|
||||
.into_any_element()
|
||||
} else {
|
||||
HighlightedLabel::new(
|
||||
entry.branch.name.clone(),
|
||||
entry.branch.name().to_owned(),
|
||||
entry.positions.clone(),
|
||||
)
|
||||
.truncate()
|
||||
@@ -470,7 +488,7 @@ impl PickerDelegate for BranchListDelegate {
|
||||
let message = if entry.is_new {
|
||||
if let Some(current_branch) =
|
||||
self.repo.as_ref().and_then(|repo| {
|
||||
repo.read(cx).branch.as_ref().map(|b| b.name.clone())
|
||||
repo.read(cx).branch.as_ref().map(|b| b.name())
|
||||
})
|
||||
{
|
||||
format!("based off {}", current_branch)
|
||||
|
||||
@@ -321,8 +321,8 @@ impl CommitModal {
|
||||
let branch = active_repo
|
||||
.as_ref()
|
||||
.and_then(|repo| repo.read(cx).branch.as_ref())
|
||||
.map(|b| b.name.clone())
|
||||
.unwrap_or_else(|| "<no branch>".into());
|
||||
.map(|b| b.name().to_owned())
|
||||
.unwrap_or_else(|| "<no branch>".to_owned());
|
||||
|
||||
let branch_picker_button = panel_button(branch)
|
||||
.icon(IconName::GitBranch)
|
||||
|
||||
@@ -1953,7 +1953,12 @@ impl GitPanel {
|
||||
})?;
|
||||
|
||||
let pull = repo.update(cx, |repo, cx| {
|
||||
repo.pull(branch.name.clone(), remote.name.clone(), askpass, cx)
|
||||
repo.pull(
|
||||
branch.name().to_owned().into(),
|
||||
remote.name.clone(),
|
||||
askpass,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
let remote_message = pull.await?;
|
||||
@@ -2020,7 +2025,7 @@ impl GitPanel {
|
||||
|
||||
let push = repo.update(cx, |repo, cx| {
|
||||
repo.push(
|
||||
branch.name.clone(),
|
||||
branch.name().to_owned().into(),
|
||||
remote.name.clone(),
|
||||
options,
|
||||
askpass_delegate,
|
||||
@@ -2030,7 +2035,7 @@ impl GitPanel {
|
||||
|
||||
let remote_output = push.await?;
|
||||
|
||||
let action = RemoteAction::Push(branch.name, remote);
|
||||
let action = RemoteAction::Push(branch.name().to_owned().into(), remote);
|
||||
this.update(cx, |this, cx| match remote_output {
|
||||
Ok(remote_message) => this.show_remote_output(action, remote_message, cx),
|
||||
Err(e) => {
|
||||
@@ -2092,7 +2097,7 @@ impl GitPanel {
|
||||
return Err(anyhow::anyhow!("No active branch"));
|
||||
};
|
||||
|
||||
Ok(repo.get_remotes(Some(current_branch.name.to_string())))
|
||||
Ok(repo.get_remotes(Some(current_branch.name().to_string())))
|
||||
})??
|
||||
.await??;
|
||||
|
||||
@@ -4363,19 +4368,17 @@ impl RenderOnce for PanelRepoFooter {
|
||||
let branch_name = self
|
||||
.branch
|
||||
.as_ref()
|
||||
.map(|branch| branch.name.clone())
|
||||
.map(|branch| branch.name().to_owned())
|
||||
.or_else(|| {
|
||||
self.head_commit.as_ref().map(|commit| {
|
||||
SharedString::from(
|
||||
commit
|
||||
.sha
|
||||
.chars()
|
||||
.take(MAX_SHORT_SHA_LEN)
|
||||
.collect::<String>(),
|
||||
)
|
||||
commit
|
||||
.sha
|
||||
.chars()
|
||||
.take(MAX_SHORT_SHA_LEN)
|
||||
.collect::<String>()
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| SharedString::from(" (no branch)"));
|
||||
.unwrap_or_else(|| " (no branch)".to_owned());
|
||||
let show_separator = self.branch.is_some() || self.head_commit.is_some();
|
||||
|
||||
let active_repo_name = self.active_repository.clone();
|
||||
@@ -4542,7 +4545,7 @@ impl Component for PanelRepoFooter {
|
||||
fn branch(upstream: Option<UpstreamTracking>) -> Branch {
|
||||
Branch {
|
||||
is_head: true,
|
||||
name: "some-branch".into(),
|
||||
ref_name: "some-branch".into(),
|
||||
upstream: upstream.map(|tracking| Upstream {
|
||||
ref_name: "origin/some-branch".into(),
|
||||
tracking,
|
||||
@@ -4559,7 +4562,7 @@ impl Component for PanelRepoFooter {
|
||||
fn custom(branch_name: &str, upstream: Option<UpstreamTracking>) -> Branch {
|
||||
Branch {
|
||||
is_head: true,
|
||||
name: branch_name.to_string().into(),
|
||||
ref_name: branch_name.to_string().into(),
|
||||
upstream: upstream.map(|tracking| Upstream {
|
||||
ref_name: format!("zed/{}", branch_name).into(),
|
||||
tracking,
|
||||
|
||||
@@ -1099,7 +1099,7 @@ impl RenderOnce for ProjectDiffEmptyState {
|
||||
v_flex()
|
||||
.child(Headline::new(ahead_string).size(HeadlineSize::Small))
|
||||
.child(
|
||||
Label::new(format!("Push your changes to {}", branch.name))
|
||||
Label::new(format!("Push your changes to {}", branch.name()))
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
@@ -1113,7 +1113,7 @@ impl RenderOnce for ProjectDiffEmptyState {
|
||||
v_flex()
|
||||
.child(Headline::new("Publish Branch").size(HeadlineSize::Small))
|
||||
.child(
|
||||
Label::new(format!("Create {} on remote", branch.name))
|
||||
Label::new(format!("Create {} on remote", branch.name()))
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
@@ -1183,7 +1183,7 @@ mod preview {
|
||||
fn branch(upstream: Option<UpstreamTracking>) -> Branch {
|
||||
Branch {
|
||||
is_head: true,
|
||||
name: "some-branch".into(),
|
||||
ref_name: "some-branch".into(),
|
||||
upstream: upstream.map(|tracking| Upstream {
|
||||
ref_name: "origin/some-branch".into(),
|
||||
tracking,
|
||||
|
||||
@@ -137,6 +137,7 @@ pub enum IconName {
|
||||
GitBranchSmall,
|
||||
Github,
|
||||
Globe,
|
||||
Hammer,
|
||||
Hash,
|
||||
HistoryRerun,
|
||||
Image,
|
||||
|
||||
@@ -2141,6 +2141,14 @@ impl Buffer {
|
||||
self.edit([(0..self.len(), text)], None, cx)
|
||||
}
|
||||
|
||||
/// Appends the given text to the end of the buffer.
|
||||
pub fn append<T>(&mut self, text: T, cx: &mut Context<Self>) -> Option<clock::Lamport>
|
||||
where
|
||||
T: Into<Arc<str>>,
|
||||
{
|
||||
self.edit([(self.len()..self.len(), text)], None, cx)
|
||||
}
|
||||
|
||||
/// Applies the given edits to the buffer. Each edit is specified as a range of text to
|
||||
/// delete, and a string of text to insert at that location.
|
||||
///
|
||||
|
||||
@@ -260,15 +260,6 @@ impl LspAdapter for RustLspAdapter {
|
||||
Some("rust-analyzer/flycheck".into())
|
||||
}
|
||||
|
||||
fn retain_old_diagnostic(&self, previous_diagnostic: &Diagnostic, cx: &App) -> bool {
|
||||
let zed_provides_cargo_diagnostics = ProjectSettings::get_global(cx)
|
||||
.diagnostics
|
||||
.fetch_cargo_diagnostics();
|
||||
// Zed manages the lifecycle of cargo diagnostics when configured so.
|
||||
zed_provides_cargo_diagnostics
|
||||
&& previous_diagnostic.source.as_deref() == Some(CARGO_DIAGNOSTICS_SOURCE_NAME)
|
||||
}
|
||||
|
||||
fn process_diagnostics(
|
||||
&self,
|
||||
params: &mut lsp::PublishDiagnosticsParams,
|
||||
@@ -516,10 +507,10 @@ impl LspAdapter for RustLspAdapter {
|
||||
}
|
||||
}
|
||||
|
||||
let zed_provides_cargo_diagnostics = ProjectSettings::get_global(cx)
|
||||
let cargo_diagnostics_fetched_separately = ProjectSettings::get_global(cx)
|
||||
.diagnostics
|
||||
.fetch_cargo_diagnostics();
|
||||
if zed_provides_cargo_diagnostics {
|
||||
if cargo_diagnostics_fetched_separately {
|
||||
let disable_check_on_save = json!({
|
||||
"checkOnSave": false,
|
||||
});
|
||||
|
||||
@@ -646,7 +646,7 @@ struct ActiveItem {
|
||||
item_handle: Box<dyn WeakItemHandle>,
|
||||
active_editor: WeakEntity<Editor>,
|
||||
_buffer_search_subscription: Subscription,
|
||||
_editor_subscrpiption: Subscription,
|
||||
_editor_subscription: Subscription,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -2962,7 +2962,7 @@ impl OutlinePanel {
|
||||
);
|
||||
self.active_item = Some(ActiveItem {
|
||||
_buffer_search_subscription: buffer_search_subscription,
|
||||
_editor_subscrpiption: subscribe_for_editor_events(&new_active_editor, window, cx),
|
||||
_editor_subscription: subscribe_for_editor_events(&new_active_editor, window, cx),
|
||||
item_handle: new_active_item.downgrade_item(),
|
||||
active_editor: new_active_editor.downgrade(),
|
||||
});
|
||||
|
||||
@@ -588,7 +588,9 @@ impl<D: PickerDelegate> Picker<D> {
|
||||
self.update_matches(query, window, cx);
|
||||
}
|
||||
editor::EditorEvent::Blurred => {
|
||||
self.cancel(&menu::Cancel, window, cx);
|
||||
if self.is_modal {
|
||||
self.cancel(&menu::Cancel, window, cx);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
@@ -3790,13 +3790,9 @@ impl Repository {
|
||||
|
||||
pub fn branches(&mut self) -> oneshot::Receiver<Result<Vec<Branch>>> {
|
||||
let id = self.id;
|
||||
self.send_job(None, move |repo, cx| async move {
|
||||
self.send_job(None, move |repo, _| async move {
|
||||
match repo {
|
||||
RepositoryState::Local { backend, .. } => {
|
||||
let backend = backend.clone();
|
||||
cx.background_spawn(async move { backend.branches().await })
|
||||
.await
|
||||
}
|
||||
RepositoryState::Local { backend, .. } => backend.branches().await,
|
||||
RepositoryState::Remote { project_id, client } => {
|
||||
let response = client
|
||||
.request(proto::GitGetBranches {
|
||||
@@ -4460,7 +4456,7 @@ fn deserialize_blame_buffer_response(
|
||||
fn branch_to_proto(branch: &git::repository::Branch) -> proto::Branch {
|
||||
proto::Branch {
|
||||
is_head: branch.is_head,
|
||||
name: branch.name.to_string(),
|
||||
ref_name: branch.ref_name.to_string(),
|
||||
unix_timestamp: branch
|
||||
.most_recent_commit
|
||||
.as_ref()
|
||||
@@ -4489,7 +4485,7 @@ fn branch_to_proto(branch: &git::repository::Branch) -> proto::Branch {
|
||||
fn proto_to_branch(proto: &proto::Branch) -> git::repository::Branch {
|
||||
git::repository::Branch {
|
||||
is_head: proto.is_head,
|
||||
name: proto.name.clone().into(),
|
||||
ref_name: proto.ref_name.clone().into(),
|
||||
upstream: proto
|
||||
.upstream
|
||||
.as_ref()
|
||||
|
||||
@@ -13,7 +13,7 @@ use client::proto::{self, PeerId};
|
||||
use clock::Global;
|
||||
use collections::HashSet;
|
||||
use futures::future;
|
||||
use gpui::{App, AsyncApp, Entity};
|
||||
use gpui::{App, AsyncApp, Entity, Task};
|
||||
use language::{
|
||||
Anchor, Bias, Buffer, BufferSnapshot, CachedLspAdapter, CharKind, OffsetRangeExt, PointUtf16,
|
||||
ToOffset, ToPointUtf16, Transaction, Unclipped,
|
||||
@@ -966,7 +966,7 @@ fn language_server_for_buffer(
|
||||
.ok_or_else(|| anyhow!("no language server found for buffer"))
|
||||
}
|
||||
|
||||
async fn location_links_from_proto(
|
||||
pub async fn location_links_from_proto(
|
||||
proto_links: Vec<proto::LocationLink>,
|
||||
lsp_store: Entity<LspStore>,
|
||||
mut cx: AsyncApp,
|
||||
@@ -974,70 +974,72 @@ async fn location_links_from_proto(
|
||||
let mut links = Vec::new();
|
||||
|
||||
for link in proto_links {
|
||||
links.push(location_link_from_proto(link, &lsp_store, &mut cx).await?)
|
||||
links.push(location_link_from_proto(link, lsp_store.clone(), &mut cx).await?)
|
||||
}
|
||||
|
||||
Ok(links)
|
||||
}
|
||||
|
||||
pub async fn location_link_from_proto(
|
||||
pub fn location_link_from_proto(
|
||||
link: proto::LocationLink,
|
||||
lsp_store: &Entity<LspStore>,
|
||||
lsp_store: Entity<LspStore>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<LocationLink> {
|
||||
let origin = match link.origin {
|
||||
Some(origin) => {
|
||||
let buffer_id = BufferId::new(origin.buffer_id)?;
|
||||
let buffer = lsp_store
|
||||
.update(cx, |lsp_store, cx| {
|
||||
lsp_store.wait_for_remote_buffer(buffer_id, cx)
|
||||
})?
|
||||
.await?;
|
||||
let start = origin
|
||||
.start
|
||||
.and_then(deserialize_anchor)
|
||||
.ok_or_else(|| anyhow!("missing origin start"))?;
|
||||
let end = origin
|
||||
.end
|
||||
.and_then(deserialize_anchor)
|
||||
.ok_or_else(|| anyhow!("missing origin end"))?;
|
||||
buffer
|
||||
.update(cx, |buffer, _| buffer.wait_for_anchors([start, end]))?
|
||||
.await?;
|
||||
Some(Location {
|
||||
buffer,
|
||||
range: start..end,
|
||||
})
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
) -> Task<Result<LocationLink>> {
|
||||
cx.spawn(async move |cx| {
|
||||
let origin = match link.origin {
|
||||
Some(origin) => {
|
||||
let buffer_id = BufferId::new(origin.buffer_id)?;
|
||||
let buffer = lsp_store
|
||||
.update(cx, |lsp_store, cx| {
|
||||
lsp_store.wait_for_remote_buffer(buffer_id, cx)
|
||||
})?
|
||||
.await?;
|
||||
let start = origin
|
||||
.start
|
||||
.and_then(deserialize_anchor)
|
||||
.ok_or_else(|| anyhow!("missing origin start"))?;
|
||||
let end = origin
|
||||
.end
|
||||
.and_then(deserialize_anchor)
|
||||
.ok_or_else(|| anyhow!("missing origin end"))?;
|
||||
buffer
|
||||
.update(cx, |buffer, _| buffer.wait_for_anchors([start, end]))?
|
||||
.await?;
|
||||
Some(Location {
|
||||
buffer,
|
||||
range: start..end,
|
||||
})
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let target = link.target.ok_or_else(|| anyhow!("missing target"))?;
|
||||
let buffer_id = BufferId::new(target.buffer_id)?;
|
||||
let buffer = lsp_store
|
||||
.update(cx, |lsp_store, cx| {
|
||||
lsp_store.wait_for_remote_buffer(buffer_id, cx)
|
||||
})?
|
||||
.await?;
|
||||
let start = target
|
||||
.start
|
||||
.and_then(deserialize_anchor)
|
||||
.ok_or_else(|| anyhow!("missing target start"))?;
|
||||
let end = target
|
||||
.end
|
||||
.and_then(deserialize_anchor)
|
||||
.ok_or_else(|| anyhow!("missing target end"))?;
|
||||
buffer
|
||||
.update(cx, |buffer, _| buffer.wait_for_anchors([start, end]))?
|
||||
.await?;
|
||||
let target = Location {
|
||||
buffer,
|
||||
range: start..end,
|
||||
};
|
||||
Ok(LocationLink { origin, target })
|
||||
let target = link.target.ok_or_else(|| anyhow!("missing target"))?;
|
||||
let buffer_id = BufferId::new(target.buffer_id)?;
|
||||
let buffer = lsp_store
|
||||
.update(cx, |lsp_store, cx| {
|
||||
lsp_store.wait_for_remote_buffer(buffer_id, cx)
|
||||
})?
|
||||
.await?;
|
||||
let start = target
|
||||
.start
|
||||
.and_then(deserialize_anchor)
|
||||
.ok_or_else(|| anyhow!("missing target start"))?;
|
||||
let end = target
|
||||
.end
|
||||
.and_then(deserialize_anchor)
|
||||
.ok_or_else(|| anyhow!("missing target end"))?;
|
||||
buffer
|
||||
.update(cx, |buffer, _| buffer.wait_for_anchors([start, end]))?
|
||||
.await?;
|
||||
let target = Location {
|
||||
buffer,
|
||||
range: start..end,
|
||||
};
|
||||
Ok(LocationLink { origin, target })
|
||||
})
|
||||
}
|
||||
|
||||
async fn location_links_from_lsp(
|
||||
pub async fn location_links_from_lsp(
|
||||
message: Option<lsp::GotoDefinitionResponse>,
|
||||
lsp_store: Entity<LspStore>,
|
||||
buffer: Entity<Buffer>,
|
||||
@@ -1178,7 +1180,7 @@ pub async fn location_link_from_lsp(
|
||||
})
|
||||
}
|
||||
|
||||
fn location_links_to_proto(
|
||||
pub fn location_links_to_proto(
|
||||
links: Vec<LocationLink>,
|
||||
lsp_store: &mut LspStore,
|
||||
peer_id: PeerId,
|
||||
|
||||
@@ -8,6 +8,7 @@ use crate::{
|
||||
buffer_store::{BufferStore, BufferStoreEvent},
|
||||
environment::ProjectEnvironment,
|
||||
lsp_command::{self, *},
|
||||
lsp_store,
|
||||
manifest_tree::{AdapterQuery, LanguageServerTree, LaunchDisposition, ManifestTree},
|
||||
prettier_store::{self, PrettierStore, PrettierStoreEvent},
|
||||
project_settings::{LspSettings, ProjectSettings},
|
||||
@@ -3396,7 +3397,7 @@ pub struct LanguageServerStatus {
|
||||
pub name: String,
|
||||
pub pending_work: BTreeMap<String, LanguageServerProgress>,
|
||||
pub has_pending_diagnostic_updates: bool,
|
||||
pub progress_tokens: HashSet<String>,
|
||||
progress_tokens: HashSet<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -3449,8 +3450,14 @@ impl LspStore {
|
||||
client.add_entity_request_handler(Self::handle_lsp_command::<PerformRename>);
|
||||
client.add_entity_request_handler(Self::handle_lsp_command::<LinkedEditingRange>);
|
||||
|
||||
client.add_entity_request_handler(Self::handle_lsp_ext_cancel_flycheck);
|
||||
client.add_entity_request_handler(Self::handle_lsp_ext_run_flycheck);
|
||||
client.add_entity_request_handler(Self::handle_lsp_ext_clear_flycheck);
|
||||
client.add_entity_request_handler(Self::handle_lsp_command::<lsp_ext_command::ExpandMacro>);
|
||||
client.add_entity_request_handler(Self::handle_lsp_command::<lsp_ext_command::OpenDocs>);
|
||||
client.add_entity_request_handler(
|
||||
Self::handle_lsp_command::<lsp_ext_command::GoToParentModule>,
|
||||
);
|
||||
client.add_entity_request_handler(
|
||||
Self::handle_lsp_command::<lsp_ext_command::GetLspRunnables>,
|
||||
);
|
||||
@@ -3791,13 +3798,11 @@ impl LspStore {
|
||||
evt: &extension::Event,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
#[expect(
|
||||
irrefutable_let_patterns,
|
||||
reason = "Make sure to handle new event types in extension properly"
|
||||
)]
|
||||
let extension::Event::ExtensionsInstalledChanged = evt else {
|
||||
return;
|
||||
};
|
||||
match evt {
|
||||
extension::Event::ExtensionInstalled(_)
|
||||
| extension::Event::ConfigureExtensionRequested(_) => return,
|
||||
extension::Event::ExtensionsInstalledChanged => {}
|
||||
}
|
||||
if self.as_local().is_none() {
|
||||
return;
|
||||
}
|
||||
@@ -6238,13 +6243,6 @@ impl LspStore {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn language_server_with_name(&self, name: &str, cx: &App) -> Option<LanguageServerId> {
|
||||
self.as_local()?
|
||||
.lsp_tree
|
||||
.read(cx)
|
||||
.server_id_for_name(&LanguageServerName::from(name))
|
||||
}
|
||||
|
||||
pub fn language_servers_for_local_buffer<'a>(
|
||||
&'a self,
|
||||
buffer: &Buffer,
|
||||
@@ -7030,37 +7028,26 @@ impl LspStore {
|
||||
mut cx: AsyncApp,
|
||||
) -> Result<proto::LanguageServerIdForNameResponse> {
|
||||
let name = &envelope.payload.name;
|
||||
match envelope.payload.buffer_id {
|
||||
Some(buffer_id) => {
|
||||
let buffer_id = BufferId::new(buffer_id)?;
|
||||
lsp_store
|
||||
.update(&mut cx, |lsp_store, cx| {
|
||||
let buffer = lsp_store.buffer_store.read(cx).get_existing(buffer_id)?;
|
||||
let server_id = buffer.update(cx, |buffer, cx| {
|
||||
lsp_store
|
||||
.language_servers_for_local_buffer(buffer, cx)
|
||||
.find_map(|(adapter, server)| {
|
||||
if adapter.name.0.as_ref() == name {
|
||||
Some(server.server_id())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
});
|
||||
Ok(server_id)
|
||||
})?
|
||||
.map(|server_id| proto::LanguageServerIdForNameResponse {
|
||||
server_id: server_id.map(|id| id.to_proto()),
|
||||
})
|
||||
}
|
||||
None => lsp_store.update(&mut cx, |lsp_store, cx| {
|
||||
proto::LanguageServerIdForNameResponse {
|
||||
server_id: lsp_store
|
||||
.language_server_with_name(name, cx)
|
||||
.map(|id| id.to_proto()),
|
||||
}
|
||||
}),
|
||||
}
|
||||
let buffer_id = BufferId::new(envelope.payload.buffer_id)?;
|
||||
lsp_store
|
||||
.update(&mut cx, |lsp_store, cx| {
|
||||
let buffer = lsp_store.buffer_store.read(cx).get_existing(buffer_id)?;
|
||||
let server_id = buffer.update(cx, |buffer, cx| {
|
||||
lsp_store
|
||||
.language_servers_for_local_buffer(buffer, cx)
|
||||
.find_map(|(adapter, server)| {
|
||||
if adapter.name.0.as_ref() == name {
|
||||
Some(server.server_id())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
});
|
||||
Ok(server_id)
|
||||
})?
|
||||
.map(|server_id| proto::LanguageServerIdForNameResponse {
|
||||
server_id: server_id.map(|id| id.to_proto()),
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_rename_project_entry(
|
||||
@@ -7284,6 +7271,77 @@ impl LspStore {
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_lsp_ext_cancel_flycheck(
|
||||
lsp_store: Entity<Self>,
|
||||
envelope: TypedEnvelope<proto::LspExtCancelFlycheck>,
|
||||
mut cx: AsyncApp,
|
||||
) -> Result<proto::Ack> {
|
||||
let server_id = LanguageServerId(envelope.payload.language_server_id as usize);
|
||||
lsp_store.update(&mut cx, |lsp_store, _| {
|
||||
if let Some(server) = lsp_store.language_server_for_id(server_id) {
|
||||
server
|
||||
.notify::<lsp_store::lsp_ext_command::LspExtCancelFlycheck>(&())
|
||||
.context("handling lsp ext cancel flycheck")
|
||||
} else {
|
||||
anyhow::Ok(())
|
||||
}
|
||||
})??;
|
||||
|
||||
Ok(proto::Ack {})
|
||||
}
|
||||
|
||||
async fn handle_lsp_ext_run_flycheck(
|
||||
lsp_store: Entity<Self>,
|
||||
envelope: TypedEnvelope<proto::LspExtRunFlycheck>,
|
||||
mut cx: AsyncApp,
|
||||
) -> Result<proto::Ack> {
|
||||
let server_id = LanguageServerId(envelope.payload.language_server_id as usize);
|
||||
lsp_store.update(&mut cx, |lsp_store, cx| {
|
||||
if let Some(server) = lsp_store.language_server_for_id(server_id) {
|
||||
let text_document = if envelope.payload.current_file_only {
|
||||
let buffer_id = BufferId::new(envelope.payload.buffer_id)?;
|
||||
lsp_store
|
||||
.buffer_store()
|
||||
.read(cx)
|
||||
.get(buffer_id)
|
||||
.and_then(|buffer| Some(buffer.read(cx).file()?.as_local()?.abs_path(cx)))
|
||||
.map(|path| make_text_document_identifier(&path))
|
||||
.transpose()?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
server
|
||||
.notify::<lsp_store::lsp_ext_command::LspExtRunFlycheck>(
|
||||
&lsp_store::lsp_ext_command::RunFlycheckParams { text_document },
|
||||
)
|
||||
.context("handling lsp ext run flycheck")
|
||||
} else {
|
||||
anyhow::Ok(())
|
||||
}
|
||||
})??;
|
||||
|
||||
Ok(proto::Ack {})
|
||||
}
|
||||
|
||||
async fn handle_lsp_ext_clear_flycheck(
|
||||
lsp_store: Entity<Self>,
|
||||
envelope: TypedEnvelope<proto::LspExtClearFlycheck>,
|
||||
mut cx: AsyncApp,
|
||||
) -> Result<proto::Ack> {
|
||||
let server_id = LanguageServerId(envelope.payload.language_server_id as usize);
|
||||
lsp_store.update(&mut cx, |lsp_store, _| {
|
||||
if let Some(server) = lsp_store.language_server_for_id(server_id) {
|
||||
server
|
||||
.notify::<lsp_store::lsp_ext_command::LspExtClearFlycheck>(&())
|
||||
.context("handling lsp ext clear flycheck")
|
||||
} else {
|
||||
anyhow::Ok(())
|
||||
}
|
||||
})??;
|
||||
|
||||
Ok(proto::Ack {})
|
||||
}
|
||||
|
||||
pub fn disk_based_diagnostics_started(
|
||||
&mut self,
|
||||
language_server_id: LanguageServerId,
|
||||
@@ -7536,7 +7594,7 @@ impl LspStore {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn on_lsp_progress(
|
||||
fn on_lsp_progress(
|
||||
&mut self,
|
||||
progress: lsp::ProgressParams,
|
||||
language_server_id: LanguageServerId,
|
||||
|
||||
@@ -2,9 +2,10 @@ use crate::{
|
||||
LocationLink,
|
||||
lsp_command::{
|
||||
LspCommand, location_link_from_lsp, location_link_from_proto, location_link_to_proto,
|
||||
location_links_from_lsp, location_links_from_proto, location_links_to_proto,
|
||||
},
|
||||
lsp_store::LspStore,
|
||||
make_text_document_identifier,
|
||||
make_lsp_text_document_position, make_text_document_identifier,
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use async_trait::async_trait;
|
||||
@@ -24,9 +25,9 @@ use std::{
|
||||
use task::TaskTemplate;
|
||||
use text::{BufferId, PointUtf16, ToPointUtf16};
|
||||
|
||||
pub enum LspExpandMacro {}
|
||||
pub enum LspExtExpandMacro {}
|
||||
|
||||
impl lsp::request::Request for LspExpandMacro {
|
||||
impl lsp::request::Request for LspExtExpandMacro {
|
||||
type Params = ExpandMacroParams;
|
||||
type Result = Option<ExpandedMacro>;
|
||||
const METHOD: &'static str = "rust-analyzer/expandMacro";
|
||||
@@ -59,7 +60,7 @@ pub struct ExpandMacro {
|
||||
#[async_trait(?Send)]
|
||||
impl LspCommand for ExpandMacro {
|
||||
type Response = ExpandedMacro;
|
||||
type LspRequest = LspExpandMacro;
|
||||
type LspRequest = LspExtExpandMacro;
|
||||
type ProtoRequest = proto::LspExtExpandMacro;
|
||||
|
||||
fn display_name(&self) -> &str {
|
||||
@@ -301,6 +302,19 @@ pub struct SwitchSourceHeaderResult(pub String);
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SwitchSourceHeader;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GoToParentModule {
|
||||
pub position: PointUtf16,
|
||||
}
|
||||
|
||||
pub struct LspGoToParentModule {}
|
||||
|
||||
impl lsp::request::Request for LspGoToParentModule {
|
||||
type Params = lsp::TextDocumentPositionParams;
|
||||
type Result = Option<Vec<lsp::LocationLink>>;
|
||||
const METHOD: &'static str = "experimental/parentModule";
|
||||
}
|
||||
|
||||
#[async_trait(?Send)]
|
||||
impl LspCommand for SwitchSourceHeader {
|
||||
type Response = SwitchSourceHeaderResult;
|
||||
@@ -379,6 +393,96 @@ impl LspCommand for SwitchSourceHeader {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait(?Send)]
|
||||
impl LspCommand for GoToParentModule {
|
||||
type Response = Vec<LocationLink>;
|
||||
type LspRequest = LspGoToParentModule;
|
||||
type ProtoRequest = proto::LspExtGoToParentModule;
|
||||
|
||||
fn display_name(&self) -> &str {
|
||||
"Go to parent module"
|
||||
}
|
||||
|
||||
fn to_lsp(
|
||||
&self,
|
||||
path: &Path,
|
||||
_: &Buffer,
|
||||
_: &Arc<LanguageServer>,
|
||||
_: &App,
|
||||
) -> Result<lsp::TextDocumentPositionParams> {
|
||||
make_lsp_text_document_position(path, self.position)
|
||||
}
|
||||
|
||||
async fn response_from_lsp(
|
||||
self,
|
||||
links: Option<Vec<lsp::LocationLink>>,
|
||||
lsp_store: Entity<LspStore>,
|
||||
buffer: Entity<Buffer>,
|
||||
server_id: LanguageServerId,
|
||||
cx: AsyncApp,
|
||||
) -> anyhow::Result<Vec<LocationLink>> {
|
||||
location_links_from_lsp(
|
||||
links.map(lsp::GotoDefinitionResponse::Link),
|
||||
lsp_store,
|
||||
buffer,
|
||||
server_id,
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn to_proto(&self, project_id: u64, buffer: &Buffer) -> proto::LspExtGoToParentModule {
|
||||
proto::LspExtGoToParentModule {
|
||||
project_id,
|
||||
buffer_id: buffer.remote_id().to_proto(),
|
||||
position: Some(language::proto::serialize_anchor(
|
||||
&buffer.anchor_before(self.position),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn from_proto(
|
||||
request: Self::ProtoRequest,
|
||||
_: Entity<LspStore>,
|
||||
buffer: Entity<Buffer>,
|
||||
mut cx: AsyncApp,
|
||||
) -> anyhow::Result<Self> {
|
||||
let position = request
|
||||
.position
|
||||
.and_then(deserialize_anchor)
|
||||
.context("bad request with bad position")?;
|
||||
Ok(Self {
|
||||
position: buffer.update(&mut cx, |buffer, _| position.to_point_utf16(buffer))?,
|
||||
})
|
||||
}
|
||||
|
||||
fn response_to_proto(
|
||||
links: Vec<LocationLink>,
|
||||
lsp_store: &mut LspStore,
|
||||
peer_id: PeerId,
|
||||
_: &clock::Global,
|
||||
cx: &mut App,
|
||||
) -> proto::LspExtGoToParentModuleResponse {
|
||||
proto::LspExtGoToParentModuleResponse {
|
||||
links: location_links_to_proto(links, lsp_store, peer_id, cx),
|
||||
}
|
||||
}
|
||||
|
||||
async fn response_from_proto(
|
||||
self,
|
||||
message: proto::LspExtGoToParentModuleResponse,
|
||||
lsp_store: Entity<LspStore>,
|
||||
_: Entity<Buffer>,
|
||||
cx: AsyncApp,
|
||||
) -> anyhow::Result<Vec<LocationLink>> {
|
||||
location_links_from_proto(message.links, lsp_store, cx).await
|
||||
}
|
||||
|
||||
fn buffer_id_from_proto(message: &proto::LspExtGoToParentModule) -> Result<BufferId> {
|
||||
BufferId::new(message.buffer_id)
|
||||
}
|
||||
}
|
||||
|
||||
// https://rust-analyzer.github.io/book/contributing/lsp-extensions.html#runnables
|
||||
// Taken from https://github.com/rust-lang/rust-analyzer/blob/a73a37a757a58b43a796d3eb86a1f7dfd0036659/crates/rust-analyzer/src/lsp/ext.rs#L425-L489
|
||||
pub enum Runnables {}
|
||||
@@ -633,7 +737,7 @@ impl LspCommand for GetLspRunnables {
|
||||
for lsp_runnable in message.runnables {
|
||||
let location = match lsp_runnable.location {
|
||||
Some(location) => {
|
||||
Some(location_link_from_proto(location, &lsp_store, &mut cx).await?)
|
||||
Some(location_link_from_proto(location, lsp_store.clone(), &mut cx).await?)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
@@ -649,3 +753,33 @@ impl LspCommand for GetLspRunnables {
|
||||
BufferId::new(message.buffer_id)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LspExtCancelFlycheck {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LspExtRunFlycheck {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LspExtClearFlycheck {}
|
||||
|
||||
impl lsp::notification::Notification for LspExtCancelFlycheck {
|
||||
type Params = ();
|
||||
const METHOD: &'static str = "rust-analyzer/cancelFlycheck";
|
||||
}
|
||||
|
||||
impl lsp::notification::Notification for LspExtRunFlycheck {
|
||||
type Params = RunFlycheckParams;
|
||||
const METHOD: &'static str = "rust-analyzer/runFlycheck";
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RunFlycheckParams {
|
||||
pub text_document: Option<lsp::TextDocumentIdentifier>,
|
||||
}
|
||||
|
||||
impl lsp::notification::Notification for LspExtClearFlycheck {
|
||||
type Params = ();
|
||||
const METHOD: &'static str = "rust-analyzer/clearFlycheck";
|
||||
}
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
use ::serde::{Deserialize, Serialize};
|
||||
use gpui::{PromptLevel, WeakEntity};
|
||||
use anyhow::Context as _;
|
||||
use gpui::{App, Entity, PromptLevel, Task, WeakEntity};
|
||||
use lsp::LanguageServer;
|
||||
use rpc::proto;
|
||||
|
||||
use crate::{LanguageServerPromptRequest, LspStore, LspStoreEvent};
|
||||
use crate::{
|
||||
LanguageServerPromptRequest, LspStore, LspStoreEvent, Project, ProjectPath, lsp_store,
|
||||
};
|
||||
|
||||
pub const RUST_ANALYZER_NAME: &str = "rust-analyzer";
|
||||
pub const CARGO_DIAGNOSTICS_SOURCE_NAME: &str = "rustc";
|
||||
@@ -79,3 +83,161 @@ pub fn register_notifications(lsp_store: WeakEntity<LspStore>, language_server:
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
pub fn cancel_flycheck(
|
||||
project: Entity<Project>,
|
||||
buffer_path: ProjectPath,
|
||||
cx: &mut App,
|
||||
) -> Task<anyhow::Result<()>> {
|
||||
let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client();
|
||||
let lsp_store = project.read(cx).lsp_store();
|
||||
let buffer = project.update(cx, |project, cx| {
|
||||
project.buffer_store().update(cx, |buffer_store, cx| {
|
||||
buffer_store.open_buffer(buffer_path, cx)
|
||||
})
|
||||
});
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let buffer = buffer.await?;
|
||||
let Some(rust_analyzer_server) = project
|
||||
.update(cx, |project, cx| {
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
project.language_server_id_for_name(buffer, RUST_ANALYZER_NAME, cx)
|
||||
})
|
||||
})?
|
||||
.await
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
let buffer_id = buffer.update(cx, |buffer, _| buffer.remote_id().to_proto())?;
|
||||
|
||||
if let Some((client, project_id)) = upstream_client {
|
||||
let request = proto::LspExtCancelFlycheck {
|
||||
project_id,
|
||||
buffer_id,
|
||||
language_server_id: rust_analyzer_server.to_proto(),
|
||||
};
|
||||
client
|
||||
.request(request)
|
||||
.await
|
||||
.context("lsp ext cancel flycheck proto request")?;
|
||||
} else {
|
||||
lsp_store
|
||||
.update(cx, |lsp_store, _| {
|
||||
if let Some(server) = lsp_store.language_server_for_id(rust_analyzer_server) {
|
||||
server.notify::<lsp_store::lsp_ext_command::LspExtCancelFlycheck>(&())?;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})?
|
||||
.context("lsp ext cancel flycheck")?;
|
||||
};
|
||||
anyhow::Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn run_flycheck(
|
||||
project: Entity<Project>,
|
||||
buffer_path: ProjectPath,
|
||||
cx: &mut App,
|
||||
) -> Task<anyhow::Result<()>> {
|
||||
let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client();
|
||||
let lsp_store = project.read(cx).lsp_store();
|
||||
let buffer = project.update(cx, |project, cx| {
|
||||
project.buffer_store().update(cx, |buffer_store, cx| {
|
||||
buffer_store.open_buffer(buffer_path, cx)
|
||||
})
|
||||
});
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let buffer = buffer.await?;
|
||||
let Some(rust_analyzer_server) = project
|
||||
.update(cx, |project, cx| {
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
project.language_server_id_for_name(buffer, RUST_ANALYZER_NAME, cx)
|
||||
})
|
||||
})?
|
||||
.await
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
let buffer_id = buffer.update(cx, |buffer, _| buffer.remote_id().to_proto())?;
|
||||
|
||||
if let Some((client, project_id)) = upstream_client {
|
||||
let request = proto::LspExtRunFlycheck {
|
||||
project_id,
|
||||
buffer_id,
|
||||
language_server_id: rust_analyzer_server.to_proto(),
|
||||
current_file_only: false,
|
||||
};
|
||||
client
|
||||
.request(request)
|
||||
.await
|
||||
.context("lsp ext run flycheck proto request")?;
|
||||
} else {
|
||||
lsp_store
|
||||
.update(cx, |lsp_store, _| {
|
||||
if let Some(server) = lsp_store.language_server_for_id(rust_analyzer_server) {
|
||||
server.notify::<lsp_store::lsp_ext_command::LspExtRunFlycheck>(
|
||||
&lsp_store::lsp_ext_command::RunFlycheckParams {
|
||||
text_document: None,
|
||||
},
|
||||
)?;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})?
|
||||
.context("lsp ext run flycheck")?;
|
||||
};
|
||||
anyhow::Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn clear_flycheck(
|
||||
project: Entity<Project>,
|
||||
buffer_path: ProjectPath,
|
||||
cx: &mut App,
|
||||
) -> Task<anyhow::Result<()>> {
|
||||
let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client();
|
||||
let lsp_store = project.read(cx).lsp_store();
|
||||
let buffer = project.update(cx, |project, cx| {
|
||||
project.buffer_store().update(cx, |buffer_store, cx| {
|
||||
buffer_store.open_buffer(buffer_path, cx)
|
||||
})
|
||||
});
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let buffer = buffer.await?;
|
||||
let Some(rust_analyzer_server) = project
|
||||
.update(cx, |project, cx| {
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
project.language_server_id_for_name(buffer, RUST_ANALYZER_NAME, cx)
|
||||
})
|
||||
})?
|
||||
.await
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
let buffer_id = buffer.update(cx, |buffer, _| buffer.remote_id().to_proto())?;
|
||||
|
||||
if let Some((client, project_id)) = upstream_client {
|
||||
let request = proto::LspExtClearFlycheck {
|
||||
project_id,
|
||||
buffer_id,
|
||||
language_server_id: rust_analyzer_server.to_proto(),
|
||||
};
|
||||
client
|
||||
.request(request)
|
||||
.await
|
||||
.context("lsp ext clear flycheck proto request")?;
|
||||
} else {
|
||||
lsp_store
|
||||
.update(cx, |lsp_store, _| {
|
||||
if let Some(server) = lsp_store.language_server_for_id(rust_analyzer_server) {
|
||||
server.notify::<lsp_store::lsp_ext_command::LspExtClearFlycheck>(&())?;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})?
|
||||
.context("lsp ext clear flycheck")?;
|
||||
};
|
||||
anyhow::Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -247,20 +247,6 @@ impl LanguageServerTree {
|
||||
self.languages.adapter_for_name(name)
|
||||
}
|
||||
|
||||
pub fn server_id_for_name(&self, name: &LanguageServerName) -> Option<LanguageServerId> {
|
||||
self.instances
|
||||
.values()
|
||||
.flat_map(|instance| instance.roots.values())
|
||||
.flatten()
|
||||
.find_map(|(server_name, (data, _))| {
|
||||
if server_name == name {
|
||||
data.id.get().copied()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn adapters_for_language(
|
||||
&self,
|
||||
settings_location: SettingsLocation,
|
||||
|
||||
@@ -4748,42 +4748,6 @@ impl Project {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn language_server_with_name(
|
||||
&self,
|
||||
name: &str,
|
||||
cx: &App,
|
||||
) -> Task<Option<LanguageServerId>> {
|
||||
if self.is_local() {
|
||||
Task::ready(self.lsp_store.read(cx).language_server_with_name(name, cx))
|
||||
} else if let Some(project_id) = self.remote_id() {
|
||||
let request = self.client.request(proto::LanguageServerIdForName {
|
||||
project_id,
|
||||
buffer_id: None,
|
||||
name: name.to_string(),
|
||||
});
|
||||
cx.background_spawn(async move {
|
||||
let response = request.await.log_err()?;
|
||||
response.server_id.map(LanguageServerId::from_proto)
|
||||
})
|
||||
} else if let Some(ssh_client) = self.ssh_client.as_ref() {
|
||||
let request =
|
||||
ssh_client
|
||||
.read(cx)
|
||||
.proto_client()
|
||||
.request(proto::LanguageServerIdForName {
|
||||
project_id: SSH_PROJECT_ID,
|
||||
buffer_id: None,
|
||||
name: name.to_string(),
|
||||
});
|
||||
cx.background_spawn(async move {
|
||||
let response = request.await.log_err()?;
|
||||
response.server_id.map(LanguageServerId::from_proto)
|
||||
})
|
||||
} else {
|
||||
Task::ready(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn language_server_id_for_name(
|
||||
&self,
|
||||
buffer: &Buffer,
|
||||
@@ -4805,7 +4769,7 @@ impl Project {
|
||||
} else if let Some(project_id) = self.remote_id() {
|
||||
let request = self.client.request(proto::LanguageServerIdForName {
|
||||
project_id,
|
||||
buffer_id: Some(buffer.remote_id().to_proto()),
|
||||
buffer_id: buffer.remote_id().to_proto(),
|
||||
name: name.to_string(),
|
||||
});
|
||||
cx.background_spawn(async move {
|
||||
@@ -4819,7 +4783,7 @@ impl Project {
|
||||
.proto_client()
|
||||
.request(proto::LanguageServerIdForName {
|
||||
project_id: SSH_PROJECT_ID,
|
||||
buffer_id: Some(buffer.remote_id().to_proto()),
|
||||
buffer_id: buffer.remote_id().to_proto(),
|
||||
name: name.to_string(),
|
||||
});
|
||||
cx.background_spawn(async move {
|
||||
|
||||
@@ -129,7 +129,7 @@ pub struct InlineDiagnosticsSettings {
|
||||
/// Default: false
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// Whether to only show the inline diaganostics after a delay after the
|
||||
/// Whether to only show the inline diagnostics after a delay after the
|
||||
/// last editor event.
|
||||
///
|
||||
/// Default: 150
|
||||
@@ -155,37 +155,12 @@ pub struct InlineDiagnosticsSettings {
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct CargoDiagnosticsSettings {
|
||||
/// When enabled, Zed runs `cargo check --message-format=json`-based commands and
|
||||
/// collect cargo diagnostics instead of rust-analyzer.
|
||||
/// When enabled, Zed disables rust-analyzer's check on save and starts to query
|
||||
/// Cargo diagnostics separately.
|
||||
///
|
||||
/// Default: false
|
||||
#[serde(default)]
|
||||
pub fetch_cargo_diagnostics: bool,
|
||||
|
||||
/// A command override for fetching the cargo diagnostics.
|
||||
/// First argument is the command, followed by the arguments.
|
||||
///
|
||||
/// Default: ["cargo", "check", "--quiet", "--workspace", "--message-format=json", "--all-targets", "--keep-going"]
|
||||
#[serde(default = "default_diagnostics_fetch_command")]
|
||||
pub diagnostics_fetch_command: Vec<String>,
|
||||
|
||||
/// Extra environment variables to pass to the diagnostics fetch command.
|
||||
///
|
||||
/// Default: {}
|
||||
#[serde(default)]
|
||||
pub env: HashMap<String, String>,
|
||||
}
|
||||
|
||||
fn default_diagnostics_fetch_command() -> Vec<String> {
|
||||
vec![
|
||||
"cargo".to_string(),
|
||||
"check".to_string(),
|
||||
"--quiet".to_string(),
|
||||
"--workspace".to_string(),
|
||||
"--message-format=json".to_string(),
|
||||
"--all-targets".to_string(),
|
||||
"--keep-going".to_string(),
|
||||
]
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema)]
|
||||
|
||||
@@ -179,6 +179,14 @@ impl TaskContexts {
|
||||
})
|
||||
.copied()
|
||||
}
|
||||
|
||||
pub fn task_context_for_worktree_id(&self, worktree_id: WorktreeId) -> Option<&TaskContext> {
|
||||
self.active_worktree_context
|
||||
.iter()
|
||||
.chain(self.other_worktree_contexts.iter())
|
||||
.find(|(id, _)| *id == worktree_id)
|
||||
.map(|(_, context)| context)
|
||||
}
|
||||
}
|
||||
|
||||
impl TaskSourceKind {
|
||||
@@ -206,13 +214,15 @@ impl Inventory {
|
||||
cx.new(|_| Self::default())
|
||||
}
|
||||
|
||||
pub fn list_debug_scenarios(&self, worktree: Option<WorktreeId>) -> Vec<DebugScenario> {
|
||||
pub fn list_debug_scenarios(
|
||||
&self,
|
||||
worktrees: impl Iterator<Item = WorktreeId>,
|
||||
) -> Vec<(TaskSourceKind, DebugScenario)> {
|
||||
let global_scenarios = self.global_debug_scenarios_from_settings();
|
||||
let worktree_scenarios = self.worktree_scenarios_from_settings(worktree);
|
||||
|
||||
worktree_scenarios
|
||||
worktrees
|
||||
.flat_map(|tree_id| self.worktree_scenarios_from_settings(Some(tree_id)))
|
||||
.chain(global_scenarios)
|
||||
.map(|(_, scenario)| scenario)
|
||||
.collect()
|
||||
}
|
||||
|
||||
|
||||
@@ -60,6 +60,8 @@ struct PromptTemplateContext {
|
||||
|
||||
#[serde(flatten)]
|
||||
model: ModelContext,
|
||||
|
||||
has_tools: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
@@ -328,6 +330,7 @@ impl PromptBuilder {
|
||||
let template_context = PromptTemplateContext {
|
||||
project: context.clone(),
|
||||
model: model_context.clone(),
|
||||
has_tools: !model_context.available_tools.is_empty(),
|
||||
};
|
||||
|
||||
self.handlebars
|
||||
|
||||
@@ -75,7 +75,7 @@ message GetPermalinkToLineResponse {
|
||||
|
||||
message Branch {
|
||||
bool is_head = 1;
|
||||
string name = 2;
|
||||
string ref_name = 2;
|
||||
optional uint64 unix_timestamp = 3;
|
||||
optional GitUpstream upstream = 4;
|
||||
optional CommitSummary most_recent_commit = 5;
|
||||
|
||||
@@ -182,6 +182,16 @@ message LspExtSwitchSourceHeaderResponse {
|
||||
string target_file = 1;
|
||||
}
|
||||
|
||||
message LspExtGoToParentModule {
|
||||
uint64 project_id = 1;
|
||||
uint64 buffer_id = 2;
|
||||
Anchor position = 3;
|
||||
}
|
||||
|
||||
message LspExtGoToParentModuleResponse {
|
||||
repeated LocationLink links = 1;
|
||||
}
|
||||
|
||||
message GetCompletionsResponse {
|
||||
repeated Completion completions = 1;
|
||||
repeated VectorClockEntry version = 2;
|
||||
@@ -696,7 +706,7 @@ message LspResponse {
|
||||
|
||||
message LanguageServerIdForName {
|
||||
uint64 project_id = 1;
|
||||
optional uint64 buffer_id = 2;
|
||||
uint64 buffer_id = 2;
|
||||
string name = 3;
|
||||
}
|
||||
|
||||
@@ -718,3 +728,22 @@ message LspRunnable {
|
||||
bytes task_template = 1;
|
||||
optional LocationLink location = 2;
|
||||
}
|
||||
|
||||
message LspExtCancelFlycheck {
|
||||
uint64 project_id = 1;
|
||||
uint64 buffer_id = 2;
|
||||
uint64 language_server_id = 3;
|
||||
}
|
||||
|
||||
message LspExtRunFlycheck {
|
||||
uint64 project_id = 1;
|
||||
uint64 buffer_id = 2;
|
||||
uint64 language_server_id = 3;
|
||||
bool current_file_only = 4;
|
||||
}
|
||||
|
||||
message LspExtClearFlycheck {
|
||||
uint64 project_id = 1;
|
||||
uint64 buffer_id = 2;
|
||||
uint64 language_server_id = 3;
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user