Compare commits

..

30 Commits

Author SHA1 Message Date
Antonio Scandurra
7ab62666c3 Emit agent's position when streaming edits
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
2025-05-02 12:11:29 +02:00
Antonio Scandurra
35539847a4 Allow StreamingEditFileTool to also create files (#29785)
Refs #29733 

This pull request introduces a new field to the `StreamingEditFileTool`
that lets the model create or overwrite a file in a streaming way. When
one of the `assistant.stream_edits` setting / `agent-stream-edits`
feature flag is enabled, we are going to disable the `CreateFileTool` so
that the agent model can only use `StreamingEditFileTool` for file
creation.

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
2025-05-02 09:57:04 +00:00
Anthony Eid
f619d5f02a debugger: Add debug task picker to new session modal (#29702)
## Preview 

![image](https://github.com/user-attachments/assets/203a577f-3b38-4017-9571-de1234415162)


### TODO
- [x] Add scenario picker to new session modal
- [x] Make debugger start action open new session modal instead of task
modal
- [x] Fix `esc` not clearing the cancelling the new session modal while
it's in scenario or attach mode
- [x] Resolve debug scenario's correctly

Release Notes:

- N/A
2025-05-02 08:38:29 +00:00
Kirill Bulatov
ba59305510 Use rust-analyzer's flycheck as source of cargo diagnostics (#29779)
Follow-up of https://github.com/zed-industries/zed/pull/29706

Instead of doing `cargo check` manually, use rust-analyzer's flycheck:
at the cost of more sophisticated check command configuration, we keep
much less code in Zed, and get a proper progress report.

User-facing UI does not change except `diagnostics_fetch_command` and
`env` settings removed from the diagnostics settings.

Release Notes:

- N/A
2025-05-02 10:07:51 +03:00
Nate Butler
672a1dd553 Add Agent Preview trait (#29760)
Like the title says

Release Notes:

- N/A
2025-05-01 23:03:06 -04:00
Marshall Bowers
93cc4946d8 agent: Make thread completion mode non-optional (#29772)
This PR makes the thread completion mode non-optional.

Release Notes:

- N/A
2025-05-02 02:41:54 +00:00
Marshall Bowers
0c0a4ed866 collab: Return increased limit for extended trials from GET /billing/usage (#29771)
This PR updates the `GET /billing/usage` endpoint to return the
increased usage limit for users in the extended trial.

Release Notes:

- N/A
2025-05-02 02:31:30 +00:00
Marshall Bowers
51f1998107 Fix typo in typos.toml (#29770)
This PR fixes a typo in `typos.toml`. How ironic.

Release Notes:

- N/A
2025-05-02 02:01:07 +00:00
Marshall Bowers
1ffedf4a08 collab: Add endpoint for migrating users to new billing (#29769)
This PR adds a new `POST /billing/subscriptions/migrate` endpoint for
migrating users to the new billing system.

When called with a GitHub user ID this endpoint will:

1. Find the active billing subscription for this user (if they have one)
2. Cancel the subscription and send a final invoice
3. Ensure the user is in the `new-billing` and `assistant2` feature
flags

Release Notes:

- N/A
2025-05-02 01:47:09 +00:00
Cole Miller
d25da9728b Run additional checks from script/clippy if local (#29768)
Should cut down on the number of CI cycles if you're forgetful like I
am!

Release Notes:

- N/A
2025-05-02 01:26:12 +00:00
Cole Miller
e1e3f2e423 Improve handling of remote-tracking branches in the picker (#29744)
Release Notes:

- Changed the git branch picker to make remote-tracking branches less
prominent

---------

Co-authored-by: Anthony Eid <hello@anthonyeid.me>
2025-05-01 21:24:26 -04:00
Finn Evers
92b9ecd7d2 agent: Do not render unnecessary lines in edit file tool card (#29766)
This PR prevents any unnecessary lines from being rendered in the edit
file tool card in the case of small diffs.

I think this (hopefully) addresses the last remaining task from
https://github.com/zed-industries/zed/pull/29448.

| `main` | This PR |
| --- | --- |
| <img width="634" alt="main"
src="https://github.com/user-attachments/assets/7c06394e-957a-4d36-a484-5974687041e9"
/> | <img width="634" alt="PR"
src="https://github.com/user-attachments/assets/84206d5a-a93a-4a42-99ca-7cdebb0d91bb"
/> |

(The last empty line in the second image is an empty line present in the
file itself)

---

n the second commit I also preemtively disabled vertical overscrolling
for full mode editors which are sized by content. This is basically the
same fix as in https://github.com/zed-industries/zed/pull/28471.
Strictly speaking, this is not needed for the fix here, but I thought it
might be nice to have for the future to prevent any issues from occuring
due to overscroll.

Release Notes:

- agent: Improved rendering of small diffs for the edit file tool card.
2025-05-01 20:40:12 -03:00
Marshall Bowers
758d260cec collab: Add ability to initiate a checkout session for the Zed Free plan (#29767)
This PR adds the ability to initiate a checkout session for the Zed Free
plan.

Release Notes:

- N/A
2025-05-01 23:35:23 +00:00
Danilo Leal
8d4d3badf3 agent: Add design adjustments to MCP config flow (#29765)
Mostly somewhat small UI tweaks around the MCP extension config flow and
the settings section.

Release Notes:

- N/A
2025-05-01 19:29:59 -03:00
Marshall Bowers
7c23d13773 agent: Render the max mode toggle using a muted color (#29763)
This PR updates the max mode toggle to use the muted color.

This makes it fit in more with the rest of the controls.

<img width="243" alt="Screenshot 2025-05-01 at 5 24 01 PM"
src="https://github.com/user-attachments/assets/57267d29-3c7b-4ea9-b6b9-81c42f6b7e1c"
/>

Release Notes:

- agent: Adjusted the color of the max mode toggle.
2025-05-01 21:40:10 +00:00
Richard Feldman
ad87c545c7 Make context pills clickable while editing (#29740)
Release Notes:

- Fixed a bug where clicking context pills switched into the "editing
message" state instead of clicking the pill.

Co-authored-by: Michael <michael@zed.dev>
Co-authored-by: Ben <ben@zed.dev>
2025-05-01 20:28:54 +00:00
Richard Feldman
23fbab15ee Manual no tool calls (#29745)
Now instead of the model hallucinating tool calls, we get requests for
more context:

<img width="620" alt="Screenshot 2025-05-01 at 12 45 49 PM"
src="https://github.com/user-attachments/assets/847d5c14-82f6-4234-b85a-8cd2bc7ab11d"
/>

It still knows how to answer general questions:
<img width="624" alt="Screenshot 2025-05-01 at 12 47 44 PM"
src="https://github.com/user-attachments/assets/43ab0fc3-4cc8-452f-b26b-474b5d31919f"
/>

Release Notes:

- Fixed the model still trying to do tool calls when no tools selected
(e.g. in `Manual` profile).

---------

Co-authored-by: Ben <ben@zed.dev>
Co-authored-by: Michael <michael@zed.dev>
2025-05-01 16:11:13 -04:00
Richard Feldman
d7e181576e Respect cursor_pointer when a ButtonLike is disabled (#29737)
This is desirable for when we want to use a `ButtonLike` to show a
tooltip over an icon, and we don't want it to show the "not allowed"
cursor on hover.

Release Notes:

- N/A
2025-05-01 15:34:40 -04:00
Eva Pace
9788aff4b1 Fix license symlinks (#29758)
Closes #29527

It looks funny in the diff, but the symlinks are indeed correct:

-
https://github.com/evaporei/zed/blob/fix/license-symlinks/crates/askpass/LICENSE-GPL
-
https://github.com/evaporei/zed/blob/fix/license-symlinks/crates/ui_macros/LICENSE-GPL

I did check all ~170 crates, these were the only inconsistent ones.

Release Notes:

- N/A
2025-05-01 19:24:14 +00:00
Kirill Bulatov
2a319efade Add editor::GoToParentModule for rust-analyzer backed projects (#29755)
Support rust-analyzer's "go to parent module" action


https://rust-analyzer.github.io/book/contributing/lsp-extensions.html#parent-module

Release Notes:

- Added `editor::GoToParentModule` for rust-analyzer backed projects

---------

Co-authored-by: Julia Ryan <juliaryan3.14@gmail.com>
2025-05-01 18:28:05 +00:00
Jonathan LEI
50ec26c163 Fix user rules ignored by agent (#29754)
Closes #29753

The template contains an error: `has_default_user_rules` is always
undefined and should be `has_user_rules` instead.

Release Notes:

- Fixed default user rules ignored during prompt building.
2025-05-01 18:22:48 +00:00
Danilo Leal
39dd133b1c agent: Remove unused agent: chat mode command palette action (#29741)
We weren't using this one anymore. We used to use it for the switch that
toggled tools on, which doesn't exist anymore.

Release Notes:

- N/A

---------

Co-authored-by: Joseph T. Lyons <josephtlyons@gmail.com>
2025-05-01 15:09:14 -03:00
Bennet Bo Fenner
24eb039752 context servers: Show configuration modal when extension is installed (#29309)
WIP

Release Notes:

- N/A

---------

Co-authored-by: Danilo Leal <67129314+danilo-leal@users.noreply.github.com>
Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Marshall Bowers <git@maxdeviant.com>
Co-authored-by: Cole Miller <m@cole-miller.net>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
2025-05-01 20:02:14 +02:00
Peter Tripp
bffa53d706 docs: Reorder macOS development documentation (#29751)
Release Notes:

- N/A
2025-05-01 17:34:17 +00:00
Bennet Bo Fenner
0e5e8f9f8d Allow MIT-0 license in checks (#29748)
Part of #29309

The license is on par with other licenses in the list:
https://github.com/aws/mit-0

Release Notes:

- N/A
2025-05-01 17:30:16 +00:00
Danilo Leal
96d785cb45 git: Improve co-author button (#29742)
This PR changes the tooltip label to say "Remove" when you have the
button toggled on and collaborators in the list.

Release Notes:

- N/A

Co-authored-by: Joseph T. Lyons <josephtlyons@gmail.com>
2025-05-01 14:12:52 -03:00
Marshall Bowers
57610c9935 collab: Add billing thresholds to request overage subscription items (#29738)
This PR adds billing thresholds of the unit equivalent of $20 for model
request overages.

Release Notes:

- N/A
2025-05-01 16:10:06 +00:00
Marshall Bowers
5bf1b4f0a8 collab: Add use_new_billing to LlmTokenClaims (#29739)
This PR adds a `use_new_billing` field to the LLM token claims, based on
the `new-billing` feature flag.

Release Notes:

- N/A
2025-05-01 15:43:53 +00:00
Antonio Scandurra
f891dfb358 Introduce a new StreamingEditFileTool (#29733)
This pull request introduces a new tool for streaming edits. The
short-term goal is for this tool to replace the existing `EditFileTool`,
but we want to get this out the door as soon as possible so that we can
start testing it.

`StreamingEditFileTool` is mutually exclusive with `EditFileTool`. It
will be enabled by default for anyone who has the `agent-stream-edits`
feature flag, as well as people that set `assistant.stream_edits` to
`true` in their settings.

### Implementation

Streaming is achieved by requesting a completion while the `edit_file`
tool gets called. We invoke the model by taking the existing
conversation with the agent and appending a prompt specifically tailored
for editing. In that prompt, we ask the model to produce a stream of
`<old_text>`/`<new_text>` tags. As the model streams text in, we
incrementally parse it and start editing as soon as we can.

### Evals

Note that, as part of this pull request, I also defined some new evals
that I used to drive the behavior of the recursive LLM call. To run
them, use this command:

```bash
cargo test --package=assistant_tools --features eval -- eval_extract_handle_command_output
```

Or comment out the `#[cfg_attr(not(feature = "eval"), ignore)]` macro.

I recommend running them one at a time, because right now we don't
really have a way of orchestrating of all these evals. I think we should
invest into that effort once the new agent panel goes live.

Release Notes:

- N/A

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
2025-05-01 17:37:43 +02:00
Ben Kunkle
e3a2d52472 zlog: Fall back to printing module path instead of *unknown* or just crate name (#29691)
Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...
2025-05-01 10:59:51 -04:00
142 changed files with 56279 additions and 2082 deletions

172
Cargo.lock generated
View File

@@ -68,6 +68,7 @@ dependencies = [
"convert_case 0.8.0",
"db",
"editor",
"extension",
"feature_flags",
"file_icons",
"fs",
@@ -81,6 +82,7 @@ dependencies = [
"indexmap",
"indoc",
"itertools 0.14.0",
"jsonschema",
"language",
"language_model",
"language_model_selector",
@@ -90,6 +92,7 @@ dependencies = [
"markdown",
"menu",
"multi_buffer",
"notifications",
"ordered-float 2.10.1",
"parking_lot",
"paths",
@@ -106,6 +109,7 @@ dependencies = [
"schemars",
"serde",
"serde_json",
"serde_json_lenient",
"settings",
"smallvec",
"smol",
@@ -148,7 +152,9 @@ checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011"
dependencies = [
"cfg-if",
"const-random",
"getrandom 0.2.15",
"once_cell",
"serde",
"version_check",
"zerocopy 0.7.35",
]
@@ -704,7 +710,9 @@ dependencies = [
name = "assistant_tools"
version = "0.1.0"
dependencies = [
"aho-corasick",
"anyhow",
"assistant_settings",
"assistant_tool",
"buffer_diff",
"chrono",
@@ -712,25 +720,36 @@ dependencies = [
"clock",
"collections",
"component",
"derive_more",
"editor",
"feature_flags",
"fs",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"handlebars 4.5.0",
"html_to_markdown",
"http_client",
"indoc",
"itertools 0.14.0",
"language",
"language_model",
"language_models",
"linkme",
"open",
"pretty_assertions",
"project",
"rand 0.8.5",
"regex",
"reqwest_client",
"rust-embed",
"schemars",
"serde",
"serde_json",
"settings",
"smallvec",
"streaming_diff",
"strsim",
"task",
"tempfile",
"terminal",
@@ -2173,6 +2192,12 @@ dependencies = [
"piper",
]
[[package]]
name = "borrow-or-share"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3eeab4423108c5d7c744f4d234de88d18d636100093ae04caf4825134b9c3a32"
[[package]]
name = "borsh"
version = "1.5.7"
@@ -2288,6 +2313,12 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "bytecount"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce"
[[package]]
name = "bytemuck"
version = "1.22.0"
@@ -3208,7 +3239,9 @@ dependencies = [
name = "component_preview"
version = "0.1.0"
dependencies = [
"agent",
"anyhow",
"assistant_tool",
"client",
"collections",
"component",
@@ -3218,6 +3251,7 @@ dependencies = [
"log",
"notifications",
"project",
"prompt_store",
"serde",
"ui",
"ui_input",
@@ -4365,7 +4399,6 @@ name = "diagnostics"
version = "0.1.0"
dependencies = [
"anyhow",
"cargo_metadata",
"client",
"collections",
"component",
@@ -4375,7 +4408,6 @@ dependencies = [
"futures 0.3.31",
"gpui",
"indoc",
"itertools 0.14.0",
"language",
"linkme",
"log",
@@ -4387,7 +4419,6 @@ dependencies = [
"serde",
"serde_json",
"settings",
"smol",
"text",
"theme",
"ui",
@@ -4770,6 +4801,15 @@ dependencies = [
"zeroize",
]
[[package]]
name = "email_address"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449"
dependencies = [
"serde",
]
[[package]]
name = "embed-resource"
version = "3.0.2"
@@ -5003,6 +5043,7 @@ dependencies = [
"node_runtime",
"pathdiff",
"paths",
"pretty_assertions",
"project",
"prompt_store",
"regex",
@@ -5184,6 +5225,7 @@ dependencies = [
"collections",
"db",
"editor",
"extension",
"extension_host",
"fs",
"fuzzy",
@@ -5416,6 +5458,17 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8bf7cc16383c4b8d58b9905a8509f02926ce3058053c056376248d958c9df1e8"
[[package]]
name = "fluent-uri"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1918b65d96df47d3591bed19c5cca17e3fa5d0707318e4b5ef2eae01764df7e5"
dependencies = [
"borrow-or-share",
"ref-cast",
"serde",
]
[[package]]
name = "flume"
version = "0.11.1"
@@ -5570,6 +5623,16 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "fraction"
version = "0.15.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f158e3ff0a1b334408dc9fb811cd99b446986f4d8b741bb08f9df1604085ae7"
dependencies = [
"lazy_static",
"num",
]
[[package]]
name = "freetype-sys"
version = "0.20.1"
@@ -6367,6 +6430,7 @@ dependencies = [
"log",
"pest",
"pest_derive",
"rust-embed",
"serde",
"serde_json",
"thiserror 1.0.69",
@@ -7572,6 +7636,33 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "jsonschema"
version = "0.30.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1b46a0365a611fbf1d2143104dcf910aada96fafd295bab16c60b802bf6fa1d"
dependencies = [
"ahash 0.8.11",
"base64 0.22.1",
"bytecount",
"email_address",
"fancy-regex 0.14.0",
"fraction",
"idna",
"itoa",
"num-cmp",
"num-traits",
"once_cell",
"percent-encoding",
"referencing",
"regex",
"regex-syntax 0.8.5",
"reqwest 0.12.15 (registry+https://github.com/rust-lang/crates.io-index)",
"serde",
"serde_json",
"uuid-simd",
]
[[package]]
name = "jsonwebtoken"
version = "9.3.1"
@@ -8256,7 +8347,7 @@ dependencies = [
"prost 0.9.0",
"prost-build 0.9.0",
"prost-types 0.9.0",
"reqwest 0.12.15",
"reqwest 0.12.15 (git+https://github.com/zed-industries/reqwest.git?rev=951c770a32f1998d6e999cef3e59e0013e6c4415)",
"serde",
"workspace-hack",
]
@@ -9166,6 +9257,12 @@ dependencies = [
"zeroize",
]
[[package]]
name = "num-cmp"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63335b2e2c34fae2fb0aa2cecfd9f0832a1e24b3b32ecec612c3426d46dc8aaa"
[[package]]
name = "num-complex"
version = "0.4.6"
@@ -11759,6 +11856,20 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "referencing"
version = "0.30.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8eff4fa778b5c2a57e85c5f2fe3a709c52f0e60d23146e2151cbef5893f420e"
dependencies = [
"ahash 0.8.11",
"fluent-uri",
"once_cell",
"parking_lot",
"percent-encoding",
"serde_json",
]
[[package]]
name = "refineable"
version = "0.1.0"
@@ -12028,6 +12139,43 @@ dependencies = [
"winreg 0.50.0",
]
[[package]]
name = "reqwest"
version = "0.12.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb"
dependencies = [
"base64 0.22.1",
"bytes 1.10.1",
"futures-channel",
"futures-core",
"futures-util",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"hyper 1.6.0",
"hyper-util",
"ipnet",
"js-sys",
"log",
"mime",
"once_cell",
"percent-encoding",
"pin-project-lite",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper 1.0.2",
"tokio",
"tower 0.5.2",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
"windows-registry 0.4.0",
]
[[package]]
name = "reqwest"
version = "0.12.15"
@@ -12088,7 +12236,7 @@ dependencies = [
"http_client_tls",
"log",
"regex",
"reqwest 0.12.15",
"reqwest 0.12.15 (git+https://github.com/zed-industries/reqwest.git?rev=951c770a32f1998d6e999cef3e59e0013e6c4415)",
"serde",
"smol",
"tokio",
@@ -15939,6 +16087,17 @@ dependencies = [
"sha1_smol",
]
[[package]]
name = "uuid-simd"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23b082222b4f6619906941c17eb2297fff4c2fb96cb60164170522942a200bd8"
dependencies = [
"outref",
"uuid",
"vsimd",
]
[[package]]
name = "v_frame"
version = "0.3.8"
@@ -18032,12 +18191,14 @@ dependencies = [
"getrandom 0.2.15",
"getrandom 0.3.2",
"gimli",
"handlebars 4.5.0",
"hashbrown 0.14.5",
"hashbrown 0.15.2",
"heck 0.4.1",
"hmac",
"hyper 0.14.32",
"hyper-rustls 0.27.5",
"idna",
"indexmap",
"inout",
"itertools 0.12.1",
@@ -18061,6 +18222,7 @@ dependencies = [
"num-bigint-dig",
"num-integer",
"num-iter",
"num-rational",
"num-traits",
"object",
"once_cell",

View File

@@ -462,6 +462,7 @@ indexmap = { version = "2.7.0", features = ["serde"] }
indoc = "2"
inventory = "0.3.19"
itertools = "0.14.0"
jsonschema = "0.30.0"
jsonwebtoken = "9.3"
jupyter-protocol = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed" ,rev = "7130c804216b6914355d15d0b91ea91f6babd734" }

1
assets/icons/hammer.svg Normal file
View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-hammer-icon lucide-hammer"><path d="m15 12-8.373 8.373a1 1 0 1 1-3-3L12 9"/><path d="m18 15 4-4"/><path d="m21.5 11.5-1.914-1.914A2 2 0 0 1 19 8.172V7l-2.26-2.26a6 6 0 0 0-4.202-1.756L9 2.96l.92.82A6.18 6.18 0 0 1 12 8.4V10l2 2h1.172a2 2 0 0 1 1.414.586L18.5 14.5"/></svg>

After

Width:  |  Height:  |  Size: 475 B

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-user-round-check-icon lucide-user-round-check"><path d="M2 21a8 8 0 0 1 13.292-6"/><circle cx="10" cy="8" r="5"/><path d="m16 19 2 2 4-4"/></svg>

After

Width:  |  Height:  |  Size: 348 B

View File

@@ -248,7 +248,6 @@
"ctrl-shift-o": "agent::ToggleNavigationMenu",
"ctrl-shift-i": "agent::ToggleOptionsMenu",
"shift-escape": "agent::ExpandMessageEditor",
"ctrl-e": "agent::ChatMode",
"ctrl-alt-e": "agent::RemoveAllContext"
}
},
@@ -963,6 +962,14 @@
"escape": "menu::Cancel"
}
},
{
"context": "ConfigureContextServerModal > Editor",
"bindings": {
"escape": "menu::Cancel",
"enter": "editor::Newline",
"ctrl-enter": "menu::Confirm"
}
},
{
"context": "Diagnostics",
"use_key_equivalents": true,

View File

@@ -293,7 +293,6 @@
"cmd-shift-o": "agent::ToggleNavigationMenu",
"cmd-shift-i": "agent::ToggleOptionsMenu",
"shift-escape": "agent::ExpandMessageEditor",
"cmd-e": "agent::ChatMode",
"cmd-alt-e": "agent::RemoveAllContext"
}
},
@@ -1069,6 +1068,15 @@
"escape": "menu::Cancel"
}
},
{
"context": "ConfigureContextServerModal > Editor",
"use_key_equivalents": true,
"bindings": {
"escape": "menu::Cancel",
"enter": "editor::Newline",
"cmd-enter": "menu::Confirm"
}
},
{
"context": "Diagnostics",
"use_key_equivalents": true,

View File

@@ -3,11 +3,12 @@ You are a highly skilled software engineer with extensive knowledge in many prog
## Communication
1. Be conversational but professional.
2. Refer to the USER in the second person and yourself in the first person.
2. Refer to the user in the second person and yourself in the first person.
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names.
4. NEVER lie or make things up.
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
{{#if has_tools}}
## Tool Use
1. Make sure to adhere to the tools schema.
@@ -22,6 +23,7 @@ You are a highly skilled software engineer with extensive knowledge in many prog
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
{{! TODO: If there are files, we should mention it but otherwise omit that fact }}
{{#if has_tools}}
If appropriate, use tool calls to explore the current project, which contains the following root directories:
{{#each worktrees}}
@@ -36,6 +38,14 @@ If appropriate, use tool calls to explore the current project, which contains th
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file.
{{/if}}
{{/if}}
{{else}}
You are being tasked with providing a response, but you have no ability to use tools or to read or write any aspect of the user's system (other than any context the user might have provided to you).
As such, if you need the user to perform any actions for you, you must request them explicitly. Bias towards giving a response to the best of your ability, and then making requests for the user to take action (e.g. to give you more context) only optionally.
The one exception to this is if the user references something you don't know about - for example, the name of a source code file, function, type, or other piece of code that you have no awareness of. In this case, you MUST NOT MAKE SOMETHING UP, or assume you know what that thing is or how it works. Instead, you must ask the user for clarification rather than giving a response.
{{/if}}
## Code Block Formatting
@@ -111,6 +121,8 @@ In Markdown, hash marks signify headings. For example:
```
</bad_example_do_not_do_this>
This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks.
{{#if has_tools}}
## Fixing Diagnostics
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.
@@ -124,10 +136,11 @@ Otherwise, follow debugging best practices:
2. Add descriptive logging statements and error messages to track variable and code state.
3. Add test functions and statements to isolate the problem.
{{/if}}
## Calling External APIs
1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission.
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data.
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file(s). If no such file exists or if the package is not present, use the latest version that is in your training data.
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
## System Information
@@ -135,10 +148,10 @@ Otherwise, follow debugging best practices:
Operating System: {{os}}
Default Shell: {{shell}}
{{#if (or has_rules has_default_user_rules)}}
{{#if (or has_rules has_user_rules)}}
## User's Custom Instructions
The following additional instructions are provided by the user, and should be followed to the best of your ability without interfering with the tool use guidelines.
The following additional instructions are provided by the user, and should be followed to the best of your ability{{#if has_tools}} without interfering with the tool use guidelines{{/if}}.
{{#if has_rules}}
There are project rules that apply to these root directories:

View File

@@ -657,6 +657,8 @@
},
// When enabled, the agent can run potentially destructive actions without asking for your confirmation.
"always_allow_tool_actions": false,
// When enabled, the agent will stream edits.
"stream_edits": false,
"default_profile": "write",
"profiles": {
"ask": {
@@ -933,22 +935,9 @@
"max_severity": null
},
"rust": {
// When enabled, Zed runs `cargo check --message-format=json`-based commands and
// collect cargo diagnostics instead of rust-analyzer.
"fetch_cargo_diagnostics": false,
// A command override for fetching the cargo diagnostics.
// First argument is the command, followed by the arguments.
"diagnostics_fetch_command": [
"cargo",
"check",
"--quiet",
"--workspace",
"--message-format=json",
"--all-targets",
"--keep-going"
],
// Extra environment variables to pass to the diagnostics fetch command.
"env": {}
// When enabled, Zed disables rust-analyzer's check on save and starts to query
// Cargo diagnostics separately.
"fetch_cargo_diagnostics": false
}
},
// Files or globs of files that will be excluded by Zed entirely. They will be skipped during file

View File

@@ -35,6 +35,7 @@ context_server.workspace = true
convert_case.workspace = true
db.workspace = true
editor.workspace = true
extension.workspace = true
feature_flags.workspace = true
file_icons.workspace = true
fs.workspace = true
@@ -47,6 +48,7 @@ html_to_markdown.workspace = true
http_client.workspace = true
indexmap.workspace = true
itertools.workspace = true
jsonschema.workspace = true
language.workspace = true
language_model.workspace = true
language_model_selector.workspace = true
@@ -56,6 +58,7 @@ lsp.workspace = true
markdown.workspace = true
menu.workspace = true
multi_buffer.workspace = true
notifications.workspace = true
ordered-float.workspace = true
parking_lot.workspace = true
paths.workspace = true
@@ -71,6 +74,7 @@ rope.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
serde_json_lenient.workspace = true
settings.workspace = true
smallvec.workspace = true
smol.workspace = true

View File

@@ -21,9 +21,9 @@ use editor::scroll::Autoscroll;
use editor::{Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer};
use gpui::{
AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardEntry,
ClipboardItem, CursorStyle, DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter,
Focusable, Hsla, ListAlignment, ListState, MouseButton, PlatformDisplay, ScrollHandle,
Stateful, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation,
ClipboardItem, DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter, Focusable, Hsla,
ListAlignment, ListState, MouseButton, PlatformDisplay, ScrollHandle, Stateful,
StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation,
UnderlineStyle, WeakEntity, WindowHandle, linear_color_stop, linear_gradient, list, percentage,
pulsating_between,
};
@@ -45,8 +45,8 @@ use std::time::Duration;
use text::ToPoint;
use theme::ThemeSettings;
use ui::{
ButtonLike, Disclosure, IconButton, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState,
TextSize, Tooltip, prelude::*,
Disclosure, IconButton, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState, TextSize,
Tooltip, prelude::*,
};
use util::ResultExt as _;
use util::markdown::MarkdownCodeBlock;
@@ -360,42 +360,25 @@ fn render_markdown_code_block(
cx,
)),
CodeBlockKind::FencedSrc(path_range) => path_range.path.file_name().map(|file_name| {
let language = parsed_markdown
.languages_by_path
.get(&path_range.path)
.or_else(|| {
path_range
.path
.extension()
.and_then(OsStr::to_str)
.and_then(|str| {
let ext = SharedString::new(str.to_string());
parsed_markdown.languages_by_name.get(&ext)
})
});
// We tell the model to use /dev/null for the path instead of using ```language
// because otherwise it consistently fails to use code citations.
if path_range.path.starts_with("/dev/null") {
let icon = language.and_then(|language| {
language
.config()
.matcher
.path_suffixes
.iter()
.find_map(|extension| {
file_icons::FileIcons::get_icon(Path::new(extension), cx)
})
.map(|icon_path| {
code_block_icon(ix, icon_path, Some(language.name().into()))
})
});
let ext = path_range
.path
.extension()
.and_then(OsStr::to_str)
.map(|str| SharedString::new(str.to_string()))
.unwrap_or_default();
div().children(icon).into_any_element()
render_code_language(
parsed_markdown
.languages_by_path
.get(&path_range.path)
.or_else(|| parsed_markdown.languages_by_name.get(&ext)),
ext,
cx,
)
} else {
let icon = file_icons::FileIcons::get_icon(&path_range.path, cx).map(|icon_path| {
code_block_icon(ix, icon_path, language.map(|lang| lang.name().into()))
});
let content = if let Some(parent) = path_range.path.parent() {
h_flex()
.ml_1()
@@ -428,11 +411,19 @@ fn render_markdown_code_block(
.hover(|item| item.bg(cx.theme().colors().element_hover.opacity(0.5)))
.tooltip(Tooltip::text("Jump to File"))
.child(
h_flex().gap_0p5().children(icon).child(content).child(
Icon::new(IconName::ArrowUpRight)
.size(IconSize::XSmall)
.color(Color::Ignored),
),
h_flex()
.gap_0p5()
.children(
file_icons::FileIcons::get_icon(&path_range.path, cx)
.map(Icon::from_path)
.map(|icon| icon.color(Color::Muted).size(IconSize::XSmall)),
)
.child(content)
.child(
Icon::new(IconName::ArrowUpRight)
.size(IconSize::XSmall)
.color(Color::Ignored),
),
)
.on_click({
let path_range = path_range.clone();
@@ -621,26 +612,6 @@ fn render_markdown_code_block(
)
}
fn code_block_icon(
ix: usize,
icon_path: SharedString,
tooltip: Option<SharedString>,
) -> ButtonLike {
let without_tooltip = ButtonLike::new(("code_block_icon", ix))
.disabled(true)
.cursor_style(CursorStyle::Arrow)
.child(
Icon::from_path(icon_path)
.color(Color::Muted)
.size(IconSize::XSmall),
);
match tooltip {
Some(tooltip) => without_tooltip.tooltip(Tooltip::text(tooltip)),
None => without_tooltip,
}
}
fn render_code_language(
language: Option<&Arc<Language>>,
name_fallback: SharedString,

View File

@@ -6,6 +6,7 @@ mod assistant_panel;
mod buffer_codegen;
mod context;
mod context_picker;
mod context_server_configuration;
mod context_store;
mod context_strip;
mod history_store;
@@ -30,6 +31,7 @@ use command_palette_hooks::CommandPaletteFilter;
use feature_flags::{Assistant2FeatureFlag, FeatureFlagAppExt};
use fs::Fs;
use gpui::{App, actions, impl_actions};
use language::LanguageRegistry;
use prompt_store::PromptBuilder;
use schemars::JsonSchema;
use serde::Deserialize;
@@ -44,6 +46,8 @@ pub use crate::inline_assistant::InlineAssistant;
pub use crate::thread::{Message, MessageSegment, Thread, ThreadEvent};
pub use crate::thread_store::ThreadStore;
pub use agent_diff::{AgentDiff, AgentDiffToolbar};
pub use context_store::ContextStore;
pub use ui::{all_agent_previews, get_agent_preview};
actions!(
agent,
@@ -60,7 +64,6 @@ actions!(
AddContextServer,
RemoveSelectedThread,
Chat,
ChatMode,
CycleNextInlineAssist,
CyclePreviousInlineAssist,
FocusUp,
@@ -107,11 +110,13 @@ pub fn init(
fs: Arc<dyn Fs>,
client: Arc<Client>,
prompt_builder: Arc<PromptBuilder>,
language_registry: Arc<LanguageRegistry>,
cx: &mut App,
) {
AssistantSettings::register(cx);
thread_store::init(cx);
assistant_panel::init(cx);
context_server_configuration::init(language_registry, cx);
inline_assistant::init(
fs.clone(),

View File

@@ -1,16 +1,18 @@
mod add_context_server_modal;
mod configure_context_server_modal;
mod manage_profiles_modal;
mod tool_picker;
use std::sync::Arc;
use std::{sync::Arc, time::Duration};
use assistant_settings::AssistantSettings;
use assistant_tool::{ToolSource, ToolWorkingSet};
use collections::HashMap;
use context_server::manager::ContextServerManager;
use context_server::manager::{ContextServer, ContextServerManager, ContextServerStatus};
use fs::Fs;
use gpui::{
Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, ScrollHandle, Subscription,
Action, Animation, AnimationExt as _, AnyView, App, Entity, EventEmitter, FocusHandle,
Focusable, ScrollHandle, Subscription, pulsating_between,
};
use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
use settings::{Settings, update_settings_file};
@@ -22,6 +24,7 @@ use util::ResultExt as _;
use zed_actions::ExtensionCategoryFilter;
pub(crate) use add_context_server_modal::AddContextServerModal;
pub(crate) use configure_context_server_modal::ConfigureContextServerModal;
pub(crate) use manage_profiles_modal::ManageProfilesModal;
use crate::AddContextServer;
@@ -254,10 +257,12 @@ impl AssistantConfiguration {
)
}
fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
fn render_context_servers_section(
&mut self,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
let context_servers = self.context_server_manager.read(cx).all_servers().clone();
let tools_by_source = self.tools.read(cx).tools_by_source(cx);
let empty = Vec::new();
const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
@@ -272,136 +277,11 @@ impl AssistantConfiguration {
.child(Headline::new("Model Context Protocol (MCP) Servers"))
.child(Label::new(SUBHEADING).color(Color::Muted)),
)
.children(context_servers.into_iter().map(|context_server| {
let is_running = context_server.client().is_some();
let are_tools_expanded = self
.expanded_context_server_tools
.get(&context_server.id())
.copied()
.unwrap_or_default();
let tools = tools_by_source
.get(&ToolSource::ContextServer {
id: context_server.id().into(),
})
.unwrap_or_else(|| &empty);
let tool_count = tools.len();
v_flex()
.id(SharedString::from(context_server.id()))
.border_1()
.rounded_md()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().background.opacity(0.25))
.child(
h_flex()
.p_1()
.justify_between()
.when(are_tools_expanded && tool_count > 1, |element| {
element
.border_b_1()
.border_color(cx.theme().colors().border)
})
.child(
h_flex()
.gap_2()
.child(
Disclosure::new("tool-list-disclosure", are_tools_expanded)
.disabled(tool_count == 0)
.on_click(cx.listener({
let context_server_id = context_server.id();
move |this, _event, _window, _cx| {
let is_open = this
.expanded_context_server_tools
.entry(context_server_id.clone())
.or_insert(false);
*is_open = !*is_open;
}
})),
)
.child(Indicator::dot().color(if is_running {
Color::Success
} else {
Color::Error
}))
.child(Label::new(context_server.id()))
.child(
Label::new(format!("{tool_count} tools"))
.color(Color::Muted)
.size(LabelSize::Small),
),
)
.child(
Switch::new("context-server-switch", is_running.into())
.color(SwitchColor::Accent)
.on_click({
let context_server_manager =
self.context_server_manager.clone();
let context_server = context_server.clone();
move |state, _window, cx| match state {
ToggleState::Unselected
| ToggleState::Indeterminate => {
context_server_manager.update(cx, |this, cx| {
this.stop_server(context_server.clone(), cx)
.log_err();
});
}
ToggleState::Selected => {
cx.spawn({
let context_server_manager =
context_server_manager.clone();
let context_server = context_server.clone();
async move |cx| {
if let Some(start_server_task) =
context_server_manager
.update(cx, |this, cx| {
this.start_server(
context_server,
cx,
)
})
.log_err()
{
start_server_task.await.log_err();
}
}
})
.detach();
}
}
}),
),
)
.map(|parent| {
if !are_tools_expanded {
return parent;
}
parent.child(v_flex().py_1p5().px_1().gap_1().children(
tools.into_iter().enumerate().map(|(ix, tool)| {
h_flex()
.id(("tool-item", ix))
.px_1()
.gap_2()
.justify_between()
.hover(|style| style.bg(cx.theme().colors().element_hover))
.rounded_sm()
.child(
Label::new(tool.name())
.buffer_font(cx)
.size(LabelSize::Small),
)
.child(
Icon::new(IconName::Info)
.size(IconSize::Small)
.color(Color::Ignored),
)
.tooltip(Tooltip::text(tool.description()))
}),
))
})
}))
.children(
context_servers
.into_iter()
.map(|context_server| self.render_context_server(context_server, window, cx)),
)
.child(
h_flex()
.justify_between()
@@ -429,7 +309,7 @@ impl AssistantConfiguration {
.style(ButtonStyle::Filled)
.layer(ElevationIndex::ModalSurface)
.full_width()
.icon(IconName::DatabaseZap)
.icon(IconName::Hammer)
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.on_click(|_event, window, cx| {
@@ -447,10 +327,214 @@ impl AssistantConfiguration {
),
)
}
fn render_context_server(
&self,
context_server: Arc<ContextServer>,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl use<> + IntoElement {
let tools_by_source = self.tools.read(cx).tools_by_source(cx);
let server_status = self
.context_server_manager
.read(cx)
.status_for_server(&context_server.id());
let is_running = matches!(server_status, Some(ContextServerStatus::Running));
let error = if let Some(ContextServerStatus::Error(error)) = server_status.clone() {
Some(error)
} else {
None
};
let are_tools_expanded = self
.expanded_context_server_tools
.get(&context_server.id())
.copied()
.unwrap_or_default();
let tools = tools_by_source
.get(&ToolSource::ContextServer {
id: context_server.id().into(),
})
.map_or([].as_slice(), |tools| tools.as_slice());
let tool_count = tools.len();
let border_color = cx.theme().colors().border.opacity(0.6);
v_flex()
.id(SharedString::from(context_server.id()))
.border_1()
.rounded_md()
.border_color(border_color)
.bg(cx.theme().colors().background.opacity(0.2))
.overflow_hidden()
.child(
h_flex()
.p_1()
.justify_between()
.when(
error.is_some() || are_tools_expanded && tool_count > 1,
|element| element.border_b_1().border_color(border_color),
)
.child(
h_flex()
.gap_1p5()
.child(
Disclosure::new(
"tool-list-disclosure",
are_tools_expanded || error.is_some(),
)
.disabled(tool_count == 0)
.on_click(cx.listener({
let context_server_id = context_server.id();
move |this, _event, _window, _cx| {
let is_open = this
.expanded_context_server_tools
.entry(context_server_id.clone())
.or_insert(false);
*is_open = !*is_open;
}
})),
)
.child(match server_status {
Some(ContextServerStatus::Starting) => {
let color = Color::Success.color(cx);
Indicator::dot()
.color(Color::Success)
.with_animation(
SharedString::from(format!(
"{}-starting",
context_server.id(),
)),
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.4, 1.)),
move |this, delta| {
this.color(color.alpha(delta).into())
},
)
.into_any_element()
}
Some(ContextServerStatus::Running) => {
Indicator::dot().color(Color::Success).into_any_element()
}
Some(ContextServerStatus::Error(_)) => {
Indicator::dot().color(Color::Error).into_any_element()
}
None => Indicator::dot().color(Color::Muted).into_any_element(),
})
.child(Label::new(context_server.id()).ml_0p5())
.when(is_running, |this| {
this.child(
Label::new(if tool_count == 1 {
SharedString::from("1 tool")
} else {
SharedString::from(format!("{} tools", tool_count))
})
.color(Color::Muted)
.size(LabelSize::Small),
)
}),
)
.child(
Switch::new("context-server-switch", is_running.into())
.color(SwitchColor::Accent)
.on_click({
let context_server_manager = self.context_server_manager.clone();
let context_server = context_server.clone();
move |state, _window, cx| match state {
ToggleState::Unselected | ToggleState::Indeterminate => {
context_server_manager.update(cx, |this, cx| {
this.stop_server(context_server.clone(), cx).log_err();
});
}
ToggleState::Selected => {
cx.spawn({
let context_server_manager =
context_server_manager.clone();
let context_server = context_server.clone();
async move |cx| {
if let Some(start_server_task) =
context_server_manager
.update(cx, |this, cx| {
this.start_server(context_server, cx)
})
.log_err()
{
start_server_task.await.log_err();
}
}
})
.detach();
}
}
}),
),
)
.map(|parent| {
if let Some(error) = error {
return parent.child(
h_flex()
.p_2()
.gap_2()
.items_start()
.child(
h_flex()
.flex_none()
.h(window.line_height() / 1.6_f32)
.justify_center()
.child(
Icon::new(IconName::XCircle)
.size(IconSize::XSmall)
.color(Color::Error),
),
)
.child(
div().w_full().child(
Label::new(error)
.buffer_font(cx)
.color(Color::Muted)
.size(LabelSize::Small),
),
),
);
}
if !are_tools_expanded || tools.is_empty() {
return parent;
}
parent.child(v_flex().py_1p5().px_1().gap_1().children(
tools.into_iter().enumerate().map(|(ix, tool)| {
h_flex()
.id(("tool-item", ix))
.px_1()
.gap_2()
.justify_between()
.hover(|style| style.bg(cx.theme().colors().element_hover))
.rounded_sm()
.child(
Label::new(tool.name())
.buffer_font(cx)
.size(LabelSize::Small),
)
.child(
Icon::new(IconName::Info)
.size(IconSize::Small)
.color(Color::Ignored),
)
.tooltip(Tooltip::text(tool.description()))
}),
))
})
}
}
impl Render for AssistantConfiguration {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
.id("assistant-configuration")
.key_context("AgentConfiguration")
@@ -467,7 +551,7 @@ impl Render for AssistantConfiguration {
.overflow_y_scroll()
.child(self.render_command_permission(cx))
.child(Divider::horizontal().color(DividerColor::Border))
.child(self.render_context_servers_section(cx))
.child(self.render_context_servers_section(window, cx))
.child(Divider::horizontal().color(DividerColor::Border))
.child(self.render_provider_configuration_section(cx)),
)

View File

@@ -0,0 +1,443 @@
use std::{
sync::{Arc, Mutex},
time::Duration,
};
use anyhow::Context as _;
use context_server::manager::{ContextServerManager, ContextServerStatus};
use editor::{Editor, EditorElement, EditorStyle};
use extension::ContextServerConfiguration;
use gpui::{
Animation, AnimationExt, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Task,
TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, percentage,
};
use language::{Language, LanguageRegistry};
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
use notifications::status_toast::{StatusToast, ToastIcon};
use settings::{Settings as _, update_settings_file};
use theme::ThemeSettings;
use ui::{KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*};
use util::ResultExt;
use workspace::{ModalView, Workspace};
pub(crate) struct ConfigureContextServerModal {
workspace: WeakEntity<Workspace>,
context_servers_to_setup: Vec<ConfigureContextServer>,
context_server_manager: Entity<ContextServerManager>,
}
struct ConfigureContextServer {
id: Arc<str>,
installation_instructions: Entity<markdown::Markdown>,
settings_validator: Option<jsonschema::Validator>,
settings_editor: Entity<Editor>,
last_error: Option<SharedString>,
waiting_for_context_server: bool,
}
impl ConfigureContextServerModal {
pub fn new(
configurations: impl Iterator<Item = (Arc<str>, ContextServerConfiguration)>,
jsonc_language: Option<Arc<Language>>,
context_server_manager: Entity<ContextServerManager>,
language_registry: Arc<LanguageRegistry>,
workspace: WeakEntity<Workspace>,
window: &mut Window,
cx: &mut App,
) -> Option<Self> {
let context_servers_to_setup = configurations
.map(|(id, manifest)| {
let jsonc_language = jsonc_language.clone();
let settings_validator = jsonschema::validator_for(&manifest.settings_schema)
.context("Failed to load JSON schema for context server settings")
.log_err();
ConfigureContextServer {
id: id.clone(),
installation_instructions: cx.new(|cx| {
Markdown::new(
manifest.installation_instructions.clone().into(),
Some(language_registry.clone()),
None,
cx,
)
}),
settings_validator,
settings_editor: cx.new(|cx| {
let mut editor = Editor::auto_height(16, window, cx);
editor.set_text(manifest.default_settings.trim(), window, cx);
if let Some(buffer) = editor.buffer().read(cx).as_singleton() {
buffer.update(cx, |buffer, cx| buffer.set_language(jsonc_language, cx))
}
editor
}),
waiting_for_context_server: false,
last_error: None,
}
})
.collect::<Vec<_>>();
if context_servers_to_setup.is_empty() {
return None;
}
Some(Self {
workspace,
context_servers_to_setup,
context_server_manager,
})
}
}
impl ConfigureContextServerModal {
pub fn confirm(&mut self, cx: &mut Context<Self>) {
if self.context_servers_to_setup.is_empty() {
return;
}
let Some(workspace) = self.workspace.upgrade() else {
return;
};
let configuration = &mut self.context_servers_to_setup[0];
if configuration.waiting_for_context_server {
return;
}
let settings_value = match serde_json_lenient::from_str::<serde_json::Value>(
&configuration.settings_editor.read(cx).text(cx),
) {
Ok(value) => value,
Err(error) => {
configuration.last_error = Some(error.to_string().into());
cx.notify();
return;
}
};
if let Some(validator) = configuration.settings_validator.as_ref() {
if let Err(error) = validator.validate(&settings_value) {
configuration.last_error = Some(error.to_string().into());
cx.notify();
return;
}
}
let id = configuration.id.clone();
let settings_changed = context_server::ContextServerSettings::get_global(cx)
.context_servers
.get(&id)
.map_or(true, |config| {
config.settings.as_ref() != Some(&settings_value)
});
let is_running = self.context_server_manager.read(cx).status_for_server(&id)
== Some(ContextServerStatus::Running);
if !settings_changed && is_running {
self.complete_setup(id, cx);
return;
}
configuration.waiting_for_context_server = true;
let task = wait_for_context_server(&self.context_server_manager, id.clone(), cx);
cx.spawn({
let id = id.clone();
async move |this, cx| {
let result = task.await;
this.update(cx, |this, cx| match result {
Ok(_) => {
this.complete_setup(id, cx);
}
Err(err) => {
if let Some(configuration) = this.context_servers_to_setup.get_mut(0) {
configuration.last_error = Some(err.into());
configuration.waiting_for_context_server = false;
} else {
this.dismiss(cx);
}
cx.notify();
}
})
}
})
.detach();
// When we write the settings to the file, the context server will be restarted.
update_settings_file::<context_server::ContextServerSettings>(
workspace.read(cx).app_state().fs.clone(),
cx,
{
let id = id.clone();
|settings, _| {
if let Some(server_config) = settings.context_servers.get_mut(&id) {
server_config.settings = Some(settings_value);
} else {
settings.context_servers.insert(
id,
context_server::ServerConfig {
settings: Some(settings_value),
..Default::default()
},
);
}
}
},
);
}
fn complete_setup(&mut self, id: Arc<str>, cx: &mut Context<Self>) {
self.context_servers_to_setup.remove(0);
cx.notify();
if !self.context_servers_to_setup.is_empty() {
return;
}
self.workspace
.update(cx, {
|workspace, cx| {
let status_toast = StatusToast::new(
format!("{} configured successfully.", id),
cx,
|this, _cx| {
this.icon(ToastIcon::new(IconName::Hammer).color(Color::Muted))
.action("Dismiss", |_, _| {})
},
);
workspace.toggle_status_toast(status_toast, cx);
}
})
.log_err();
self.dismiss(cx);
}
fn dismiss(&self, cx: &mut Context<Self>) {
cx.emit(DismissEvent);
}
}
fn wait_for_context_server(
context_server_manager: &Entity<ContextServerManager>,
context_server_id: Arc<str>,
cx: &mut App,
) -> Task<Result<(), Arc<str>>> {
let (tx, rx) = futures::channel::oneshot::channel();
let tx = Arc::new(Mutex::new(Some(tx)));
let subscription = cx.subscribe(context_server_manager, move |_, event, _cx| match event {
context_server::manager::Event::ServerStatusChanged { server_id, status } => match status {
Some(ContextServerStatus::Running) => {
if server_id == &context_server_id {
if let Some(tx) = tx.lock().unwrap().take() {
let _ = tx.send(Ok(()));
}
}
}
Some(ContextServerStatus::Error(error)) => {
if server_id == &context_server_id {
if let Some(tx) = tx.lock().unwrap().take() {
let _ = tx.send(Err(error.clone()));
}
}
}
_ => {}
},
});
cx.spawn(async move |_cx| {
let result = rx.await.unwrap();
drop(subscription);
result
})
}
impl Render for ConfigureContextServerModal {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let Some(configuration) = self.context_servers_to_setup.first() else {
return div().child("No context servers to setup");
};
let focus_handle = self.focus_handle(cx);
div()
.elevation_3(cx)
.w(rems(34.))
.key_context("ConfigureContextServerModal")
.on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| this.confirm(cx)))
.on_action(cx.listener(|this, _: &menu::Cancel, _window, cx| this.dismiss(cx)))
.capture_any_mouse_down(cx.listener(|this, _, window, cx| {
this.focus_handle(cx).focus(window);
}))
.child(
Modal::new("configure-context-server", None)
.header(ModalHeader::new().headline(format!("Configure {}", configuration.id)))
.section(
Section::new()
.child(div().pb_2().text_sm().child(MarkdownElement::new(
configuration.installation_instructions.clone(),
default_markdown_style(window, cx),
)))
.child(
div()
.p_2()
.rounded_md()
.border_1()
.border_color(cx.theme().colors().border_variant)
.bg(cx.theme().colors().editor_background)
.gap_1()
.child({
let settings = ThemeSettings::get_global(cx);
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.buffer_font.family.clone(),
font_fallbacks: settings.buffer_font.fallbacks.clone(),
font_size: settings.buffer_font_size(cx).into(),
font_weight: settings.buffer_font.weight,
line_height: relative(
settings.buffer_line_height.value(),
),
..Default::default()
};
EditorElement::new(
&configuration.settings_editor,
EditorStyle {
background: cx.theme().colors().editor_background,
local_player: cx.theme().players().local(),
text: text_style,
syntax: cx.theme().syntax().clone(),
..Default::default()
},
)
})
.when_some(configuration.last_error.clone(), |this, error| {
this.child(
h_flex()
.gap_2()
.px_2()
.py_1()
.child(
Icon::new(IconName::Warning)
.size(IconSize::XSmall)
.color(Color::Warning),
)
.child(
div().w_full().child(
Label::new(error)
.size(LabelSize::Small)
.color(Color::Muted),
),
),
)
}),
)
.when(configuration.waiting_for_context_server, |this| {
this.child(
h_flex()
.gap_1p5()
.child(
Icon::new(IconName::ArrowCircle)
.size(IconSize::XSmall)
.color(Color::Info)
.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| {
icon.transform(Transformation::rotate(
percentage(delta),
))
},
)
.into_any_element(),
)
.child(
Label::new("Waiting for Context Server")
.size(LabelSize::Small)
.color(Color::Muted),
),
)
}),
)
.footer(
ModalFooter::new().end_slot(
h_flex()
.gap_1()
.child(
Button::new("cancel", "Cancel")
.key_binding(
KeyBinding::for_action_in(
&menu::Cancel,
&focus_handle,
window,
cx,
)
.map(|kb| kb.size(rems_from_px(12.))),
)
.on_click(cx.listener(|this, _event, _window, cx| {
this.dismiss(cx)
})),
)
.child(
Button::new("configure-server", "Configure MCP")
.disabled(configuration.waiting_for_context_server)
.key_binding(
KeyBinding::for_action_in(
&menu::Confirm,
&focus_handle,
window,
cx,
)
.map(|kb| kb.size(rems_from_px(12.))),
)
.on_click(cx.listener(|this, _event, _window, cx| {
this.confirm(cx)
})),
),
),
),
)
}
}
pub(crate) fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
let theme_settings = ThemeSettings::get_global(cx);
let colors = cx.theme().colors();
let mut text_style = window.text_style();
text_style.refine(&TextStyleRefinement {
font_family: Some(theme_settings.ui_font.family.clone()),
font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
font_features: Some(theme_settings.ui_font.features.clone()),
font_size: Some(TextSize::XSmall.rems(cx).into()),
color: Some(colors.text_muted),
..Default::default()
});
MarkdownStyle {
base_text_style: text_style.clone(),
selection_background_color: cx.theme().players().local().selection,
link: TextStyleRefinement {
background_color: Some(colors.editor_foreground.opacity(0.025)),
underline: Some(UnderlineStyle {
color: Some(colors.text_accent.opacity(0.5)),
thickness: px(1.),
..Default::default()
}),
..Default::default()
},
..Default::default()
}
}
impl ModalView for ConfigureContextServerModal {}
impl EventEmitter<DismissEvent> for ConfigureContextServerModal {}
impl Focusable for ConfigureContextServerModal {
fn focus_handle(&self, cx: &App) -> FocusHandle {
if let Some(current) = self.context_servers_to_setup.first() {
current.settings_editor.read(cx).focus_handle(cx)
} else {
cx.focus_handle()
}
}
}

View File

@@ -843,7 +843,7 @@ pub fn load_context(
text.push_str(
"\n<context>\n\
The following items were attached by the user. \
You don't need to use other tools to read them.\n\n",
They are up-to-date and don't need to be re-read.\n\n",
);
if !file_context.is_empty() {

View File

@@ -0,0 +1,120 @@
use std::sync::Arc;
use anyhow::Context as _;
use context_server::ContextServerDescriptorRegistry;
use extension::ExtensionManifest;
use language::LanguageRegistry;
use ui::prelude::*;
use util::ResultExt;
use workspace::Workspace;
use crate::{AssistantPanel, assistant_configuration::ConfigureContextServerModal};
pub(crate) fn init(language_registry: Arc<LanguageRegistry>, cx: &mut App) {
cx.observe_new(move |_: &mut Workspace, window, cx| {
let Some(window) = window else {
return;
};
if let Some(extension_events) = extension::ExtensionEvents::try_global(cx).as_ref() {
cx.subscribe_in(extension_events, window, {
let language_registry = language_registry.clone();
move |workspace, _, event, window, cx| match event {
extension::Event::ExtensionInstalled(manifest) => {
show_configure_mcp_modal(
language_registry.clone(),
manifest,
workspace,
window,
cx,
);
}
extension::Event::ConfigureExtensionRequested(manifest) => {
if !manifest.context_servers.is_empty() {
show_configure_mcp_modal(
language_registry.clone(),
manifest,
workspace,
window,
cx,
);
}
}
_ => {}
}
})
.detach();
} else {
log::info!(
"No extension events global found. Skipping context server configuration wizard"
);
}
})
.detach();
}
fn show_configure_mcp_modal(
language_registry: Arc<LanguageRegistry>,
manifest: &Arc<ExtensionManifest>,
workspace: &mut Workspace,
window: &mut Window,
cx: &mut Context<'_, Workspace>,
) {
let Some(context_server_manager) = workspace.panel::<AssistantPanel>(cx).map(|panel| {
panel
.read(cx)
.thread_store()
.read(cx)
.context_server_manager()
}) else {
return;
};
let registry = ContextServerDescriptorRegistry::global(cx).read(cx);
let project = workspace.project().clone();
let configuration_tasks = manifest
.context_servers
.keys()
.cloned()
.filter_map({
|key| {
let descriptor = registry.context_server_descriptor(&key)?;
Some(cx.spawn({
let project = project.clone();
async move |_, cx| {
descriptor
.configuration(project, &cx)
.await
.context("Failed to resolve context server configuration")
.log_err()
.flatten()
.map(|config| (key, config))
}
}))
}
})
.collect::<Vec<_>>();
let jsonc_language = language_registry.language_for_name("jsonc");
cx.spawn_in(window, async move |this, cx| {
let descriptors = futures::future::join_all(configuration_tasks).await;
let jsonc_language = jsonc_language.await.ok();
this.update_in(cx, |this, window, cx| {
let modal = ConfigureContextServerModal::new(
descriptors.into_iter().flatten(),
jsonc_language,
context_server_manager,
language_registry,
cx.entity().downgrade(),
window,
cx,
);
if let Some(modal) = modal {
this.toggle_modal(window, cx, |_, _| modal);
}
})
})
.detach();
}

View File

@@ -4,7 +4,7 @@ use std::sync::Arc;
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
use crate::context::{ContextLoadResult, load_context};
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
use crate::ui::AnimatedLabel;
use crate::ui::{AgentPreview, AnimatedLabel};
use buffer_diff::BufferDiff;
use collections::HashSet;
use editor::actions::{MoveUp, Paste};
@@ -42,10 +42,11 @@ use crate::profile_selector::ProfileSelector;
use crate::thread::{Thread, TokenUsageRatio};
use crate::thread_store::ThreadStore;
use crate::{
AgentDiff, Chat, ChatMode, ExpandMessageEditor, NewThread, OpenAgentDiff, RemoveAllContext,
ToggleContextPicker, ToggleProfileSelector,
ActiveThread, AgentDiff, Chat, ExpandMessageEditor, NewThread, OpenAgentDiff, RemoveAllContext,
ToggleContextPicker, ToggleProfileSelector, register_agent_preview,
};
#[derive(RegisterComponent)]
pub struct MessageEditor {
thread: Entity<Thread>,
incompatible_tools_state: Entity<IncompatibleToolsState>,
@@ -206,10 +207,6 @@ impl MessageEditor {
&self.context_store
}
fn toggle_chat_mode(&mut self, _: &ChatMode, _window: &mut Window, cx: &mut Context<Self>) {
cx.notify();
}
pub fn expand_message_editor(
&mut self,
_: &ExpandMessageEditor,
@@ -432,12 +429,13 @@ impl MessageEditor {
Some(
IconButton::new("max-mode", IconName::ZedMaxMode)
.icon_size(IconSize::Small)
.toggle_state(active_completion_mode == Some(CompletionMode::Max))
.icon_color(Color::Muted)
.toggle_state(active_completion_mode == CompletionMode::Max)
.on_click(cx.listener(move |this, _event, _window, cx| {
this.thread.update(cx, |thread, _cx| {
thread.set_completion_mode(match active_completion_mode {
Some(CompletionMode::Max) => Some(CompletionMode::Normal),
Some(CompletionMode::Normal) | None => Some(CompletionMode::Max),
CompletionMode::Max => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Max,
});
});
}))
@@ -499,7 +497,6 @@ impl MessageEditor {
.on_action(cx.listener(Self::toggle_context_picker))
.on_action(cx.listener(Self::remove_all_context))
.on_action(cx.listener(Self::move_up))
.on_action(cx.listener(Self::toggle_chat_mode))
.on_action(cx.listener(Self::expand_message_editor))
.capture_action(cx.listener(Self::paste))
.gap_2()
@@ -1206,3 +1203,53 @@ impl Render for MessageEditor {
})
}
}
impl Component for MessageEditor {
fn scope() -> ComponentScope {
ComponentScope::Agent
}
}
impl AgentPreview for MessageEditor {
fn create_preview(
workspace: WeakEntity<Workspace>,
active_thread: Entity<ActiveThread>,
thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyElement> {
if let Some(workspace_entity) = workspace.upgrade() {
let fs = workspace_entity.read(cx).app_state().fs.clone();
let weak_project = workspace_entity.read(cx).project().clone().downgrade();
let context_store = cx.new(|_cx| ContextStore::new(weak_project, None));
let thread = active_thread.read(cx).thread().clone();
let example_message_editor = cx.new(|cx| {
MessageEditor::new(
fs,
workspace,
context_store,
None,
thread_store,
thread,
window,
cx,
)
});
Some(
v_flex()
.gap_4()
.children(vec![single_example(
"Default",
example_message_editor.clone().into_any_element(),
)])
.into_any_element(),
)
} else {
None
}
}
}
register_agent_preview!(MessageEditor);

View File

@@ -301,6 +301,14 @@ pub enum TokenUsageRatio {
Exceeded,
}
fn default_completion_mode(cx: &App) -> CompletionMode {
if cx.is_staff() {
CompletionMode::Max
} else {
CompletionMode::Normal
}
}
/// A thread of conversation with the LLM.
pub struct Thread {
id: ThreadId,
@@ -310,7 +318,7 @@ pub struct Thread {
detailed_summary_task: Task<Option<()>>,
detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
completion_mode: Option<CompletionMode>,
completion_mode: CompletionMode,
messages: Vec<Message>,
next_message_id: MessageId,
last_prompt_id: PromptId,
@@ -366,7 +374,7 @@ impl Thread {
detailed_summary_task: Task::ready(None),
detailed_summary_tx,
detailed_summary_rx,
completion_mode: None,
completion_mode: default_completion_mode(cx),
messages: Vec::new(),
next_message_id: MessageId(0),
last_prompt_id: PromptId::new(),
@@ -440,7 +448,7 @@ impl Thread {
detailed_summary_task: Task::ready(None),
detailed_summary_tx,
detailed_summary_rx,
completion_mode: None,
completion_mode: default_completion_mode(cx),
messages: serialized
.messages
.into_iter()
@@ -569,11 +577,11 @@ impl Thread {
}
}
pub fn completion_mode(&self) -> Option<CompletionMode> {
pub fn completion_mode(&self) -> CompletionMode {
self.completion_mode
}
pub fn set_completion_mode(&mut self, mode: Option<CompletionMode>) {
pub fn set_completion_mode(&mut self, mode: CompletionMode) {
self.completion_mode = mode;
}
@@ -1152,9 +1160,9 @@ impl Thread {
request.tools = available_tools;
request.mode = if model.supports_max_mode() {
self.completion_mode
Some(self.completion_mode)
} else {
None
Some(CompletionMode::Normal)
};
request
@@ -2110,7 +2118,7 @@ impl Thread {
.map(|repo| {
repo.update(cx, |repo, _| {
let current_branch =
repo.branch.as_ref().map(|branch| branch.name.to_string());
repo.branch.as_ref().map(|branch| branch.name().to_owned());
repo.send_job(None, |state, _| async move {
let RepositoryState::Local { backend, .. } = state else {
return GitState {
@@ -2509,7 +2517,7 @@ mod tests {
let expected_context = format!(
r#"
<context>
The following items were attached by the user. You don't need to use other tools to read them.
The following items were attached by the user. They are up-to-date and don't need to be re-read.
<files>
```rs {path_part}

View File

@@ -9,8 +9,8 @@ use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
use context_server::manager::ContextServerManager;
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
use context_server::manager::{ContextServerManager, ContextServerStatus};
use context_server::{ContextServerDescriptorRegistry, ContextServerTool};
use futures::channel::{mpsc, oneshot};
use futures::future::{self, BoxFuture, Shared};
use futures::{FutureExt as _, StreamExt as _};
@@ -108,7 +108,7 @@ impl ThreadStore {
prompt_store: Option<Entity<PromptStore>>,
cx: &mut Context<Self>,
) -> (Self, oneshot::Receiver<()>) {
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
let context_server_factory_registry = ContextServerDescriptorRegistry::default_global(cx);
let context_server_manager = cx.new(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});
@@ -555,62 +555,68 @@ impl ThreadStore {
) {
let tool_working_set = self.tools.clone();
match event {
context_server::manager::Event::ServerStarted { server_id } => {
if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
let context_server_manager = context_server_manager.clone();
cx.spawn({
let server = server.clone();
let server_id = server_id.clone();
async move |this, cx| {
let Some(protocol) = server.client() else {
return;
};
context_server::manager::Event::ServerStatusChanged { server_id, status } => {
match status {
Some(ContextServerStatus::Running) => {
if let Some(server) = context_server_manager.read(cx).get_server(server_id)
{
let context_server_manager = context_server_manager.clone();
cx.spawn({
let server = server.clone();
let server_id = server_id.clone();
async move |this, cx| {
let Some(protocol) = server.client() else {
return;
};
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() {
let tool_ids = tool_working_set
.update(cx, |tool_working_set, _| {
tools
.tools
.into_iter()
.map(|tool| {
log::info!(
"registering context server tool: {:?}",
tool.name
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
),
))
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() {
let tool_ids = tool_working_set
.update(cx, |tool_working_set, _| {
tools
.tools
.into_iter()
.map(|tool| {
log::info!(
"registering context server tool: {:?}",
tool.name
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
),
))
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
})
.log_err();
.log_err();
if let Some(tool_ids) = tool_ids {
this.update(cx, |this, cx| {
this.context_server_tool_ids
.insert(server_id, tool_ids);
this.load_default_profile(cx);
})
.log_err();
if let Some(tool_ids) = tool_ids {
this.update(cx, |this, cx| {
this.context_server_tool_ids
.insert(server_id, tool_ids);
this.load_default_profile(cx);
})
.log_err();
}
}
}
}
}
})
.detach();
}
})
.detach();
}
}
context_server::manager::Event::ServerStopped { server_id } => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.update(cx, |tool_working_set, _| {
tool_working_set.remove(&tool_ids);
});
self.load_default_profile(cx);
}
None => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.update(cx, |tool_working_set, _| {
tool_working_set.remove(&tool_ids);
});
self.load_default_profile(cx);
}
}
_ => {}
}
}
}

View File

@@ -1,9 +1,12 @@
mod agent_notification;
pub mod agent_preview;
mod animated_label;
mod context_pill;
mod upsell;
mod usage_banner;
pub use agent_notification::*;
pub use agent_preview::*;
pub use animated_label::*;
pub use context_pill::*;
pub use usage_banner::*;

View File

@@ -0,0 +1,99 @@
use collections::HashMap;
use component::ComponentId;
use gpui::{App, Entity, WeakEntity};
use linkme::distributed_slice;
use std::sync::OnceLock;
use ui::{AnyElement, Component, Window};
use workspace::Workspace;
use crate::{ActiveThread, ThreadStore};
/// Function type for creating agent component previews
pub type PreviewFn = fn(
WeakEntity<Workspace>,
Entity<ActiveThread>,
WeakEntity<ThreadStore>,
&mut Window,
&mut App,
) -> Option<AnyElement>;
/// Distributed slice for preview registration functions
#[distributed_slice]
pub static __ALL_AGENT_PREVIEWS: [fn() -> (ComponentId, PreviewFn)] = [..];
/// Trait that must be implemented by components that provide agent previews.
pub trait AgentPreview: Component {
/// Get the ID for this component
///
/// Eventually this will move to the component trait.
fn id() -> ComponentId
where
Self: Sized,
{
ComponentId(Self::name())
}
/// Static method to create a preview for this component type
fn create_preview(
workspace: WeakEntity<Workspace>,
active_thread: Entity<ActiveThread>,
thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyElement>
where
Self: Sized;
}
/// Register an agent preview for the given component type
#[macro_export]
macro_rules! register_agent_preview {
($type:ty) => {
#[linkme::distributed_slice($crate::ui::agent_preview::__ALL_AGENT_PREVIEWS)]
static __REGISTER_AGENT_PREVIEW: fn() -> (
component::ComponentId,
$crate::ui::agent_preview::PreviewFn,
) = || {
(
<$type as $crate::ui::agent_preview::AgentPreview>::id(),
<$type as $crate::ui::agent_preview::AgentPreview>::create_preview,
)
};
};
}
/// Lazy initialized registry of preview functions
static AGENT_PREVIEW_REGISTRY: OnceLock<HashMap<ComponentId, PreviewFn>> = OnceLock::new();
/// Initialize the agent preview registry if needed
fn get_or_init_registry() -> &'static HashMap<ComponentId, PreviewFn> {
AGENT_PREVIEW_REGISTRY.get_or_init(|| {
let mut map = HashMap::default();
for register_fn in __ALL_AGENT_PREVIEWS.iter() {
let (id, preview_fn) = register_fn();
map.insert(id, preview_fn);
}
map
})
}
/// Get a specific agent preview by component ID.
pub fn get_agent_preview(
id: &ComponentId,
workspace: WeakEntity<Workspace>,
active_thread: Entity<ActiveThread>,
thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyElement> {
let registry = get_or_init_registry();
registry
.get(id)
.and_then(|preview_fn| preview_fn(workspace, active_thread, thread_store, window, cx))
}
/// Get all registered agent previews.
pub fn all_agent_previews() -> Vec<ComponentId> {
let registry = get_or_init_registry();
registry.keys().cloned().collect()
}

View File

@@ -216,9 +216,10 @@ impl RenderOnce for ContextPill {
})
.when_some(on_click.as_ref(), |element, on_click| {
let on_click = on_click.clone();
element
.cursor_pointer()
.on_click(move |event, window, cx| on_click(event, window, cx))
element.cursor_pointer().on_click(move |event, window, cx| {
on_click(event, window, cx);
cx.stop_propagation();
})
})
.into_any_element()
}
@@ -254,7 +255,10 @@ impl RenderOnce for ContextPill {
})
.when_some(on_click.as_ref(), |element, on_click| {
let on_click = on_click.clone();
element.on_click(move |event, window, cx| on_click(event, window, cx))
element.on_click(move |event, window, cx| {
on_click(event, window, cx);
cx.stop_propagation();
})
})
.into_any(),
}

View File

@@ -0,0 +1,163 @@
use component::{Component, ComponentScope, single_example};
use gpui::{
AnyElement, App, ClickEvent, IntoElement, ParentElement, RenderOnce, SharedString, Styled,
Window,
};
use theme::ActiveTheme;
use ui::{
Button, ButtonCommon, ButtonStyle, Checkbox, Clickable, Color, Label, LabelCommon,
RegisterComponent, ToggleState, h_flex, v_flex,
};
/// A component that displays an upsell message with a call-to-action button
///
/// # Example
/// ```
/// let upsell = Upsell::new(
/// "Upgrade to Zed Pro",
/// "Get unlimited access to AI features and more",
/// "Upgrade Now",
/// Box::new(|_, _window, cx| {
/// cx.open_url("https://zed.dev/pricing");
/// }),
/// Box::new(|_, _window, cx| {
/// // Handle dismiss
/// }),
/// Box::new(|checked, window, cx| {
/// // Handle don't show again
/// }),
/// );
/// ```
#[derive(IntoElement, RegisterComponent)]
pub struct Upsell {
title: SharedString,
message: SharedString,
cta_text: SharedString,
on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App)>,
on_dismiss: Box<dyn Fn(&ClickEvent, &mut Window, &mut App)>,
on_dont_show_again: Box<dyn Fn(bool, &mut Window, &mut App)>,
}
impl Upsell {
/// Create a new upsell component
pub fn new(
title: impl Into<SharedString>,
message: impl Into<SharedString>,
cta_text: impl Into<SharedString>,
on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App)>,
on_dismiss: Box<dyn Fn(&ClickEvent, &mut Window, &mut App)>,
on_dont_show_again: Box<dyn Fn(bool, &mut Window, &mut App)>,
) -> Self {
Self {
title: title.into(),
message: message.into(),
cta_text: cta_text.into(),
on_click,
on_dismiss,
on_dont_show_again,
}
}
}
impl RenderOnce for Upsell {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
v_flex()
.w_full()
.p_4()
.gap_3()
.bg(cx.theme().colors().surface_background)
.rounded_md()
.border_1()
.border_color(cx.theme().colors().border)
.child(
v_flex()
.gap_1()
.child(
Label::new(self.title)
.size(ui::LabelSize::Large)
.weight(gpui::FontWeight::BOLD),
)
.child(Label::new(self.message).color(Color::Muted)),
)
.child(
h_flex()
.w_full()
.justify_between()
.items_center()
.child(
h_flex()
.items_center()
.gap_1()
.child(
Checkbox::new("dont-show-again", ToggleState::Unselected).on_click(
move |_, window, cx| {
(self.on_dont_show_again)(true, window, cx);
},
),
)
.child(
Label::new("Don't show again")
.color(Color::Muted)
.size(ui::LabelSize::Small),
),
)
.child(
h_flex()
.gap_2()
.child(
Button::new("dismiss-button", "Dismiss")
.style(ButtonStyle::Subtle)
.on_click(self.on_dismiss),
)
.child(
Button::new("cta-button", self.cta_text)
.style(ButtonStyle::Filled)
.on_click(self.on_click),
),
),
)
}
}
impl Component for Upsell {
fn scope() -> ComponentScope {
ComponentScope::Agent
}
fn name() -> &'static str {
"Upsell"
}
fn description() -> Option<&'static str> {
Some("A promotional component that displays a message with a call-to-action.")
}
fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
let examples = vec![
single_example(
"Default",
Upsell::new(
"Upgrade to Zed Pro",
"Get unlimited access to AI features and more with Zed Pro. Unlock advanced AI capabilities and other premium features.",
"Upgrade Now",
Box::new(|_, _, _| {}),
Box::new(|_, _, _| {}),
Box::new(|_, _, _| {}),
).render(window, cx).into_any_element(),
),
single_example(
"Short Message",
Upsell::new(
"Try Zed Pro for free",
"Start your 7-day trial today.",
"Start Trial",
Box::new(|_, _, _| {}),
Box::new(|_, _, _| {}),
Box::new(|_, _, _| {}),
).render(window, cx).into_any_element(),
),
];
Some(v_flex().gap_4().children(examples).into_any_element())
}
}

View File

@@ -98,6 +98,10 @@ impl RenderOnce for UsageBanner {
}
impl Component for UsageBanner {
fn scope() -> ComponentScope {
ComponentScope::Agent
}
fn sort_name() -> &'static str {
"AgentUsageBanner"
}

View File

@@ -7,8 +7,8 @@ use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
use client::{Client, TypedEnvelope, proto, telemetry::Telemetry};
use clock::ReplicaId;
use collections::HashMap;
use context_server::ContextServerFactoryRegistry;
use context_server::manager::ContextServerManager;
use context_server::ContextServerDescriptorRegistry;
use context_server::manager::{ContextServerManager, ContextServerStatus};
use fs::{Fs, RemoveOptions};
use futures::StreamExt;
use fuzzy::StringMatchCandidate;
@@ -99,7 +99,7 @@ impl ContextStore {
let this = cx.new(|cx: &mut Context<Self>| {
let context_server_factory_registry =
ContextServerFactoryRegistry::default_global(cx);
ContextServerDescriptorRegistry::default_global(cx);
let context_server_manager = cx.new(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});
@@ -831,54 +831,60 @@ impl ContextStore {
) {
let slash_command_working_set = self.slash_commands.clone();
match event {
context_server::manager::Event::ServerStarted { server_id } => {
if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
let context_server_manager = context_server_manager.clone();
cx.spawn({
let server = server.clone();
let server_id = server_id.clone();
async move |this, cx| {
let Some(protocol) = server.client() else {
return;
};
context_server::manager::Event::ServerStatusChanged { server_id, status } => {
match status {
Some(ContextServerStatus::Running) => {
if let Some(server) = context_server_manager.read(cx).get_server(server_id)
{
let context_server_manager = context_server_manager.clone();
cx.spawn({
let server = server.clone();
let server_id = server_id.clone();
async move |this, cx| {
let Some(protocol) = server.client() else {
return;
};
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
if let Some(prompts) = protocol.list_prompts().await.log_err() {
let slash_command_ids = prompts
.into_iter()
.filter(assistant_slash_commands::acceptable_prompt)
.map(|prompt| {
log::info!(
"registering context server command: {:?}",
prompt.name
);
slash_command_working_set.insert(Arc::new(
assistant_slash_commands::ContextServerSlashCommand::new(
context_server_manager.clone(),
&server,
prompt,
),
))
})
.collect::<Vec<_>>();
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
if let Some(prompts) = protocol.list_prompts().await.log_err() {
let slash_command_ids = prompts
.into_iter()
.filter(assistant_slash_commands::acceptable_prompt)
.map(|prompt| {
log::info!(
"registering context server command: {:?}",
prompt.name
);
slash_command_working_set.insert(Arc::new(
assistant_slash_commands::ContextServerSlashCommand::new(
context_server_manager.clone(),
&server,
prompt,
),
))
})
.collect::<Vec<_>>();
this.update( cx, |this, _cx| {
this.context_server_slash_command_ids
.insert(server_id.clone(), slash_command_ids);
})
.log_err();
this.update( cx, |this, _cx| {
this.context_server_slash_command_ids
.insert(server_id.clone(), slash_command_ids);
})
.log_err();
}
}
}
}
})
.detach();
}
})
.detach();
}
}
context_server::manager::Event::ServerStopped { server_id } => {
if let Some(slash_command_ids) =
self.context_server_slash_command_ids.remove(server_id)
{
slash_command_working_set.remove(&slash_command_ids);
}
None => {
if let Some(slash_command_ids) =
self.context_server_slash_command_ids.remove(server_id)
{
slash_command_working_set.remove(&slash_command_ids);
}
}
_ => {}
}
}
}

View File

@@ -6,7 +6,7 @@ use ::open_ai::Model as OpenAiModel;
use anthropic::Model as AnthropicModel;
use anyhow::{Result, bail};
use deepseek::Model as DeepseekModel;
use feature_flags::{Assistant2FeatureFlag, FeatureFlagAppExt};
use feature_flags::{AgentStreamEditsFeatureFlag, Assistant2FeatureFlag, FeatureFlagAppExt};
use gpui::{App, Pixels};
use indexmap::IndexMap;
use language_model::{CloudModel, LanguageModel};
@@ -87,9 +87,14 @@ pub struct AssistantSettings {
pub profiles: IndexMap<AgentProfileId, AgentProfile>,
pub always_allow_tool_actions: bool,
pub notify_when_agent_waiting: NotifyWhenAgentWaiting,
pub stream_edits: bool,
}
impl AssistantSettings {
pub fn stream_edits(&self, cx: &App) -> bool {
cx.has_flag::<AgentStreamEditsFeatureFlag>() || self.stream_edits
}
pub fn are_live_diffs_enabled(&self, cx: &App) -> bool {
if cx.has_flag::<Assistant2FeatureFlag>() {
return false;
@@ -218,6 +223,7 @@ impl AssistantSettingsContent {
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
},
VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(),
},
@@ -245,6 +251,7 @@ impl AssistantSettingsContent {
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
},
None => AssistantSettingsContentV2::default(),
}
@@ -495,6 +502,7 @@ impl Default for VersionedAssistantSettingsContent {
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
})
}
}
@@ -550,6 +558,10 @@ pub struct AssistantSettingsContentV2 {
///
/// Default: "primary_screen"
notify_when_agent_waiting: Option<NotifyWhenAgentWaiting>,
/// Whether to stream edits from the agent as they are received.
///
/// Default: false
stream_edits: Option<bool>,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
@@ -712,6 +724,7 @@ impl Settings for AssistantSettings {
&mut settings.notify_when_agent_waiting,
value.notify_when_agent_waiting,
);
merge(&mut settings.stream_edits, value.stream_edits);
merge(&mut settings.default_profile, value.default_profile);
if let Some(profiles) = value.profiles {
@@ -843,6 +856,7 @@ mod tests {
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
},
)),
}

View File

@@ -11,16 +11,24 @@ workspace = true
[lib]
path = "src/assistant_tools.rs"
[features]
eval = []
[dependencies]
aho-corasick.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
assistant_settings.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
collections.workspace = true
component.workspace = true
editor.workspace = true
derive_more.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
handlebars = { workspace = true, features = ["rust-embed"] }
html_to_markdown.workspace = true
http_client.workspace = true
indoc.workspace = true
@@ -31,9 +39,14 @@ linkme.workspace = true
open.workspace = true
project.workspace = true
regex.workspace = true
rust-embed.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smallvec.workspace = true
streaming_diff.workspace = true
strsim.workspace = true
task.workspace = true
terminal.workspace = true
terminal_view.workspace = true
@@ -49,10 +62,15 @@ client = { workspace = true, features = ["test-support"] }
clock = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
gpui_tokio.workspace = true
fs = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
language_models.workspace = true
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
pretty_assertions.workspace = true
reqwest_client.workspace = true
settings = { workspace = true, features = ["test-support"] }
task = { workspace = true, features = ["test-support"]}
tempfile.workspace = true

View File

@@ -7,6 +7,7 @@ mod create_directory_tool;
mod create_file_tool;
mod delete_path_tool;
mod diagnostics_tool;
mod edit_agent;
mod edit_file_tool;
mod fetch_tool;
mod find_path_tool;
@@ -19,7 +20,9 @@ mod read_file_tool;
mod rename_tool;
mod replace;
mod schema;
mod streaming_edit_file_tool;
mod symbol_info_tool;
mod templates;
mod terminal_tool;
mod thinking_tool;
mod ui;
@@ -27,14 +30,19 @@ mod web_search_tool;
use std::sync::Arc;
use assistant_settings::AssistantSettings;
use assistant_tool::ToolRegistry;
use copy_path_tool::CopyPathTool;
use feature_flags::{AgentStreamEditsFeatureFlag, FeatureFlagAppExt};
use gpui::App;
use http_client::HttpClientWithUrl;
use language_model::LanguageModelRegistry;
use move_path_tool::MovePathTool;
use settings::{Settings, SettingsStore};
use web_search_tool::WebSearchTool;
pub(crate) use templates::*;
use crate::batch_tool::BatchTool;
use crate::code_action_tool::CodeActionTool;
use crate::code_symbols_tool::CodeSymbolsTool;
@@ -52,6 +60,7 @@ use crate::now_tool::NowTool;
use crate::open_tool::OpenTool;
use crate::read_file_tool::ReadFileTool;
use crate::rename_tool::RenameTool;
use crate::streaming_edit_file_tool::StreamingEditFileTool;
use crate::symbol_info_tool::SymbolInfoTool;
use crate::terminal_tool::TerminalTool;
use crate::thinking_tool::ThinkingTool;
@@ -68,10 +77,8 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(TerminalTool);
registry.register_tool(BatchTool);
registry.register_tool(CreateDirectoryTool);
registry.register_tool(CreateFileTool);
registry.register_tool(CopyPathTool);
registry.register_tool(DeletePathTool);
registry.register_tool(EditFileTool);
registry.register_tool(SymbolInfoTool);
registry.register_tool(CodeActionTool);
registry.register_tool(MovePathTool);
@@ -88,6 +95,12 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(ThinkingTool);
registry.register_tool(FetchTool::new(http_client));
register_edit_file_tool(cx);
cx.observe_flag::<AgentStreamEditsFeatureFlag, _>(|_, cx| register_edit_file_tool(cx))
.detach();
cx.observe_global::<SettingsStore>(register_edit_file_tool)
.detach();
cx.subscribe(
&LanguageModelRegistry::global(cx),
move |registry, event, cx| match event {
@@ -108,6 +121,21 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
.detach();
}
fn register_edit_file_tool(cx: &mut App) {
let registry = ToolRegistry::global(cx);
registry.unregister_tool(CreateFileTool);
registry.unregister_tool(EditFileTool);
registry.unregister_tool(StreamingEditFileTool);
if AssistantSettings::get_global(cx).stream_edits(cx) {
registry.register_tool(StreamingEditFileTool);
} else {
registry.register_tool(CreateFileTool);
registry.register_tool(EditFileTool);
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -146,6 +174,7 @@ mod tests {
#[gpui::test]
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
settings::init(cx);
AssistantSettings::register(cx);
let client = Client::new(
Arc::new(FakeSystemClock::new()),

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,408 @@
use derive_more::{Add, AddAssign};
use smallvec::SmallVec;
use std::{cmp, mem, ops::Range};
const OLD_TEXT_END_TAG: &str = "</old_text>";
const NEW_TEXT_END_TAG: &str = "</new_text>";
const END_TAG_LEN: usize = OLD_TEXT_END_TAG.len();
const _: () = debug_assert!(OLD_TEXT_END_TAG.len() == NEW_TEXT_END_TAG.len());
#[derive(Debug)]
pub enum EditParserEvent {
OldText(String),
NewTextChunk { chunk: String, done: bool },
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Add, AddAssign)]
pub struct EditParserMetrics {
pub tags: usize,
pub mismatched_tags: usize,
}
#[derive(Debug)]
pub struct EditParser {
state: EditParserState,
buffer: String,
metrics: EditParserMetrics,
}
#[derive(Debug, PartialEq)]
enum EditParserState {
Pending,
WithinOldText,
AfterOldText,
WithinNewText { start: bool },
}
impl EditParser {
pub fn new() -> Self {
EditParser {
state: EditParserState::Pending,
buffer: String::new(),
metrics: EditParserMetrics::default(),
}
}
pub fn push(&mut self, chunk: &str) -> SmallVec<[EditParserEvent; 1]> {
self.buffer.push_str(chunk);
let mut edit_events = SmallVec::new();
loop {
match &mut self.state {
EditParserState::Pending => {
if let Some(start) = self.buffer.find("<old_text>") {
self.buffer.drain(..start + "<old_text>".len());
self.state = EditParserState::WithinOldText;
} else {
break;
}
}
EditParserState::WithinOldText => {
if let Some(tag_range) = self.find_end_tag() {
let mut start = 0;
if self.buffer.starts_with('\n') {
start = 1;
}
let mut old_text = self.buffer[start..tag_range.start].to_string();
if old_text.ends_with('\n') {
old_text.pop();
}
self.metrics.tags += 1;
if &self.buffer[tag_range.clone()] != OLD_TEXT_END_TAG {
self.metrics.mismatched_tags += 1;
}
self.buffer.drain(..tag_range.end);
self.state = EditParserState::AfterOldText;
edit_events.push(EditParserEvent::OldText(old_text));
} else {
break;
}
}
EditParserState::AfterOldText => {
if let Some(start) = self.buffer.find("<new_text>") {
self.buffer.drain(..start + "<new_text>".len());
self.state = EditParserState::WithinNewText { start: true };
} else {
break;
}
}
EditParserState::WithinNewText { start } => {
if !self.buffer.is_empty() {
if *start && self.buffer.starts_with('\n') {
self.buffer.remove(0);
}
*start = false;
}
if let Some(tag_range) = self.find_end_tag() {
let mut chunk = self.buffer[..tag_range.start].to_string();
if chunk.ends_with('\n') {
chunk.pop();
}
self.metrics.tags += 1;
if &self.buffer[tag_range.clone()] != NEW_TEXT_END_TAG {
self.metrics.mismatched_tags += 1;
}
self.buffer.drain(..tag_range.end);
self.state = EditParserState::Pending;
edit_events.push(EditParserEvent::NewTextChunk { chunk, done: true });
} else {
let mut end_prefixes = (1..END_TAG_LEN)
.flat_map(|i| [&NEW_TEXT_END_TAG[..i], &OLD_TEXT_END_TAG[..i]])
.chain(["\n"]);
if end_prefixes.all(|prefix| !self.buffer.ends_with(&prefix)) {
edit_events.push(EditParserEvent::NewTextChunk {
chunk: mem::take(&mut self.buffer),
done: false,
});
}
break;
}
}
}
}
edit_events
}
fn find_end_tag(&self) -> Option<Range<usize>> {
let old_text_end_tag_ix = self.buffer.find(OLD_TEXT_END_TAG);
let new_text_end_tag_ix = self.buffer.find(NEW_TEXT_END_TAG);
let start_ix = if let Some((old_text_ix, new_text_ix)) =
old_text_end_tag_ix.zip(new_text_end_tag_ix)
{
cmp::min(old_text_ix, new_text_ix)
} else {
old_text_end_tag_ix.or(new_text_end_tag_ix)?
};
Some(start_ix..start_ix + END_TAG_LEN)
}
pub fn finish(self) -> EditParserMetrics {
self.metrics
}
}
#[cfg(test)]
mod tests {
use super::*;
use indoc::indoc;
use rand::prelude::*;
use std::cmp;
#[gpui::test(iterations = 1000)]
fn test_single_edit(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text>original</old_text><new_text>updated</new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "original".to_string(),
new_text: "updated".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_multiple_edits(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
indoc! {"
<old_text>
first old
</old_text><new_text>first new</new_text>
<old_text>second old</old_text><new_text>
second new
</new_text>
"},
&mut parser,
&mut rng
),
vec![
Edit {
old_text: "first old".to_string(),
new_text: "first new".to_string(),
},
Edit {
old_text: "second old".to_string(),
new_text: "second new".to_string(),
},
]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 4,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_edits_with_extra_text(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
indoc! {"
ignore this <old_text>
content</old_text>extra stuff<new_text>updated content</new_text>trailing data
more text <old_text>second item
</old_text>middle text<new_text>modified second item</new_text>end
<old_text>third case</old_text><new_text>improved third case</new_text> with trailing text
"},
&mut parser,
&mut rng
),
vec![
Edit {
old_text: "content".to_string(),
new_text: "updated content".to_string(),
},
Edit {
old_text: "second item".to_string(),
new_text: "modified second item".to_string(),
},
Edit {
old_text: "third case".to_string(),
new_text: "improved third case".to_string(),
},
]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 6,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_nested_tags(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text>code with <tag>nested</tag> elements</old_text><new_text>new <code>content</code></new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "code with <tag>nested</tag> elements".to_string(),
new_text: "new <code>content</code>".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_empty_old_and_new_text(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text></old_text><new_text></new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "".to_string(),
new_text: "".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 100)]
fn test_multiline_content(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text>line1\nline2\nline3</old_text><new_text>line1\nmodified line2\nline3</new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "line1\nline2\nline3".to_string(),
new_text: "line1\nmodified line2\nline3".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_mismatched_tags(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
// Reduced from an actual Sonnet 3.7 output
indoc! {"
<old_text>
a
b
c
</new_text>
<new_text>
a
B
c
</old_text>
<old_text>
d
e
f
</new_text>
<new_text>
D
e
F
</old_text>
"},
&mut parser,
&mut rng
),
vec![
Edit {
old_text: "a\nb\nc".to_string(),
new_text: "a\nB\nc".to_string(),
},
Edit {
old_text: "d\ne\nf".to_string(),
new_text: "D\ne\nF".to_string(),
}
]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 4,
mismatched_tags: 4
}
);
}
#[derive(Default, Debug, PartialEq, Eq)]
struct Edit {
old_text: String,
new_text: String,
}
fn parse_random_chunks(input: &str, parser: &mut EditParser, rng: &mut StdRng) -> Vec<Edit> {
let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
chunk_indices.sort();
chunk_indices.push(input.len());
let mut pending_edit = Edit::default();
let mut edits = Vec::new();
let mut last_ix = 0;
for chunk_ix in chunk_indices {
for event in parser.push(&input[last_ix..chunk_ix]) {
match event {
EditParserEvent::OldText(old_text) => {
pending_edit.old_text = old_text;
}
EditParserEvent::NewTextChunk { chunk, done } => {
pending_edit.new_text.push_str(&chunk);
if done {
edits.push(pending_edit);
pending_edit = Edit::default();
}
}
}
}
last_ix = chunk_ix;
}
edits
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,328 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View File

@@ -0,0 +1,374 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
async fn run_git_blame(
git_binary: &Path,
working_directory: &Path,
path: &Path,
contents: &Rope,
) -> Result<String> {
let mut child = util::command::new_smol_command(git_binary)
.current_dir(working_directory)
.arg("blame")
.arg("--incremental")
.arg("--contents")
.arg("-")
.arg(path.as_os_str())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
let stdin = child
.stdin
.as_mut()
.context("failed to get pipe to stdin of git blame command")?;
for chunk in contents.chunks() {
stdin.write_all(chunk.as_bytes()).await?;
}
stdin.flush().await?;
let output = child
.output()
.await
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let trimmed = stderr.trim();
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
return Ok(String::new());
}
return Err(anyhow!("git blame process failed: {}", stderr));
}
Ok(String::from_utf8(output.stdout)?)
}
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View File

@@ -0,0 +1,378 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
async fn run_git_blame(
git_binary: &Path,
working_directory: &Path,
path: &Path,
contents: &Rope,
) -> Result<String> {
let mut child = util::command::new_smol_command(git_binary)
.current_dir(working_directory)
.arg("blame")
.arg("--incremental")
.arg("--contents")
.arg("-")
.arg(path.as_os_str())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
let stdin = child
.stdin
.as_mut()
.context("failed to get pipe to stdin of git blame command")?;
for chunk in contents.chunks() {
stdin.write_all(chunk.as_bytes()).await?;
}
stdin.flush().await?;
let output = child
.output()
.await
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
handle_command_output(output)
}
fn handle_command_output(output: std::process::Output) -> Result<String> {
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let trimmed = stderr.trim();
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
return Ok(String::new());
}
return Err(anyhow!("git blame process failed: {}", stderr));
}
Ok(String::from_utf8(output.stdout)?)
}
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View File

@@ -0,0 +1,374 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
async fn run_git_blame(
git_binary: &Path,
working_directory: &Path,
path: &Path,
contents: &Rope,
) -> Result<String> {
let mut child = util::command::new_smol_command(git_binary)
.current_dir(working_directory)
.arg("blame")
.arg("--incremental")
.arg("--contents")
.arg("-")
.arg(path.as_os_str())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
let stdin = child
.stdin
.as_mut()
.context("failed to get pipe to stdin of git blame command")?;
for chunk in contents.chunks() {
stdin.write_all(chunk.as_bytes()).await?;
}
stdin.flush().await?;
let output = child
.output()
.await
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let trimmed = stderr.trim();
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
return Ok(String::new());
}
return Err(anyhow!("git blame process failed: {}", stderr));
}
Ok(String::from_utf8(output.stdout)?)
}
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View File

@@ -0,0 +1,339 @@
// font-kit/src/canvas.rs
//
// Copyright © 2018 The Pathfinder Project Developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//! An in-memory bitmap surface for glyph rasterization.
use lazy_static::lazy_static;
use pathfinder_geometry::rect::RectI;
use pathfinder_geometry::vector::Vector2I;
use std::cmp;
use std::fmt;
use crate::utils;
lazy_static! {
static ref BITMAP_1BPP_TO_8BPP_LUT: [[u8; 8]; 256] = {
let mut lut = [[0; 8]; 256];
for byte in 0..0x100 {
let mut value = [0; 8];
for bit in 0..8 {
if (byte & (0x80 >> bit)) != 0 {
value[bit] = 0xff;
}
}
lut[byte] = value
}
lut
};
}
/// An in-memory bitmap surface for glyph rasterization.
pub struct Canvas {
/// The raw pixel data.
pub pixels: Vec<u8>,
/// The size of the buffer, in pixels.
pub size: Vector2I,
/// The number of *bytes* between successive rows.
pub stride: usize,
/// The image format of the canvas.
pub format: Format,
}
impl Canvas {
/// Creates a new blank canvas with the given pixel size and format.
///
/// Stride is automatically calculated from width.
///
/// The canvas is initialized with transparent black (all values 0).
#[inline]
pub fn new(size: Vector2I, format: Format) -> Canvas {
Canvas::with_stride(
size,
size.x() as usize * format.bytes_per_pixel() as usize,
format,
)
}
/// Creates a new blank canvas with the given pixel size, stride (number of bytes between
/// successive rows), and format.
///
/// The canvas is initialized with transparent black (all values 0).
pub fn with_stride(size: Vector2I, stride: usize, format: Format) -> Canvas {
Canvas {
pixels: vec![0; stride * size.y() as usize],
size,
stride,
format,
}
}
#[allow(dead_code)]
pub(crate) fn blit_from_canvas(&mut self, src: &Canvas) {
self.blit_from(
Vector2I::default(),
&src.pixels,
src.size,
src.stride,
src.format,
)
}
/// Blits to a rectangle with origin at `dst_point` and size according to `src_size`.
/// If the target area overlaps the boundaries of the canvas, only the drawable region is blitted.
/// `dst_point` and `src_size` are specified in pixels. `src_stride` is specified in bytes.
/// `src_stride` must be equal or larger than the actual data length.
#[allow(dead_code)]
pub(crate) fn blit_from(
&mut self,
dst_point: Vector2I,
src_bytes: &[u8],
src_size: Vector2I,
src_stride: usize,
src_format: Format,
) {
assert_eq!(
src_stride * src_size.y() as usize,
src_bytes.len(),
"Number of pixels in src_bytes does not match stride and size."
);
assert!(
src_stride >= src_size.x() as usize * src_format.bytes_per_pixel() as usize,
"src_stride must be >= than src_size.x()"
);
let dst_rect = RectI::new(dst_point, src_size);
let dst_rect = dst_rect.intersection(RectI::new(Vector2I::default(), self.size));
let dst_rect = match dst_rect {
Some(dst_rect) => dst_rect,
None => return,
};
match (self.format, src_format) {
(Format::A8, Format::A8)
| (Format::Rgb24, Format::Rgb24)
| (Format::Rgba32, Format::Rgba32) => {
self.blit_from_with::<BlitMemcpy>(dst_rect, src_bytes, src_stride, src_format)
}
(Format::A8, Format::Rgb24) => {
self.blit_from_with::<BlitRgb24ToA8>(dst_rect, src_bytes, src_stride, src_format)
}
(Format::Rgb24, Format::A8) => {
self.blit_from_with::<BlitA8ToRgb24>(dst_rect, src_bytes, src_stride, src_format)
}
(Format::Rgb24, Format::Rgba32) => self
.blit_from_with::<BlitRgba32ToRgb24>(dst_rect, src_bytes, src_stride, src_format),
(Format::Rgba32, Format::Rgb24) => self
.blit_from_with::<BlitRgb24ToRgba32>(dst_rect, src_bytes, src_stride, src_format),
(Format::Rgba32, Format::A8) | (Format::A8, Format::Rgba32) => unimplemented!(),
}
}
#[allow(dead_code)]
pub(crate) fn blit_from_bitmap_1bpp(
&mut self,
dst_point: Vector2I,
src_bytes: &[u8],
src_size: Vector2I,
src_stride: usize,
) {
if self.format != Format::A8 {
unimplemented!()
}
let dst_rect = RectI::new(dst_point, src_size);
let dst_rect = dst_rect.intersection(RectI::new(Vector2I::default(), self.size));
let dst_rect = match dst_rect {
Some(dst_rect) => dst_rect,
None => return,
};
let size = dst_rect.size();
let dest_bytes_per_pixel = self.format.bytes_per_pixel() as usize;
let dest_row_stride = size.x() as usize * dest_bytes_per_pixel;
let src_row_stride = utils::div_round_up(size.x() as usize, 8);
for y in 0..size.y() {
let (dest_row_start, src_row_start) = (
(y + dst_rect.origin_y()) as usize * self.stride
+ dst_rect.origin_x() as usize * dest_bytes_per_pixel,
y as usize * src_stride,
);
let dest_row_end = dest_row_start + dest_row_stride;
let src_row_end = src_row_start + src_row_stride;
let dest_row_pixels = &mut self.pixels[dest_row_start..dest_row_end];
let src_row_pixels = &src_bytes[src_row_start..src_row_end];
for x in 0..src_row_stride {
let pattern = &BITMAP_1BPP_TO_8BPP_LUT[src_row_pixels[x] as usize];
let dest_start = x * 8;
let dest_end = cmp::min(dest_start + 8, dest_row_stride);
let src = &pattern[0..(dest_end - dest_start)];
dest_row_pixels[dest_start..dest_end].clone_from_slice(src);
}
}
}
/// Blits to area `rect` using the data given in the buffer `src_bytes`.
/// `src_stride` must be specified in bytes.
/// The dimensions of `rect` must be in pixels.
fn blit_from_with<B: Blit>(
&mut self,
rect: RectI,
src_bytes: &[u8],
src_stride: usize,
src_format: Format,
) {
let src_bytes_per_pixel = src_format.bytes_per_pixel() as usize;
let dest_bytes_per_pixel = self.format.bytes_per_pixel() as usize;
for y in 0..rect.height() {
let (dest_row_start, src_row_start) = (
(y + rect.origin_y()) as usize * self.stride
+ rect.origin_x() as usize * dest_bytes_per_pixel,
y as usize * src_stride,
);
let dest_row_end = dest_row_start + rect.width() as usize * dest_bytes_per_pixel;
let src_row_end = src_row_start + rect.width() as usize * src_bytes_per_pixel;
let dest_row_pixels = &mut self.pixels[dest_row_start..dest_row_end];
let src_row_pixels = &src_bytes[src_row_start..src_row_end];
B::blit(dest_row_pixels, src_row_pixels)
}
}
}
impl fmt::Debug for Canvas {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Canvas")
.field("pixels", &self.pixels.len()) // Do not dump a vector content.
.field("size", &self.size)
.field("stride", &self.stride)
.field("format", &self.format)
.finish()
}
}
/// The image format for the canvas.
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Format {
/// Premultiplied R8G8B8A8, little-endian.
Rgba32,
/// R8G8B8, little-endian.
Rgb24,
/// A8.
A8,
}
impl Format {
/// Returns the number of bits per pixel that this image format corresponds to.
#[inline]
pub fn bits_per_pixel(self) -> u8 {
match self {
Format::Rgba32 => 32,
Format::Rgb24 => 24,
Format::A8 => 8,
}
}
/// Returns the number of color channels per pixel that this image format corresponds to.
#[inline]
pub fn components_per_pixel(self) -> u8 {
match self {
Format::Rgba32 => 4,
Format::Rgb24 => 3,
Format::A8 => 1,
}
}
/// Returns the number of bits per color channel that this image format contains.
#[inline]
pub fn bits_per_component(self) -> u8 {
self.bits_per_pixel() / self.components_per_pixel()
}
/// Returns the number of bytes per pixel that this image format corresponds to.
#[inline]
pub fn bytes_per_pixel(self) -> u8 {
self.bits_per_pixel() / 8
}
}
/// The antialiasing strategy that should be used when rasterizing glyphs.
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum RasterizationOptions {
/// "Black-and-white" rendering. Each pixel is either entirely on or off.
Bilevel,
/// Grayscale antialiasing. Only one channel is used.
GrayscaleAa,
/// Subpixel RGB antialiasing, for LCD screens.
SubpixelAa,
}
trait Blit {
fn blit(dest: &mut [u8], src: &[u8]);
}
struct BlitMemcpy;
impl Blit for BlitMemcpy {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
dest.clone_from_slice(src)
}
}
struct BlitRgb24ToA8;
impl Blit for BlitRgb24ToA8 {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
// TODO(pcwalton): SIMD.
for (dest, src) in dest.iter_mut().zip(src.chunks(3)) {
*dest = src[1]
}
}
}
struct BlitA8ToRgb24;
impl Blit for BlitA8ToRgb24 {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
for (dest, src) in dest.chunks_mut(3).zip(src.iter()) {
dest[0] = *src;
dest[1] = *src;
dest[2] = *src;
}
}
}
struct BlitRgba32ToRgb24;
impl Blit for BlitRgba32ToRgb24 {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
// TODO(pcwalton): SIMD.
for (dest, src) in dest.chunks_mut(3).zip(src.chunks(4)) {
dest.copy_from_slice(&src[0..3])
}
}
}
struct BlitRgb24ToRgba32;
impl Blit for BlitRgb24ToRgba32 {
fn blit(dest: &mut [u8], src: &[u8]) {
for (dest, src) in dest.chunks_mut(4).zip(src.chunks(3)) {
dest[0] = src[0];
dest[1] = src[1];
dest[2] = src[2];
dest[3] = 255;
}
}
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,14 @@
class InputCell:
def __init__(self, initial_value):
self.value = None
class ComputeCell:
def __init__(self, inputs, compute_function):
self.value = None
def add_callback(self, callback):
pass
def remove_callback(self, callback):
pass

View File

@@ -0,0 +1,271 @@
# These tests are auto-generated with test data from:
# https://github.com/exercism/problem-specifications/tree/main/exercises/react/canonical-data.json
# File last updated on 2023-07-19
from functools import partial
import unittest
from react import (
InputCell,
ComputeCell,
)
class ReactTest(unittest.TestCase):
def test_input_cells_have_a_value(self):
input = InputCell(10)
self.assertEqual(input.value, 10)
def test_an_input_cell_s_value_can_be_set(self):
input = InputCell(4)
input.value = 20
self.assertEqual(input.value, 20)
def test_compute_cells_calculate_initial_value(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
self.assertEqual(output.value, 2)
def test_compute_cells_take_inputs_in_the_right_order(self):
one = InputCell(1)
two = InputCell(2)
output = ComputeCell(
[
one,
two,
],
lambda inputs: inputs[0] + inputs[1] * 10,
)
self.assertEqual(output.value, 21)
def test_compute_cells_update_value_when_dependencies_are_changed(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
input.value = 3
self.assertEqual(output.value, 4)
def test_compute_cells_can_depend_on_other_compute_cells(self):
input = InputCell(1)
times_two = ComputeCell(
[
input,
],
lambda inputs: inputs[0] * 2,
)
times_thirty = ComputeCell(
[
input,
],
lambda inputs: inputs[0] * 30,
)
output = ComputeCell(
[
times_two,
times_thirty,
],
lambda inputs: inputs[0] + inputs[1],
)
self.assertEqual(output.value, 32)
input.value = 3
self.assertEqual(output.value, 96)
def test_compute_cells_fire_callbacks(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 3
self.assertEqual(cb1_observer[-1], 4)
def test_callback_cells_only_fire_on_change(self):
input = InputCell(1)
output = ComputeCell([input], lambda inputs: 111 if inputs[0] < 3 else 222)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer, [])
input.value = 4
self.assertEqual(cb1_observer[-1], 222)
def test_callbacks_do_not_report_already_reported_values(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer[-1], 3)
input.value = 3
self.assertEqual(cb1_observer[-1], 4)
def test_callbacks_can_fire_from_multiple_cells(self):
input = InputCell(1)
plus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
minus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] - 1,
)
cb1_observer = []
cb2_observer = []
callback1 = self.callback_factory(cb1_observer)
callback2 = self.callback_factory(cb2_observer)
plus_one.add_callback(callback1)
minus_one.add_callback(callback2)
input.value = 10
self.assertEqual(cb1_observer[-1], 11)
self.assertEqual(cb2_observer[-1], 9)
def test_callbacks_can_be_added_and_removed(self):
input = InputCell(11)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
cb2_observer = []
cb3_observer = []
callback1 = self.callback_factory(cb1_observer)
callback2 = self.callback_factory(cb2_observer)
callback3 = self.callback_factory(cb3_observer)
output.add_callback(callback1)
output.add_callback(callback2)
input.value = 31
self.assertEqual(cb1_observer[-1], 32)
self.assertEqual(cb2_observer[-1], 32)
output.remove_callback(callback1)
output.add_callback(callback3)
input.value = 41
self.assertEqual(len(cb1_observer), 1)
self.assertEqual(cb2_observer[-1], 42)
self.assertEqual(cb3_observer[-1], 42)
def test_removing_a_callback_multiple_times_doesn_t_interfere_with_other_callbacks(
self,
):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
cb2_observer = []
callback1 = self.callback_factory(cb1_observer)
callback2 = self.callback_factory(cb2_observer)
output.add_callback(callback1)
output.add_callback(callback2)
output.remove_callback(callback1)
output.remove_callback(callback1)
output.remove_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer, [])
self.assertEqual(cb2_observer[-1], 3)
def test_callbacks_should_only_be_called_once_even_if_multiple_dependencies_change(
self,
):
input = InputCell(1)
plus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
minus_one1 = ComputeCell(
[
input,
],
lambda inputs: inputs[0] - 1,
)
minus_one2 = ComputeCell(
[
minus_one1,
],
lambda inputs: inputs[0] - 1,
)
output = ComputeCell(
[
plus_one,
minus_one2,
],
lambda inputs: inputs[0] * inputs[1],
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 4
self.assertEqual(cb1_observer[-1], 10)
def test_callbacks_should_not_be_called_if_dependencies_change_but_output_value_doesn_t_change(
self,
):
input = InputCell(1)
plus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
minus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] - 1,
)
always_two = ComputeCell(
[
plus_one,
minus_one,
],
lambda inputs: inputs[0] - inputs[1],
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
always_two.add_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer, [])
input.value = 3
self.assertEqual(cb1_observer, [])
input.value = 4
self.assertEqual(cb1_observer, [])
input.value = 5
self.assertEqual(cb1_observer, [])
# Utility functions.
def callback_factory(self, observer):
def callback(observer, value):
observer.append(value)
return partial(callback, observer)

View File

@@ -282,7 +282,7 @@ pub struct EditFileToolCard {
}
impl EditFileToolCard {
fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
pub fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
let multibuffer = cx.new(|_| MultiBuffer::without_headers(Capability::ReadOnly));
let editor = cx.new(|cx| {
let mut editor = Editor::new(
@@ -323,7 +323,7 @@ impl EditFileToolCard {
}
}
fn set_diff(
pub fn set_diff(
&mut self,
path: Arc<Path>,
old_text: String,
@@ -343,6 +343,7 @@ impl EditFileToolCard {
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
.collect::<Vec<_>>();
multibuffer.clear(cx);
let (_, is_newly_added) = multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&buffer, cx),
buffer,
@@ -576,13 +577,10 @@ impl ToolCard for EditFileToolCard {
card.child(
v_flex()
.relative()
.map(|editor_container| {
if self.full_height_expanded {
editor_container.h_full()
} else {
editor_container
.h(DEFAULT_COLLAPSED_LINES as f32 * editor_line_height)
}
.h_full()
.when(!self.full_height_expanded, |editor_container| {
editor_container
.max_h(DEFAULT_COLLAPSED_LINES as f32 * editor_line_height)
})
.overflow_hidden()
.border_t_1()

View File

@@ -0,0 +1,352 @@
use crate::{
Templates,
edit_agent::{EditAgent, EditAgentOutputEvent},
edit_file_tool::EditFileToolCard,
schema::json_schema_for,
};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult};
use futures::StreamExt;
use gpui::{AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task};
use indoc::formatdoc;
use language_model::{
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolSchemaFormat,
};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc};
use ui::prelude::*;
use util::ResultExt;
pub struct StreamingEditFileTool;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct StreamingEditFileToolInput {
/// A one-line, user-friendly markdown description of the edit. This will be
/// shown in the UI and also passed to another model to perform the edit.
///
/// Be terse, but also descriptive in what you want to achieve with this
/// edit. Avoid generic instructions.
///
/// NEVER mention the file path in this description.
///
/// <example>Fix API endpoint URLs</example>
/// <example>Update copyright year in `page_footer`</example>
///
/// Make sure to include this field before all the others in the input object
/// so that we can display it immediately.
pub display_description: String,
/// The full path of the file to create or modify in the project.
///
/// WARNING: When specifying which file path need changing, you MUST
/// start each path with one of the project's root directories.
///
/// The following examples assume we have two root directories in the project:
/// - backend
/// - frontend
///
/// <example>
/// `backend/src/main.rs`
///
/// Notice how the file path starts with root-1. Without that, the path
/// would be ambiguous and the call would fail!
/// </example>
///
/// <example>
/// `frontend/db.js`
/// </example>
pub path: PathBuf,
/// If true, this tool will recreate the file from scratch.
/// If false, this tool will produce granular edits to an existing file.
pub create_or_overwrite: bool,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct PartialInput {
#[serde(default)]
path: String,
#[serde(default)]
display_description: String,
}
const DEFAULT_UI_TEXT: &str = "Editing file";
impl Tool for StreamingEditFileTool {
fn name(&self) -> String {
"edit_file".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("streaming_edit_file_tool/description.md").to_string()
}
fn icon(&self) -> IconName {
IconName::Pencil
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<StreamingEditFileToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<StreamingEditFileToolInput>(input.clone()) {
Ok(input) => input.display_description,
Err(_) => "Editing file".to_string(),
}
}
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
let description = input.display_description.trim();
if !description.is_empty() {
return description.to_string();
}
let path = input.path.trim();
if !path.is_empty() {
return path.to_string();
}
}
DEFAULT_UI_TEXT.to_string()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<StreamingEditFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!(
"Path {} not found in project",
input.path.display()
)))
.into();
};
let Some(worktree) = project
.read(cx)
.worktree_for_id(project_path.worktree_id, cx)
else {
return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
};
let exists = worktree.update(cx, |worktree, cx| {
worktree.file_exists(&project_path.path, cx)
});
let card = window.and_then(|window| {
window
.update(cx, |_, window, cx| {
cx.new(|cx| {
EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
})
})
.ok()
});
let card_clone = card.clone();
let messages = messages.to_vec();
let task = cx.spawn(async move |cx: &mut AsyncApp| {
if !input.create_or_overwrite && !exists.await? {
return Err(anyhow!("{} not found", input.path.display()));
}
let model = cx
.update(|cx| LanguageModelRegistry::read_global(cx).default_model())?
.context("default model not set")?
.model;
let edit_agent = EditAgent::new(model, action_log, Templates::new());
let buffer = project
.update(cx, |project, cx| {
project.open_buffer(project_path.clone(), cx)
})?
.await?;
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let old_text = cx
.background_spawn({
let old_snapshot = old_snapshot.clone();
async move { old_snapshot.text() }
})
.await;
let (output, mut events) = if input.create_or_overwrite {
edit_agent.overwrite(
buffer.clone(),
input.display_description.clone(),
messages,
cx,
)
} else {
edit_agent.edit(
buffer.clone(),
input.display_description.clone(),
messages,
cx,
)
};
let mut hallucinated_old_text = false;
while let Some(event) = events.next().await {
match event {
EditAgentOutputEvent::Edited { position } => {
if let Some(card) = card_clone.as_ref() {
let new_snapshot =
buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let new_text = cx
.background_spawn({
let new_snapshot = new_snapshot.clone();
async move { new_snapshot.text() }
})
.await;
card.update(cx, |card, cx| {
card.set_diff(
project_path.path.clone(),
old_text.clone(),
new_text,
cx,
);
})
.log_err();
}
}
EditAgentOutputEvent::OldTextNotFound(_) => hallucinated_old_text = true,
}
}
output.await?;
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let new_text = cx.background_spawn({
let new_snapshot = new_snapshot.clone();
async move { new_snapshot.text() }
});
let diff = cx.background_spawn(async move {
language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
});
let (new_text, diff) = futures::join!(new_text, diff);
if let Some(card) = card_clone {
card.update(cx, |card, cx| {
card.set_diff(project_path.path.clone(), old_text, new_text, cx);
})
.log_err();
}
let input_path = input.path.display();
if diff.is_empty() {
if hallucinated_old_text {
Err(anyhow!(formatdoc! {"
Some edits were produced but none of them could be applied.
Read the relevant sections of {input_path} again so that
I can perform the requested edits.
"}))
} else {
Ok("No edits were made.".to_string())
}
} else {
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff))
}
});
ToolResult {
output: task,
card: card.map(AnyToolCard::from),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn still_streaming_ui_text_with_path() {
let input = json!({
"path": "src/main.rs",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
"src/main.rs"
);
}
#[test]
fn still_streaming_ui_text_with_description() {
let input = json!({
"path": "",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
"Fix error handling",
);
}
#[test]
fn still_streaming_ui_text_with_path_and_description() {
let input = json!({
"path": "src/main.rs",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
"Fix error handling",
);
}
#[test]
fn still_streaming_ui_text_no_path_or_description() {
let input = json!({
"path": "",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
DEFAULT_UI_TEXT,
);
}
#[test]
fn still_streaming_ui_text_with_null() {
let input = serde_json::Value::Null;
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
DEFAULT_UI_TEXT,
);
}
}

View File

@@ -0,0 +1,8 @@
This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
Before using this tool:
1. Use the `read_file` tool to understand the file's contents and context
2. Verify the directory path is correct (only applicable when creating new files):
- Use the `list_directory` tool to verify the parent directory exists and is the correct location

View File

@@ -0,0 +1,32 @@
use anyhow::Result;
use handlebars::Handlebars;
use rust_embed::RustEmbed;
use serde::Serialize;
use std::sync::Arc;
#[derive(RustEmbed)]
#[folder = "src/templates"]
#[include = "*.hbs"]
struct Assets;
pub struct Templates(Handlebars<'static>);
impl Templates {
pub fn new() -> Arc<Self> {
let mut handlebars = Handlebars::new();
handlebars.register_embed_templates::<Assets>().unwrap();
handlebars.register_escape_fn(|text| text.into());
Arc::new(Self(handlebars))
}
}
pub trait Template: Sized {
const TEMPLATE_NAME: &'static str;
fn render(&self, templates: &Templates) -> Result<String>
where
Self: Serialize + Sized,
{
Ok(templates.0.render(Self::TEMPLATE_NAME, self)?)
}
}

View File

@@ -0,0 +1,12 @@
You are an expert engineer and your task is to write a new file from scratch.
<file_to_edit>
{{path}}
</file_to_edit>
<edit_description>
{{edit_description}}
</edit_description>
You MUST respond directly with the file's content, without explanations, additional text or triple backticks.
The text you output will be saved verbatim as the content of the file.

View File

@@ -0,0 +1,23 @@
You are an expert coder, and have been tasked with looking at the following diff:
<diff>
{{diff}}
</diff>
Evaluate the following assertions:
<assertions>
{{assertions}}
</assertions>
You must respond with a short analysis and a score between 0 and 100, where:
- 0 means no assertions pass
- 100 means all the assertions pass perfectly
<analysis>
- Assertion 1: one line describing why the first assertion passes or fails (even partially)
- Assertion 2: one line describing why the second assertion passes or fails (even partially)
- ...
- Assertion N: one line describing why the Nth assertion passes or fails (even partially)
</analysis>
<score>YOUR FINAL SCORE HERE</score>

View File

@@ -0,0 +1,49 @@
You are an expert text editor and your task is to produce a series of edits to a file given a description of the changes you need to make.
You MUST respond with a series of edits to that one file in the following format:
```
<edits>
<old_text>
OLD TEXT 1 HERE
</old_text>
<new_text>
NEW TEXT 1 HERE
</new_text>
<old_text>
OLD TEXT 2 HERE
</old_text>
<new_text>
NEW TEXT 2 HERE
</new_text>
<old_text>
OLD TEXT 3 HERE
</old_text>
<new_text>
NEW TEXT 3 HERE
</new_text>
</edits>
```
Rules for editing:
- `old_text` represents lines in the input file that will be replaced with `new_text`. `old_text` MUST exactly match the existing file content, character for character, including indentation.
- Always include enough context around the lines you want to replace in `old_text` such that it's impossible to mistake them for other lines.
- If you want to replace many occurrences of the same text, repeat the same `old_text`/`new_text` pair multiple times and I will apply them sequentially, one occurrence at a time.
- When reporting multiple edits, each edit assumes the previous one has already been applied! Therefore, you must ensure `old_text` doesn't reference text that has already been modified by a previous edit.
- Don't explain the edits, just report them.
- Only edit the file specified in `<file_to_edit>` and NEVER include edits to other files!
- If you open an <old_text> tag, you MUST close it using </old_text>
- If you open an <new_text> tag, you MUST close it using </new_text>
<file_to_edit>
{{path}}
</file_to_edit>
<edit_description>
{{edit_description}}
</edit_description>

View File

@@ -27,7 +27,9 @@ use crate::db::billing_subscription::{
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
};
use crate::llm::db::subscription_usage_meter::CompletionMode;
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
use crate::llm::{
AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT,
};
use crate::rpc::{ResultExt as _, Server};
use crate::{AppState, Cents, Error, Result};
use crate::{db::UserId, llm::db::LlmDatabase};
@@ -54,6 +56,10 @@ pub fn router() -> Router {
"/billing/subscriptions/manage",
post(manage_billing_subscription),
)
.route(
"/billing/subscriptions/migrate",
post(migrate_to_new_billing),
)
.route("/billing/monthly_spend", get(get_monthly_spend))
.route("/billing/usage", get(get_current_usage))
}
@@ -256,6 +262,7 @@ async fn list_billing_subscriptions(
enum ProductCode {
ZedPro,
ZedProTrial,
ZedFree,
}
#[derive(Debug, Deserialize)]
@@ -386,6 +393,11 @@ async fn create_billing_subscription(
)
.await?
}
Some(ProductCode::ZedFree) => {
stripe_billing
.checkout_with_zed_free(customer_id, &user.github_login, &success_url)
.await?
}
None => {
let default_model = llm_db.model(
zed_llm_client::LanguageModelProvider::Anthropic,
@@ -604,6 +616,85 @@ async fn manage_billing_subscription(
}))
}
#[derive(Debug, Deserialize)]
struct MigrateToNewBillingBody {
github_user_id: i32,
}
#[derive(Debug, Serialize)]
struct MigrateToNewBillingResponse {
/// The ID of the subscription that was canceled.
canceled_subscription_id: String,
}
async fn migrate_to_new_billing(
Extension(app): Extension<Arc<AppState>>,
extract::Json(body): extract::Json<MigrateToNewBillingBody>,
) -> Result<Json<MigrateToNewBillingResponse>> {
let Some(stripe_client) = app.stripe_client.clone() else {
log::error!("failed to retrieve Stripe client");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let user = app
.db
.get_user_by_github_user_id(body.github_user_id)
.await?
.ok_or_else(|| anyhow!("user not found"))?;
let old_billing_subscriptions_by_user = app
.db
.get_active_billing_subscriptions(HashSet::from_iter([user.id]))
.await?;
let Some((_billing_customer, billing_subscription)) =
old_billing_subscriptions_by_user.get(&user.id)
else {
return Err(Error::http(
StatusCode::NOT_FOUND,
"No active billing subscriptions to migrate".into(),
));
};
let stripe_subscription_id = billing_subscription
.stripe_subscription_id
.parse::<stripe::SubscriptionId>()
.context("failed to parse Stripe subscription ID from database")?;
Subscription::cancel(
&stripe_client,
&stripe_subscription_id,
stripe::CancelSubscription {
invoice_now: Some(true),
..Default::default()
},
)
.await?;
let feature_flags = app.db.list_feature_flags().await?;
for feature_flag in ["new-billing", "assistant2"] {
let already_in_feature_flag = feature_flags.iter().any(|flag| flag.flag == feature_flag);
if already_in_feature_flag {
continue;
}
let feature_flag = feature_flags
.iter()
.find(|flag| flag.flag == feature_flag)
.context("failed to find feature flag: {feature_flag:?}")?;
app.db.add_user_flag(user.id, feature_flag.id).await?;
}
Ok(Json(MigrateToNewBillingResponse {
canceled_subscription_id: stripe_subscription_id.to_string(),
}))
}
/// The amount of time we wait in between each poll of Stripe events.
///
/// This value should strike a balance between:
@@ -1168,8 +1259,21 @@ async fn get_current_usage(
SubscriptionKind::ZedFree => zed_llm_client::Plan::Free,
};
let feature_flags = app.db.get_user_flags(user.id).await?;
let has_extended_trial = feature_flags
.iter()
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
let model_requests_limit = match plan.model_requests_limit() {
zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
zed_llm_client::UsageLimit::Limited(limit) => {
let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
1_000
} else {
limit
};
Some(limit)
}
zed_llm_client::UsageLimit::Unlimited => None,
};
let edit_prediction_limit = match plan.edit_predictions_limit() {
@@ -1403,13 +1507,13 @@ async fn sync_model_request_usage_with_stripe(
.await?;
let claude_3_5_sonnet = stripe_billing
.find_price_id_by_lookup_key("claude-3-5-sonnet-requests")
.find_price_by_lookup_key("claude-3-5-sonnet-requests")
.await?;
let claude_3_7_sonnet = stripe_billing
.find_price_id_by_lookup_key("claude-3-7-sonnet-requests")
.find_price_by_lookup_key("claude-3-7-sonnet-requests")
.await?;
let claude_3_7_sonnet_max = stripe_billing
.find_price_id_by_lookup_key("claude-3-7-sonnet-requests-max")
.find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
.await?;
for (usage_meter, usage) in usage_meters {
@@ -1434,7 +1538,7 @@ async fn sync_model_request_usage_with_stripe(
let model = llm_db.model_by_id(usage_meter.model_id)?;
let (price_id, meter_event_name) = match model.name.as_str() {
let (price, meter_event_name) = match model.name.as_str() {
"claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
"claude-3-7-sonnet" => match usage_meter.mode {
CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
@@ -1448,7 +1552,7 @@ async fn sync_model_request_usage_with_stripe(
};
stripe_billing
.subscribe_to_price(&stripe_subscription_id, price_id)
.subscribe_to_price(&stripe_subscription_id, price)
.await?;
stripe_billing
.bill_model_request_usage(

View File

@@ -32,6 +32,8 @@ pub struct LlmTokenClaims {
pub has_llm_subscription: bool,
pub max_monthly_spend_in_cents: u32,
pub custom_llm_monthly_allowance_in_cents: Option<u32>,
#[serde(default)]
pub use_new_billing: bool,
pub plan: Plan,
#[serde(default)]
pub has_extended_trial: bool,
@@ -90,6 +92,7 @@ impl LlmTokenClaims {
custom_llm_monthly_allowance_in_cents: user
.custom_llm_monthly_allowance_in_cents
.map(|allowance| allowance as u32),
use_new_billing: feature_flags.iter().any(|flag| flag == "new-billing"),
plan: subscription
.as_ref()
.and_then(|subscription| subscription.kind)

View File

@@ -327,6 +327,10 @@ impl Server {
.add_request_handler(
forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
)
.add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
.add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
.add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
.add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
.add_request_handler(
forward_read_only_project_request::<proto::LanguageServerIdForName>,
)

View File

@@ -99,6 +99,16 @@ impl StripeBilling {
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
}
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
self.state
.read()
.await
.prices_by_lookup_key
.get(lookup_key)
.cloned()
.ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
}
pub async fn register_model_for_token_based_usage(
&self,
model: &llm::db::model::Model,
@@ -238,21 +248,29 @@ impl StripeBilling {
pub async fn subscribe_to_price(
&self,
subscription_id: &stripe::SubscriptionId,
price_id: &stripe::PriceId,
price: &stripe::Price,
) -> Result<()> {
let subscription =
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
if subscription_contains_price(&subscription, price_id) {
if subscription_contains_price(&subscription, &price.id) {
return Ok(());
}
const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
let price_per_unit = price.unit_amount.unwrap_or_default();
let units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
stripe::Subscription::update(
&self.client,
subscription_id,
stripe::UpdateSubscription {
items: Some(vec![stripe::UpdateSubscriptionItems {
price: Some(price_id.to_string()),
price: Some(price.id.to_string()),
billing_thresholds: Some(stripe::SubscriptionItemBillingThresholds {
usage_gte: Some(units_for_billing_threshold),
}),
..Default::default()
}]),
trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
@@ -547,6 +565,29 @@ impl StripeBilling {
let session = stripe::CheckoutSession::create(&self.client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
pub async fn checkout_with_zed_free(
&self,
customer_id: stripe::CustomerId,
github_login: &str,
success_url: &str,
) -> Result<String> {
let zed_free_price_id = self.zed_free_price_id().await?;
let mut params = stripe::CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
price: Some(zed_free_price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
params.success_url = Some(success_url);
let session = stripe::CheckoutSession::create(&self.client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
}
#[derive(Serialize)]

View File

@@ -25,7 +25,7 @@ use language::{
use project::{
ProjectPath, SERVER_PROGRESS_THROTTLE_TIMEOUT,
lsp_store::{
lsp_ext_command::{ExpandedMacro, LspExpandMacro},
lsp_ext_command::{ExpandedMacro, LspExtExpandMacro},
rust_analyzer_ext::RUST_ANALYZER_NAME,
},
project_settings::{InlineBlameSettings, ProjectSettings},
@@ -2704,8 +2704,8 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes
let fake_language_server = fake_language_servers.next().await.unwrap();
// host
let mut expand_request_a =
fake_language_server.set_request_handler::<LspExpandMacro, _, _>(|params, _| async move {
let mut expand_request_a = fake_language_server.set_request_handler::<LspExtExpandMacro, _, _>(
|params, _| async move {
assert_eq!(
params.text_document.uri,
lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(),
@@ -2715,7 +2715,8 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes
name: "test_macro_name".to_string(),
expansion: "test_macro_expansion on the host".to_string(),
}))
});
},
);
editor_a.update_in(cx_a, |editor, window, cx| {
expand_macro_recursively(editor, &ExpandMacroRecursively, window, cx)
@@ -2738,8 +2739,8 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes
});
// client
let mut expand_request_b =
fake_language_server.set_request_handler::<LspExpandMacro, _, _>(|params, _| async move {
let mut expand_request_b = fake_language_server.set_request_handler::<LspExtExpandMacro, _, _>(
|params, _| async move {
assert_eq!(
params.text_document.uri,
lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(),
@@ -2749,7 +2750,8 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes
name: "test_macro_name".to_string(),
expansion: "test_macro_expansion on the client".to_string(),
}))
});
},
);
editor_b.update_in(cx_b, |editor, window, cx| {
expand_macro_recursively(editor, &ExpandMacroRecursively, window, cx)

View File

@@ -2902,7 +2902,7 @@ async fn test_git_branch_name(
.read(cx)
.branch
.as_ref()
.map(|branch| branch.name.to_string()),
.map(|branch| branch.name().to_owned()),
branch_name
)
}
@@ -6864,7 +6864,7 @@ async fn test_remote_git_branches(
let branches_b = branches_b
.into_iter()
.map(|branch| branch.name.to_string())
.map(|branch| branch.name().to_string())
.collect::<HashSet<_>>();
assert_eq!(branches_b, branches_set);
@@ -6895,7 +6895,7 @@ async fn test_remote_git_branches(
})
});
assert_eq!(host_branch.name, branches[2]);
assert_eq!(host_branch.name(), branches[2]);
// Also try creating a new branch
cx_b.update(|cx| {
@@ -6933,5 +6933,5 @@ async fn test_remote_git_branches(
})
});
assert_eq!(host_branch.name, "totally-new-branch");
assert_eq!(host_branch.name(), "totally-new-branch");
}

View File

@@ -293,7 +293,7 @@ async fn test_ssh_collaboration_git_branches(
let branches_b = branches_b
.into_iter()
.map(|branch| branch.name.to_string())
.map(|branch| branch.name().to_string())
.collect::<HashSet<_>>();
assert_eq!(&branches_b, &branches_set);
@@ -326,7 +326,7 @@ async fn test_ssh_collaboration_git_branches(
})
});
assert_eq!(server_branch.name, branches[2]);
assert_eq!(server_branch.name(), branches[2]);
// Also try creating a new branch
cx_b.update(|cx| {
@@ -366,7 +366,7 @@ async fn test_ssh_collaboration_git_branches(
})
});
assert_eq!(server_branch.name, "totally-new-branch");
assert_eq!(server_branch.name(), "totally-new-branch");
// Remove the git repository and check that all participants get the update.
remote_fs

View File

@@ -15,18 +15,21 @@ path = "src/component_preview.rs"
default = []
[dependencies]
agent.workspace = true
anyhow.workspace = true
client.workspace = true
collections.workspace = true
component.workspace = true
db.workspace = true
gpui.workspace = true
languages.workspace = true
notifications.workspace = true
log.workspace = true
notifications.workspace = true
project.workspace = true
prompt_store.workspace = true
serde.workspace = true
ui.workspace = true
ui_input.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
db.workspace = true
anyhow.workspace = true
serde.workspace = true
assistant_tool.workspace = true

View File

@@ -3,10 +3,12 @@
//! A view for exploring Zed components.
mod persistence;
mod preview_support;
use std::iter::Iterator;
use std::sync::Arc;
use agent::{ActiveThread, ThreadStore};
use client::UserStore;
use component::{ComponentId, ComponentMetadata, components};
use gpui::{
@@ -19,6 +21,7 @@ use gpui::{ListState, ScrollHandle, ScrollStrategy, UniformListScrollHandle};
use languages::LanguageRegistry;
use notifications::status_toast::{StatusToast, ToastIcon};
use persistence::COMPONENT_PREVIEW_DB;
use preview_support::active_thread::{load_preview_thread_store, static_active_thread};
use project::Project;
use ui::{Divider, HighlightedLabel, ListItem, ListSubHeader, prelude::*};
@@ -33,6 +36,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
cx.observe_new(move |workspace: &mut Workspace, _window, cx| {
let app_state = app_state.clone();
let project = workspace.project().clone();
let weak_workspace = cx.entity().downgrade();
workspace.register_action(
@@ -45,6 +49,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
let component_preview = cx.new(|cx| {
ComponentPreview::new(
weak_workspace.clone(),
project.clone(),
language_registry,
user_store,
None,
@@ -52,6 +57,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
window,
cx,
)
.expect("Failed to create component preview")
});
workspace.add_item_to_active_pane(
@@ -69,6 +75,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
enum PreviewEntry {
AllComponents,
ActiveThread,
Separator,
Component(ComponentMetadata, Option<Vec<usize>>),
SectionHeader(SharedString),
@@ -91,6 +98,7 @@ enum PreviewPage {
#[default]
AllComponents,
Component(ComponentId),
ActiveThread,
}
struct ComponentPreview {
@@ -102,24 +110,63 @@ struct ComponentPreview {
active_page: PreviewPage,
components: Vec<ComponentMetadata>,
component_list: ListState,
agent_previews: Vec<
Box<
dyn Fn(
&Self,
WeakEntity<Workspace>,
Entity<ActiveThread>,
WeakEntity<ThreadStore>,
&mut Window,
&mut App,
) -> Option<AnyElement>,
>,
>,
cursor_index: usize,
language_registry: Arc<LanguageRegistry>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
user_store: Entity<UserStore>,
filter_editor: Entity<SingleLineInput>,
filter_text: String,
// preview support
thread_store: Option<Entity<ThreadStore>>,
active_thread: Option<Entity<ActiveThread>>,
}
impl ComponentPreview {
pub fn new(
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
language_registry: Arc<LanguageRegistry>,
user_store: Entity<UserStore>,
selected_index: impl Into<Option<usize>>,
active_page: Option<PreviewPage>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
) -> anyhow::Result<Self> {
let workspace_clone = workspace.clone();
let project_clone = project.clone();
let entity = cx.weak_entity();
window
.spawn(cx, async move |cx| {
let thread_store_task =
load_preview_thread_store(workspace_clone.clone(), project_clone.clone(), cx)
.await;
if let Ok(thread_store) = thread_store_task.await {
entity
.update_in(cx, |this, window, cx| {
this.thread_store = Some(thread_store.clone());
this.create_active_thread(window, cx);
})
.ok();
}
})
.detach();
let sorted_components = components().all_sorted();
let selected_index = selected_index.into().unwrap_or(0);
let active_page = active_page.unwrap_or(PreviewPage::AllComponents);
@@ -143,6 +190,40 @@ impl ComponentPreview {
},
);
// Initialize agent previews
let agent_previews = agent::all_agent_previews()
.into_iter()
.map(|id| {
Box::new(
move |_self: &ComponentPreview,
workspace: WeakEntity<Workspace>,
active_thread: Entity<ActiveThread>,
thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut App| {
agent::get_agent_preview(
&id,
workspace,
active_thread,
thread_store,
window,
cx,
)
},
)
as Box<
dyn Fn(
&ComponentPreview,
WeakEntity<Workspace>,
Entity<ActiveThread>,
WeakEntity<ThreadStore>,
&mut Window,
&mut App,
) -> Option<AnyElement>,
>
})
.collect::<Vec<_>>();
let mut component_preview = Self {
workspace_id: None,
focus_handle: cx.focus_handle(),
@@ -151,13 +232,17 @@ impl ComponentPreview {
language_registry,
user_store,
workspace,
project,
active_page,
component_map: components().0,
components: sorted_components,
component_list,
agent_previews,
cursor_index: selected_index,
filter_editor,
filter_text: String::new(),
thread_store: None,
active_thread: None,
};
if component_preview.cursor_index > 0 {
@@ -169,13 +254,41 @@ impl ComponentPreview {
let focus_handle = component_preview.filter_editor.read(cx).focus_handle(cx);
window.focus(&focus_handle);
component_preview
Ok(component_preview)
}
pub fn create_active_thread(
&mut self,
window: &mut Window,
cx: &mut Context<Self>,
) -> &mut Self {
let workspace = self.workspace.clone();
let language_registry = self.language_registry.clone();
let weak_handle = self.workspace.clone();
if let Some(workspace) = workspace.upgrade() {
let project = workspace.read(cx).project().clone();
if let Some(thread_store) = self.thread_store.clone() {
let active_thread = static_active_thread(
weak_handle,
project,
language_registry,
thread_store,
window,
cx,
);
self.active_thread = Some(active_thread);
cx.notify();
}
}
self
}
pub fn active_page_id(&self, _cx: &App) -> ActivePageId {
match &self.active_page {
PreviewPage::AllComponents => ActivePageId::default(),
PreviewPage::Component(component_id) => ActivePageId(component_id.0.to_string()),
PreviewPage::ActiveThread => ActivePageId("active_thread".to_string()),
}
}
@@ -289,6 +402,7 @@ impl ComponentPreview {
// Always show all components first
entries.push(PreviewEntry::AllComponents);
entries.push(PreviewEntry::ActiveThread);
entries.push(PreviewEntry::Separator);
let mut scopes: Vec<_> = scope_groups
@@ -389,6 +503,19 @@ impl ComponentPreview {
}))
.into_any_element()
}
PreviewEntry::ActiveThread => {
let selected = self.active_page == PreviewPage::ActiveThread;
ListItem::new(ix)
.child(Label::new("Active Thread").color(Color::Default))
.selectable(true)
.toggle_state(selected)
.inset(true)
.on_click(cx.listener(move |this, _, _, cx| {
this.set_active_page(PreviewPage::ActiveThread, cx);
}))
.into_any_element()
}
PreviewEntry::Separator => ListItem::new(ix)
.child(
h_flex()
@@ -471,6 +598,7 @@ impl ComponentPreview {
.render_scope_header(ix, shared_string.clone(), window, cx)
.into_any_element(),
PreviewEntry::AllComponents => div().w_full().h_0().into_any_element(),
PreviewEntry::ActiveThread => div().w_full().h_0().into_any_element(),
PreviewEntry::Separator => div().w_full().h_0().into_any_element(),
})
.unwrap()
@@ -595,6 +723,41 @@ impl ComponentPreview {
}
}
fn render_active_thread(
&self,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
v_flex()
.id("render-active-thread")
.size_full()
.child(
v_flex().children(self.agent_previews.iter().filter_map(|preview_fn| {
if let (Some(thread_store), Some(active_thread)) = (
self.thread_store.as_ref().map(|ts| ts.downgrade()),
self.active_thread.clone(),
) {
preview_fn(
self,
self.workspace.clone(),
active_thread,
thread_store,
window,
cx,
)
.map(|element| div().child(element))
} else {
None
}
})),
)
.children(self.active_thread.clone().map(|thread| thread.clone()))
.when_none(&self.active_thread.clone(), |this| {
this.child("No active thread")
})
.into_any_element()
}
fn test_status_toast(&self, cx: &mut Context<Self>) {
if let Some(workspace) = self.workspace.upgrade() {
workspace.update(cx, |workspace, cx| {
@@ -704,6 +867,9 @@ impl Render for ComponentPreview {
PreviewPage::Component(id) => self
.render_component_page(&id, window, cx)
.into_any_element(),
PreviewPage::ActiveThread => {
self.render_active_thread(window, cx).into_any_element()
}
}),
)
}
@@ -759,20 +925,28 @@ impl Item for ComponentPreview {
let language_registry = self.language_registry.clone();
let user_store = self.user_store.clone();
let weak_workspace = self.workspace.clone();
let project = self.project.clone();
let selected_index = self.cursor_index;
let active_page = self.active_page.clone();
Some(cx.new(|cx| {
Self::new(
weak_workspace,
language_registry,
user_store,
selected_index,
Some(active_page),
window,
cx,
)
}))
let self_result = Self::new(
weak_workspace,
project,
language_registry,
user_store,
selected_index,
Some(active_page),
window,
cx,
);
match self_result {
Ok(preview) => Some(cx.new(|_cx| preview)),
Err(e) => {
log::error!("Failed to clone component preview: {}", e);
None
}
}
}
fn to_item_events(event: &Self::Event, mut f: impl FnMut(workspace::item::ItemEvent)) {
@@ -838,10 +1012,12 @@ impl SerializableItem for ComponentPreview {
let user_store = user_store.clone();
let language_registry = language_registry.clone();
let weak_workspace = workspace.clone();
let project = project.clone();
cx.update(move |window, cx| {
Ok(cx.new(|cx| {
ComponentPreview::new(
weak_workspace,
project,
language_registry,
user_store,
None,
@@ -849,6 +1025,7 @@ impl SerializableItem for ComponentPreview {
window,
cx,
)
.expect("Failed to create component preview")
}))
})?
})

View File

@@ -0,0 +1 @@
pub mod active_thread;

View File

@@ -0,0 +1,69 @@
use languages::LanguageRegistry;
use project::Project;
use std::sync::Arc;
use agent::{ActiveThread, ContextStore, MessageSegment, ThreadStore};
use assistant_tool::ToolWorkingSet;
use gpui::{AppContext, AsyncApp, Entity, Task, WeakEntity};
use prompt_store::PromptBuilder;
use ui::{App, Window};
use workspace::Workspace;
pub async fn load_preview_thread_store(
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
cx: &mut AsyncApp,
) -> Task<anyhow::Result<Entity<ThreadStore>>> {
cx.spawn(async move |cx| {
workspace
.update(cx, |_, cx| {
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
None,
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
})?
.await
})
}
pub fn static_active_thread(
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>,
window: &mut Window,
cx: &mut App,
) -> Entity<ActiveThread> {
let context_store =
cx.new(|_| ContextStore::new(project.downgrade(), Some(thread_store.downgrade())));
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
thread.update(cx, |thread, cx| {
thread.insert_assistant_message(vec![
MessageSegment::Text("I'll help you fix the lifetime error in your `cx.spawn` call. When working with async operations in GPUI, there are specific patterns to follow for proper lifetime management.".to_string()),
MessageSegment::Text("\n\nLet's look at what's happening in your code:".to_string()),
MessageSegment::Text("\n\n---\n\nLet's check the current state of the active_thread.rs file to understand what might have changed:".to_string()),
MessageSegment::Text("\n\n---\n\nLooking at the implementation of `load_preview_thread_store` and understanding GPUI's async patterns, here's the issue:".to_string()),
MessageSegment::Text("\n\n1. `load_preview_thread_store` returns a `Task<anyhow::Result<Entity<ThreadStore>>>`, which means it's already a task".to_string()),
MessageSegment::Text("\n2. When you call this function inside another `spawn` call, you're nesting tasks incorrectly".to_string()),
MessageSegment::Text("\n3. The `this` parameter you're trying to use in your closure has the wrong context".to_string()),
MessageSegment::Text("\n\nHere's the correct way to implement this:".to_string()),
MessageSegment::Text("\n\n---\n\nThe problem is in how you're setting up the async closure and trying to reference variables like `window` and `language_registry` that aren't accessible in that scope.".to_string()),
MessageSegment::Text("\n\nHere's how to fix it:".to_string()),
], cx);
});
cx.new(|cx| {
ActiveThread::new(
thread,
thread_store,
context_store,
language_registry,
workspace.clone(),
window,
cx,
)
})
}

View File

@@ -34,3 +34,7 @@ smol.workspace = true
url = { workspace = true, features = ["serde"] }
util.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }

View File

@@ -140,7 +140,7 @@ impl Client {
/// This function initializes a new Client by spawning a child process for the context server,
/// setting up communication channels, and initializing handlers for input/output operations.
/// It takes a server ID, binary information, and an async app context as input.
pub fn new(
pub fn stdio(
server_id: ContextServerId,
binary: ModelContextServerBinary,
cx: AsyncApp,
@@ -158,7 +158,16 @@ impl Client {
.unwrap_or_else(String::new);
let transport = Arc::new(StdioTransport::new(binary, &cx)?);
Self::new(server_id, server_name.into(), transport, cx)
}
/// Creates a new Client instance for a context server.
pub fn new(
server_id: ContextServerId,
server_name: Arc<str>,
transport: Arc<dyn Transport>,
cx: AsyncApp,
) -> Result<Self> {
let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
let (output_done_tx, output_done_rx) = barrier::channel();
@@ -167,7 +176,7 @@ impl Client {
let response_handlers =
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
let stdout_input_task = cx.spawn({
let receive_input_task = cx.spawn({
let notification_handlers = notification_handlers.clone();
let response_handlers = response_handlers.clone();
let transport = transport.clone();
@@ -177,13 +186,13 @@ impl Client {
.await
}
});
let stderr_input_task = cx.spawn({
let receive_err_task = cx.spawn({
let transport = transport.clone();
async move |_| Self::handle_stderr(transport).log_err().await
async move |_| Self::handle_err(transport).log_err().await
});
let input_task = cx.spawn(async move |_| {
let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
stdout.or(stderr)
let (input, err) = futures::join!(receive_input_task, receive_err_task);
input.or(err)
});
let output_task = cx.background_spawn({
@@ -201,7 +210,7 @@ impl Client {
server_id,
notification_handlers,
response_handlers,
name: server_name.into(),
name: server_name,
next_id: Default::default(),
outbound_tx,
executor: cx.background_executor().clone(),
@@ -247,7 +256,7 @@ impl Client {
/// Handles the stderr output from the context server.
/// Continuously reads and logs any error messages from the server.
async fn handle_stderr(transport: Arc<dyn Transport>) -> anyhow::Result<()> {
async fn handle_err(transport: Arc<dyn Transport>) -> anyhow::Result<()> {
while let Some(err) = transport.receive_err().next().await {
log::warn!("context server stderr: {}", err.trim());
}

View File

@@ -12,7 +12,7 @@ pub use context_server_settings::{ContextServerSettings, ServerCommand, ServerCo
use gpui::{App, actions};
pub use crate::context_server_tool::ContextServerTool;
pub use crate::registry::ContextServerFactoryRegistry;
pub use crate::registry::ContextServerDescriptorRegistry;
actions!(context_servers, [Restart]);
@@ -21,7 +21,7 @@ pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers";
pub fn init(cx: &mut App) {
context_server_settings::init(cx);
ContextServerFactoryRegistry::default_global(cx);
ContextServerDescriptorRegistry::default_global(cx);
extension_context_server::init(cx);
CommandPaletteFilter::update_global(cx, |filter, _cx| {

View File

@@ -1,9 +1,21 @@
use std::sync::Arc;
use extension::{Extension, ExtensionContextServerProxy, ExtensionHostProxy, ProjectDelegate};
use gpui::{App, Entity};
use anyhow::Result;
use extension::{
ContextServerConfiguration, Extension, ExtensionContextServerProxy, ExtensionHostProxy,
ProjectDelegate,
};
use gpui::{App, AsyncApp, Entity, Task};
use project::Project;
use crate::{ContextServerFactoryRegistry, ServerCommand};
use crate::{ContextServerDescriptorRegistry, ServerCommand, registry};
pub fn init(cx: &mut App) {
let proxy = ExtensionHostProxy::default_global(cx);
proxy.register_context_server_proxy(ContextServerDescriptorRegistryProxy {
context_server_factory_registry: ContextServerDescriptorRegistry::global(cx),
});
}
struct ExtensionProject {
worktree_ids: Vec<u64>,
@@ -15,60 +27,78 @@ impl ProjectDelegate for ExtensionProject {
}
}
pub fn init(cx: &mut App) {
let proxy = ExtensionHostProxy::default_global(cx);
proxy.register_context_server_proxy(ContextServerFactoryRegistryProxy {
context_server_factory_registry: ContextServerFactoryRegistry::global(cx),
});
struct ContextServerDescriptor {
id: Arc<str>,
extension: Arc<dyn Extension>,
}
struct ContextServerFactoryRegistryProxy {
context_server_factory_registry: Entity<ContextServerFactoryRegistry>,
fn extension_project(project: Entity<Project>, cx: &mut AsyncApp) -> Result<Arc<ExtensionProject>> {
project.update(cx, |project, cx| {
Arc::new(ExtensionProject {
worktree_ids: project
.visible_worktrees(cx)
.map(|worktree| worktree.read(cx).id().to_proto())
.collect(),
})
})
}
impl ExtensionContextServerProxy for ContextServerFactoryRegistryProxy {
impl registry::ContextServerDescriptor for ContextServerDescriptor {
fn command(&self, project: Entity<Project>, cx: &AsyncApp) -> Task<Result<ServerCommand>> {
let id = self.id.clone();
let extension = self.extension.clone();
cx.spawn(async move |cx| {
let extension_project = extension_project(project, cx)?;
let mut command = extension
.context_server_command(id.clone(), extension_project.clone())
.await?;
command.command = extension
.path_from_extension(command.command.as_ref())
.to_string_lossy()
.to_string();
log::info!("loaded command for context server {id}: {command:?}");
Ok(ServerCommand {
path: command.command,
args: command.args,
env: Some(command.env.into_iter().collect()),
})
})
}
fn configuration(
&self,
project: Entity<Project>,
cx: &AsyncApp,
) -> Task<Result<Option<ContextServerConfiguration>>> {
let id = self.id.clone();
let extension = self.extension.clone();
cx.spawn(async move |cx| {
let extension_project = extension_project(project, cx)?;
let configuration = extension
.context_server_configuration(id.clone(), extension_project)
.await?;
log::debug!("loaded configuration for context server {id}: {configuration:?}");
Ok(configuration)
})
}
}
struct ContextServerDescriptorRegistryProxy {
context_server_factory_registry: Entity<ContextServerDescriptorRegistry>,
}
impl ExtensionContextServerProxy for ContextServerDescriptorRegistryProxy {
fn register_context_server(&self, extension: Arc<dyn Extension>, id: Arc<str>, cx: &mut App) {
self.context_server_factory_registry
.update(cx, |registry, _| {
registry.register_server_factory(
registry.register_context_server_descriptor(
id.clone(),
Arc::new({
move |project, cx| {
log::info!(
"loading command for context server {id} from extension {}",
extension.manifest().id
);
let id = id.clone();
let extension = extension.clone();
cx.spawn(async move |cx| {
let extension_project = project.update(cx, |project, cx| {
Arc::new(ExtensionProject {
worktree_ids: project
.visible_worktrees(cx)
.map(|worktree| worktree.read(cx).id().to_proto())
.collect(),
})
})?;
let mut command = extension
.context_server_command(id.clone(), extension_project)
.await?;
command.command = extension
.path_from_extension(command.command.as_ref())
.to_string_lossy()
.to_string();
log::info!("loaded command for context server {id}: {command:?}");
Ok(ServerCommand {
path: command.command,
args: command.args,
env: Some(command.env.into_iter().collect()),
})
})
}
}),
Arc::new(ContextServerDescriptor { id, extension })
as Arc<dyn registry::ContextServerDescriptor>,
)
});
}

View File

@@ -27,18 +27,27 @@ use project::Project;
use settings::{Settings, SettingsStore};
use util::ResultExt as _;
use crate::transport::Transport;
use crate::{ContextServerSettings, ServerConfig};
use crate::{
CONTEXT_SERVERS_NAMESPACE, ContextServerFactoryRegistry,
CONTEXT_SERVERS_NAMESPACE, ContextServerDescriptorRegistry,
client::{self, Client},
types,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ContextServerStatus {
Starting,
Running,
Error(Arc<str>),
}
pub struct ContextServer {
pub id: Arc<str>,
pub config: Arc<ServerConfig>,
pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
transport: Option<Arc<dyn Transport>>,
}
impl ContextServer {
@@ -47,9 +56,20 @@ impl ContextServer {
id,
config,
client: RwLock::new(None),
transport: None,
}
}
#[cfg(any(test, feature = "test-support"))]
pub fn test(id: Arc<str>, transport: Arc<dyn crate::transport::Transport>) -> Arc<Self> {
Arc::new(Self {
id,
client: RwLock::new(None),
config: Arc::new(ServerConfig::default()),
transport: Some(transport),
})
}
pub fn id(&self) -> Arc<str> {
self.id.clone()
}
@@ -63,20 +83,32 @@ impl ContextServer {
}
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
log::info!("starting context server {}", self.id);
let Some(command) = &self.config.command else {
bail!("no command specified for server {}", self.id);
let client = if let Some(transport) = self.transport.clone() {
Client::new(
client::ContextServerId(self.id.clone()),
self.id(),
transport,
cx.clone(),
)?
} else {
let Some(command) = &self.config.command else {
bail!("no command specified for server {}", self.id);
};
Client::stdio(
client::ContextServerId(self.id.clone()),
client::ModelContextServerBinary {
executable: Path::new(&command.path).to_path_buf(),
args: command.args.clone(),
env: command.env.clone(),
},
cx.clone(),
)?
};
let client = Client::new(
client::ContextServerId(self.id.clone()),
client::ModelContextServerBinary {
executable: Path::new(&command.path).to_path_buf(),
args: command.args.clone(),
env: command.env.clone(),
},
cx.clone(),
)?;
self.initialize(client).await
}
async fn initialize(&self, client: Client) -> Result<()> {
log::info!("starting context server {}", self.id);
let protocol = crate::protocol::ModelContextProtocol::new(client);
let client_info = types::Implementation {
name: "Zed".to_string(),
@@ -105,23 +137,26 @@ impl ContextServer {
pub struct ContextServerManager {
servers: HashMap<Arc<str>, Arc<ContextServer>>,
server_status: HashMap<Arc<str>, ContextServerStatus>,
project: Entity<Project>,
registry: Entity<ContextServerFactoryRegistry>,
registry: Entity<ContextServerDescriptorRegistry>,
update_servers_task: Option<Task<Result<()>>>,
needs_server_update: bool,
_subscriptions: Vec<Subscription>,
}
pub enum Event {
ServerStarted { server_id: Arc<str> },
ServerStopped { server_id: Arc<str> },
ServerStatusChanged {
server_id: Arc<str>,
status: Option<ContextServerStatus>,
},
}
impl EventEmitter<Event> for ContextServerManager {}
impl ContextServerManager {
pub fn new(
registry: Entity<ContextServerFactoryRegistry>,
registry: Entity<ContextServerDescriptorRegistry>,
project: Entity<Project>,
cx: &mut Context<Self>,
) -> Self {
@@ -138,6 +173,7 @@ impl ContextServerManager {
registry,
needs_server_update: false,
servers: HashMap::default(),
server_status: HashMap::default(),
update_servers_task: None,
};
this.available_context_servers_changed(cx);
@@ -153,7 +189,9 @@ impl ContextServerManager {
this.needs_server_update = false;
})?;
Self::maintain_servers(this.clone(), cx).await?;
if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
log::error!("Error maintaining context servers: {}", err);
}
this.update(cx, |this, cx| {
let has_any_context_servers = !this.running_servers().is_empty();
@@ -181,52 +219,37 @@ impl ContextServerManager {
.cloned()
}
pub fn status_for_server(&self, id: &str) -> Option<ContextServerStatus> {
self.server_status.get(id).cloned()
}
pub fn start_server(
&self,
server: Arc<ContextServer>,
cx: &mut Context<Self>,
) -> Task<anyhow::Result<()>> {
cx.spawn(async move |this, cx| {
let id = server.id.clone();
server.start(&cx).await?;
this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?;
Ok(())
})
) -> Task<Result<()>> {
cx.spawn(async move |this, cx| Self::run_server(this, server, cx).await)
}
pub fn stop_server(
&self,
&mut self,
server: Arc<ContextServer>,
cx: &mut Context<Self>,
) -> anyhow::Result<()> {
server.stop()?;
cx.emit(Event::ServerStopped {
server_id: server.id(),
});
) -> Result<()> {
server.stop().log_err();
self.update_server_status(server.id().clone(), None, cx);
Ok(())
}
pub fn restart_server(
&mut self,
id: &Arc<str>,
cx: &mut Context<Self>,
) -> Task<anyhow::Result<()>> {
pub fn restart_server(&mut self, id: &Arc<str>, cx: &mut Context<Self>) -> Task<Result<()>> {
let id = id.clone();
cx.spawn(async move |this, cx| {
if let Some(server) = this.update(cx, |this, _cx| this.servers.remove(&id))? {
server.stop()?;
let config = server.config();
this.update(cx, |this, cx| this.stop_server(server, cx))??;
let new_server = Arc::new(ContextServer::new(id.clone(), config));
new_server.clone().start(&cx).await?;
this.update(cx, |this, cx| {
this.servers.insert(id.clone(), new_server);
cx.emit(Event::ServerStopped {
server_id: id.clone(),
});
cx.emit(Event::ServerStarted {
server_id: id.clone(),
});
})?;
Self::run_server(this, new_server, cx).await?;
}
Ok(())
})
@@ -263,12 +286,14 @@ impl ContextServerManager {
(this.registry.clone(), this.project.clone())
})?;
for (id, factory) in
registry.read_with(cx, |registry, _| registry.context_server_factories())?
for (id, descriptor) in
registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
{
let config = desired_servers.entry(id).or_default();
if config.command.is_none() {
if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
if let Some(extension_command) =
descriptor.command(project.clone(), &cx).await.log_err()
{
config.command = Some(extension_command);
}
}
@@ -290,28 +315,270 @@ impl ContextServerManager {
for (id, config) in desired_servers {
let existing_config = this.servers.get(&id).map(|server| server.config());
if existing_config.as_deref() != Some(&config) {
let config = Arc::new(config);
let server = Arc::new(ContextServer::new(id.clone(), config));
let server = Arc::new(ContextServer::new(id.clone(), Arc::new(config)));
servers_to_start.insert(id.clone(), server.clone());
let old_server = this.servers.insert(id.clone(), server);
if let Some(old_server) = old_server {
if let Some(old_server) = this.servers.remove(&id) {
servers_to_stop.insert(id, old_server);
}
}
}
})?;
for (id, server) in servers_to_stop {
server.stop().log_err();
this.update(cx, |_, cx| cx.emit(Event::ServerStopped { server_id: id }))?;
for (_, server) in servers_to_stop {
this.update(cx, |this, cx| this.stop_server(server, cx).ok())?;
}
for (id, server) in servers_to_start {
if server.start(&cx).await.log_err().is_some() {
this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?;
}
for (_, server) in servers_to_start {
Self::run_server(this.clone(), server, cx).await.ok();
}
Ok(())
}
async fn run_server(
this: WeakEntity<Self>,
server: Arc<ContextServer>,
cx: &mut AsyncApp,
) -> Result<()> {
let id = server.id();
this.update(cx, |this, cx| {
this.update_server_status(id.clone(), Some(ContextServerStatus::Starting), cx);
this.servers.insert(id.clone(), server.clone());
})?;
match server.start(&cx).await {
Ok(_) => {
log::debug!("`{}` context server started", id);
this.update(cx, |this, cx| {
this.update_server_status(id.clone(), Some(ContextServerStatus::Running), cx)
})?;
Ok(())
}
Err(err) => {
log::error!("`{}` context server failed to start\n{}", id, err);
this.update(cx, |this, cx| {
this.update_server_status(
id.clone(),
Some(ContextServerStatus::Error(err.to_string().into())),
cx,
)
})?;
Err(err)
}
}
}
fn update_server_status(
&mut self,
id: Arc<str>,
status: Option<ContextServerStatus>,
cx: &mut Context<Self>,
) {
if let Some(status) = status.clone() {
self.server_status.insert(id.clone(), status);
} else {
self.server_status.remove(&id);
}
cx.emit(Event::ServerStatusChanged {
server_id: id,
status,
});
}
}
#[cfg(test)]
mod tests {
use std::pin::Pin;
use crate::types::{
Implementation, InitializeResponse, ProtocolVersion, RequestType, ServerCapabilities,
};
use super::*;
use futures::{Stream, StreamExt as _, lock::Mutex};
use gpui::{AppContext as _, TestAppContext};
use project::FakeFs;
use serde_json::json;
use util::path;
#[gpui::test]
async fn test_context_server_status(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({"code.rs": ""})).await;
let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
let manager = cx.new(|cx| ContextServerManager::new(registry.clone(), project, cx));
let server_1_id: Arc<str> = "mcp-1".into();
let server_2_id: Arc<str> = "mcp-2".into();
let transport_1 = Arc::new(FakeTransport::new(
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response("mcp-1".to_string()))
}
_ => None,
},
));
let transport_2 = Arc::new(FakeTransport::new(
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response("mcp-2".to_string()))
}
_ => None,
},
));
let server_1 = ContextServer::test(server_1_id.clone(), transport_1.clone());
let server_2 = ContextServer::test(server_2_id.clone(), transport_2.clone());
manager
.update(cx, |manager, cx| manager.start_server(server_1, cx))
.await
.unwrap();
cx.update(|cx| {
assert_eq!(
manager.read(cx).status_for_server(&server_1_id),
Some(ContextServerStatus::Running)
);
assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
});
manager
.update(cx, |manager, cx| manager.start_server(server_2.clone(), cx))
.await
.unwrap();
cx.update(|cx| {
assert_eq!(
manager.read(cx).status_for_server(&server_1_id),
Some(ContextServerStatus::Running)
);
assert_eq!(
manager.read(cx).status_for_server(&server_2_id),
Some(ContextServerStatus::Running)
);
});
manager
.update(cx, |manager, cx| manager.stop_server(server_2, cx))
.unwrap();
cx.update(|cx| {
assert_eq!(
manager.read(cx).status_for_server(&server_1_id),
Some(ContextServerStatus::Running)
);
assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
});
}
async fn create_test_project(
cx: &mut TestAppContext,
files: serde_json::Value,
) -> Entity<Project> {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/test"), files).await;
Project::test(fs, [path!("/test").as_ref()], cx).await
}
fn init_test_settings(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
ContextServerSettings::register(cx);
});
}
fn create_initialize_response(server_name: String) -> serde_json::Value {
serde_json::to_value(&InitializeResponse {
protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
server_info: Implementation {
name: server_name,
version: "1.0.0".to_string(),
},
capabilities: ServerCapabilities::default(),
meta: None,
})
.unwrap()
}
struct FakeTransport {
on_request: Arc<
dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
+ Send
+ Sync,
>,
tx: futures::channel::mpsc::UnboundedSender<String>,
rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
}
impl FakeTransport {
fn new(
on_request: impl Fn(
u64,
Option<RequestType>,
serde_json::Value,
) -> Option<serde_json::Value>
+ 'static
+ Send
+ Sync,
) -> Self {
let (tx, rx) = futures::channel::mpsc::unbounded();
Self {
on_request: Arc::new(on_request),
tx,
rx: Arc::new(Mutex::new(rx)),
}
}
}
#[async_trait::async_trait]
impl Transport for FakeTransport {
async fn send(&self, message: String) -> Result<()> {
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
if let Some(method) = msg.get("method") {
let request_type = method
.as_str()
.and_then(|method| types::RequestType::try_from(method).ok());
if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
let response = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": payload
});
self.tx
.unbounded_send(response.to_string())
.map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?;
}
}
}
Ok(())
}
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
let rx = self.rx.clone();
Box::pin(futures::stream::unfold(rx, |rx| async move {
let mut rx_guard = rx.lock().await;
if let Some(message) = rx_guard.next().await {
drop(rx_guard);
Some((message, rx))
} else {
None
}
}))
}
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
Box::pin(futures::stream::empty())
}
}
}

View File

@@ -2,38 +2,47 @@ use std::sync::Arc;
use anyhow::Result;
use collections::HashMap;
use extension::ContextServerConfiguration;
use gpui::{App, AppContext as _, AsyncApp, Entity, Global, ReadGlobal, Task};
use project::Project;
use crate::ServerCommand;
pub type ContextServerFactory =
Arc<dyn Fn(Entity<Project>, &AsyncApp) -> Task<Result<ServerCommand>> + Send + Sync + 'static>;
struct GlobalContextServerFactoryRegistry(Entity<ContextServerFactoryRegistry>);
impl Global for GlobalContextServerFactoryRegistry {}
#[derive(Default)]
pub struct ContextServerFactoryRegistry {
context_servers: HashMap<Arc<str>, ContextServerFactory>,
pub trait ContextServerDescriptor {
fn command(&self, project: Entity<Project>, cx: &AsyncApp) -> Task<Result<ServerCommand>>;
fn configuration(
&self,
project: Entity<Project>,
cx: &AsyncApp,
) -> Task<Result<Option<ContextServerConfiguration>>>;
}
impl ContextServerFactoryRegistry {
/// Returns the global [`ContextServerFactoryRegistry`].
struct GlobalContextServerDescriptorRegistry(Entity<ContextServerDescriptorRegistry>);
impl Global for GlobalContextServerDescriptorRegistry {}
#[derive(Default)]
pub struct ContextServerDescriptorRegistry {
context_servers: HashMap<Arc<str>, Arc<dyn ContextServerDescriptor>>,
}
impl ContextServerDescriptorRegistry {
/// Returns the global [`ContextServerDescriptorRegistry`].
pub fn global(cx: &App) -> Entity<Self> {
GlobalContextServerFactoryRegistry::global(cx).0.clone()
GlobalContextServerDescriptorRegistry::global(cx).0.clone()
}
/// Returns the global [`ContextServerFactoryRegistry`].
/// Returns the global [`ContextServerDescriptorRegistry`].
///
/// Inserts a default [`ContextServerFactoryRegistry`] if one does not yet exist.
/// Inserts a default [`ContextServerDescriptorRegistry`] if one does not yet exist.
pub fn default_global(cx: &mut App) -> Entity<Self> {
if !cx.has_global::<GlobalContextServerFactoryRegistry>() {
if !cx.has_global::<GlobalContextServerDescriptorRegistry>() {
let registry = cx.new(|_| Self::new());
cx.set_global(GlobalContextServerFactoryRegistry(registry));
cx.set_global(GlobalContextServerDescriptorRegistry(registry));
}
cx.global::<GlobalContextServerFactoryRegistry>().0.clone()
cx.global::<GlobalContextServerDescriptorRegistry>()
.0
.clone()
}
pub fn new() -> Self {
@@ -42,20 +51,28 @@ impl ContextServerFactoryRegistry {
}
}
pub fn context_server_factories(&self) -> Vec<(Arc<str>, ContextServerFactory)> {
pub fn context_server_descriptors(&self) -> Vec<(Arc<str>, Arc<dyn ContextServerDescriptor>)> {
self.context_servers
.iter()
.map(|(id, factory)| (id.clone(), factory.clone()))
.collect()
}
/// Registers the provided [`ContextServerFactory`].
pub fn register_server_factory(&mut self, id: Arc<str>, factory: ContextServerFactory) {
self.context_servers.insert(id, factory);
pub fn context_server_descriptor(&self, id: &str) -> Option<Arc<dyn ContextServerDescriptor>> {
self.context_servers.get(id).cloned()
}
/// Unregisters the [`ContextServerFactory`] for the server with the given ID.
pub fn unregister_server_factory_by_id(&mut self, server_id: &str) {
/// Registers the provided [`ContextServerDescriptor`].
pub fn register_context_server_descriptor(
&mut self,
id: Arc<str>,
descriptor: Arc<dyn ContextServerDescriptor>,
) {
self.context_servers.insert(id, descriptor);
}
/// Unregisters the [`ContextServerDescriptor`] for the server with the given ID.
pub fn unregister_context_server_descriptor_by_id(&mut self, server_id: &str) {
self.context_servers.remove(server_id);
}
}

View File

@@ -42,6 +42,30 @@ impl RequestType {
}
}
impl TryFrom<&str> for RequestType {
type Error = ();
fn try_from(s: &str) -> Result<Self, Self::Error> {
match s {
"initialize" => Ok(RequestType::Initialize),
"tools/call" => Ok(RequestType::CallTool),
"resources/unsubscribe" => Ok(RequestType::ResourcesUnsubscribe),
"resources/subscribe" => Ok(RequestType::ResourcesSubscribe),
"resources/read" => Ok(RequestType::ResourcesRead),
"resources/list" => Ok(RequestType::ResourcesList),
"logging/setLevel" => Ok(RequestType::LoggingSetLevel),
"prompts/get" => Ok(RequestType::PromptsGet),
"prompts/list" => Ok(RequestType::PromptsList),
"completion/complete" => Ok(RequestType::CompletionComplete),
"ping" => Ok(RequestType::Ping),
"tools/list" => Ok(RequestType::ListTools),
"resources/templates/list" => Ok(RequestType::ListResourceTemplates),
"roots/list" => Ok(RequestType::ListRoots),
_ => Err(()),
}
}
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ProtocolVersion(pub String);
@@ -154,7 +178,7 @@ pub struct CompletionArgument {
pub value: String,
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeResponse {
pub protocol_version: ProtocolVersion,
@@ -343,7 +367,7 @@ pub struct ClientCapabilities {
pub roots: Option<RootsCapabilities>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Default, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ServerCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]

View File

@@ -444,10 +444,7 @@ pub trait DebugAdapter: 'static + Send + Sync {
log::info!("Getting latest version of debug adapter {}", self.name());
delegate.update_status(self.name(), DapStatus::CheckingForUpdate);
if let Some(version) = self.fetch_latest_adapter_version(delegate).await.log_err() {
log::info!(
"Installiing latest version of debug adapter {}",
self.name()
);
log::info!("Installing latest version of debug adapter {}", self.name());
delegate.update_status(self.name(), DapStatus::Downloading);
match self.install_binary(version, delegate).await {
Ok(_) => {

View File

@@ -7,6 +7,7 @@ use crate::{
};
use crate::{new_session_modal::NewSessionModal, session::DebugSession};
use anyhow::{Result, anyhow};
use collections::{HashMap, HashSet};
use command_palette_hooks::CommandPaletteFilter;
use dap::DebugRequest;
use dap::{
@@ -26,6 +27,7 @@ use project::{Project, debugger::session::ThreadStatus};
use rpc::proto::{self};
use settings::Settings;
use std::any::TypeId;
use std::path::PathBuf;
use task::{DebugScenario, TaskContext};
use ui::{ContextMenu, Divider, DropdownMenu, Tooltip, prelude::*};
use workspace::SplitDirection;
@@ -403,7 +405,6 @@ impl DebugPanel {
pub fn resolve_scenario(
&self,
scenario: DebugScenario,
task_context: TaskContext,
buffer: Option<Entity<Buffer>>,
window: &Window,
@@ -424,8 +425,60 @@ impl DebugPanel {
stop_on_entry,
} = scenario;
let request = if let Some(mut request) = request {
// Resolve task variables within the request.
if let DebugRequest::Launch(_) = &mut request {}
if let DebugRequest::Launch(launch_config) = &mut request {
let mut variable_names = HashMap::default();
let mut substituted_variables = HashSet::default();
let task_variables = task_context
.task_variables
.iter()
.map(|(key, value)| {
let key_string = key.to_string();
if !variable_names.contains_key(&key_string) {
variable_names.insert(key_string.clone(), key.clone());
}
(key_string, value.as_str())
})
.collect::<HashMap<_, _>>();
let cwd = launch_config
.cwd
.as_ref()
.and_then(|cwd| cwd.to_str())
.and_then(|cwd| {
task::substitute_all_template_variables_in_str(
cwd,
&task_variables,
&variable_names,
&mut substituted_variables,
)
});
if let Some(cwd) = cwd {
launch_config.cwd = Some(PathBuf::from(cwd))
}
if let Some(program) = task::substitute_all_template_variables_in_str(
&launch_config.program,
&task_variables,
&variable_names,
&mut substituted_variables,
) {
launch_config.program = program;
}
for arg in launch_config.args.iter_mut() {
if let Some(substituted_arg) =
task::substitute_all_template_variables_in_str(
&arg,
&task_variables,
&variable_names,
&mut substituted_variables,
)
{
*arg = substituted_arg;
}
}
}
request
} else if let Some(build) = build {
@@ -944,6 +997,7 @@ impl DebugPanel {
past_debug_definition,
weak_panel,
workspace,
None,
window,
cx,
)

View File

@@ -158,6 +158,7 @@ pub fn init(cx: &mut App) {
debug_panel.read(cx).past_debug_definition.clone(),
weak_panel,
weak_workspace,
None,
window,
cx,
)
@@ -166,14 +167,22 @@ pub fn init(cx: &mut App) {
},
)
.register_action(|workspace: &mut Workspace, _: &Start, window, cx| {
tasks_ui::toggle_modal(
workspace,
None,
task::TaskModal::DebugModal,
window,
cx,
)
.detach();
if let Some(debug_panel) = workspace.panel::<DebugPanel>(cx) {
let weak_panel = debug_panel.downgrade();
let weak_workspace = cx.weak_entity();
let task_store = workspace.project().read(cx).task_store().clone();
workspace.toggle_modal(window, cx, |window, cx| {
NewSessionModal::new(
debug_panel.read(cx).past_debug_definition.clone(),
weak_panel,
weak_workspace,
Some(task_store),
window,
cx,
)
});
}
});
})
})

View File

@@ -6,19 +6,25 @@ use std::{
use dap::{DapRegistry, DebugRequest, adapters::DebugTaskDefinition};
use editor::{Editor, EditorElement, EditorStyle};
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
App, AppContext, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, TextStyle,
WeakEntity,
App, AppContext, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render,
Subscription, TextStyle, WeakEntity,
};
use picker::{Picker, PickerDelegate, highlighted_match_with_paths::HighlightedMatch};
use project::{TaskSourceKind, task_store::TaskStore};
use session_modes::{AttachMode, DebugScenarioDelegate, LaunchMode};
use settings::Settings;
use task::{DebugScenario, LaunchRequest, TaskContext};
use task::{DebugScenario, LaunchRequest};
use theme::ThemeSettings;
use ui::{
ActiveTheme, Button, ButtonCommon, ButtonSize, CheckboxWithLabel, Clickable, Color, Context,
ContextMenu, Disableable, DropdownMenu, FluentBuilder, InteractiveElement, IntoElement, Label,
LabelCommon as _, ParentElement, RenderOnce, SharedString, Styled, StyledExt, ToggleButton,
ToggleState, Toggleable, Window, div, h_flex, relative, rems, v_flex,
ContextMenu, Disableable, DropdownMenu, FluentBuilder, Icon, IconName, InteractiveElement,
IntoElement, Label, LabelCommon as _, ListItem, ListItemSpacing, ParentElement, RenderOnce,
SharedString, Styled, StyledExt, ToggleButton, ToggleState, Toggleable, Window, div, h_flex,
relative, rems, v_flex,
};
use util::ResultExt;
use workspace::{ModalView, Workspace};
use crate::{attach_modal::AttachModal, debugger_panel::DebugPanel};
@@ -57,6 +63,7 @@ impl NewSessionModal {
past_debug_definition: Option<DebugTaskDefinition>,
debug_panel: WeakEntity<DebugPanel>,
workspace: WeakEntity<Workspace>,
task_store: Option<Entity<TaskStore>>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -73,6 +80,18 @@ impl NewSessionModal {
_ => None,
};
if let Some(task_store) = task_store {
cx.defer_in(window, |this, window, cx| {
this.mode = NewSessionMode::scenario(
this.debug_panel.clone(),
this.workspace.clone(),
task_store,
window,
cx,
);
});
};
Self {
workspace: workspace.clone(),
debugger,
@@ -86,10 +105,10 @@ impl NewSessionModal {
}
}
fn debug_config(&self, cx: &App, debugger: &str) -> DebugScenario {
let request = self.mode.debug_task(cx);
fn debug_config(&self, cx: &App, debugger: &str) -> Option<DebugScenario> {
let request = self.mode.debug_task(cx)?;
let label = suggested_label(&request, debugger);
DebugScenario {
Some(DebugScenario {
adapter: debugger.to_owned().into(),
label,
request: Some(request),
@@ -100,21 +119,42 @@ impl NewSessionModal {
_ => None,
},
build: None,
}
})
}
fn start_new_session(&self, window: &mut Window, cx: &mut Context<Self>) {
let Some(debugger) = self.debugger.as_ref() else {
// todo: show in UI.
// todo(debugger): show in UI.
log::error!("No debugger selected");
return;
};
let config = self.debug_config(cx, debugger);
if let NewSessionMode::Scenario(picker) = &self.mode {
picker.update(cx, |picker, cx| {
picker.delegate.confirm(false, window, cx);
});
return;
}
let Some(config) = self.debug_config(cx, debugger) else {
log::error!("debug config not found in mode: {}", self.mode);
return;
};
let debug_panel = self.debug_panel.clone();
let workspace = self.workspace.clone();
cx.spawn_in(window, async move |this, cx| {
let task_contexts = workspace
.update_in(cx, |workspace, window, cx| {
tasks_ui::task_contexts(workspace, window, cx)
})?
.await;
let task_context = task_contexts.active_context().cloned().unwrap_or_default();
debug_panel.update_in(cx, |debug_panel, window, cx| {
debug_panel.start_session(config, TaskContext::default(), None, window, cx)
debug_panel.start_session(config, task_context, None, window, cx)
})?;
this.update(cx, |_, cx| {
cx.emit(DismissEvent);
@@ -256,9 +296,14 @@ impl NewSessionModal {
.iter()
.flat_map(|task_inventory| {
task_inventory.read(cx).list_debug_scenarios(
worktree.as_ref().map(|worktree| worktree.read(cx).id()),
worktree
.as_ref()
.map(|worktree| worktree.read(cx).id())
.iter()
.copied(),
)
})
.map(|(_source_kind, scenario)| scenario)
.collect()
})
.ok()
@@ -277,102 +322,22 @@ impl NewSessionModal {
}
}
#[derive(Clone)]
struct LaunchMode {
program: Entity<Editor>,
cwd: Entity<Editor>,
}
impl LaunchMode {
fn new(
past_launch_config: Option<LaunchRequest>,
window: &mut Window,
cx: &mut App,
) -> Entity<Self> {
let (past_program, past_cwd) = past_launch_config
.map(|config| (Some(config.program), config.cwd))
.unwrap_or_else(|| (None, None));
let program = cx.new(|cx| Editor::single_line(window, cx));
program.update(cx, |this, cx| {
this.set_placeholder_text("Program path", cx);
if let Some(past_program) = past_program {
this.set_text(past_program, window, cx);
};
});
let cwd = cx.new(|cx| Editor::single_line(window, cx));
cwd.update(cx, |this, cx| {
this.set_placeholder_text("Working Directory", cx);
if let Some(past_cwd) = past_cwd {
this.set_text(past_cwd.to_string_lossy(), window, cx);
};
});
cx.new(|_| Self { program, cwd })
}
fn debug_task(&self, cx: &App) -> task::LaunchRequest {
let path = self.cwd.read(cx).text(cx);
task::LaunchRequest {
program: self.program.read(cx).text(cx),
cwd: path.is_empty().not().then(|| PathBuf::from(path)),
args: Default::default(),
env: Default::default(),
}
}
}
#[derive(Clone)]
struct AttachMode {
definition: DebugTaskDefinition,
attach_picker: Entity<AttachModal>,
}
impl AttachMode {
fn new(
debugger: Option<SharedString>,
workspace: Entity<Workspace>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> Entity<Self> {
let definition = DebugTaskDefinition {
adapter: debugger.clone().unwrap_or_default(),
label: "Attach New Session Setup".into(),
request: dap::DebugRequest::Attach(task::AttachRequest { process_id: None }),
initialize_args: None,
tcp_connection: None,
stop_on_entry: Some(false),
};
let attach_picker = cx.new(|cx| {
let modal = AttachModal::new(definition.clone(), workspace, false, window, cx);
window.focus(&modal.focus_handle(cx));
modal
});
cx.new(|_| Self {
definition,
attach_picker,
})
}
fn debug_task(&self) -> task::AttachRequest {
task::AttachRequest { process_id: None }
}
}
static SELECT_DEBUGGER_LABEL: SharedString = SharedString::new_static("Select Debugger");
static SELECT_SCENARIO_LABEL: SharedString = SharedString::new_static("Select Profile");
#[derive(Clone)]
enum NewSessionMode {
Launch(Entity<LaunchMode>),
Scenario(Entity<Picker<DebugScenarioDelegate>>),
Attach(Entity<AttachMode>),
}
impl NewSessionMode {
fn debug_task(&self, cx: &App) -> DebugRequest {
fn debug_task(&self, cx: &App) -> Option<DebugRequest> {
match self {
NewSessionMode::Launch(entity) => entity.read(cx).debug_task(cx).into(),
NewSessionMode::Attach(entity) => entity.read(cx).debug_task().into(),
NewSessionMode::Launch(entity) => Some(entity.read(cx).debug_task(cx).into()),
NewSessionMode::Attach(entity) => Some(entity.read(cx).debug_task().into()),
NewSessionMode::Scenario(_) => None,
}
}
fn as_attach(&self) -> Option<&Entity<AttachMode>> {
@@ -382,6 +347,78 @@ impl NewSessionMode {
None
}
}
fn scenario(
debug_panel: WeakEntity<DebugPanel>,
workspace: WeakEntity<Workspace>,
task_store: Entity<TaskStore>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> NewSessionMode {
let picker = cx.new(|cx| {
Picker::uniform_list(
DebugScenarioDelegate::new(debug_panel, workspace, task_store),
window,
cx,
)
.modal(false)
});
cx.subscribe(&picker, |_, _, _, cx| {
cx.emit(DismissEvent);
})
.detach();
picker.focus_handle(cx).focus(window);
NewSessionMode::Scenario(picker)
}
fn attach(
debugger: Option<SharedString>,
workspace: Entity<Workspace>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> Self {
Self::Attach(AttachMode::new(debugger, workspace, window, cx))
}
fn launch(
past_launch_config: Option<LaunchRequest>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> Self {
Self::Launch(LaunchMode::new(past_launch_config, window, cx))
}
fn has_match(&self, cx: &App) -> bool {
match self {
NewSessionMode::Scenario(picker) => picker.read(cx).delegate.match_count() > 0,
NewSessionMode::Attach(picker) => {
picker
.read(cx)
.attach_picker
.read(cx)
.picker
.read(cx)
.delegate
.match_count()
> 0
}
_ => false,
}
}
}
impl std::fmt::Display for NewSessionMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mode = match self {
NewSessionMode::Launch(_) => "launch".to_owned(),
NewSessionMode::Attach(_) => "attach".to_owned(),
NewSessionMode::Scenario(_) => "scenario picker".to_owned(),
};
write!(f, "{}", mode)
}
}
impl Focusable for NewSessionMode {
@@ -389,6 +426,7 @@ impl Focusable for NewSessionMode {
match &self {
NewSessionMode::Launch(entity) => entity.read(cx).program.focus_handle(cx),
NewSessionMode::Attach(entity) => entity.read(cx).attach_picker.focus_handle(cx),
NewSessionMode::Scenario(entity) => entity.read(cx).focus_handle(cx),
}
}
}
@@ -437,27 +475,14 @@ impl RenderOnce for NewSessionMode {
NewSessionMode::Attach(entity) => entity.update(cx, |this, cx| {
this.clone().render(window, cx).into_any_element()
}),
NewSessionMode::Scenario(entity) => v_flex()
.w(rems(34.))
.child(entity.clone())
.into_any_element(),
}
}
}
impl NewSessionMode {
fn attach(
debugger: Option<SharedString>,
workspace: Entity<Workspace>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> Self {
Self::Attach(AttachMode::new(debugger, workspace, window, cx))
}
fn launch(
past_launch_config: Option<LaunchRequest>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> Self {
Self::Launch(LaunchMode::new(past_launch_config, window, cx))
}
}
fn render_editor(editor: &Entity<Editor>, window: &mut Window, cx: &App) -> impl IntoElement {
let settings = ThemeSettings::get_global(cx);
let theme = cx.theme();
@@ -519,6 +544,34 @@ impl Render for NewSessionModal {
h_flex()
.justify_start()
.w_full()
.child(
ToggleButton::new("debugger-session-ui-picker-button", "Scenarios")
.size(ButtonSize::Default)
.style(ui::ButtonStyle::Subtle)
.toggle_state(matches!(self.mode, NewSessionMode::Scenario(_)))
.on_click(cx.listener(|this, _, window, cx| {
let Some(task_store) = this
.workspace
.update(cx, |workspace, cx| {
workspace.project().read(cx).task_store().clone()
})
.ok()
else {
return;
};
this.mode = NewSessionMode::scenario(
this.debug_panel.clone(),
this.workspace.clone(),
task_store,
window,
cx,
);
cx.notify();
}))
.first(),
)
.child(
ToggleButton::new(
"debugger-session-ui-launch-button",
@@ -532,7 +585,7 @@ impl Render for NewSessionModal {
this.mode.focus_handle(cx).focus(window);
cx.notify();
}))
.first(),
.middle(),
)
.child(
ToggleButton::new(
@@ -601,10 +654,21 @@ impl Render for NewSessionModal {
})
.child(
Button::new("debugger-spawn", "Start")
.on_click(cx.listener(|this, _, window, cx| {
this.start_new_session(window, cx);
.on_click(cx.listener(|this, _, window, cx| match &this.mode {
NewSessionMode::Scenario(picker) => {
picker.update(cx, |picker, cx| {
picker.delegate.confirm(true, window, cx)
})
}
_ => this.start_new_session(window, cx),
}))
.disabled(self.debugger.is_none()),
.disabled(match self.mode {
NewSessionMode::Scenario(_) => !self.mode.has_match(cx),
NewSessionMode::Attach(_) => {
self.debugger.is_none() || !self.mode.has_match(cx)
}
NewSessionMode::Launch(_) => self.debugger.is_none(),
}),
),
),
)
@@ -619,3 +683,319 @@ impl Focusable for NewSessionModal {
}
impl ModalView for NewSessionModal {}
// This module makes sure that the modes setup the correct subscriptions whenever they're created
mod session_modes {
use std::rc::Rc;
use super::*;
#[derive(Clone)]
#[non_exhaustive]
pub(super) struct LaunchMode {
pub(super) program: Entity<Editor>,
pub(super) cwd: Entity<Editor>,
}
impl LaunchMode {
pub(super) fn new(
past_launch_config: Option<LaunchRequest>,
window: &mut Window,
cx: &mut App,
) -> Entity<Self> {
let (past_program, past_cwd) = past_launch_config
.map(|config| (Some(config.program), config.cwd))
.unwrap_or_else(|| (None, None));
let program = cx.new(|cx| Editor::single_line(window, cx));
program.update(cx, |this, cx| {
this.set_placeholder_text("Program path", cx);
if let Some(past_program) = past_program {
this.set_text(past_program, window, cx);
};
});
let cwd = cx.new(|cx| Editor::single_line(window, cx));
cwd.update(cx, |this, cx| {
this.set_placeholder_text("Working Directory", cx);
if let Some(past_cwd) = past_cwd {
this.set_text(past_cwd.to_string_lossy(), window, cx);
};
});
cx.new(|_| Self { program, cwd })
}
pub(super) fn debug_task(&self, cx: &App) -> task::LaunchRequest {
let path = self.cwd.read(cx).text(cx);
task::LaunchRequest {
program: self.program.read(cx).text(cx),
cwd: path.is_empty().not().then(|| PathBuf::from(path)),
args: Default::default(),
env: Default::default(),
}
}
}
#[derive(Clone)]
pub(super) struct AttachMode {
pub(super) definition: DebugTaskDefinition,
pub(super) attach_picker: Entity<AttachModal>,
_subscription: Rc<Subscription>,
}
impl AttachMode {
pub(super) fn new(
debugger: Option<SharedString>,
workspace: Entity<Workspace>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> Entity<Self> {
let definition = DebugTaskDefinition {
adapter: debugger.clone().unwrap_or_default(),
label: "Attach New Session Setup".into(),
request: dap::DebugRequest::Attach(task::AttachRequest { process_id: None }),
initialize_args: None,
tcp_connection: None,
stop_on_entry: Some(false),
};
let attach_picker = cx.new(|cx| {
let modal = AttachModal::new(definition.clone(), workspace, false, window, cx);
window.focus(&modal.focus_handle(cx));
modal
});
let subscription = cx.subscribe(&attach_picker, |_, _, _, cx| {
cx.emit(DismissEvent);
});
cx.new(|_| Self {
definition,
attach_picker,
_subscription: Rc::new(subscription),
})
}
pub(super) fn debug_task(&self) -> task::AttachRequest {
task::AttachRequest { process_id: None }
}
}
pub(super) struct DebugScenarioDelegate {
task_store: Entity<TaskStore>,
candidates: Option<Vec<(TaskSourceKind, DebugScenario)>>,
selected_index: usize,
matches: Vec<StringMatch>,
prompt: String,
debug_panel: WeakEntity<DebugPanel>,
workspace: WeakEntity<Workspace>,
}
impl DebugScenarioDelegate {
pub(super) fn new(
debug_panel: WeakEntity<DebugPanel>,
workspace: WeakEntity<Workspace>,
task_store: Entity<TaskStore>,
) -> Self {
Self {
task_store,
candidates: None,
selected_index: 0,
matches: Vec::new(),
prompt: String::new(),
debug_panel,
workspace,
}
}
}
impl PickerDelegate for DebugScenarioDelegate {
type ListItem = ui::ListItem;
fn match_count(&self) -> usize {
self.matches.len()
}
fn selected_index(&self) -> usize {
self.selected_index
}
fn set_selected_index(
&mut self,
ix: usize,
_window: &mut Window,
_cx: &mut Context<picker::Picker<Self>>,
) {
self.selected_index = ix;
}
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> std::sync::Arc<str> {
"".into()
}
fn update_matches(
&mut self,
query: String,
window: &mut Window,
cx: &mut Context<picker::Picker<Self>>,
) -> gpui::Task<()> {
let candidates: Vec<_> = match &self.candidates {
Some(candidates) => candidates
.into_iter()
.enumerate()
.map(|(index, (_, candidate))| {
StringMatchCandidate::new(index, candidate.label.as_ref())
})
.collect(),
None => {
let worktree_ids: Vec<_> = self
.workspace
.update(cx, |this, cx| {
this.visible_worktrees(cx)
.map(|tree| tree.read(cx).id())
.collect()
})
.ok()
.unwrap_or_default();
let scenarios: Vec<_> = self
.task_store
.read(cx)
.task_inventory()
.map(|item| item.read(cx).list_debug_scenarios(worktree_ids.into_iter()))
.unwrap_or_default();
self.candidates = Some(scenarios.clone());
scenarios
.into_iter()
.enumerate()
.map(|(index, (_, candidate))| {
StringMatchCandidate::new(index, candidate.label.as_ref())
})
.collect()
}
};
cx.spawn_in(window, async move |picker, cx| {
let matches = fuzzy::match_strings(
&candidates,
&query,
true,
1000,
&Default::default(),
cx.background_executor().clone(),
)
.await;
picker
.update(cx, |picker, _| {
let delegate = &mut picker.delegate;
delegate.matches = matches;
delegate.prompt = query;
if delegate.matches.is_empty() {
delegate.selected_index = 0;
} else {
delegate.selected_index =
delegate.selected_index.min(delegate.matches.len() - 1);
}
})
.log_err();
})
}
fn confirm(
&mut self,
_: bool,
window: &mut Window,
cx: &mut Context<picker::Picker<Self>>,
) {
let debug_scenario =
self.matches
.get(self.selected_index())
.and_then(|match_candidate| {
self.candidates
.as_ref()
.map(|candidates| candidates[match_candidate.candidate_id].clone())
});
let Some((task_source_kind, debug_scenario)) = debug_scenario else {
return;
};
let task_context = if let TaskSourceKind::Worktree {
id: worktree_id,
directory_in_worktree: _,
id_base: _,
} = task_source_kind
{
let workspace = self.workspace.clone();
cx.spawn_in(window, async move |_, cx| {
workspace
.update_in(cx, |workspace, window, cx| {
tasks_ui::task_contexts(workspace, window, cx)
})
.ok()?
.await
.task_context_for_worktree_id(worktree_id)
.cloned()
})
} else {
gpui::Task::ready(None)
};
cx.spawn_in(window, async move |this, cx| {
let task_context = task_context.await.unwrap_or_default();
this.update_in(cx, |this, window, cx| {
this.delegate
.debug_panel
.update(cx, |panel, cx| {
panel.start_session(debug_scenario, task_context, None, window, cx);
})
.ok();
cx.emit(DismissEvent);
})
.ok();
})
.detach();
}
fn dismissed(&mut self, _: &mut Window, cx: &mut Context<picker::Picker<Self>>) {
cx.emit(DismissEvent);
}
fn render_match(
&self,
ix: usize,
selected: bool,
window: &mut Window,
cx: &mut Context<picker::Picker<Self>>,
) -> Option<Self::ListItem> {
let hit = &self.matches[ix];
let highlighted_location = HighlightedMatch {
text: hit.string.clone(),
highlight_positions: hit.positions.clone(),
char_count: hit.string.chars().count(),
color: Color::Default,
};
let icon = Icon::new(IconName::FileTree)
.color(Color::Muted)
.size(ui::IconSize::Small);
Some(
ListItem::new(SharedString::from(format!("debug-scenario-selection-{ix}")))
.inset(true)
.start_slot::<Icon>(icon)
.spacing(ListItemSpacing::Sparse)
.toggle_state(selected)
.child(highlighted_location.render(window, cx)),
)
}
}
}

View File

@@ -188,7 +188,7 @@ impl Render for SubView {
cx.notify();
}))
.size_full()
// Add border uncoditionally to prevent layout shifts on focus changes.
// Add border unconditionally to prevent layout shifts on focus changes.
.border_1()
.when(self.pane_focus_handle.contains_focused(window, cx), |el| {
el.border_color(cx.theme().colors().pane_focused_border)

View File

@@ -14,7 +14,6 @@ doctest = false
[dependencies]
anyhow.workspace = true
cargo_metadata.workspace = true
collections.workspace = true
component.workspace = true
ctor.workspace = true
@@ -23,7 +22,6 @@ env_logger.workspace = true
futures.workspace = true
gpui.workspace = true
indoc.workspace = true
itertools.workspace = true
language.workspace = true
linkme.workspace = true
log.workspace = true
@@ -34,7 +32,6 @@ rand.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
text.workspace = true
theme.workspace = true
ui.workspace = true

View File

@@ -1,603 +0,0 @@
use std::{
path::{Component, Path, Prefix},
process::Stdio,
sync::atomic::{self, AtomicUsize},
};
use cargo_metadata::{
Message,
diagnostic::{Applicability, Diagnostic as CargoDiagnostic, DiagnosticLevel, DiagnosticSpan},
};
use collections::HashMap;
use gpui::{AppContext, Entity, Task};
use itertools::Itertools as _;
use language::Diagnostic;
use project::{
Worktree, lsp_store::rust_analyzer_ext::CARGO_DIAGNOSTICS_SOURCE_NAME,
project_settings::ProjectSettings,
};
use serde::{Deserialize, Serialize};
use settings::Settings;
use smol::{
channel::Receiver,
io::{AsyncBufReadExt, BufReader},
process::Command,
};
use ui::App;
use util::ResultExt;
use crate::ProjectDiagnosticsEditor;
#[derive(Debug, serde::Deserialize)]
#[serde(untagged)]
enum CargoMessage {
Cargo(Message),
Rustc(CargoDiagnostic),
}
/// Appends formatted string to a `String`.
macro_rules! format_to {
($buf:expr) => ();
($buf:expr, $lit:literal $($arg:tt)*) => {
{
use ::std::fmt::Write as _;
// We can't do ::std::fmt::Write::write_fmt($buf, format_args!($lit $($arg)*))
// unfortunately, as that loses out on autoref behavior.
_ = $buf.write_fmt(format_args!($lit $($arg)*))
}
};
}
pub fn cargo_diagnostics_sources(
editor: &ProjectDiagnosticsEditor,
cx: &App,
) -> Vec<Entity<Worktree>> {
let fetch_cargo_diagnostics = ProjectSettings::get_global(cx)
.diagnostics
.fetch_cargo_diagnostics();
if !fetch_cargo_diagnostics {
return Vec::new();
}
editor
.project
.read(cx)
.worktrees(cx)
.filter(|worktree| worktree.read(cx).entry_for_path("Cargo.toml").is_some())
.collect()
}
#[derive(Debug)]
pub enum FetchUpdate {
Diagnostic(CargoDiagnostic),
Progress(String),
}
#[derive(Debug)]
pub enum FetchStatus {
Started,
Progress { message: String },
Finished,
}
pub fn fetch_worktree_diagnostics(
worktree_root: &Path,
cx: &App,
) -> Option<(Task<()>, Receiver<FetchUpdate>)> {
let diagnostics_settings = ProjectSettings::get_global(cx)
.diagnostics
.cargo
.as_ref()
.filter(|cargo_diagnostics| cargo_diagnostics.fetch_cargo_diagnostics)?;
let command_string = diagnostics_settings
.diagnostics_fetch_command
.iter()
.join(" ");
let mut command_parts = diagnostics_settings.diagnostics_fetch_command.iter();
let mut command = Command::new(command_parts.next()?)
.args(command_parts)
.envs(diagnostics_settings.env.clone())
.current_dir(worktree_root)
.stdout(Stdio::piped())
.stderr(Stdio::null())
.kill_on_drop(true)
.spawn()
.log_err()?;
let stdout = command.stdout.take()?;
let mut reader = BufReader::new(stdout);
let (tx, rx) = smol::channel::unbounded();
let error_threshold = 10;
let cargo_diagnostics_fetch_task = cx.background_spawn(async move {
let _command = command;
let mut errors = 0;
loop {
let mut line = String::new();
match reader.read_line(&mut line).await {
Ok(0) => {
return;
},
Ok(_) => {
errors = 0;
let mut deserializer = serde_json::Deserializer::from_str(&line);
deserializer.disable_recursion_limit();
let send_result = match CargoMessage::deserialize(&mut deserializer) {
Ok(CargoMessage::Cargo(Message::CompilerMessage(message))) => tx.send(FetchUpdate::Diagnostic(message.message)).await,
Ok(CargoMessage::Cargo(Message::CompilerArtifact(artifact))) => tx.send(FetchUpdate::Progress(format!("Compiled {:?}", artifact.manifest_path.parent().unwrap_or(&artifact.manifest_path)))).await,
Ok(CargoMessage::Cargo(_)) => Ok(()),
Ok(CargoMessage::Rustc(rustc_message)) => tx.send(FetchUpdate::Diagnostic(rustc_message)).await,
Err(_) => {
log::debug!("Failed to parse cargo diagnostics from line '{line}'");
Ok(())
},
};
if send_result.is_err() {
return;
}
},
Err(e) => {
log::error!("Failed to read line from {command_string} command output when fetching cargo diagnostics: {e}");
errors += 1;
if errors >= error_threshold {
log::error!("Failed {error_threshold} times, aborting the diagnostics fetch");
return;
}
},
}
}
});
Some((cargo_diagnostics_fetch_task, rx))
}
static CARGO_DIAGNOSTICS_FETCH_GENERATION: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
struct CargoFetchDiagnosticData {
generation: usize,
}
pub fn next_cargo_fetch_generation() {
CARGO_DIAGNOSTICS_FETCH_GENERATION.fetch_add(1, atomic::Ordering::Release);
}
pub fn is_outdated_cargo_fetch_diagnostic(diagnostic: &Diagnostic) -> bool {
if let Some(data) = diagnostic
.data
.clone()
.and_then(|data| serde_json::from_value::<CargoFetchDiagnosticData>(data).ok())
{
let current_generation = CARGO_DIAGNOSTICS_FETCH_GENERATION.load(atomic::Ordering::Acquire);
data.generation < current_generation
} else {
false
}
}
/// Converts a Rust root diagnostic to LSP form
///
/// This flattens the Rust diagnostic by:
///
/// 1. Creating a LSP diagnostic with the root message and primary span.
/// 2. Adding any labelled secondary spans to `relatedInformation`
/// 3. Categorising child diagnostics as either `SuggestedFix`es,
/// `relatedInformation` or additional message lines.
///
/// If the diagnostic has no primary span this will return `None`
///
/// Taken from https://github.com/rust-lang/rust-analyzer/blob/fe7b4f2ad96f7c13cc571f45edc2c578b35dddb4/crates/rust-analyzer/src/diagnostics/to_proto.rs#L275-L285
pub(crate) fn map_rust_diagnostic_to_lsp(
worktree_root: &Path,
cargo_diagnostic: &CargoDiagnostic,
) -> Vec<(lsp::Url, lsp::Diagnostic)> {
let primary_spans: Vec<&DiagnosticSpan> = cargo_diagnostic
.spans
.iter()
.filter(|s| s.is_primary)
.collect();
if primary_spans.is_empty() {
return Vec::new();
}
let severity = diagnostic_severity(cargo_diagnostic.level);
let mut source = String::from(CARGO_DIAGNOSTICS_SOURCE_NAME);
let mut code = cargo_diagnostic.code.as_ref().map(|c| c.code.clone());
if let Some(code_val) = &code {
// See if this is an RFC #2103 scoped lint (e.g. from Clippy)
let scoped_code: Vec<&str> = code_val.split("::").collect();
if scoped_code.len() == 2 {
source = String::from(scoped_code[0]);
code = Some(String::from(scoped_code[1]));
}
}
let mut needs_primary_span_label = true;
let mut subdiagnostics = Vec::new();
let mut tags = Vec::new();
for secondary_span in cargo_diagnostic.spans.iter().filter(|s| !s.is_primary) {
if let Some(label) = secondary_span.label.clone() {
subdiagnostics.push(lsp::DiagnosticRelatedInformation {
location: location(worktree_root, secondary_span),
message: label,
});
}
}
let mut message = cargo_diagnostic.message.clone();
for child in &cargo_diagnostic.children {
let child = map_rust_child_diagnostic(worktree_root, child);
match child {
MappedRustChildDiagnostic::SubDiagnostic(sub) => {
subdiagnostics.push(sub);
}
MappedRustChildDiagnostic::MessageLine(message_line) => {
format_to!(message, "\n{message_line}");
// These secondary messages usually duplicate the content of the
// primary span label.
needs_primary_span_label = false;
}
}
}
if let Some(code) = &cargo_diagnostic.code {
let code = code.code.as_str();
if matches!(
code,
"dead_code"
| "unknown_lints"
| "unreachable_code"
| "unused_attributes"
| "unused_imports"
| "unused_macros"
| "unused_variables"
) {
tags.push(lsp::DiagnosticTag::UNNECESSARY);
}
if matches!(code, "deprecated") {
tags.push(lsp::DiagnosticTag::DEPRECATED);
}
}
let code_description = match source.as_str() {
"rustc" => rustc_code_description(code.as_deref()),
"clippy" => clippy_code_description(code.as_deref()),
_ => None,
};
let generation = CARGO_DIAGNOSTICS_FETCH_GENERATION.load(atomic::Ordering::Acquire);
let data = Some(
serde_json::to_value(CargoFetchDiagnosticData { generation })
.expect("Serializing a regular Rust struct"),
);
primary_spans
.iter()
.flat_map(|primary_span| {
let primary_location = primary_location(worktree_root, primary_span);
let message = {
let mut message = message.clone();
if needs_primary_span_label {
if let Some(primary_span_label) = &primary_span.label {
format_to!(message, "\n{primary_span_label}");
}
}
message
};
// Each primary diagnostic span may result in multiple LSP diagnostics.
let mut diagnostics = Vec::new();
let mut related_info_macro_calls = vec![];
// If error occurs from macro expansion, add related info pointing to
// where the error originated
// Also, we would generate an additional diagnostic, so that exact place of macro
// will be highlighted in the error origin place.
let span_stack = std::iter::successors(Some(*primary_span), |span| {
Some(&span.expansion.as_ref()?.span)
});
for (i, span) in span_stack.enumerate() {
if is_dummy_macro_file(&span.file_name) {
continue;
}
// First span is the original diagnostic, others are macro call locations that
// generated that code.
let is_in_macro_call = i != 0;
let secondary_location = location(worktree_root, span);
if secondary_location == primary_location {
continue;
}
related_info_macro_calls.push(lsp::DiagnosticRelatedInformation {
location: secondary_location.clone(),
message: if is_in_macro_call {
"Error originated from macro call here".to_owned()
} else {
"Actual error occurred here".to_owned()
},
});
// For the additional in-macro diagnostic we add the inverse message pointing to the error location in code.
let information_for_additional_diagnostic =
vec![lsp::DiagnosticRelatedInformation {
location: primary_location.clone(),
message: "Exact error occurred here".to_owned(),
}];
let diagnostic = lsp::Diagnostic {
range: secondary_location.range,
// downgrade to hint if we're pointing at the macro
severity: Some(lsp::DiagnosticSeverity::HINT),
code: code.clone().map(lsp::NumberOrString::String),
code_description: code_description.clone(),
source: Some(source.clone()),
message: message.clone(),
related_information: Some(information_for_additional_diagnostic),
tags: if tags.is_empty() {
None
} else {
Some(tags.clone())
},
data: data.clone(),
};
diagnostics.push((secondary_location.uri, diagnostic));
}
// Emit the primary diagnostic.
diagnostics.push((
primary_location.uri.clone(),
lsp::Diagnostic {
range: primary_location.range,
severity,
code: code.clone().map(lsp::NumberOrString::String),
code_description: code_description.clone(),
source: Some(source.clone()),
message,
related_information: {
let info = related_info_macro_calls
.iter()
.cloned()
.chain(subdiagnostics.iter().cloned())
.collect::<Vec<_>>();
if info.is_empty() { None } else { Some(info) }
},
tags: if tags.is_empty() {
None
} else {
Some(tags.clone())
},
data: data.clone(),
},
));
// Emit hint-level diagnostics for all `related_information` entries such as "help"s.
// This is useful because they will show up in the user's editor, unlike
// `related_information`, which just produces hard-to-read links, at least in VS Code.
let back_ref = lsp::DiagnosticRelatedInformation {
location: primary_location,
message: "original diagnostic".to_owned(),
};
for sub in &subdiagnostics {
diagnostics.push((
sub.location.uri.clone(),
lsp::Diagnostic {
range: sub.location.range,
severity: Some(lsp::DiagnosticSeverity::HINT),
code: code.clone().map(lsp::NumberOrString::String),
code_description: code_description.clone(),
source: Some(source.clone()),
message: sub.message.clone(),
related_information: Some(vec![back_ref.clone()]),
tags: None, // don't apply modifiers again
data: data.clone(),
},
));
}
diagnostics
})
.collect()
}
fn rustc_code_description(code: Option<&str>) -> Option<lsp::CodeDescription> {
code.filter(|code| {
let mut chars = code.chars();
chars.next() == Some('E')
&& chars.by_ref().take(4).all(|c| c.is_ascii_digit())
&& chars.next().is_none()
})
.and_then(|code| {
lsp::Url::parse(&format!(
"https://doc.rust-lang.org/error-index.html#{code}"
))
.ok()
.map(|href| lsp::CodeDescription { href })
})
}
fn clippy_code_description(code: Option<&str>) -> Option<lsp::CodeDescription> {
code.and_then(|code| {
lsp::Url::parse(&format!(
"https://rust-lang.github.io/rust-clippy/master/index.html#{code}"
))
.ok()
.map(|href| lsp::CodeDescription { href })
})
}
/// Determines the LSP severity from a diagnostic
fn diagnostic_severity(level: DiagnosticLevel) -> Option<lsp::DiagnosticSeverity> {
let res = match level {
DiagnosticLevel::Ice => lsp::DiagnosticSeverity::ERROR,
DiagnosticLevel::Error => lsp::DiagnosticSeverity::ERROR,
DiagnosticLevel::Warning => lsp::DiagnosticSeverity::WARNING,
DiagnosticLevel::Note => lsp::DiagnosticSeverity::INFORMATION,
DiagnosticLevel::Help => lsp::DiagnosticSeverity::HINT,
_ => return None,
};
Some(res)
}
enum MappedRustChildDiagnostic {
SubDiagnostic(lsp::DiagnosticRelatedInformation),
MessageLine(String),
}
fn map_rust_child_diagnostic(
worktree_root: &Path,
cargo_diagnostic: &CargoDiagnostic,
) -> MappedRustChildDiagnostic {
let spans: Vec<&DiagnosticSpan> = cargo_diagnostic
.spans
.iter()
.filter(|s| s.is_primary)
.collect();
if spans.is_empty() {
// `rustc` uses these spanless children as a way to print multi-line
// messages
return MappedRustChildDiagnostic::MessageLine(cargo_diagnostic.message.clone());
}
let mut edit_map: HashMap<lsp::Url, Vec<lsp::TextEdit>> = HashMap::default();
let mut suggested_replacements = Vec::new();
for &span in &spans {
if let Some(suggested_replacement) = &span.suggested_replacement {
if !suggested_replacement.is_empty() {
suggested_replacements.push(suggested_replacement);
}
let location = location(worktree_root, span);
let edit = lsp::TextEdit::new(location.range, suggested_replacement.clone());
// Only actually emit a quickfix if the suggestion is "valid enough".
// We accept both "MaybeIncorrect" and "MachineApplicable". "MaybeIncorrect" means that
// the suggestion is *complete* (contains no placeholders where code needs to be
// inserted), but might not be what the user wants, or might need minor adjustments.
if matches!(
span.suggestion_applicability,
None | Some(Applicability::MaybeIncorrect | Applicability::MachineApplicable)
) {
edit_map.entry(location.uri).or_default().push(edit);
}
}
}
// rustc renders suggestion diagnostics by appending the suggested replacement, so do the same
// here, otherwise the diagnostic text is missing useful information.
let mut message = cargo_diagnostic.message.clone();
if !suggested_replacements.is_empty() {
message.push_str(": ");
let suggestions = suggested_replacements
.iter()
.map(|suggestion| format!("`{suggestion}`"))
.join(", ");
message.push_str(&suggestions);
}
MappedRustChildDiagnostic::SubDiagnostic(lsp::DiagnosticRelatedInformation {
location: location(worktree_root, spans[0]),
message,
})
}
/// Converts a Rust span to a LSP location
fn location(worktree_root: &Path, span: &DiagnosticSpan) -> lsp::Location {
let file_name = worktree_root.join(&span.file_name);
let uri = url_from_abs_path(&file_name);
let range = {
lsp::Range::new(
position(span, span.line_start, span.column_start.saturating_sub(1)),
position(span, span.line_end, span.column_end.saturating_sub(1)),
)
};
lsp::Location::new(uri, range)
}
/// Returns a `Url` object from a given path, will lowercase drive letters if present.
/// This will only happen when processing windows paths.
///
/// When processing non-windows path, this is essentially the same as `Url::from_file_path`.
pub(crate) fn url_from_abs_path(path: &Path) -> lsp::Url {
let url = lsp::Url::from_file_path(path).unwrap();
match path.components().next() {
Some(Component::Prefix(prefix))
if matches!(prefix.kind(), Prefix::Disk(_) | Prefix::VerbatimDisk(_)) =>
{
// Need to lowercase driver letter
}
_ => return url,
}
let driver_letter_range = {
let (scheme, drive_letter, _rest) = match url.as_str().splitn(3, ':').collect_tuple() {
Some(it) => it,
None => return url,
};
let start = scheme.len() + ':'.len_utf8();
start..(start + drive_letter.len())
};
// Note: lowercasing the `path` itself doesn't help, the `Url::parse`
// machinery *also* canonicalizes the drive letter. So, just massage the
// string in place.
let mut url: String = url.into();
url[driver_letter_range].make_ascii_lowercase();
lsp::Url::parse(&url).unwrap()
}
fn position(
span: &DiagnosticSpan,
line_number: usize,
column_offset_utf32: usize,
) -> lsp::Position {
let line_index = line_number - span.line_start;
let column_offset_encoded = match span.text.get(line_index) {
// Fast path.
Some(line) if line.text.is_ascii() => column_offset_utf32,
Some(line) => {
let line_prefix_len = line
.text
.char_indices()
.take(column_offset_utf32)
.last()
.map(|(pos, c)| pos + c.len_utf8())
.unwrap_or(0);
let line_prefix = &line.text[..line_prefix_len];
line_prefix.len()
}
None => column_offset_utf32,
};
lsp::Position {
line: (line_number as u32).saturating_sub(1),
character: column_offset_encoded as u32,
}
}
/// Checks whether a file name is from macro invocation and does not refer to an actual file.
fn is_dummy_macro_file(file_name: &str) -> bool {
file_name.starts_with('<') && file_name.ends_with('>')
}
/// Extracts a suitable "primary" location from a rustc diagnostic.
///
/// This takes locations pointing into the standard library, or generally outside the current
/// workspace into account and tries to avoid those, in case macros are involved.
fn primary_location(worktree_root: &Path, span: &DiagnosticSpan) -> lsp::Location {
let span_stack = std::iter::successors(Some(span), |span| Some(&span.expansion.as_ref()?.span));
for span in span_stack.clone() {
let abs_path = worktree_root.join(&span.file_name);
if !is_dummy_macro_file(&span.file_name) && abs_path.starts_with(worktree_root) {
return location(worktree_root, span);
}
}
// Fall back to the outermost macro invocation if no suitable span comes up.
let last_span = span_stack.last().unwrap();
location(worktree_root, last_span)
}

View File

@@ -1,4 +1,3 @@
mod cargo;
pub mod items;
mod toolbar_controls;
@@ -8,18 +7,14 @@ mod diagnostic_renderer;
mod diagnostics_tests;
use anyhow::Result;
use cargo::{
FetchStatus, FetchUpdate, cargo_diagnostics_sources, fetch_worktree_diagnostics,
is_outdated_cargo_fetch_diagnostic, map_rust_diagnostic_to_lsp, next_cargo_fetch_generation,
url_from_abs_path,
};
use collections::{BTreeSet, HashMap, HashSet};
use collections::{BTreeSet, HashMap};
use diagnostic_renderer::DiagnosticBlock;
use editor::{
DEFAULT_MULTIBUFFER_CONTEXT, Editor, EditorEvent, ExcerptRange, MultiBuffer, PathKey,
display_map::{BlockPlacement, BlockProperties, BlockStyle, CustomBlockId},
scroll::Autoscroll,
};
use futures::future::join_all;
use gpui::{
AnyElement, AnyView, App, AsyncApp, Context, Entity, EventEmitter, FocusHandle, Focusable,
Global, InteractiveElement, IntoElement, ParentElement, Render, SharedString, Styled,
@@ -28,10 +23,10 @@ use gpui::{
use language::{
Bias, Buffer, BufferRow, BufferSnapshot, DiagnosticEntry, Point, ToTreeSitterPoint,
};
use lsp::{DiagnosticSeverity, LanguageServerId};
use lsp::DiagnosticSeverity;
use project::{
DiagnosticSummary, Project, ProjectPath, Worktree,
lsp_store::rust_analyzer_ext::{CARGO_DIAGNOSTICS_SOURCE_NAME, RUST_ANALYZER_NAME},
DiagnosticSummary, Project, ProjectPath,
lsp_store::rust_analyzer_ext::{cancel_flycheck, run_flycheck},
project_settings::ProjectSettings,
};
use settings::Settings;
@@ -84,8 +79,9 @@ pub(crate) struct ProjectDiagnosticsEditor {
}
struct CargoDiagnosticsFetchState {
task: Option<Task<()>>,
rust_analyzer: Option<LanguageServerId>,
fetch_task: Option<Task<()>>,
cancel_task: Option<Task<()>>,
diagnostic_sources: Arc<Vec<ProjectPath>>,
}
impl EventEmitter<EditorEvent> for ProjectDiagnosticsEditor {}
@@ -252,8 +248,9 @@ impl ProjectDiagnosticsEditor {
paths_to_update: Default::default(),
update_excerpts_task: None,
cargo_diagnostics_fetch: CargoDiagnosticsFetchState {
task: None,
rust_analyzer: None,
fetch_task: None,
cancel_task: None,
diagnostic_sources: Arc::new(Vec::new()),
},
_subscription: project_event_subscription,
};
@@ -346,7 +343,7 @@ impl ProjectDiagnosticsEditor {
.fetch_cargo_diagnostics();
if fetch_cargo_diagnostics {
if self.cargo_diagnostics_fetch.task.is_some() {
if self.cargo_diagnostics_fetch.fetch_task.is_some() {
self.stop_cargo_diagnostics_fetch(cx);
} else {
self.update_all_diagnostics(window, cx);
@@ -375,300 +372,63 @@ impl ProjectDiagnosticsEditor {
}
fn update_all_diagnostics(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let cargo_diagnostics_sources = cargo_diagnostics_sources(self, cx);
let cargo_diagnostics_sources = self.cargo_diagnostics_sources(cx);
if cargo_diagnostics_sources.is_empty() {
self.update_all_excerpts(window, cx);
} else {
self.fetch_cargo_diagnostics(Arc::new(cargo_diagnostics_sources), window, cx);
self.fetch_cargo_diagnostics(Arc::new(cargo_diagnostics_sources), cx);
}
}
fn fetch_cargo_diagnostics(
&mut self,
diagnostics_sources: Arc<Vec<Entity<Worktree>>>,
window: &mut Window,
diagnostics_sources: Arc<Vec<ProjectPath>>,
cx: &mut Context<Self>,
) {
self.cargo_diagnostics_fetch.task = Some(cx.spawn_in(window, async move |editor, cx| {
let rust_analyzer_server = editor
.update(cx, |editor, cx| {
editor
.project
.read(cx)
.language_server_with_name(RUST_ANALYZER_NAME, cx)
})
.ok();
let rust_analyzer_server = match rust_analyzer_server {
Some(rust_analyzer_server) => rust_analyzer_server.await,
None => None,
};
let project = self.project.clone();
self.cargo_diagnostics_fetch.cancel_task = None;
self.cargo_diagnostics_fetch.fetch_task = None;
self.cargo_diagnostics_fetch.diagnostic_sources = diagnostics_sources.clone();
if self.cargo_diagnostics_fetch.diagnostic_sources.is_empty() {
return;
}
let mut worktree_diagnostics_tasks = Vec::new();
let mut paths_with_reported_cargo_diagnostics = HashSet::default();
if let Some(rust_analyzer_server) = rust_analyzer_server {
let can_continue = editor
.update(cx, |editor, cx| {
editor.cargo_diagnostics_fetch.rust_analyzer = Some(rust_analyzer_server);
let status_inserted =
editor
.project
.read(cx)
.lsp_store()
.update(cx, |lsp_store, cx| {
if let Some(rust_analyzer_status) = lsp_store
.language_server_statuses
.get_mut(&rust_analyzer_server)
{
rust_analyzer_status
.progress_tokens
.insert(fetch_cargo_diagnostics_token());
paths_with_reported_cargo_diagnostics.extend(editor.diagnostics.iter().filter_map(|(buffer_id, diagnostics)| {
if diagnostics.iter().any(|d| d.diagnostic.source.as_deref() == Some(CARGO_DIAGNOSTICS_SOURCE_NAME)) {
Some(*buffer_id)
} else {
None
}
}).filter_map(|buffer_id| {
let buffer = lsp_store.buffer_store().read(cx).get(buffer_id)?;
let path = buffer.read(cx).file()?.as_local()?.abs_path(cx);
Some(url_from_abs_path(&path))
}));
true
} else {
false
}
});
if status_inserted {
editor.update_cargo_fetch_status(FetchStatus::Started, cx);
next_cargo_fetch_generation();
true
} else {
false
}
self.cargo_diagnostics_fetch.fetch_task = Some(cx.spawn(async move |editor, cx| {
let mut fetch_tasks = Vec::new();
for buffer_path in diagnostics_sources.iter().cloned() {
if cx
.update(|cx| {
fetch_tasks.push(run_flycheck(project.clone(), buffer_path, cx));
})
.unwrap_or(false);
if can_continue {
for worktree in diagnostics_sources.iter() {
if let Some(((_task, worktree_diagnostics), worktree_root)) = cx
.update(|_, cx| {
let worktree_root = worktree.read(cx).abs_path();
log::info!("Fetching cargo diagnostics for {worktree_root:?}");
fetch_worktree_diagnostics(&worktree_root, cx)
.zip(Some(worktree_root))
})
.ok()
.flatten()
{
let editor = editor.clone();
worktree_diagnostics_tasks.push(cx.spawn(async move |cx| {
let _task = _task;
let mut file_diagnostics = HashMap::default();
let mut diagnostics_total = 0;
let mut updated_urls = HashSet::default();
while let Ok(fetch_update) = worktree_diagnostics.recv().await {
match fetch_update {
FetchUpdate::Diagnostic(diagnostic) => {
for (url, diagnostic) in map_rust_diagnostic_to_lsp(
&worktree_root,
&diagnostic,
) {
let file_diagnostics = file_diagnostics
.entry(url)
.or_insert_with(Vec::<lsp::Diagnostic>::new);
let i = file_diagnostics
.binary_search_by(|probe| {
probe.range.start.cmp(&diagnostic.range.start)
.then(probe.range.end.cmp(&diagnostic.range.end))
.then(Ordering::Greater)
})
.unwrap_or_else(|i| i);
file_diagnostics.insert(i, diagnostic);
}
let file_changed = file_diagnostics.len() > 1;
if file_changed {
if editor
.update_in(cx, |editor, window, cx| {
editor
.project
.read(cx)
.lsp_store()
.update(cx, |lsp_store, cx| {
for (uri, mut diagnostics) in
file_diagnostics.drain()
{
diagnostics.dedup();
diagnostics_total += diagnostics.len();
updated_urls.insert(uri.clone());
lsp_store.merge_diagnostics(
rust_analyzer_server,
lsp::PublishDiagnosticsParams {
uri,
diagnostics,
version: None,
},
&[],
|diagnostic, _| {
!is_outdated_cargo_fetch_diagnostic(diagnostic)
},
cx,
)?;
}
anyhow::Ok(())
})?;
editor.update_all_excerpts(window, cx);
anyhow::Ok(())
})
.ok()
.transpose()
.ok()
.flatten()
.is_none()
{
break;
}
}
}
FetchUpdate::Progress(message) => {
if editor
.update(cx, |editor, cx| {
editor.update_cargo_fetch_status(
FetchStatus::Progress { message },
cx,
);
})
.is_err()
{
return updated_urls;
}
}
}
}
editor
.update_in(cx, |editor, window, cx| {
editor
.project
.read(cx)
.lsp_store()
.update(cx, |lsp_store, cx| {
for (uri, mut diagnostics) in
file_diagnostics.drain()
{
diagnostics.dedup();
diagnostics_total += diagnostics.len();
updated_urls.insert(uri.clone());
lsp_store.merge_diagnostics(
rust_analyzer_server,
lsp::PublishDiagnosticsParams {
uri,
diagnostics,
version: None,
},
&[],
|diagnostic, _| {
!is_outdated_cargo_fetch_diagnostic(diagnostic)
},
cx,
)?;
}
anyhow::Ok(())
})?;
editor.update_all_excerpts(window, cx);
anyhow::Ok(())
})
.ok();
log::info!("Fetched {diagnostics_total} cargo diagnostics for worktree {worktree_root:?}");
updated_urls
}));
}
}
} else {
log::info!(
"No rust-analyzer language server found, skipping diagnostics fetch"
);
.is_err()
{
break;
}
}
let updated_urls = futures::future::join_all(worktree_diagnostics_tasks).await.into_iter().flatten().collect();
if let Some(rust_analyzer_server) = rust_analyzer_server {
let _ = join_all(fetch_tasks).await;
editor
.update_in(cx, |editor, window, cx| {
editor
.project
.read(cx)
.lsp_store()
.update(cx, |lsp_store, cx| {
for uri_to_cleanup in paths_with_reported_cargo_diagnostics.difference(&updated_urls).cloned() {
lsp_store.merge_diagnostics(
rust_analyzer_server,
lsp::PublishDiagnosticsParams {
uri: uri_to_cleanup,
diagnostics: Vec::new(),
version: None,
},
&[],
|diagnostic, _| {
!is_outdated_cargo_fetch_diagnostic(diagnostic)
},
cx,
).ok();
}
});
editor.update_all_excerpts(window, cx);
editor.stop_cargo_diagnostics_fetch(cx);
cx.notify();
.update(cx, |editor, _| {
editor.cargo_diagnostics_fetch.fetch_task = None;
})
.ok();
}
}));
}
fn update_cargo_fetch_status(&self, status: FetchStatus, cx: &mut App) {
let Some(rust_analyzer) = self.cargo_diagnostics_fetch.rust_analyzer else {
return;
};
let work_done = match status {
FetchStatus::Started => lsp::WorkDoneProgress::Begin(lsp::WorkDoneProgressBegin {
title: "cargo".to_string(),
cancellable: None,
message: Some("Fetching cargo diagnostics".to_string()),
percentage: None,
}),
FetchStatus::Progress { message } => {
lsp::WorkDoneProgress::Report(lsp::WorkDoneProgressReport {
message: Some(message),
cancellable: None,
percentage: None,
})
}
FetchStatus::Finished => {
lsp::WorkDoneProgress::End(lsp::WorkDoneProgressEnd { message: None })
}
};
let progress = lsp::ProgressParams {
token: lsp::NumberOrString::String(fetch_cargo_diagnostics_token()),
value: lsp::ProgressParamsValue::WorkDone(work_done),
};
self.project
.read(cx)
.lsp_store()
.update(cx, |lsp_store, cx| {
lsp_store.on_lsp_progress(progress, rust_analyzer, None, cx)
});
}
fn stop_cargo_diagnostics_fetch(&mut self, cx: &mut App) {
self.update_cargo_fetch_status(FetchStatus::Finished, cx);
self.cargo_diagnostics_fetch.task = None;
log::info!("Finished fetching cargo diagnostics");
self.cargo_diagnostics_fetch.fetch_task = None;
let mut cancel_gasks = Vec::new();
for buffer_path in std::mem::take(&mut self.cargo_diagnostics_fetch.diagnostic_sources)
.iter()
.cloned()
{
cancel_gasks.push(cancel_flycheck(self.project.clone(), buffer_path, cx));
}
self.cargo_diagnostics_fetch.cancel_task = Some(cx.background_spawn(async move {
let _ = join_all(cancel_gasks).await;
log::info!("Finished fetching cargo diagnostics");
}));
}
/// Enqueue an update of all excerpts. Updates all paths that either
@@ -897,6 +657,30 @@ impl ProjectDiagnosticsEditor {
})
})
}
pub fn cargo_diagnostics_sources(&self, cx: &App) -> Vec<ProjectPath> {
let fetch_cargo_diagnostics = ProjectSettings::get_global(cx)
.diagnostics
.fetch_cargo_diagnostics();
if !fetch_cargo_diagnostics {
return Vec::new();
}
self.project
.read(cx)
.worktrees(cx)
.filter_map(|worktree| {
let _cargo_toml_entry = worktree.read(cx).entry_for_path("Cargo.toml")?;
let rust_file_entry = worktree.read(cx).entries(false, 0).find(|entry| {
entry
.path
.extension()
.and_then(|extension| extension.to_str())
== Some("rs")
})?;
self.project.read(cx).path_for_entry(rust_file_entry.id, cx)
})
.collect()
}
}
impl Focusable for ProjectDiagnosticsEditor {
@@ -1286,7 +1070,3 @@ fn is_line_blank_or_indented_less(
let line_indent = snapshot.line_indent_for_row(row);
line_indent.is_line_blank() || line_indent.len(tab_size) < indent_level
}
fn fetch_cargo_diagnostics_token() -> String {
"fetch_cargo_diagnostics".to_string()
}

View File

@@ -1,6 +1,5 @@
use std::sync::Arc;
use crate::cargo::cargo_diagnostics_sources;
use crate::{ProjectDiagnosticsEditor, ToggleDiagnosticsRefresh};
use gpui::{Context, Entity, EventEmitter, ParentElement, Render, WeakEntity, Window};
use ui::prelude::*;
@@ -16,11 +15,9 @@ impl Render for ToolbarControls {
let mut include_warnings = false;
let mut has_stale_excerpts = false;
let mut is_updating = false;
let cargo_diagnostics_sources = Arc::new(
self.diagnostics()
.map(|editor| cargo_diagnostics_sources(editor.read(cx), cx))
.unwrap_or_default(),
);
let cargo_diagnostics_sources = Arc::new(self.diagnostics().map_or(Vec::new(), |editor| {
editor.read(cx).cargo_diagnostics_sources(cx)
}));
let fetch_cargo_diagnostics = !cargo_diagnostics_sources.is_empty();
if let Some(editor) = self.diagnostics() {
@@ -28,7 +25,7 @@ impl Render for ToolbarControls {
include_warnings = diagnostics.include_warnings;
has_stale_excerpts = !diagnostics.paths_to_update.is_empty();
is_updating = if fetch_cargo_diagnostics {
diagnostics.cargo_diagnostics_fetch.task.is_some()
diagnostics.cargo_diagnostics_fetch.fetch_task.is_some()
} else {
diagnostics.update_excerpts_task.is_some()
|| diagnostics
@@ -93,7 +90,6 @@ impl Render for ToolbarControls {
if fetch_cargo_diagnostics {
diagnostics.fetch_cargo_diagnostics(
cargo_diagnostics_sources,
window,
cx,
);
} else {

View File

@@ -249,7 +249,9 @@ actions!(
ApplyDiffHunk,
Backspace,
Cancel,
CancelFlycheck,
CancelLanguageServerWork,
ClearFlycheck,
ConfirmRename,
ConfirmCompletionInsert,
ConfirmCompletionReplace,
@@ -308,6 +310,7 @@ actions!(
GoToImplementation,
GoToImplementationSplit,
GoToNextChange,
GoToParentModule,
GoToPreviousChange,
GoToPreviousDiagnostic,
GoToTypeDefinition,
@@ -371,6 +374,7 @@ actions!(
RevertFile,
ReloadFile,
Rewrap,
RunFlycheck,
ScrollCursorBottom,
ScrollCursorCenter,
ScrollCursorCenterTopBottom,

View File

@@ -6874,7 +6874,12 @@ impl Element for EditorElement {
// The max scroll position for the top of the window
let max_scroll_top = if matches!(
snapshot.mode,
EditorMode::AutoHeight { .. } | EditorMode::SingleLine { .. }
EditorMode::SingleLine { .. }
| EditorMode::AutoHeight { .. }
| EditorMode::Full {
sized_by_content: true,
..
}
) {
(max_row - height_in_lines + 1.).max(0.)
} else {

View File

@@ -4,15 +4,20 @@ use anyhow::Context as _;
use gpui::{App, AppContext as _, Context, Entity, Window};
use language::{Capability, Language, proto::serialize_anchor};
use multi_buffer::MultiBuffer;
use project::lsp_store::{
lsp_ext_command::{DocsUrls, ExpandMacro, ExpandedMacro},
rust_analyzer_ext::RUST_ANALYZER_NAME,
use project::{
ProjectItem,
lsp_command::location_link_from_proto,
lsp_store::{
lsp_ext_command::{DocsUrls, ExpandMacro, ExpandedMacro},
rust_analyzer_ext::{RUST_ANALYZER_NAME, cancel_flycheck, clear_flycheck, run_flycheck},
},
};
use rpc::proto;
use text::ToPointUtf16;
use crate::{
Editor, ExpandMacroRecursively, OpenDocs, element::register_action,
CancelFlycheck, ClearFlycheck, Editor, ExpandMacroRecursively, GoToParentModule,
GotoDefinitionKind, OpenDocs, RunFlycheck, element::register_action, hover_links::HoverLink,
lsp_ext::find_specific_language_server_in_selection,
};
@@ -30,11 +35,97 @@ pub fn apply_related_actions(editor: &Entity<Editor>, window: &mut Window, cx: &
.filter_map(|buffer| buffer.read(cx).language())
.any(|language| is_rust_language(language))
{
register_action(&editor, window, go_to_parent_module);
register_action(&editor, window, expand_macro_recursively);
register_action(&editor, window, open_docs);
register_action(&editor, window, cancel_flycheck_action);
register_action(&editor, window, run_flycheck_action);
register_action(&editor, window, clear_flycheck_action);
}
}
pub fn go_to_parent_module(
editor: &mut Editor,
_: &GoToParentModule,
window: &mut Window,
cx: &mut Context<Editor>,
) {
if editor.selections.count() == 0 {
return;
}
let Some(project) = &editor.project else {
return;
};
let server_lookup = find_specific_language_server_in_selection(
editor,
cx,
is_rust_language,
RUST_ANALYZER_NAME,
);
let project = project.clone();
let lsp_store = project.read(cx).lsp_store();
let upstream_client = lsp_store.read(cx).upstream_client();
cx.spawn_in(window, async move |editor, cx| {
let Some((trigger_anchor, _, server_to_query, buffer)) = server_lookup.await else {
return anyhow::Ok(());
};
let location_links = if let Some((client, project_id)) = upstream_client {
let buffer_id = buffer.update(cx, |buffer, _| buffer.remote_id())?;
let request = proto::LspExtGoToParentModule {
project_id,
buffer_id: buffer_id.to_proto(),
position: Some(serialize_anchor(&trigger_anchor.text_anchor)),
};
let response = client
.request(request)
.await
.context("lsp ext go to parent module proto request")?;
futures::future::join_all(
response
.links
.into_iter()
.map(|link| location_link_from_proto(link, lsp_store.clone(), cx)),
)
.await
.into_iter()
.collect::<anyhow::Result<_>>()
.context("go to parent module via collab")?
} else {
let buffer_snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
let position = trigger_anchor.text_anchor.to_point_utf16(&buffer_snapshot);
project
.update(cx, |project, cx| {
project.request_lsp(
buffer,
project::LanguageServerToQuery::Other(server_to_query),
project::lsp_store::lsp_ext_command::GoToParentModule { position },
cx,
)
})?
.await
.context("go to parent module")?
};
editor
.update_in(cx, |editor, window, cx| {
editor.navigate_to_hover_links(
Some(GotoDefinitionKind::Declaration),
location_links.into_iter().map(HoverLink::Text).collect(),
false,
window,
cx,
)
})?
.await?;
Ok(())
})
.detach_and_log_err(cx);
}
pub fn expand_macro_recursively(
editor: &mut Editor,
_: &ExpandMacroRecursively,
@@ -213,3 +304,87 @@ pub fn open_docs(editor: &mut Editor, _: &OpenDocs, window: &mut Window, cx: &mu
})
.detach_and_log_err(cx);
}
fn cancel_flycheck_action(
editor: &mut Editor,
_: &CancelFlycheck,
_: &mut Window,
cx: &mut Context<Editor>,
) {
let Some(project) = &editor.project else {
return;
};
let Some(buffer_id) = editor
.selections
.disjoint_anchors()
.iter()
.find_map(|selection| {
let buffer_id = selection.start.buffer_id.or(selection.end.buffer_id)?;
let project = project.read(cx);
let entry_id = project
.buffer_for_id(buffer_id, cx)?
.read(cx)
.entry_id(cx)?;
project.path_for_entry(entry_id, cx)
})
else {
return;
};
cancel_flycheck(project.clone(), buffer_id, cx).detach_and_log_err(cx);
}
fn run_flycheck_action(
editor: &mut Editor,
_: &RunFlycheck,
_: &mut Window,
cx: &mut Context<Editor>,
) {
let Some(project) = &editor.project else {
return;
};
let Some(buffer_id) = editor
.selections
.disjoint_anchors()
.iter()
.find_map(|selection| {
let buffer_id = selection.start.buffer_id.or(selection.end.buffer_id)?;
let project = project.read(cx);
let entry_id = project
.buffer_for_id(buffer_id, cx)?
.read(cx)
.entry_id(cx)?;
project.path_for_entry(entry_id, cx)
})
else {
return;
};
run_flycheck(project.clone(), buffer_id, cx).detach_and_log_err(cx);
}
fn clear_flycheck_action(
editor: &mut Editor,
_: &ClearFlycheck,
_: &mut Window,
cx: &mut Context<Editor>,
) {
let Some(project) = &editor.project else {
return;
};
let Some(buffer_id) = editor
.selections
.disjoint_anchors()
.iter()
.find_map(|selection| {
let buffer_id = selection.start.buffer_id.or(selection.end.buffer_id)?;
let project = project.read(cx);
let entry_id = project
.buffer_for_id(buffer_id, cx)?
.read(cx)
.entry_id(cx)?;
project.path_for_entry(entry_id, cx)
})
else {
return;
};
clear_flycheck(project.clone(), buffer_id, cx).detach_and_log_err(cx);
}

View File

@@ -48,6 +48,7 @@ markdown.workspace = true
node_runtime.workspace = true
pathdiff.workspace = true
paths.workspace = true
pretty_assertions.workspace = true
project.workspace = true
prompt_store.workspace = true
regex.workspace = true

View File

@@ -1,6 +1,7 @@
{
"assistant": {
"always_allow_tool_actions": true,
"stream_edits": true,
"version": "2"
}
}

View File

@@ -420,12 +420,18 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
languages::init(languages.clone(), node_runtime.clone(), cx);
assistant_tools::init(client.http_client(), cx);
context_server::init(cx);
prompt_store::init(cx);
let stdout_is_a_pty = false;
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
agent::init(
fs.clone(),
client.clone(),
prompt_builder.clone(),
languages.clone(),
cx,
);
assistant_tools::init(client.http_client(), cx);
SettingsStore::update_global(cx, |store, cx| {
store.set_user_settings(include_str!("../runner_settings.json"), cx)

View File

@@ -160,7 +160,11 @@ impl ExampleContext {
if left == right {
Ok(())
} else {
println!("{}{:#?} != {:#?}", self.log_prefix, left, right);
println!(
"{}{}",
self.log_prefix,
pretty_assertions::Comparison::new(&left, &right)
);
Err(anyhow::Error::from(FailedAssertion(message.clone())))
},
message,
@@ -334,8 +338,8 @@ impl ExampleContext {
}
pub fn edits(&self) -> HashMap<Arc<Path>, FileEdits> {
self.app
.read_entity(&self.agent_thread, |thread, cx| {
self.agent_thread
.read_with(&self.app, |thread, cx| {
let action_log = thread.action_log().read(cx);
HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
|(buffer, diff)| {
@@ -503,16 +507,16 @@ impl ToolUse {
}
}
#[derive(Debug)]
#[derive(Debug, Eq, PartialEq)]
pub struct FileEdits {
hunks: Vec<FileEditHunk>,
pub hunks: Vec<FileEditHunk>,
}
#[derive(Debug)]
struct FileEditHunk {
base_text: String,
text: String,
status: DiffHunkStatus,
#[derive(Debug, Eq, PartialEq)]
pub struct FileEditHunk {
pub base_text: String,
pub text: String,
pub status: DiffHunkStatus,
}
impl FileEdits {

View File

@@ -121,6 +121,12 @@ pub trait Extension: Send + Sync + 'static {
project: Arc<dyn ProjectDelegate>,
) -> Result<Command>;
async fn context_server_configuration(
&self,
context_server_id: Arc<str>,
project: Arc<dyn ProjectDelegate>,
) -> Result<Option<ContextServerConfiguration>>;
async fn suggest_docs_packages(&self, provider: Arc<str>) -> Result<Vec<String>>;
async fn index_docs(

View File

@@ -1,5 +1,9 @@
use std::sync::Arc;
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global};
use crate::ExtensionManifest;
pub fn init(cx: &mut App) {
let extension_events = cx.new(ExtensionEvents::new);
cx.set_global(GlobalExtensionEvents(extension_events));
@@ -31,7 +35,9 @@ impl ExtensionEvents {
#[derive(Clone)]
pub enum Event {
ExtensionInstalled(Arc<ExtensionManifest>),
ExtensionsInstalledChanged,
ConfigureExtensionRequested(Arc<ExtensionManifest>),
}
impl EventEmitter<Event> for ExtensionEvents {}

View File

@@ -1,8 +1,10 @@
mod context_server;
mod lsp;
mod slash_command;
use std::ops::Range;
pub use context_server::*;
pub use lsp::*;
pub use slash_command::*;

View File

@@ -0,0 +1,10 @@
/// Configuration for a context server.
#[derive(Debug, Clone)]
pub struct ContextServerConfiguration {
/// Installation instructions for the user.
pub installation_instructions: String,
/// Default settings for the context server.
pub default_settings: String,
/// JSON schema describing server settings.
pub settings_schema: serde_json::Value,
}

View File

@@ -18,6 +18,7 @@ pub use wit::{
CodeLabel, CodeLabelSpan, CodeLabelSpanLiteral, Command, DownloadedFileType, EnvVars,
KeyValueStore, LanguageServerInstallationStatus, Project, Range, Worktree, download_file,
make_file_executable,
zed::extension::context_server::ContextServerConfiguration,
zed::extension::github::{
GithubRelease, GithubReleaseAsset, GithubReleaseOptions, github_release_by_tag_name,
latest_github_release,
@@ -159,6 +160,15 @@ pub trait Extension: Send + Sync {
Err("`context_server_command` not implemented".to_string())
}
/// Returns the configuration options for the specified context server.
fn context_server_configuration(
&mut self,
_context_server_id: &ContextServerId,
_project: &Project,
) -> Result<Option<ContextServerConfiguration>> {
Ok(None)
}
/// Returns a list of package names as suggestions to be included in the
/// search results of the `/docs` slash command.
///
@@ -342,6 +352,14 @@ impl wit::Guest for Component {
extension().context_server_command(&context_server_id, project)
}
fn context_server_configuration(
context_server_id: String,
project: &Project,
) -> Result<Option<ContextServerConfiguration>, String> {
let context_server_id = ContextServerId(context_server_id);
extension().context_server_configuration(&context_server_id, project)
}
fn suggest_docs_packages(provider: String) -> Result<Vec<String>, String> {
extension().suggest_docs_packages(provider)
}

View File

@@ -0,0 +1,11 @@
interface context-server {
///
record context-server-configuration {
///
installation-instructions: string,
///
settings-schema: string,
///
default-settings: string,
}
}

View File

@@ -1,6 +1,7 @@
package zed:extension;
world extension {
import context-server;
import github;
import http-client;
import platform;
@@ -8,6 +9,7 @@ world extension {
import nodejs;
use common.{env-vars, range};
use context-server.{context-server-configuration};
use lsp.{completion, symbol};
use process.{command};
use slash-command.{slash-command, slash-command-argument-completion, slash-command-output};
@@ -139,6 +141,9 @@ world extension {
/// Returns the command used to start up a context server.
export context-server-command: func(context-server-id: string, project: borrow<project>) -> result<command, string>;
/// Returns the configuration for a context server.
export context-server-configuration: func(context-server-id: string, project: borrow<project>) -> result<option<context-server-configuration>, string>;
/// Returns a list of packages as suggestions to be included in the `/docs`
/// search results.
///

View File

@@ -431,6 +431,13 @@ impl ExtensionStore {
.filter_map(|extension| extension.dev.then_some(&extension.manifest))
}
pub fn extension_manifest_for_id(&self, extension_id: &str) -> Option<&Arc<ExtensionManifest>> {
self.extension_index
.extensions
.get(extension_id)
.map(|extension| &extension.manifest)
}
/// Returns the names of themes provided by extensions.
pub fn extension_themes<'a>(
&'a self,
@@ -744,8 +751,18 @@ impl ExtensionStore {
.await;
if let ExtensionOperation::Install = operation {
this.update( cx, |_, cx| {
cx.emit(Event::ExtensionInstalled(extension_id));
this.update( cx, |this, cx| {
cx.emit(Event::ExtensionInstalled(extension_id.clone()));
if let Some(events) = ExtensionEvents::try_global(cx) {
if let Some(manifest) = this.extension_manifest_for_id(&extension_id) {
events.update(cx, |this, cx| {
this.emit(
extension::Event::ExtensionInstalled(manifest.clone()),
cx,
)
});
}
}
})
.ok();
}
@@ -935,6 +952,17 @@ impl ExtensionStore {
.await?;
this.update(cx, |this, cx| this.reload(None, cx))?.await;
this.update(cx, |this, cx| {
cx.emit(Event::ExtensionInstalled(extension_id.clone()));
if let Some(events) = ExtensionEvents::try_global(cx) {
if let Some(manifest) = this.extension_manifest_for_id(&extension_id) {
events.update(cx, |this, cx| {
this.emit(extension::Event::ExtensionInstalled(manifest.clone()), cx)
});
}
}
})?;
Ok(())
})
}

View File

@@ -4,8 +4,9 @@ use crate::ExtensionManifest;
use anyhow::{Context as _, Result, anyhow, bail};
use async_trait::async_trait;
use extension::{
CodeLabel, Command, Completion, ExtensionHostProxy, KeyValueStoreDelegate, ProjectDelegate,
SlashCommand, SlashCommandArgumentCompletion, SlashCommandOutput, Symbol, WorktreeDelegate,
CodeLabel, Command, Completion, ContextServerConfiguration, ExtensionHostProxy,
KeyValueStoreDelegate, ProjectDelegate, SlashCommand, SlashCommandArgumentCompletion,
SlashCommandOutput, Symbol, WorktreeDelegate,
};
use fs::{Fs, normalize_path};
use futures::future::LocalBoxFuture;
@@ -306,6 +307,33 @@ impl extension::Extension for WasmExtension {
.await
}
async fn context_server_configuration(
&self,
context_server_id: Arc<str>,
project: Arc<dyn ProjectDelegate>,
) -> Result<Option<ContextServerConfiguration>> {
self.call(|extension, store| {
async move {
let project_resource = store.data_mut().table().push(project)?;
let Some(configuration) = extension
.call_context_server_configuration(
store,
context_server_id.clone(),
project_resource,
)
.await?
.map_err(|err| anyhow!("{err}"))?
else {
return Ok(None);
};
Ok(Some(configuration.try_into()?))
}
.boxed()
})
.await
}
async fn suggest_docs_packages(&self, provider: Arc<str>) -> Result<Vec<String>> {
self.call(|extension, store| {
async move {

View File

@@ -25,6 +25,7 @@ use wasmtime::{
pub use latest::CodeLabelSpanLiteral;
pub use latest::{
CodeLabel, CodeLabelSpan, Command, ExtensionProject, Range, SlashCommand,
zed::extension::context_server::ContextServerConfiguration,
zed::extension::lsp::{
Completion, CompletionKind, CompletionLabelDetails, InsertTextFormat, Symbol, SymbolKind,
},
@@ -726,6 +727,29 @@ impl Extension {
}
}
pub async fn call_context_server_configuration(
&self,
store: &mut Store<WasmState>,
context_server_id: Arc<str>,
project: Resource<ExtensionProject>,
) -> Result<Result<Option<ContextServerConfiguration>, String>> {
match self {
Extension::V0_5_0(ext) => {
ext.call_context_server_configuration(store, &context_server_id, project)
.await
}
Extension::V0_0_1(_)
| Extension::V0_0_4(_)
| Extension::V0_0_6(_)
| Extension::V0_1_0(_)
| Extension::V0_2_0(_)
| Extension::V0_3_0(_)
| Extension::V0_4_0(_) => Err(anyhow!(
"`context_server_configuration` not available prior to v0.5.0"
)),
}
}
pub async fn call_suggest_docs_packages(
&self,
store: &mut Store<WasmState>,

View File

@@ -247,6 +247,21 @@ impl From<SlashCommandArgumentCompletion> for extension::SlashCommandArgumentCom
}
}
impl TryFrom<ContextServerConfiguration> for extension::ContextServerConfiguration {
type Error = anyhow::Error;
fn try_from(value: ContextServerConfiguration) -> Result<Self, Self::Error> {
let settings_schema: serde_json::Value = serde_json::from_str(&value.settings_schema)
.context("Failed to parse settings_schema")?;
Ok(Self {
installation_instructions: value.installation_instructions,
default_settings: value.default_settings,
settings_schema,
})
}
}
impl HostKeyValueStore for WasmState {
async fn insert(
&mut self,
@@ -610,6 +625,9 @@ impl process::Host for WasmState {
#[async_trait]
impl slash_command::Host for WasmState {}
#[async_trait]
impl context_server::Host for WasmState {}
impl ExtensionImports for WasmState {
async fn get_settings(
&mut self,

View File

@@ -17,6 +17,7 @@ client.workspace = true
collections.workspace = true
db.workspace = true
editor.workspace = true
extension.workspace = true
extension_host.workspace = true
fs.workspace = true
fuzzy.workspace = true

View File

@@ -246,6 +246,12 @@ fn keywords_by_feature() -> &'static BTreeMap<Feature, Vec<&'static str>> {
})
}
struct ExtensionCardButtons {
install_or_uninstall: Button,
upgrade: Option<Button>,
configure: Option<Button>,
}
pub struct ExtensionsPage {
workspace: WeakEntity<Workspace>,
list: UniformListScrollHandle,
@@ -522,6 +528,8 @@ impl ExtensionsPage {
let repository_url = extension.repository.clone();
let can_configure = !extension.context_servers.is_empty();
ExtensionCard::new()
.child(
h_flex()
@@ -568,7 +576,36 @@ impl ExtensionsPage {
})
.color(Color::Accent)
.disabled(matches!(status, ExtensionStatus::Removing)),
),
)
.when(can_configure, |this| {
this.child(
Button::new(
SharedString::from(format!("configure-{}", extension.id)),
"Configure",
)
.on_click({
let manifest = Arc::new(extension.clone());
move |_, _, cx| {
if let Some(events) =
extension::ExtensionEvents::try_global(cx)
{
events.update(cx, |this, cx| {
this.emit(
extension::Event::ConfigureExtensionRequested(
manifest.clone(),
),
cx,
)
});
}
}
})
.color(Color::Accent)
.disabled(matches!(status, ExtensionStatus::Installing)),
)
}),
),
)
.child(
@@ -629,8 +666,7 @@ impl ExtensionsPage {
let has_dev_extension = Self::dev_extension_exists(&extension.id, cx);
let extension_id = extension.id.clone();
let (install_or_uninstall_button, upgrade_button) =
self.buttons_for_entry(extension, &status, has_dev_extension, cx);
let buttons = self.buttons_for_entry(extension, &status, has_dev_extension, cx);
let version = extension.manifest.version.clone();
let repository_url = extension.manifest.repository.clone();
let authors = extension.manifest.authors.clone();
@@ -695,8 +731,9 @@ impl ExtensionsPage {
h_flex()
.gap_2()
.justify_between()
.children(upgrade_button)
.child(install_or_uninstall_button),
.children(buttons.upgrade)
.children(buttons.configure)
.child(buttons.install_or_uninstall),
),
)
.child(
@@ -861,22 +898,35 @@ impl ExtensionsPage {
status: &ExtensionStatus,
has_dev_extension: bool,
cx: &mut Context<Self>,
) -> (Button, Option<Button>) {
) -> ExtensionCardButtons {
let is_compatible =
extension_host::is_version_compatible(ReleaseChannel::global(cx), extension);
if has_dev_extension {
// If we have a dev extension for the given extension, just treat it as uninstalled.
// The button here is a placeholder, as it won't be interactable anyways.
return (
Button::new(SharedString::from(extension.id.clone()), "Install"),
None,
);
return ExtensionCardButtons {
install_or_uninstall: Button::new(
SharedString::from(extension.id.clone()),
"Install",
),
configure: None,
upgrade: None,
};
}
let is_configurable = extension
.manifest
.provides
.contains(&ExtensionProvides::ContextServers);
match status.clone() {
ExtensionStatus::NotInstalled => (
Button::new(SharedString::from(extension.id.clone()), "Install").on_click({
ExtensionStatus::NotInstalled => ExtensionCardButtons {
install_or_uninstall: Button::new(
SharedString::from(extension.id.clone()),
"Install",
)
.on_click({
let extension_id = extension.id.clone();
move |_, _, cx| {
telemetry::event!("Extension Installed");
@@ -885,20 +935,41 @@ impl ExtensionsPage {
});
}
}),
None,
),
ExtensionStatus::Installing => (
Button::new(SharedString::from(extension.id.clone()), "Install").disabled(true),
None,
),
ExtensionStatus::Upgrading => (
Button::new(SharedString::from(extension.id.clone()), "Uninstall").disabled(true),
Some(
configure: None,
upgrade: None,
},
ExtensionStatus::Installing => ExtensionCardButtons {
install_or_uninstall: Button::new(
SharedString::from(extension.id.clone()),
"Install",
)
.disabled(true),
configure: None,
upgrade: None,
},
ExtensionStatus::Upgrading => ExtensionCardButtons {
install_or_uninstall: Button::new(
SharedString::from(extension.id.clone()),
"Uninstall",
)
.disabled(true),
configure: is_configurable.then(|| {
Button::new(
SharedString::from(format!("configure-{}", extension.id.clone())),
"Configure",
)
.disabled(true)
}),
upgrade: Some(
Button::new(SharedString::from(extension.id.clone()), "Upgrade").disabled(true),
),
),
ExtensionStatus::Installed(installed_version) => (
Button::new(SharedString::from(extension.id.clone()), "Uninstall").on_click({
},
ExtensionStatus::Installed(installed_version) => ExtensionCardButtons {
install_or_uninstall: Button::new(
SharedString::from(extension.id.clone()),
"Uninstall",
)
.on_click({
let extension_id = extension.id.clone();
move |_, _, cx| {
telemetry::event!("Extension Uninstalled", extension_id);
@@ -907,7 +978,32 @@ impl ExtensionsPage {
});
}
}),
if installed_version == extension.manifest.version {
configure: is_configurable.then(|| {
Button::new(
SharedString::from(format!("configure-{}", extension.id.clone())),
"Configure",
)
.on_click({
let extension_id = extension.id.clone();
move |_, _, cx| {
if let Some(manifest) = ExtensionStore::global(cx)
.read(cx)
.extension_manifest_for_id(&extension_id)
.cloned()
{
if let Some(events) = extension::ExtensionEvents::try_global(cx) {
events.update(cx, |this, cx| {
this.emit(
extension::Event::ConfigureExtensionRequested(manifest),
cx,
)
});
}
}
}
})
}),
upgrade: if installed_version == extension.manifest.version {
None
} else {
Some(
@@ -944,11 +1040,22 @@ impl ExtensionsPage {
}),
)
},
),
ExtensionStatus::Removing => (
Button::new(SharedString::from(extension.id.clone()), "Uninstall").disabled(true),
None,
),
},
ExtensionStatus::Removing => ExtensionCardButtons {
install_or_uninstall: Button::new(
SharedString::from(extension.id.clone()),
"Uninstall",
)
.disabled(true),
configure: is_configurable.then(|| {
Button::new(
SharedString::from(format!("configure-{}", extension.id.clone())),
"Configure",
)
.disabled(true)
}),
upgrade: None,
},
}
}

View File

@@ -59,6 +59,12 @@ impl FeatureFlag for Assistant2FeatureFlag {
const NAME: &'static str = "assistant2";
}
pub struct AgentStreamEditsFeatureFlag;
impl FeatureFlag for AgentStreamEditsFeatureFlag {
const NAME: &'static str = "agent-stream-edits";
}
pub struct NewBillingFeatureFlag;
impl FeatureFlag for NewBillingFeatureFlag {

View File

@@ -322,7 +322,7 @@ impl GitRepository for FakeGitRepository {
.iter()
.map(|branch_name| Branch {
is_head: Some(branch_name) == current_branch.as_ref(),
name: branch_name.into(),
ref_name: branch_name.into(),
most_recent_commit: None,
upstream: None,
})

View File

@@ -37,12 +37,24 @@ pub const REMOTE_CANCELLED_BY_USER: &str = "Operation cancelled by user";
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
pub struct Branch {
pub is_head: bool,
pub name: SharedString,
pub ref_name: SharedString,
pub upstream: Option<Upstream>,
pub most_recent_commit: Option<CommitSummary>,
}
impl Branch {
pub fn name(&self) -> &str {
self.ref_name
.as_ref()
.strip_prefix("refs/heads/")
.or_else(|| self.ref_name.as_ref().strip_prefix("refs/remotes/"))
.unwrap_or(self.ref_name.as_ref())
}
pub fn is_remote(&self) -> bool {
self.ref_name.starts_with("refs/remotes/")
}
pub fn tracking_status(&self) -> Option<UpstreamTrackingStatus> {
self.upstream
.as_ref()
@@ -71,6 +83,10 @@ impl Upstream {
.strip_prefix("refs/remotes/")
.and_then(|stripped| stripped.split("/").next())
}
pub fn stripped_ref_name(&self) -> Option<&str> {
self.ref_name.strip_prefix("refs/remotes/")
}
}
#[derive(Clone, Copy, Default)]
@@ -803,68 +819,69 @@ impl GitRepository for RealGitRepository {
fn branches(&self) -> BoxFuture<Result<Vec<Branch>>> {
let working_directory = self.working_directory();
let git_binary_path = self.git_binary_path.clone();
async move {
let fields = [
"%(HEAD)",
"%(objectname)",
"%(parent)",
"%(refname)",
"%(upstream)",
"%(upstream:track)",
"%(committerdate:unix)",
"%(contents:subject)",
]
.join("%00");
let args = vec![
"for-each-ref",
"refs/heads/**/*",
"refs/remotes/**/*",
"--format",
&fields,
];
let working_directory = working_directory?;
let output = new_smol_command(&git_binary_path)
.current_dir(&working_directory)
.args(args)
.output()
.await?;
if !output.status.success() {
return Err(anyhow!(
"Failed to git git branches:\n{}",
String::from_utf8_lossy(&output.stderr)
));
}
let input = String::from_utf8_lossy(&output.stdout);
let mut branches = parse_branch_input(&input)?;
if branches.is_empty() {
let args = vec!["symbolic-ref", "--quiet", "--short", "HEAD"];
self.executor
.spawn(async move {
let fields = [
"%(HEAD)",
"%(objectname)",
"%(parent)",
"%(refname)",
"%(upstream)",
"%(upstream:track)",
"%(committerdate:unix)",
"%(contents:subject)",
]
.join("%00");
let args = vec![
"for-each-ref",
"refs/heads/**/*",
"refs/remotes/**/*",
"--format",
&fields,
];
let working_directory = working_directory?;
let output = new_smol_command(&git_binary_path)
.current_dir(&working_directory)
.args(args)
.output()
.await?;
// git symbolic-ref returns a non-0 exit code if HEAD points
// to something other than a branch
if output.status.success() {
let name = String::from_utf8_lossy(&output.stdout).trim().to_string();
branches.push(Branch {
name: name.into(),
is_head: true,
upstream: None,
most_recent_commit: None,
});
if !output.status.success() {
return Err(anyhow!(
"Failed to git git branches:\n{}",
String::from_utf8_lossy(&output.stderr)
));
}
}
Ok(branches)
}
.boxed()
let input = String::from_utf8_lossy(&output.stdout);
let mut branches = parse_branch_input(&input)?;
if branches.is_empty() {
let args = vec!["symbolic-ref", "--quiet", "HEAD"];
let output = new_smol_command(&git_binary_path)
.current_dir(&working_directory)
.args(args)
.output()
.await?;
// git symbolic-ref returns a non-0 exit code if HEAD points
// to something other than a branch
if output.status.success() {
let name = String::from_utf8_lossy(&output.stdout).trim().to_string();
branches.push(Branch {
ref_name: name.into(),
is_head: true,
upstream: None,
most_recent_commit: None,
});
}
}
Ok(branches)
})
.boxed()
}
fn change_branch(&self, name: String) -> BoxFuture<Result<()>> {
@@ -1691,15 +1708,7 @@ fn parse_branch_input(input: &str) -> Result<Vec<Branch>> {
let is_current_branch = fields.next().context("no HEAD")? == "*";
let head_sha: SharedString = fields.next().context("no objectname")?.to_string().into();
let parent_sha: SharedString = fields.next().context("no parent")?.to_string().into();
let raw_ref_name = fields.next().context("no refname")?;
let ref_name: SharedString =
if let Some(ref_name) = raw_ref_name.strip_prefix("refs/heads/") {
ref_name.to_string().into()
} else if let Some(ref_name) = raw_ref_name.strip_prefix("refs/remotes/") {
ref_name.to_string().into()
} else {
return Err(anyhow!("unexpected format for refname"));
};
let ref_name = fields.next().context("no refname")?.to_string().into();
let upstream_name = fields.next().context("no upstream")?.to_string();
let upstream_tracking = parse_upstream_track(fields.next().context("no upstream:track")?)?;
let commiterdate = fields.next().context("no committerdate")?.parse::<i64>()?;
@@ -1711,7 +1720,7 @@ fn parse_branch_input(input: &str) -> Result<Vec<Branch>> {
branches.push(Branch {
is_head: is_current_branch,
name: ref_name,
ref_name: ref_name,
most_recent_commit: Some(CommitSummary {
sha: head_sha,
subject,
@@ -1974,7 +1983,7 @@ mod tests {
parse_branch_input(&input).unwrap(),
vec![Branch {
is_head: true,
name: "zed-patches".into(),
ref_name: "refs/heads/zed-patches".into(),
upstream: Some(Upstream {
ref_name: "refs/remotes/origin/zed-patches".into(),
tracking: UpstreamTracking::Tracked(UpstreamTrackingStatus {

View File

@@ -1,6 +1,7 @@
use anyhow::{Context as _, anyhow};
use fuzzy::StringMatchCandidate;
use collections::HashSet;
use git::repository::Branch;
use gpui::{
App, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement,
@@ -95,12 +96,28 @@ impl BranchList {
.context("No active repository")?
.await??;
all_branches.sort_by_key(|branch| {
branch
.most_recent_commit
.as_ref()
.map(|commit| 0 - commit.commit_timestamp)
});
let all_branches = cx
.background_spawn(async move {
let upstreams: HashSet<_> = all_branches
.iter()
.filter_map(|branch| {
let upstream = branch.upstream.as_ref()?;
Some(upstream.ref_name.clone())
})
.collect();
all_branches.retain(|branch| !upstreams.contains(&branch.ref_name));
all_branches.sort_by_key(|branch| {
branch
.most_recent_commit
.as_ref()
.map(|commit| 0 - commit.commit_timestamp)
});
all_branches
})
.await;
this.update_in(cx, |this, window, cx| {
this.picker.update(cx, |picker, cx| {
@@ -266,6 +283,7 @@ impl PickerDelegate for BranchListDelegate {
let mut matches: Vec<BranchEntry> = if query.is_empty() {
all_branches
.into_iter()
.filter(|branch| !branch.is_remote())
.take(RECENT_BRANCHES_COUNT)
.map(|branch| BranchEntry {
branch,
@@ -277,7 +295,7 @@ impl PickerDelegate for BranchListDelegate {
let candidates = all_branches
.iter()
.enumerate()
.map(|(ix, command)| StringMatchCandidate::new(ix, &command.name.clone()))
.map(|(ix, branch)| StringMatchCandidate::new(ix, branch.name()))
.collect::<Vec<StringMatchCandidate>>();
fuzzy::match_strings(
&candidates,
@@ -303,11 +321,11 @@ impl PickerDelegate for BranchListDelegate {
if !query.is_empty()
&& !matches
.first()
.is_some_and(|entry| entry.branch.name == query)
.is_some_and(|entry| entry.branch.name() == query)
{
matches.push(BranchEntry {
branch: Branch {
name: query.clone().into(),
ref_name: format!("refs/heads/{query}").into(),
is_head: false,
upstream: None,
most_recent_commit: None,
@@ -335,19 +353,19 @@ impl PickerDelegate for BranchListDelegate {
return;
};
if entry.is_new {
self.create_branch(entry.branch.name.clone(), window, cx);
self.create_branch(entry.branch.name().to_owned().into(), window, cx);
return;
}
let current_branch = self.repo.as_ref().map(|repo| {
repo.update(cx, |repo, _| {
repo.branch.as_ref().map(|branch| branch.name.clone())
repo.branch.as_ref().map(|branch| branch.ref_name.clone())
})
});
if current_branch
.flatten()
.is_some_and(|current_branch| current_branch == entry.branch.name)
.is_some_and(|current_branch| current_branch == entry.branch.ref_name)
{
cx.emit(DismissEvent);
return;
@@ -368,7 +386,7 @@ impl PickerDelegate for BranchListDelegate {
anyhow::Ok(async move {
repo.update(&mut cx, |repo, _| {
repo.change_branch(branch.name.to_string())
repo.change_branch(branch.name().to_string())
})?
.await?
})
@@ -443,13 +461,13 @@ impl PickerDelegate for BranchListDelegate {
if entry.is_new {
Label::new(format!(
"Create branch \"{}\"",
entry.branch.name
entry.branch.name()
))
.single_line()
.into_any_element()
} else {
HighlightedLabel::new(
entry.branch.name.clone(),
entry.branch.name().to_owned(),
entry.positions.clone(),
)
.truncate()
@@ -470,7 +488,7 @@ impl PickerDelegate for BranchListDelegate {
let message = if entry.is_new {
if let Some(current_branch) =
self.repo.as_ref().and_then(|repo| {
repo.read(cx).branch.as_ref().map(|b| b.name.clone())
repo.read(cx).branch.as_ref().map(|b| b.name())
})
{
format!("based off {}", current_branch)

Some files were not shown because too many files have changed in this diff Show More