Compare commits

...

52 Commits

Author SHA1 Message Date
Antonio Scandurra
7ab62666c3 Emit agent's position when streaming edits
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
2025-05-02 12:11:29 +02:00
Antonio Scandurra
35539847a4 Allow StreamingEditFileTool to also create files (#29785)
Refs #29733 

This pull request introduces a new field to the `StreamingEditFileTool`
that lets the model create or overwrite a file in a streaming way. When
one of the `assistant.stream_edits` setting / `agent-stream-edits`
feature flag is enabled, we are going to disable the `CreateFileTool` so
that the agent model can only use `StreamingEditFileTool` for file
creation.

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
2025-05-02 09:57:04 +00:00
Anthony Eid
f619d5f02a debugger: Add debug task picker to new session modal (#29702)
## Preview 

![image](https://github.com/user-attachments/assets/203a577f-3b38-4017-9571-de1234415162)


### TODO
- [x] Add scenario picker to new session modal
- [x] Make debugger start action open new session modal instead of task
modal
- [x] Fix `esc` not clearing the cancelling the new session modal while
it's in scenario or attach mode
- [x] Resolve debug scenario's correctly

Release Notes:

- N/A
2025-05-02 08:38:29 +00:00
Kirill Bulatov
ba59305510 Use rust-analyzer's flycheck as source of cargo diagnostics (#29779)
Follow-up of https://github.com/zed-industries/zed/pull/29706

Instead of doing `cargo check` manually, use rust-analyzer's flycheck:
at the cost of more sophisticated check command configuration, we keep
much less code in Zed, and get a proper progress report.

User-facing UI does not change except `diagnostics_fetch_command` and
`env` settings removed from the diagnostics settings.

Release Notes:

- N/A
2025-05-02 10:07:51 +03:00
Nate Butler
672a1dd553 Add Agent Preview trait (#29760)
Like the title says

Release Notes:

- N/A
2025-05-01 23:03:06 -04:00
Marshall Bowers
93cc4946d8 agent: Make thread completion mode non-optional (#29772)
This PR makes the thread completion mode non-optional.

Release Notes:

- N/A
2025-05-02 02:41:54 +00:00
Marshall Bowers
0c0a4ed866 collab: Return increased limit for extended trials from GET /billing/usage (#29771)
This PR updates the `GET /billing/usage` endpoint to return the
increased usage limit for users in the extended trial.

Release Notes:

- N/A
2025-05-02 02:31:30 +00:00
Marshall Bowers
51f1998107 Fix typo in typos.toml (#29770)
This PR fixes a typo in `typos.toml`. How ironic.

Release Notes:

- N/A
2025-05-02 02:01:07 +00:00
Marshall Bowers
1ffedf4a08 collab: Add endpoint for migrating users to new billing (#29769)
This PR adds a new `POST /billing/subscriptions/migrate` endpoint for
migrating users to the new billing system.

When called with a GitHub user ID this endpoint will:

1. Find the active billing subscription for this user (if they have one)
2. Cancel the subscription and send a final invoice
3. Ensure the user is in the `new-billing` and `assistant2` feature
flags

Release Notes:

- N/A
2025-05-02 01:47:09 +00:00
Cole Miller
d25da9728b Run additional checks from script/clippy if local (#29768)
Should cut down on the number of CI cycles if you're forgetful like I
am!

Release Notes:

- N/A
2025-05-02 01:26:12 +00:00
Cole Miller
e1e3f2e423 Improve handling of remote-tracking branches in the picker (#29744)
Release Notes:

- Changed the git branch picker to make remote-tracking branches less
prominent

---------

Co-authored-by: Anthony Eid <hello@anthonyeid.me>
2025-05-01 21:24:26 -04:00
Finn Evers
92b9ecd7d2 agent: Do not render unnecessary lines in edit file tool card (#29766)
This PR prevents any unnecessary lines from being rendered in the edit
file tool card in the case of small diffs.

I think this (hopefully) addresses the last remaining task from
https://github.com/zed-industries/zed/pull/29448.

| `main` | This PR |
| --- | --- |
| <img width="634" alt="main"
src="https://github.com/user-attachments/assets/7c06394e-957a-4d36-a484-5974687041e9"
/> | <img width="634" alt="PR"
src="https://github.com/user-attachments/assets/84206d5a-a93a-4a42-99ca-7cdebb0d91bb"
/> |

(The last empty line in the second image is an empty line present in the
file itself)

---

n the second commit I also preemtively disabled vertical overscrolling
for full mode editors which are sized by content. This is basically the
same fix as in https://github.com/zed-industries/zed/pull/28471.
Strictly speaking, this is not needed for the fix here, but I thought it
might be nice to have for the future to prevent any issues from occuring
due to overscroll.

Release Notes:

- agent: Improved rendering of small diffs for the edit file tool card.
2025-05-01 20:40:12 -03:00
Marshall Bowers
758d260cec collab: Add ability to initiate a checkout session for the Zed Free plan (#29767)
This PR adds the ability to initiate a checkout session for the Zed Free
plan.

Release Notes:

- N/A
2025-05-01 23:35:23 +00:00
Danilo Leal
8d4d3badf3 agent: Add design adjustments to MCP config flow (#29765)
Mostly somewhat small UI tweaks around the MCP extension config flow and
the settings section.

Release Notes:

- N/A
2025-05-01 19:29:59 -03:00
Marshall Bowers
7c23d13773 agent: Render the max mode toggle using a muted color (#29763)
This PR updates the max mode toggle to use the muted color.

This makes it fit in more with the rest of the controls.

<img width="243" alt="Screenshot 2025-05-01 at 5 24 01 PM"
src="https://github.com/user-attachments/assets/57267d29-3c7b-4ea9-b6b9-81c42f6b7e1c"
/>

Release Notes:

- agent: Adjusted the color of the max mode toggle.
2025-05-01 21:40:10 +00:00
Richard Feldman
ad87c545c7 Make context pills clickable while editing (#29740)
Release Notes:

- Fixed a bug where clicking context pills switched into the "editing
message" state instead of clicking the pill.

Co-authored-by: Michael <michael@zed.dev>
Co-authored-by: Ben <ben@zed.dev>
2025-05-01 20:28:54 +00:00
Richard Feldman
23fbab15ee Manual no tool calls (#29745)
Now instead of the model hallucinating tool calls, we get requests for
more context:

<img width="620" alt="Screenshot 2025-05-01 at 12 45 49 PM"
src="https://github.com/user-attachments/assets/847d5c14-82f6-4234-b85a-8cd2bc7ab11d"
/>

It still knows how to answer general questions:
<img width="624" alt="Screenshot 2025-05-01 at 12 47 44 PM"
src="https://github.com/user-attachments/assets/43ab0fc3-4cc8-452f-b26b-474b5d31919f"
/>

Release Notes:

- Fixed the model still trying to do tool calls when no tools selected
(e.g. in `Manual` profile).

---------

Co-authored-by: Ben <ben@zed.dev>
Co-authored-by: Michael <michael@zed.dev>
2025-05-01 16:11:13 -04:00
Richard Feldman
d7e181576e Respect cursor_pointer when a ButtonLike is disabled (#29737)
This is desirable for when we want to use a `ButtonLike` to show a
tooltip over an icon, and we don't want it to show the "not allowed"
cursor on hover.

Release Notes:

- N/A
2025-05-01 15:34:40 -04:00
Eva Pace
9788aff4b1 Fix license symlinks (#29758)
Closes #29527

It looks funny in the diff, but the symlinks are indeed correct:

-
https://github.com/evaporei/zed/blob/fix/license-symlinks/crates/askpass/LICENSE-GPL
-
https://github.com/evaporei/zed/blob/fix/license-symlinks/crates/ui_macros/LICENSE-GPL

I did check all ~170 crates, these were the only inconsistent ones.

Release Notes:

- N/A
2025-05-01 19:24:14 +00:00
Kirill Bulatov
2a319efade Add editor::GoToParentModule for rust-analyzer backed projects (#29755)
Support rust-analyzer's "go to parent module" action


https://rust-analyzer.github.io/book/contributing/lsp-extensions.html#parent-module

Release Notes:

- Added `editor::GoToParentModule` for rust-analyzer backed projects

---------

Co-authored-by: Julia Ryan <juliaryan3.14@gmail.com>
2025-05-01 18:28:05 +00:00
Jonathan LEI
50ec26c163 Fix user rules ignored by agent (#29754)
Closes #29753

The template contains an error: `has_default_user_rules` is always
undefined and should be `has_user_rules` instead.

Release Notes:

- Fixed default user rules ignored during prompt building.
2025-05-01 18:22:48 +00:00
Danilo Leal
39dd133b1c agent: Remove unused agent: chat mode command palette action (#29741)
We weren't using this one anymore. We used to use it for the switch that
toggled tools on, which doesn't exist anymore.

Release Notes:

- N/A

---------

Co-authored-by: Joseph T. Lyons <josephtlyons@gmail.com>
2025-05-01 15:09:14 -03:00
Bennet Bo Fenner
24eb039752 context servers: Show configuration modal when extension is installed (#29309)
WIP

Release Notes:

- N/A

---------

Co-authored-by: Danilo Leal <67129314+danilo-leal@users.noreply.github.com>
Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Marshall Bowers <git@maxdeviant.com>
Co-authored-by: Cole Miller <m@cole-miller.net>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
2025-05-01 20:02:14 +02:00
Peter Tripp
bffa53d706 docs: Reorder macOS development documentation (#29751)
Release Notes:

- N/A
2025-05-01 17:34:17 +00:00
Bennet Bo Fenner
0e5e8f9f8d Allow MIT-0 license in checks (#29748)
Part of #29309

The license is on par with other licenses in the list:
https://github.com/aws/mit-0

Release Notes:

- N/A
2025-05-01 17:30:16 +00:00
Danilo Leal
96d785cb45 git: Improve co-author button (#29742)
This PR changes the tooltip label to say "Remove" when you have the
button toggled on and collaborators in the list.

Release Notes:

- N/A

Co-authored-by: Joseph T. Lyons <josephtlyons@gmail.com>
2025-05-01 14:12:52 -03:00
Marshall Bowers
57610c9935 collab: Add billing thresholds to request overage subscription items (#29738)
This PR adds billing thresholds of the unit equivalent of $20 for model
request overages.

Release Notes:

- N/A
2025-05-01 16:10:06 +00:00
Marshall Bowers
5bf1b4f0a8 collab: Add use_new_billing to LlmTokenClaims (#29739)
This PR adds a `use_new_billing` field to the LLM token claims, based on
the `new-billing` feature flag.

Release Notes:

- N/A
2025-05-01 15:43:53 +00:00
Antonio Scandurra
f891dfb358 Introduce a new StreamingEditFileTool (#29733)
This pull request introduces a new tool for streaming edits. The
short-term goal is for this tool to replace the existing `EditFileTool`,
but we want to get this out the door as soon as possible so that we can
start testing it.

`StreamingEditFileTool` is mutually exclusive with `EditFileTool`. It
will be enabled by default for anyone who has the `agent-stream-edits`
feature flag, as well as people that set `assistant.stream_edits` to
`true` in their settings.

### Implementation

Streaming is achieved by requesting a completion while the `edit_file`
tool gets called. We invoke the model by taking the existing
conversation with the agent and appending a prompt specifically tailored
for editing. In that prompt, we ask the model to produce a stream of
`<old_text>`/`<new_text>` tags. As the model streams text in, we
incrementally parse it and start editing as soon as we can.

### Evals

Note that, as part of this pull request, I also defined some new evals
that I used to drive the behavior of the recursive LLM call. To run
them, use this command:

```bash
cargo test --package=assistant_tools --features eval -- eval_extract_handle_command_output
```

Or comment out the `#[cfg_attr(not(feature = "eval"), ignore)]` macro.

I recommend running them one at a time, because right now we don't
really have a way of orchestrating of all these evals. I think we should
invest into that effort once the new agent panel goes live.

Release Notes:

- N/A

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
2025-05-01 17:37:43 +02:00
Ben Kunkle
e3a2d52472 zlog: Fall back to printing module path instead of *unknown* or just crate name (#29691)
Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...
2025-05-01 10:59:51 -04:00
Danilo Leal
122af4fd53 agent: Show nav dropdown close button only on hover (#29732) 2025-05-01 11:21:57 -03:00
Kirill Bulatov
e07ffe7cf1 Allow to fetch cargo diagnostics separately (#29706)
Adjusts the way `cargo` and `rust-analyzer` diagnostics are fetched into
Zed.

Nothing is changed for defaults: in this mode, Zed does nothing but
reports file updates, which trigger rust-analyzers'
mechanisms:

* generating internal diagnostics, which it is able to produce on the
fly, without blocking cargo lock.
Unfortunately, there are not that many diagnostics in r-a, and some of
them have false-positives compared to rustc ones

* running `cargo check --workspace --all-targets` on each file save,
taking the cargo lock
For large projects like Zed, this might take a while, reducing the
ability to choose how to work with the project: e.g. it's impossible to
save multiple times without long diagnostics refreshes (may happen
automatically on e.g. focus loss), save the project and run it instantly
without waiting for cargo check to finish, etc.

In addition, it's relatively tricky to reconfigure r-a to run a
different command, with different arguments and maybe different env
vars: that would require a language server restart (and a large project
reindex) and fiddling with multiple JSON fields.

The new mode aims to separate out cargo diagnostics into its own loop so
that all Zed diagnostics features are supported still.


For that, an extra mode was introduced:

```jsonc
"rust": {
  // When enabled, Zed runs `cargo check --message-format=json`-based commands and
  // collect cargo diagnostics instead of rust-analyzer.
  "fetch_cargo_diagnostics": false,
  // A command override for fetching the cargo diagnostics.
  // First argument is the command, followed by the arguments.
  "diagnostics_fetch_command": [
    "cargo",
    "check",
    "--quiet",
    "--workspace",
    "--message-format=json",
    "--all-targets",
    "--keep-going"
  ],
  // Extra environment variables to pass to the diagnostics fetch command.
  "env": {}
}
```

which calls to cargo, parses its output and mixes in with the existing
diagnostics:




https://github.com/user-attachments/assets/e986f955-b452-4995-8aac-3049683dd22c




Release Notes:

- Added a way to get diagnostics from cargo and rust-analyzer without
mutually locking each other
- Added `ctrl-r` binding to refresh diagnostics in the project
diagnostics editor context
2025-05-01 11:25:52 +03:00
Finn Evers
5e4be013af zed: Fix application menu capitalization (#29722)
This PR is a quick follow-up to #29717 to ensure that the action within
the app menu has the same capitalization as in the context menu.

Release Notes:

- N/A
2025-05-01 08:19:02 +00:00
Aaron Feickert
f055dca592 editor: Fix context menu capitalization (#29717)
Fixes context menu capitalization for consistency.

Release Notes:

- N/A
2025-05-01 09:58:08 +03:00
Richard Feldman
5872276511 Re-enable open tool (#29707)
Release Notes:

- Added `open` tool for opening files or URLs.
2025-04-30 22:33:52 -04:00
Bennet Bo Fenner
1bf9e15f26 agent: Allow adding/removing context when editing existing messages (#29698)
Release Notes:

- agent: Support adding/removing context when editing existing message

---------

Co-authored-by: Cole Miller <m@cole-miller.net>
Co-authored-by: Cole Miller <cole@zed.dev>
2025-05-01 01:39:34 +00:00
Marshall Bowers
f046d70625 collab: Look up Stripe prices with lookup keys (#29715)
This PR makes it so we look up Stripe prices via lookup keys instead of
passing in the price IDs as environment variables.

Release Notes:

- N/A
2025-05-01 00:26:31 +00:00
Richard Feldman
afeb3d4fd9 Make eval more resilient to bad input from LLM (#29703)
I saw a slice panic (for begin > end) in a debug build of the eval. This
should just be a failed assertion, not a panic that takes out the whole
eval run!

Release Notes:

- N/A
2025-04-30 18:13:45 -04:00
Richard Feldman
92dd6b67c7 Fix potential subtraction overflow (#29697)
I saw this come up in an eval, where the LLM provided a start line of 0.

Release Notes:

- N/A
2025-04-30 18:13:37 -04:00
Cole Miller
38ede4bae3 Fix parsing of author name in git show output (#29704)
Closes #ISSUE

Release Notes:

- Fixed a bug causing incorrect formatting of git commit tooltips
2025-04-30 20:54:53 +00:00
Ben Kunkle
fc920bf63d Improve behavior around word-based brackets in bash (#29700)
Closes #28414

Makes it so that `do`, `then`, `done`, `else`, etc are treated as
brackets in bash. They are not auto-closed *yet* as that requires
additional work to function properly, however they can now be toggled
between using `%` in vim. Additionally, newlines are inserted like they
are with regular brackets (`{}()[]""''`) when hitting enter between
them.

While `if <-> fi` `while/for <-> done` and `case <-> esac` are the
*logical* matching pairs, I've opted to instead match between `then <->
else/elif/fi` `do <-> done` and `in <-> esac` as these are the pairs
that delimit the sub-scope, and are more similar to the `{}` style
bracket pairs than `if <-> }` in a c-like syntax. This does cause some
wierd behavior with `else` in `if` expressions as it matches both with
the previous `then` as well as the following `fi`, so in this case

```bash
if true; then
   foo
else
   bar
f|i
```

after hitting `%` twice times (where cursor is `|`), the cursor will end
up on the `then` instead of back on the `fi` as hitting `%` on the else
will *always* navigate up to the `then`

Release Notes:

- vim: Improved behavior around word-based delimiters in bash (`do <->
done`, `then <-> fi`, etc) so they can be toggled between using `%`
2025-04-30 19:57:29 +00:00
Richard Feldman
04c68dc0cf Make the default repetitions be 8, and concurrency 4 (#29576)
This is based on having observed that there is a lot of variation
between runs on `n=1` and `n=3`.

* With `n=8` two runs on the same branch give answers that seem close
enough to be reasonably consistent.
* With higher concurrency, trying to run this many repetitions seems to
lead language servers to time out a lot, causing evals to fail.

Release Notes:

- N/A
2025-04-30 15:21:19 -04:00
Marshall Bowers
399eced884 collab: Return current usage by model from GET /billing/usage (#29693)
This PR updates the `GET /billing/usage` endpoint to return the number
of requests made to each model and mode.

Release Notes:

- N/A
2025-04-30 19:06:39 +00:00
Richard Feldman
50f705e779 Use outline (#29687)
## Before

![Screenshot 2025-04-30 at 10 56
36 AM](https://github.com/user-attachments/assets/3a435f4c-ad45-4f26-a847-2d5c9d03648e)

## After

![Screenshot 2025-04-30 at 10 55
27 AM](https://github.com/user-attachments/assets/cc3a8144-b6fe-4a15-8a47-b2487ce4f66e)

Release Notes:

- Context picker and `@`-mentions now work with very large files.
2025-04-30 18:00:00 +00:00
Ben Kunkle
8173534ad5 zed: Reinstate default file_scan_exclusions in Zed repo project settings (#29690)
Closes #ISSUE

Re-adds default `file_scan_exclusions` to [project
settings](84e4891d54/.zed/settings.json)
that were overridden in #29106

Release Notes:

- N/A *or* Added/Fixed/Improved ...
2025-04-30 17:50:56 +00:00
Nate Butler
8c03934b26 welcome: Theme preview tile (#29689)
![CleanShot 2025-04-30 at 13 26
44@2x](https://github.com/user-attachments/assets/f68fefe2-84a1-48b7-b9a2-47c2547cd06b)


- Adds the ThemePreviewTile component, used for upcoming onboarding UI
- Adds the CornerSolver utility for resolving correct nested corner
radii

Release Notes:

- N/A
2025-04-30 17:46:11 +00:00
Patrick
84e4891d54 file_finder: Add skip_focus_for_active_in_search setting (#27624)
Closes #27073

Currently, when searching for a file with Ctrl+P, and the first file
found is the active one, file_finder skips focus to the second file
automatically. This PR adds a setting to disable this and make the first
file always the focused one.

Default setting is still skipping the active file.

Release Notes: 

- Added the `skip_focus_for_active_in_search` setting for the file
finder, which allows turning off the default behavior of skipping focus
on the active file while searching in the file finder.

---------

Co-authored-by: Smit Barmase <heysmitbarmase@gmail.com>
2025-04-30 22:58:33 +05:30
Ben Kunkle
d03d8ccec1 python: Fix identification of runnable tests within decorated test classes (#29688)
Closes #29486

Release Notes:

- python: Fixed identification of runnable test functions within
decorated pytest classes
2025-04-30 17:26:30 +00:00
Joseph T. Lyons
4d934f2884 Bump Zed to v0.186 (#29680)
Release Notes:

-N/A
2025-04-30 12:52:25 -04:00
Smit Barmase
e697cf9747 editor: Fix edit range for linked edits on do completion (#29650)
Closes #29544

Fixes an issue where accepting an HTML completion would correctly edit
the start tag but incorrectly update the end tag due to incorrect linked
edit ranges.

I want to handle multi cursor case (as it barely works now), but seems
like this should go first. As, it might need whole `do_completions`
overhaul.

Todo:
- [x] Tests for completion aceept on linked edits

Before:


https://github.com/user-attachments/assets/917f8d2a-4a0f-46e8-a004-675fde55fe3d

After:


https://github.com/user-attachments/assets/84b760b6-a5b9-45c4-85d8-b5dccf97775f

Release Notes:

- Fixes an issue where accepting an HTML completion would correctly edit
the start tag but incorrectly update the end tag.
2025-04-30 21:44:20 +05:30
Cole Miller
e4e455ae6b Fix duplicate thread entries in agent navigation menu (#29672)
Closes #ISSUE

Release Notes:

- N/A
2025-04-30 11:26:04 -04:00
Cole Miller
ff4900c942 Mitigate panic in merge conflict resolution (#29678)
We have a report of a panic when indexing into
`BufferConflicts.block_ids` using the `old_range` from the
`ConflictsUpdated` event, indicating that the `block_ids` array can get
out of sync with the underlying `ConflictSet`. This PR adds a mitigation
so that we won't panic in this situation, as a stopgap until the bug can
be reproduced in a test and fixed at the root.

Release Notes:

- N/A
2025-04-30 11:25:26 -04:00
182 changed files with 58366 additions and 1539 deletions

View File

@@ -69,7 +69,7 @@ jobs:
run: cargo build --package=eval
- name: Run eval
run: cargo run --package=eval -- --repetitions=3 --concurrency=1
run: cargo run --package=eval -- --repetitions=8 --concurrency=1
# Even the Linux runner is not stateful, in theory there is no need to do this cleanup.
# But, to avoid potential issues in the future if we choose to use a stateful Linux runner and forget to add code

View File

@@ -46,5 +46,17 @@
"formatter": "auto",
"remove_trailing_whitespace_on_save": true,
"ensure_final_newline_on_save": true,
"file_scan_exclusions": ["crates/eval/worktrees/", "crates/eval/repos/"]
"file_scan_exclusions": [
"crates/eval/worktrees/",
"crates/eval/repos/",
"**/.git",
"**/.svn",
"**/.hg",
"**/.jj",
"**/CVS",
"**/.DS_Store",
"**/Thumbs.db",
"**/.classpath",
"**/.settings"
]
}

178
Cargo.lock generated
View File

@@ -68,6 +68,7 @@ dependencies = [
"convert_case 0.8.0",
"db",
"editor",
"extension",
"feature_flags",
"file_icons",
"fs",
@@ -81,6 +82,7 @@ dependencies = [
"indexmap",
"indoc",
"itertools 0.14.0",
"jsonschema",
"language",
"language_model",
"language_model_selector",
@@ -90,6 +92,7 @@ dependencies = [
"markdown",
"menu",
"multi_buffer",
"notifications",
"ordered-float 2.10.1",
"parking_lot",
"paths",
@@ -106,6 +109,7 @@ dependencies = [
"schemars",
"serde",
"serde_json",
"serde_json_lenient",
"settings",
"smallvec",
"smol",
@@ -148,7 +152,9 @@ checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011"
dependencies = [
"cfg-if",
"const-random",
"getrandom 0.2.15",
"once_cell",
"serde",
"version_check",
"zerocopy 0.7.35",
]
@@ -690,6 +696,7 @@ dependencies = [
"pretty_assertions",
"project",
"rand 0.8.5",
"regex",
"serde",
"serde_json",
"settings",
@@ -703,7 +710,9 @@ dependencies = [
name = "assistant_tools"
version = "0.1.0"
dependencies = [
"aho-corasick",
"anyhow",
"assistant_settings",
"assistant_tool",
"buffer_diff",
"chrono",
@@ -711,26 +720,38 @@ dependencies = [
"clock",
"collections",
"component",
"derive_more",
"editor",
"feature_flags",
"fs",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"handlebars 4.5.0",
"html_to_markdown",
"http_client",
"indoc",
"itertools 0.14.0",
"language",
"language_model",
"language_models",
"linkme",
"open",
"pretty_assertions",
"project",
"rand 0.8.5",
"regex",
"reqwest_client",
"rust-embed",
"schemars",
"serde",
"serde_json",
"settings",
"smallvec",
"streaming_diff",
"strsim",
"task",
"tempfile",
"terminal",
"terminal_view",
"tree-sitter-rust",
@@ -2171,6 +2192,12 @@ dependencies = [
"piper",
]
[[package]]
name = "borrow-or-share"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3eeab4423108c5d7c744f4d234de88d18d636100093ae04caf4825134b9c3a32"
[[package]]
name = "borsh"
version = "1.5.7"
@@ -2286,6 +2313,12 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "bytecount"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce"
[[package]]
name = "bytemuck"
version = "1.22.0"
@@ -3206,7 +3239,9 @@ dependencies = [
name = "component_preview"
version = "0.1.0"
dependencies = [
"agent",
"anyhow",
"assistant_tool",
"client",
"collections",
"component",
@@ -3216,6 +3251,7 @@ dependencies = [
"log",
"notifications",
"project",
"prompt_store",
"serde",
"ui",
"ui_input",
@@ -4369,6 +4405,7 @@ dependencies = [
"ctor",
"editor",
"env_logger 0.11.8",
"futures 0.3.31",
"gpui",
"indoc",
"language",
@@ -4764,6 +4801,15 @@ dependencies = [
"zeroize",
]
[[package]]
name = "email_address"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449"
dependencies = [
"serde",
]
[[package]]
name = "embed-resource"
version = "3.0.2"
@@ -4997,6 +5043,7 @@ dependencies = [
"node_runtime",
"pathdiff",
"paths",
"pretty_assertions",
"project",
"prompt_store",
"regex",
@@ -5178,6 +5225,7 @@ dependencies = [
"collections",
"db",
"editor",
"extension",
"extension_host",
"fs",
"fuzzy",
@@ -5410,6 +5458,17 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8bf7cc16383c4b8d58b9905a8509f02926ce3058053c056376248d958c9df1e8"
[[package]]
name = "fluent-uri"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1918b65d96df47d3591bed19c5cca17e3fa5d0707318e4b5ef2eae01764df7e5"
dependencies = [
"borrow-or-share",
"ref-cast",
"serde",
]
[[package]]
name = "flume"
version = "0.11.1"
@@ -5564,6 +5623,16 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "fraction"
version = "0.15.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f158e3ff0a1b334408dc9fb811cd99b446986f4d8b741bb08f9df1604085ae7"
dependencies = [
"lazy_static",
"num",
]
[[package]]
name = "freetype-sys"
version = "0.20.1"
@@ -6361,6 +6430,7 @@ dependencies = [
"log",
"pest",
"pest_derive",
"rust-embed",
"serde",
"serde_json",
"thiserror 1.0.69",
@@ -7566,6 +7636,33 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "jsonschema"
version = "0.30.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1b46a0365a611fbf1d2143104dcf910aada96fafd295bab16c60b802bf6fa1d"
dependencies = [
"ahash 0.8.11",
"base64 0.22.1",
"bytecount",
"email_address",
"fancy-regex 0.14.0",
"fraction",
"idna",
"itoa",
"num-cmp",
"num-traits",
"once_cell",
"percent-encoding",
"referencing",
"regex",
"regex-syntax 0.8.5",
"reqwest 0.12.15 (registry+https://github.com/rust-lang/crates.io-index)",
"serde",
"serde_json",
"uuid-simd",
]
[[package]]
name = "jsonwebtoken"
version = "9.3.1"
@@ -8250,7 +8347,7 @@ dependencies = [
"prost 0.9.0",
"prost-build 0.9.0",
"prost-types 0.9.0",
"reqwest 0.12.15",
"reqwest 0.12.15 (git+https://github.com/zed-industries/reqwest.git?rev=951c770a32f1998d6e999cef3e59e0013e6c4415)",
"serde",
"workspace-hack",
]
@@ -9160,6 +9257,12 @@ dependencies = [
"zeroize",
]
[[package]]
name = "num-cmp"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63335b2e2c34fae2fb0aa2cecfd9f0832a1e24b3b32ecec612c3426d46dc8aaa"
[[package]]
name = "num-complex"
version = "0.4.6"
@@ -11753,6 +11856,20 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "referencing"
version = "0.30.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8eff4fa778b5c2a57e85c5f2fe3a709c52f0e60d23146e2151cbef5893f420e"
dependencies = [
"ahash 0.8.11",
"fluent-uri",
"once_cell",
"parking_lot",
"percent-encoding",
"serde_json",
]
[[package]]
name = "refineable"
version = "0.1.0"
@@ -12022,6 +12139,43 @@ dependencies = [
"winreg 0.50.0",
]
[[package]]
name = "reqwest"
version = "0.12.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb"
dependencies = [
"base64 0.22.1",
"bytes 1.10.1",
"futures-channel",
"futures-core",
"futures-util",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"hyper 1.6.0",
"hyper-util",
"ipnet",
"js-sys",
"log",
"mime",
"once_cell",
"percent-encoding",
"pin-project-lite",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper 1.0.2",
"tokio",
"tower 0.5.2",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
"windows-registry 0.4.0",
]
[[package]]
name = "reqwest"
version = "0.12.15"
@@ -12082,7 +12236,7 @@ dependencies = [
"http_client_tls",
"log",
"regex",
"reqwest 0.12.15",
"reqwest 0.12.15 (git+https://github.com/zed-industries/reqwest.git?rev=951c770a32f1998d6e999cef3e59e0013e6c4415)",
"serde",
"smol",
"tokio",
@@ -15933,6 +16087,17 @@ dependencies = [
"sha1_smol",
]
[[package]]
name = "uuid-simd"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23b082222b4f6619906941c17eb2297fff4c2fb96cb60164170522942a200bd8"
dependencies = [
"outref",
"uuid",
"vsimd",
]
[[package]]
name = "v_frame"
version = "0.3.8"
@@ -16878,18 +17043,22 @@ version = "0.1.0"
dependencies = [
"anyhow",
"client",
"component",
"db",
"documented",
"editor",
"fuzzy",
"gpui",
"install_cli",
"language",
"linkme",
"picker",
"project",
"schemars",
"serde",
"settings",
"telemetry",
"theme",
"ui",
"util",
"vim_mode_setting",
@@ -18022,12 +18191,14 @@ dependencies = [
"getrandom 0.2.15",
"getrandom 0.3.2",
"gimli",
"handlebars 4.5.0",
"hashbrown 0.14.5",
"hashbrown 0.15.2",
"heck 0.4.1",
"hmac",
"hyper 0.14.32",
"hyper-rustls 0.27.5",
"idna",
"indexmap",
"inout",
"itertools 0.12.1",
@@ -18051,6 +18222,7 @@ dependencies = [
"num-bigint-dig",
"num-integer",
"num-iter",
"num-rational",
"num-traits",
"object",
"once_cell",
@@ -18465,7 +18637,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.185.0"
version = "0.186.0"
dependencies = [
"activity_indicator",
"agent",

View File

@@ -435,6 +435,7 @@ dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "be69a0
dashmap = "6.0"
derive_more = "0.99.17"
dirs = "4.0"
documented = "0.9.1"
dotenv = "0.15.0"
ec4rs = "1.1"
emojis = "0.6.1"
@@ -461,6 +462,7 @@ indexmap = { version = "2.7.0", features = ["serde"] }
indoc = "2"
inventory = "0.3.19"
itertools = "0.14.0"
jsonschema = "0.30.0"
jsonwebtoken = "9.3"
jupyter-protocol = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed" ,rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
@@ -797,5 +799,6 @@ ignored = [
"serde",
"component",
"linkme",
"documented",
"workspace-hack",
]

1
assets/icons/hammer.svg Normal file
View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-hammer-icon lucide-hammer"><path d="m15 12-8.373 8.373a1 1 0 1 1-3-3L12 9"/><path d="m18 15 4-4"/><path d="m21.5 11.5-1.914-1.914A2 2 0 0 1 19 8.172V7l-2.26-2.26a6 6 0 0 0-4.202-1.756L9 2.96l.92.82A6.18 6.18 0 0 1 12 8.4V10l2 2h1.172a2 2 0 0 1 1.414.586L18.5 14.5"/></svg>

After

Width:  |  Height:  |  Size: 475 B

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-user-round-check-icon lucide-user-round-check"><path d="M2 21a8 8 0 0 1 13.292-6"/><circle cx="10" cy="8" r="5"/><path d="m16 19 2 2 4-4"/></svg>

After

Width:  |  Height:  |  Size: 348 B

View File

@@ -248,7 +248,6 @@
"ctrl-shift-o": "agent::ToggleNavigationMenu",
"ctrl-shift-i": "agent::ToggleOptionsMenu",
"shift-escape": "agent::ExpandMessageEditor",
"ctrl-e": "agent::ChatMode",
"ctrl-alt-e": "agent::RemoveAllContext"
}
},
@@ -962,5 +961,20 @@
"bindings": {
"escape": "menu::Cancel"
}
},
{
"context": "ConfigureContextServerModal > Editor",
"bindings": {
"escape": "menu::Cancel",
"enter": "editor::Newline",
"ctrl-enter": "menu::Confirm"
}
},
{
"context": "Diagnostics",
"use_key_equivalents": true,
"bindings": {
"ctrl-r": "diagnostics::ToggleDiagnosticsRefresh"
}
}
]

View File

@@ -293,7 +293,6 @@
"cmd-shift-o": "agent::ToggleNavigationMenu",
"cmd-shift-i": "agent::ToggleOptionsMenu",
"shift-escape": "agent::ExpandMessageEditor",
"cmd-e": "agent::ChatMode",
"cmd-alt-e": "agent::RemoveAllContext"
}
},
@@ -1068,5 +1067,21 @@
"bindings": {
"escape": "menu::Cancel"
}
},
{
"context": "ConfigureContextServerModal > Editor",
"use_key_equivalents": true,
"bindings": {
"escape": "menu::Cancel",
"enter": "editor::Newline",
"cmd-enter": "menu::Confirm"
}
},
{
"context": "Diagnostics",
"use_key_equivalents": true,
"bindings": {
"ctrl-r": "diagnostics::ToggleDiagnosticsRefresh"
}
}
]

View File

@@ -3,11 +3,12 @@ You are a highly skilled software engineer with extensive knowledge in many prog
## Communication
1. Be conversational but professional.
2. Refer to the USER in the second person and yourself in the first person.
2. Refer to the user in the second person and yourself in the first person.
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names.
4. NEVER lie or make things up.
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
{{#if has_tools}}
## Tool Use
1. Make sure to adhere to the tools schema.
@@ -22,6 +23,7 @@ You are a highly skilled software engineer with extensive knowledge in many prog
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
{{! TODO: If there are files, we should mention it but otherwise omit that fact }}
{{#if has_tools}}
If appropriate, use tool calls to explore the current project, which contains the following root directories:
{{#each worktrees}}
@@ -36,6 +38,14 @@ If appropriate, use tool calls to explore the current project, which contains th
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file.
{{/if}}
{{/if}}
{{else}}
You are being tasked with providing a response, but you have no ability to use tools or to read or write any aspect of the user's system (other than any context the user might have provided to you).
As such, if you need the user to perform any actions for you, you must request them explicitly. Bias towards giving a response to the best of your ability, and then making requests for the user to take action (e.g. to give you more context) only optionally.
The one exception to this is if the user references something you don't know about - for example, the name of a source code file, function, type, or other piece of code that you have no awareness of. In this case, you MUST NOT MAKE SOMETHING UP, or assume you know what that thing is or how it works. Instead, you must ask the user for clarification rather than giving a response.
{{/if}}
## Code Block Formatting
@@ -111,6 +121,8 @@ In Markdown, hash marks signify headings. For example:
```
</bad_example_do_not_do_this>
This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks.
{{#if has_tools}}
## Fixing Diagnostics
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.
@@ -124,10 +136,11 @@ Otherwise, follow debugging best practices:
2. Add descriptive logging statements and error messages to track variable and code state.
3. Add test functions and statements to isolate the problem.
{{/if}}
## Calling External APIs
1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission.
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data.
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file(s). If no such file exists or if the package is not present, use the latest version that is in your training data.
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
## System Information
@@ -135,10 +148,10 @@ Otherwise, follow debugging best practices:
Operating System: {{os}}
Default Shell: {{shell}}
{{#if (or has_rules has_default_user_rules)}}
{{#if (or has_rules has_user_rules)}}
## User's Custom Instructions
The following additional instructions are provided by the user, and should be followed to the best of your ability without interfering with the tool use guidelines.
The following additional instructions are provided by the user, and should be followed to the best of your ability{{#if has_tools}} without interfering with the tool use guidelines{{/if}}.
{{#if has_rules}}
There are project rules that apply to these root directories:

View File

@@ -657,6 +657,8 @@
},
// When enabled, the agent can run potentially destructive actions without asking for your confirmation.
"always_allow_tool_actions": false,
// When enabled, the agent will stream edits.
"stream_edits": false,
"default_profile": "write",
"profiles": {
"ask": {
@@ -671,6 +673,7 @@
"now": true,
"find_path": true,
"read_file": true,
"open": true,
"grep": true,
"thinking": true,
"web_search": true
@@ -834,7 +837,20 @@
// "modal_max_width": "full"
//
// Default: small
"modal_max_width": "small"
"modal_max_width": "small",
// Determines whether the file finder should skip focus for the active file in search results.
// There are 2 possible values:
//
// 1. true: When searching for files, if the currently active file appears as the first result,
// auto-focus will skip it and focus the second result instead.
// "skip_focus_for_active_in_search": true
//
// 2. false: When searching for files, the first result will always receive focus,
// even if it's the currently active file.
// "skip_focus_for_active_in_search": false
//
// Default: true
"skip_focus_for_active_in_search": true
},
// Whether or not to remove any trailing whitespace from lines of a buffer
// before saving it.
@@ -917,6 +933,11 @@
// The minimum severity of the diagnostics to show inline.
// Shows all diagnostics when not specified.
"max_severity": null
},
"rust": {
// When enabled, Zed disables rust-analyzer's check on save and starts to query
// Cargo diagnostics separately.
"fetch_cargo_diagnostics": false
}
},
// Files or globs of files that will be excluded by Zed entirely. They will be skipped during file

View File

@@ -35,6 +35,7 @@ context_server.workspace = true
convert_case.workspace = true
db.workspace = true
editor.workspace = true
extension.workspace = true
feature_flags.workspace = true
file_icons.workspace = true
fs.workspace = true
@@ -47,6 +48,7 @@ html_to_markdown.workspace = true
http_client.workspace = true
indexmap.workspace = true
itertools.workspace = true
jsonschema.workspace = true
language.workspace = true
language_model.workspace = true
language_model_selector.workspace = true
@@ -56,6 +58,7 @@ lsp.workspace = true
markdown.workspace = true
menu.workspace = true
multi_buffer.workspace = true
notifications.workspace = true
ordered-float.workspace = true
parking_lot.workspace = true
paths.workspace = true
@@ -71,6 +74,7 @@ rope.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
serde_json_lenient.workspace = true
settings.workspace = true
smallvec.workspace = true
smol.workspace = true

View File

@@ -1,5 +1,7 @@
use crate::context::{AgentContextHandle, RULES_ICON};
use crate::context_picker::MentionLink;
use crate::context_picker::{ContextPicker, MentionLink};
use crate::context_store::ContextStore;
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::thread::{
LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent,
ThreadFeedback,
@@ -14,14 +16,16 @@ use anyhow::Context as _;
use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
use assistant_tool::ToolUseStatus;
use collections::{HashMap, HashSet};
use editor::actions::{MoveUp, Paste};
use editor::scroll::Autoscroll;
use editor::{Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer};
use gpui::{
AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardItem,
DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter, Focusable, Hsla, ListAlignment,
ListState, MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription,
Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle,
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardEntry,
ClipboardItem, DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter, Focusable, Hsla,
ListAlignment, ListState, MouseButton, PlatformDisplay, ScrollHandle, Stateful,
StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation,
UnderlineStyle, WeakEntity, WindowHandle, linear_color_stop, linear_gradient, list, percentage,
pulsating_between,
};
use language::{Buffer, Language, LanguageRegistry};
use language_model::{
@@ -41,7 +45,8 @@ use std::time::Duration;
use text::ToPoint;
use theme::ThemeSettings;
use ui::{
Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*,
Disclosure, IconButton, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState, TextSize,
Tooltip, prelude::*,
};
use util::ResultExt as _;
use util::markdown::MarkdownCodeBlock;
@@ -49,6 +54,7 @@ use workspace::Workspace;
use zed_actions::assistant::OpenRulesLibrary;
pub struct ActiveThread {
context_store: Entity<ContextStore>,
language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>,
thread: Entity<Thread>,
@@ -61,7 +67,7 @@ pub struct ActiveThread {
hide_scrollbar_task: Option<Task<()>>,
rendered_messages_by_id: HashMap<MessageId, RenderedMessage>,
rendered_tool_uses: HashMap<LanguageModelToolUseId, RenderedToolUse>,
editing_message: Option<(MessageId, EditMessageState)>,
editing_message: Option<(MessageId, EditingMessageState)>,
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
expanded_thinking_segments: HashMap<(MessageId, usize), bool>,
expanded_code_blocks: HashMap<(MessageId, usize), bool>,
@@ -72,6 +78,7 @@ pub struct ActiveThread {
_subscriptions: Vec<Subscription>,
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
open_feedback_editors: HashMap<MessageId, Entity<Editor>>,
_load_edited_message_context_task: Option<Task<()>>,
}
struct RenderedMessage {
@@ -725,10 +732,12 @@ fn open_markdown_link(
}
}
struct EditMessageState {
struct EditingMessageState {
editor: Entity<Editor>,
context_strip: Entity<ContextStrip>,
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
last_estimated_token_count: Option<usize>,
_subscription: Subscription,
_subscriptions: [Subscription; 2],
_update_token_count_task: Option<Task<()>>,
}
@@ -736,6 +745,7 @@ impl ActiveThread {
pub fn new(
thread: Entity<Thread>,
thread_store: Entity<ThreadStore>,
context_store: Entity<ContextStore>,
language_registry: Arc<LanguageRegistry>,
workspace: WeakEntity<Workspace>,
window: &mut Window,
@@ -758,6 +768,7 @@ impl ActiveThread {
let mut this = Self {
language_registry,
thread_store,
context_store,
thread: thread.clone(),
workspace,
save_thread_task: None,
@@ -779,6 +790,7 @@ impl ActiveThread {
_subscriptions: subscriptions,
notification_subscriptions: HashMap::default(),
open_feedback_editors: HashMap::default(),
_load_edited_message_context_task: None,
};
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
@@ -1237,33 +1249,49 @@ impl ActiveThread {
return;
};
let buffer = cx.new(|cx| {
MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
});
let editor = cx.new(|cx| {
let mut editor = Editor::new(
editor::EditorMode::AutoHeight { max_lines: 8 },
buffer,
None,
window,
cx,
);
let editor = crate::message_editor::create_editor(
self.workspace.clone(),
self.context_store.downgrade(),
self.thread_store.downgrade(),
window,
cx,
);
editor.update(cx, |editor, cx| {
editor.set_text(message_text.clone(), window, cx);
editor.focus_handle(cx).focus(window);
editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
editor
});
let subscription = cx.subscribe(&editor, |this, _, event, cx| match event {
let buffer_edited_subscription = cx.subscribe(&editor, |this, _, event, cx| match event {
EditorEvent::BufferEdited => {
this.update_editing_message_token_count(true, cx);
}
_ => {}
});
let context_picker_menu_handle = PopoverMenuHandle::default();
let context_strip = cx.new(|cx| {
ContextStrip::new(
self.context_store.clone(),
self.workspace.clone(),
Some(self.thread_store.downgrade()),
context_picker_menu_handle.clone(),
SuggestContextKind::File,
window,
cx,
)
});
let context_strip_subscription =
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event);
self.editing_message = Some((
message_id,
EditMessageState {
EditingMessageState {
editor: editor.clone(),
context_strip,
context_picker_menu_handle,
last_estimated_token_count: None,
_subscription: subscription,
_subscriptions: [buffer_edited_subscription, context_strip_subscription],
_update_token_count_task: None,
},
));
@@ -1271,6 +1299,26 @@ impl ActiveThread {
cx.notify();
}
fn handle_context_strip_event(
&mut self,
_context_strip: &Entity<ContextStrip>,
event: &ContextStripEvent,
window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some((_, state)) = self.editing_message.as_ref() {
match event {
ContextStripEvent::PickerDismissed
| ContextStripEvent::BlurredEmpty
| ContextStripEvent::BlurredDown => {
let editor_focus_handle = state.editor.focus_handle(cx);
window.focus(&editor_focus_handle);
}
ContextStripEvent::BlurredUp => {}
}
}
}
fn update_editing_message_token_count(&mut self, debounce: bool, cx: &mut Context<Self>) {
let Some((message_id, state)) = self.editing_message.as_mut() else {
return;
@@ -1357,6 +1405,68 @@ impl ActiveThread {
}));
}
fn toggle_context_picker(
&mut self,
_: &crate::ToggleContextPicker,
window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some((_, state)) = self.editing_message.as_mut() {
let handle = state.context_picker_menu_handle.clone();
window.defer(cx, move |window, cx| {
handle.toggle(window, cx);
});
}
}
fn remove_all_context(
&mut self,
_: &crate::RemoveAllContext,
_window: &mut Window,
cx: &mut Context<Self>,
) {
self.context_store.update(cx, |store, _cx| store.clear());
cx.notify();
}
fn move_up(&mut self, _: &MoveUp, window: &mut Window, cx: &mut Context<Self>) {
if let Some((_, state)) = self.editing_message.as_mut() {
if state.context_picker_menu_handle.is_deployed() {
cx.propagate();
} else {
state.context_strip.focus_handle(cx).focus(window);
}
}
}
fn paste(&mut self, _: &Paste, _window: &mut Window, cx: &mut Context<Self>) {
let images = cx
.read_from_clipboard()
.map(|item| {
item.into_entries()
.filter_map(|entry| {
if let ClipboardEntry::Image(image) = entry {
Some(image)
} else {
None
}
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
if images.is_empty() {
return;
}
cx.stop_propagation();
self.context_store.update(cx, |store, cx| {
for image in images {
store.add_image_instance(Arc::new(image), cx);
}
});
}
fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
self.editing_message.take();
cx.notify();
@@ -1371,21 +1481,11 @@ impl ActiveThread {
let Some((message_id, state)) = self.editing_message.take() else {
return;
};
let edited_text = state.editor.read(cx).text(cx);
let thread_model = self.thread.update(cx, |thread, cx| {
thread.edit_message(
message_id,
Role::User,
vec![MessageSegment::Text(edited_text)],
cx,
);
for message_id in self.messages_after(message_id) {
thread.delete_message(*message_id, cx);
}
thread.get_or_init_configured_model(cx)
});
let Some(model) = thread_model else {
let Some(model) = self
.thread
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx))
else {
return;
};
@@ -1394,11 +1494,45 @@ impl ActiveThread {
return;
}
self.thread.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.send_to_model(model.model, Some(window.window_handle()), cx);
});
cx.notify();
let edited_text = state.editor.read(cx).text(cx);
let new_context = self
.context_store
.read(cx)
.new_context_for_thread(self.thread.read(cx), Some(message_id));
let project = self.thread.read(cx).project().clone();
let prompt_store = self.thread_store.read(cx).prompt_store().clone();
let load_context_task =
crate::context::load_context(new_context, &project, &prompt_store, cx);
self._load_edited_message_context_task =
Some(cx.spawn_in(window, async move |this, cx| {
let context = load_context_task.await;
let _ = this
.update_in(cx, |this, window, cx| {
this.thread.update(cx, |thread, cx| {
thread.edit_message(
message_id,
Role::User,
vec![MessageSegment::Text(edited_text)],
Some(context.loaded_context),
cx,
);
for message_id in this.messages_after(message_id) {
thread.delete_message(*message_id, cx);
}
});
this.thread.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.send_to_model(model.model, Some(window.window_handle()), cx);
});
this._load_edited_message_context_task = None;
cx.notify();
})
.log_err();
}));
}
fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
@@ -1519,6 +1653,53 @@ impl ActiveThread {
}
}
fn render_edit_message_editor(
&self,
state: &EditingMessageState,
window: &mut Window,
cx: &Context<Self>,
) -> impl IntoElement {
let settings = ThemeSettings::get_global(cx);
let font_size = TextSize::Small.rems(cx);
let line_height = font_size.to_pixels(window.rem_size()) * 1.75;
let colors = cx.theme().colors();
let text_style = TextStyle {
color: colors.text,
font_family: settings.buffer_font.family.clone(),
font_fallbacks: settings.buffer_font.fallbacks.clone(),
font_features: settings.buffer_font.features.clone(),
font_size: font_size.into(),
line_height: line_height.into(),
..Default::default()
};
v_flex()
.key_context("EditMessageEditor")
.on_action(cx.listener(Self::toggle_context_picker))
.on_action(cx.listener(Self::remove_all_context))
.on_action(cx.listener(Self::move_up))
.on_action(cx.listener(Self::cancel_editing_message))
.on_action(cx.listener(Self::confirm_editing_message))
.capture_action(cx.listener(Self::paste))
.min_h_6()
.flex_grow()
.w_full()
.gap_2()
.child(EditorElement::new(
&state.editor,
EditorStyle {
background: colors.editor_background,
local_player: cx.theme().players().local(),
text: text_style,
syntax: cx.theme().syntax().clone(),
..Default::default()
},
))
.child(state.context_strip.clone())
}
fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
let message_id = self.messages[ix];
let Some(message) = self.thread.read(cx).message(message_id) else {
@@ -1551,11 +1732,11 @@ impl ActiveThread {
let generating_label = (is_generating && is_last_message)
.then(|| AnimatedLabel::new("Generating").size(LabelSize::Small));
let edit_message_editor = self
let editing_message_state = self
.editing_message
.as_ref()
.filter(|(id, _)| *id == message_id)
.map(|(_, state)| state.editor.clone());
.map(|(_, state)| state);
let colors = cx.theme().colors();
let editor_bg_color = colors.editor_background;
@@ -1690,77 +1871,43 @@ impl ActiveThread {
let has_content = !message_is_empty || !added_context.is_empty();
let message_content = has_content.then(|| {
v_flex()
.w_full()
.gap_1()
.when(!message_is_empty, |parent| {
parent.child(
if let Some(edit_message_editor) = edit_message_editor.clone() {
let settings = ThemeSettings::get_global(cx);
let font_size = TextSize::Small.rems(cx);
let line_height = font_size.to_pixels(window.rem_size()) * 1.75;
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.buffer_font.family.clone(),
font_fallbacks: settings.buffer_font.fallbacks.clone(),
font_features: settings.buffer_font.features.clone(),
font_size: font_size.into(),
line_height: line_height.into(),
..Default::default()
};
div()
.key_context("EditMessageEditor")
.on_action(cx.listener(Self::cancel_editing_message))
.on_action(cx.listener(Self::confirm_editing_message))
.min_h_6()
.flex_grow()
.w_full()
.child(EditorElement::new(
&edit_message_editor,
EditorStyle {
background: colors.editor_background,
local_player: cx.theme().players().local(),
text: text_style,
syntax: cx.theme().syntax().clone(),
..Default::default()
},
))
.into_any()
} else {
div()
.min_h_6()
.child(self.render_message_content(
message_id,
rendered_message,
has_tool_uses,
workspace.clone(),
window,
cx,
))
.into_any()
},
)
})
.when(!added_context.is_empty(), |parent| {
parent.child(h_flex().flex_wrap().gap_1().children(
added_context.into_iter().map(|added_context| {
let context = added_context.handle.clone();
ContextPill::added(added_context, false, false, None).on_click(Rc::new(
cx.listener({
let workspace = workspace.clone();
move |_, _, window, cx| {
if let Some(workspace) = workspace.upgrade() {
open_context(&context, workspace, window, cx);
cx.notify();
if let Some(state) = editing_message_state.as_ref() {
self.render_edit_message_editor(state, window, cx)
.into_any_element()
} else {
v_flex()
.w_full()
.gap_1()
.when(!message_is_empty, |parent| {
parent.child(div().min_h_6().child(self.render_message_content(
message_id,
rendered_message,
has_tool_uses,
workspace.clone(),
window,
cx,
)))
})
.when(!added_context.is_empty(), |parent| {
parent.child(h_flex().flex_wrap().gap_1().children(
added_context.into_iter().map(|added_context| {
let context = added_context.handle.clone();
ContextPill::added(added_context, false, false, None).on_click(
Rc::new(cx.listener({
let workspace = workspace.clone();
move |_, _, window, cx| {
if let Some(workspace) = workspace.upgrade() {
open_context(&context, workspace, window, cx);
cx.notify();
}
}
}
}),
))
}),
))
})
})),
)
}),
))
})
.into_any_element()
}
});
let styled_message = match message.role {
@@ -1785,8 +1932,8 @@ impl ActiveThread {
.p_2p5()
.gap_1()
.children(message_content)
.when_some(edit_message_editor.clone(), |this, edit_editor| {
let edit_editor_clone = edit_editor.clone();
.when_some(editing_message_state, |this, state| {
let focus_handle = state.editor.focus_handle(cx).clone();
this.w_full().justify_between().child(
h_flex()
.gap_0p5()
@@ -1797,16 +1944,17 @@ impl ActiveThread {
)
.shape(ui::IconButtonShape::Square)
.icon_color(Color::Error)
.tooltip(move |window, cx| {
let focus_handle =
edit_editor_clone.focus_handle(cx);
Tooltip::for_action_in(
"Cancel Edit",
&menu::Cancel,
&focus_handle,
window,
cx,
)
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
"Cancel Edit",
&menu::Cancel,
&focus_handle,
window,
cx,
)
}
})
.on_click(cx.listener(Self::handle_cancel_click)),
)
@@ -1815,18 +1963,20 @@ impl ActiveThread {
"confirm-edit-message",
IconName::Check,
)
.disabled(edit_editor.read(cx).is_empty(cx))
.disabled(state.editor.read(cx).is_empty(cx))
.shape(ui::IconButtonShape::Square)
.icon_color(Color::Success)
.tooltip(move |window, cx| {
let focus_handle = edit_editor.focus_handle(cx);
Tooltip::for_action_in(
"Regenerate",
&menu::Confirm,
&focus_handle,
window,
cx,
)
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
"Regenerate",
&menu::Confirm,
&focus_handle,
window,
cx,
)
}
})
.on_click(
cx.listener(Self::handle_regenerate_click),
@@ -1835,7 +1985,7 @@ impl ActiveThread {
)
}),
)
.when(edit_message_editor.is_none(), |this| {
.when(editing_message_state.is_none(), |this| {
this.tooltip(Tooltip::text("Click To Edit"))
})
.on_click(cx.listener({

View File

@@ -6,6 +6,7 @@ mod assistant_panel;
mod buffer_codegen;
mod context;
mod context_picker;
mod context_server_configuration;
mod context_store;
mod context_strip;
mod history_store;
@@ -30,6 +31,7 @@ use command_palette_hooks::CommandPaletteFilter;
use feature_flags::{Assistant2FeatureFlag, FeatureFlagAppExt};
use fs::Fs;
use gpui::{App, actions, impl_actions};
use language::LanguageRegistry;
use prompt_store::PromptBuilder;
use schemars::JsonSchema;
use serde::Deserialize;
@@ -44,6 +46,8 @@ pub use crate::inline_assistant::InlineAssistant;
pub use crate::thread::{Message, MessageSegment, Thread, ThreadEvent};
pub use crate::thread_store::ThreadStore;
pub use agent_diff::{AgentDiff, AgentDiffToolbar};
pub use context_store::ContextStore;
pub use ui::{all_agent_previews, get_agent_preview};
actions!(
agent,
@@ -60,7 +64,6 @@ actions!(
AddContextServer,
RemoveSelectedThread,
Chat,
ChatMode,
CycleNextInlineAssist,
CyclePreviousInlineAssist,
FocusUp,
@@ -107,11 +110,13 @@ pub fn init(
fs: Arc<dyn Fs>,
client: Arc<Client>,
prompt_builder: Arc<PromptBuilder>,
language_registry: Arc<LanguageRegistry>,
cx: &mut App,
) {
AssistantSettings::register(cx);
thread_store::init(cx);
assistant_panel::init(cx);
context_server_configuration::init(language_registry, cx);
inline_assistant::init(
fs.clone(),

View File

@@ -1,16 +1,18 @@
mod add_context_server_modal;
mod configure_context_server_modal;
mod manage_profiles_modal;
mod tool_picker;
use std::sync::Arc;
use std::{sync::Arc, time::Duration};
use assistant_settings::AssistantSettings;
use assistant_tool::{ToolSource, ToolWorkingSet};
use collections::HashMap;
use context_server::manager::ContextServerManager;
use context_server::manager::{ContextServer, ContextServerManager, ContextServerStatus};
use fs::Fs;
use gpui::{
Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, ScrollHandle, Subscription,
Action, Animation, AnimationExt as _, AnyView, App, Entity, EventEmitter, FocusHandle,
Focusable, ScrollHandle, Subscription, pulsating_between,
};
use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
use settings::{Settings, update_settings_file};
@@ -22,6 +24,7 @@ use util::ResultExt as _;
use zed_actions::ExtensionCategoryFilter;
pub(crate) use add_context_server_modal::AddContextServerModal;
pub(crate) use configure_context_server_modal::ConfigureContextServerModal;
pub(crate) use manage_profiles_modal::ManageProfilesModal;
use crate::AddContextServer;
@@ -254,10 +257,12 @@ impl AssistantConfiguration {
)
}
fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
fn render_context_servers_section(
&mut self,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
let context_servers = self.context_server_manager.read(cx).all_servers().clone();
let tools_by_source = self.tools.read(cx).tools_by_source(cx);
let empty = Vec::new();
const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
@@ -272,136 +277,11 @@ impl AssistantConfiguration {
.child(Headline::new("Model Context Protocol (MCP) Servers"))
.child(Label::new(SUBHEADING).color(Color::Muted)),
)
.children(context_servers.into_iter().map(|context_server| {
let is_running = context_server.client().is_some();
let are_tools_expanded = self
.expanded_context_server_tools
.get(&context_server.id())
.copied()
.unwrap_or_default();
let tools = tools_by_source
.get(&ToolSource::ContextServer {
id: context_server.id().into(),
})
.unwrap_or_else(|| &empty);
let tool_count = tools.len();
v_flex()
.id(SharedString::from(context_server.id()))
.border_1()
.rounded_md()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().background.opacity(0.25))
.child(
h_flex()
.p_1()
.justify_between()
.when(are_tools_expanded && tool_count > 1, |element| {
element
.border_b_1()
.border_color(cx.theme().colors().border)
})
.child(
h_flex()
.gap_2()
.child(
Disclosure::new("tool-list-disclosure", are_tools_expanded)
.disabled(tool_count == 0)
.on_click(cx.listener({
let context_server_id = context_server.id();
move |this, _event, _window, _cx| {
let is_open = this
.expanded_context_server_tools
.entry(context_server_id.clone())
.or_insert(false);
*is_open = !*is_open;
}
})),
)
.child(Indicator::dot().color(if is_running {
Color::Success
} else {
Color::Error
}))
.child(Label::new(context_server.id()))
.child(
Label::new(format!("{tool_count} tools"))
.color(Color::Muted)
.size(LabelSize::Small),
),
)
.child(
Switch::new("context-server-switch", is_running.into())
.color(SwitchColor::Accent)
.on_click({
let context_server_manager =
self.context_server_manager.clone();
let context_server = context_server.clone();
move |state, _window, cx| match state {
ToggleState::Unselected
| ToggleState::Indeterminate => {
context_server_manager.update(cx, |this, cx| {
this.stop_server(context_server.clone(), cx)
.log_err();
});
}
ToggleState::Selected => {
cx.spawn({
let context_server_manager =
context_server_manager.clone();
let context_server = context_server.clone();
async move |cx| {
if let Some(start_server_task) =
context_server_manager
.update(cx, |this, cx| {
this.start_server(
context_server,
cx,
)
})
.log_err()
{
start_server_task.await.log_err();
}
}
})
.detach();
}
}
}),
),
)
.map(|parent| {
if !are_tools_expanded {
return parent;
}
parent.child(v_flex().py_1p5().px_1().gap_1().children(
tools.into_iter().enumerate().map(|(ix, tool)| {
h_flex()
.id(("tool-item", ix))
.px_1()
.gap_2()
.justify_between()
.hover(|style| style.bg(cx.theme().colors().element_hover))
.rounded_sm()
.child(
Label::new(tool.name())
.buffer_font(cx)
.size(LabelSize::Small),
)
.child(
Icon::new(IconName::Info)
.size(IconSize::Small)
.color(Color::Ignored),
)
.tooltip(Tooltip::text(tool.description()))
}),
))
})
}))
.children(
context_servers
.into_iter()
.map(|context_server| self.render_context_server(context_server, window, cx)),
)
.child(
h_flex()
.justify_between()
@@ -429,7 +309,7 @@ impl AssistantConfiguration {
.style(ButtonStyle::Filled)
.layer(ElevationIndex::ModalSurface)
.full_width()
.icon(IconName::DatabaseZap)
.icon(IconName::Hammer)
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.on_click(|_event, window, cx| {
@@ -447,10 +327,214 @@ impl AssistantConfiguration {
),
)
}
fn render_context_server(
&self,
context_server: Arc<ContextServer>,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl use<> + IntoElement {
let tools_by_source = self.tools.read(cx).tools_by_source(cx);
let server_status = self
.context_server_manager
.read(cx)
.status_for_server(&context_server.id());
let is_running = matches!(server_status, Some(ContextServerStatus::Running));
let error = if let Some(ContextServerStatus::Error(error)) = server_status.clone() {
Some(error)
} else {
None
};
let are_tools_expanded = self
.expanded_context_server_tools
.get(&context_server.id())
.copied()
.unwrap_or_default();
let tools = tools_by_source
.get(&ToolSource::ContextServer {
id: context_server.id().into(),
})
.map_or([].as_slice(), |tools| tools.as_slice());
let tool_count = tools.len();
let border_color = cx.theme().colors().border.opacity(0.6);
v_flex()
.id(SharedString::from(context_server.id()))
.border_1()
.rounded_md()
.border_color(border_color)
.bg(cx.theme().colors().background.opacity(0.2))
.overflow_hidden()
.child(
h_flex()
.p_1()
.justify_between()
.when(
error.is_some() || are_tools_expanded && tool_count > 1,
|element| element.border_b_1().border_color(border_color),
)
.child(
h_flex()
.gap_1p5()
.child(
Disclosure::new(
"tool-list-disclosure",
are_tools_expanded || error.is_some(),
)
.disabled(tool_count == 0)
.on_click(cx.listener({
let context_server_id = context_server.id();
move |this, _event, _window, _cx| {
let is_open = this
.expanded_context_server_tools
.entry(context_server_id.clone())
.or_insert(false);
*is_open = !*is_open;
}
})),
)
.child(match server_status {
Some(ContextServerStatus::Starting) => {
let color = Color::Success.color(cx);
Indicator::dot()
.color(Color::Success)
.with_animation(
SharedString::from(format!(
"{}-starting",
context_server.id(),
)),
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.4, 1.)),
move |this, delta| {
this.color(color.alpha(delta).into())
},
)
.into_any_element()
}
Some(ContextServerStatus::Running) => {
Indicator::dot().color(Color::Success).into_any_element()
}
Some(ContextServerStatus::Error(_)) => {
Indicator::dot().color(Color::Error).into_any_element()
}
None => Indicator::dot().color(Color::Muted).into_any_element(),
})
.child(Label::new(context_server.id()).ml_0p5())
.when(is_running, |this| {
this.child(
Label::new(if tool_count == 1 {
SharedString::from("1 tool")
} else {
SharedString::from(format!("{} tools", tool_count))
})
.color(Color::Muted)
.size(LabelSize::Small),
)
}),
)
.child(
Switch::new("context-server-switch", is_running.into())
.color(SwitchColor::Accent)
.on_click({
let context_server_manager = self.context_server_manager.clone();
let context_server = context_server.clone();
move |state, _window, cx| match state {
ToggleState::Unselected | ToggleState::Indeterminate => {
context_server_manager.update(cx, |this, cx| {
this.stop_server(context_server.clone(), cx).log_err();
});
}
ToggleState::Selected => {
cx.spawn({
let context_server_manager =
context_server_manager.clone();
let context_server = context_server.clone();
async move |cx| {
if let Some(start_server_task) =
context_server_manager
.update(cx, |this, cx| {
this.start_server(context_server, cx)
})
.log_err()
{
start_server_task.await.log_err();
}
}
})
.detach();
}
}
}),
),
)
.map(|parent| {
if let Some(error) = error {
return parent.child(
h_flex()
.p_2()
.gap_2()
.items_start()
.child(
h_flex()
.flex_none()
.h(window.line_height() / 1.6_f32)
.justify_center()
.child(
Icon::new(IconName::XCircle)
.size(IconSize::XSmall)
.color(Color::Error),
),
)
.child(
div().w_full().child(
Label::new(error)
.buffer_font(cx)
.color(Color::Muted)
.size(LabelSize::Small),
),
),
);
}
if !are_tools_expanded || tools.is_empty() {
return parent;
}
parent.child(v_flex().py_1p5().px_1().gap_1().children(
tools.into_iter().enumerate().map(|(ix, tool)| {
h_flex()
.id(("tool-item", ix))
.px_1()
.gap_2()
.justify_between()
.hover(|style| style.bg(cx.theme().colors().element_hover))
.rounded_sm()
.child(
Label::new(tool.name())
.buffer_font(cx)
.size(LabelSize::Small),
)
.child(
Icon::new(IconName::Info)
.size(IconSize::Small)
.color(Color::Ignored),
)
.tooltip(Tooltip::text(tool.description()))
}),
))
})
}
}
impl Render for AssistantConfiguration {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
.id("assistant-configuration")
.key_context("AgentConfiguration")
@@ -467,7 +551,7 @@ impl Render for AssistantConfiguration {
.overflow_y_scroll()
.child(self.render_command_permission(cx))
.child(Divider::horizontal().color(DividerColor::Border))
.child(self.render_context_servers_section(cx))
.child(self.render_context_servers_section(window, cx))
.child(Divider::horizontal().color(DividerColor::Border))
.child(self.render_provider_configuration_section(cx)),
)

View File

@@ -0,0 +1,443 @@
use std::{
sync::{Arc, Mutex},
time::Duration,
};
use anyhow::Context as _;
use context_server::manager::{ContextServerManager, ContextServerStatus};
use editor::{Editor, EditorElement, EditorStyle};
use extension::ContextServerConfiguration;
use gpui::{
Animation, AnimationExt, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Task,
TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, percentage,
};
use language::{Language, LanguageRegistry};
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
use notifications::status_toast::{StatusToast, ToastIcon};
use settings::{Settings as _, update_settings_file};
use theme::ThemeSettings;
use ui::{KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*};
use util::ResultExt;
use workspace::{ModalView, Workspace};
pub(crate) struct ConfigureContextServerModal {
workspace: WeakEntity<Workspace>,
context_servers_to_setup: Vec<ConfigureContextServer>,
context_server_manager: Entity<ContextServerManager>,
}
struct ConfigureContextServer {
id: Arc<str>,
installation_instructions: Entity<markdown::Markdown>,
settings_validator: Option<jsonschema::Validator>,
settings_editor: Entity<Editor>,
last_error: Option<SharedString>,
waiting_for_context_server: bool,
}
impl ConfigureContextServerModal {
pub fn new(
configurations: impl Iterator<Item = (Arc<str>, ContextServerConfiguration)>,
jsonc_language: Option<Arc<Language>>,
context_server_manager: Entity<ContextServerManager>,
language_registry: Arc<LanguageRegistry>,
workspace: WeakEntity<Workspace>,
window: &mut Window,
cx: &mut App,
) -> Option<Self> {
let context_servers_to_setup = configurations
.map(|(id, manifest)| {
let jsonc_language = jsonc_language.clone();
let settings_validator = jsonschema::validator_for(&manifest.settings_schema)
.context("Failed to load JSON schema for context server settings")
.log_err();
ConfigureContextServer {
id: id.clone(),
installation_instructions: cx.new(|cx| {
Markdown::new(
manifest.installation_instructions.clone().into(),
Some(language_registry.clone()),
None,
cx,
)
}),
settings_validator,
settings_editor: cx.new(|cx| {
let mut editor = Editor::auto_height(16, window, cx);
editor.set_text(manifest.default_settings.trim(), window, cx);
if let Some(buffer) = editor.buffer().read(cx).as_singleton() {
buffer.update(cx, |buffer, cx| buffer.set_language(jsonc_language, cx))
}
editor
}),
waiting_for_context_server: false,
last_error: None,
}
})
.collect::<Vec<_>>();
if context_servers_to_setup.is_empty() {
return None;
}
Some(Self {
workspace,
context_servers_to_setup,
context_server_manager,
})
}
}
impl ConfigureContextServerModal {
pub fn confirm(&mut self, cx: &mut Context<Self>) {
if self.context_servers_to_setup.is_empty() {
return;
}
let Some(workspace) = self.workspace.upgrade() else {
return;
};
let configuration = &mut self.context_servers_to_setup[0];
if configuration.waiting_for_context_server {
return;
}
let settings_value = match serde_json_lenient::from_str::<serde_json::Value>(
&configuration.settings_editor.read(cx).text(cx),
) {
Ok(value) => value,
Err(error) => {
configuration.last_error = Some(error.to_string().into());
cx.notify();
return;
}
};
if let Some(validator) = configuration.settings_validator.as_ref() {
if let Err(error) = validator.validate(&settings_value) {
configuration.last_error = Some(error.to_string().into());
cx.notify();
return;
}
}
let id = configuration.id.clone();
let settings_changed = context_server::ContextServerSettings::get_global(cx)
.context_servers
.get(&id)
.map_or(true, |config| {
config.settings.as_ref() != Some(&settings_value)
});
let is_running = self.context_server_manager.read(cx).status_for_server(&id)
== Some(ContextServerStatus::Running);
if !settings_changed && is_running {
self.complete_setup(id, cx);
return;
}
configuration.waiting_for_context_server = true;
let task = wait_for_context_server(&self.context_server_manager, id.clone(), cx);
cx.spawn({
let id = id.clone();
async move |this, cx| {
let result = task.await;
this.update(cx, |this, cx| match result {
Ok(_) => {
this.complete_setup(id, cx);
}
Err(err) => {
if let Some(configuration) = this.context_servers_to_setup.get_mut(0) {
configuration.last_error = Some(err.into());
configuration.waiting_for_context_server = false;
} else {
this.dismiss(cx);
}
cx.notify();
}
})
}
})
.detach();
// When we write the settings to the file, the context server will be restarted.
update_settings_file::<context_server::ContextServerSettings>(
workspace.read(cx).app_state().fs.clone(),
cx,
{
let id = id.clone();
|settings, _| {
if let Some(server_config) = settings.context_servers.get_mut(&id) {
server_config.settings = Some(settings_value);
} else {
settings.context_servers.insert(
id,
context_server::ServerConfig {
settings: Some(settings_value),
..Default::default()
},
);
}
}
},
);
}
fn complete_setup(&mut self, id: Arc<str>, cx: &mut Context<Self>) {
self.context_servers_to_setup.remove(0);
cx.notify();
if !self.context_servers_to_setup.is_empty() {
return;
}
self.workspace
.update(cx, {
|workspace, cx| {
let status_toast = StatusToast::new(
format!("{} configured successfully.", id),
cx,
|this, _cx| {
this.icon(ToastIcon::new(IconName::Hammer).color(Color::Muted))
.action("Dismiss", |_, _| {})
},
);
workspace.toggle_status_toast(status_toast, cx);
}
})
.log_err();
self.dismiss(cx);
}
fn dismiss(&self, cx: &mut Context<Self>) {
cx.emit(DismissEvent);
}
}
fn wait_for_context_server(
context_server_manager: &Entity<ContextServerManager>,
context_server_id: Arc<str>,
cx: &mut App,
) -> Task<Result<(), Arc<str>>> {
let (tx, rx) = futures::channel::oneshot::channel();
let tx = Arc::new(Mutex::new(Some(tx)));
let subscription = cx.subscribe(context_server_manager, move |_, event, _cx| match event {
context_server::manager::Event::ServerStatusChanged { server_id, status } => match status {
Some(ContextServerStatus::Running) => {
if server_id == &context_server_id {
if let Some(tx) = tx.lock().unwrap().take() {
let _ = tx.send(Ok(()));
}
}
}
Some(ContextServerStatus::Error(error)) => {
if server_id == &context_server_id {
if let Some(tx) = tx.lock().unwrap().take() {
let _ = tx.send(Err(error.clone()));
}
}
}
_ => {}
},
});
cx.spawn(async move |_cx| {
let result = rx.await.unwrap();
drop(subscription);
result
})
}
impl Render for ConfigureContextServerModal {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let Some(configuration) = self.context_servers_to_setup.first() else {
return div().child("No context servers to setup");
};
let focus_handle = self.focus_handle(cx);
div()
.elevation_3(cx)
.w(rems(34.))
.key_context("ConfigureContextServerModal")
.on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| this.confirm(cx)))
.on_action(cx.listener(|this, _: &menu::Cancel, _window, cx| this.dismiss(cx)))
.capture_any_mouse_down(cx.listener(|this, _, window, cx| {
this.focus_handle(cx).focus(window);
}))
.child(
Modal::new("configure-context-server", None)
.header(ModalHeader::new().headline(format!("Configure {}", configuration.id)))
.section(
Section::new()
.child(div().pb_2().text_sm().child(MarkdownElement::new(
configuration.installation_instructions.clone(),
default_markdown_style(window, cx),
)))
.child(
div()
.p_2()
.rounded_md()
.border_1()
.border_color(cx.theme().colors().border_variant)
.bg(cx.theme().colors().editor_background)
.gap_1()
.child({
let settings = ThemeSettings::get_global(cx);
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.buffer_font.family.clone(),
font_fallbacks: settings.buffer_font.fallbacks.clone(),
font_size: settings.buffer_font_size(cx).into(),
font_weight: settings.buffer_font.weight,
line_height: relative(
settings.buffer_line_height.value(),
),
..Default::default()
};
EditorElement::new(
&configuration.settings_editor,
EditorStyle {
background: cx.theme().colors().editor_background,
local_player: cx.theme().players().local(),
text: text_style,
syntax: cx.theme().syntax().clone(),
..Default::default()
},
)
})
.when_some(configuration.last_error.clone(), |this, error| {
this.child(
h_flex()
.gap_2()
.px_2()
.py_1()
.child(
Icon::new(IconName::Warning)
.size(IconSize::XSmall)
.color(Color::Warning),
)
.child(
div().w_full().child(
Label::new(error)
.size(LabelSize::Small)
.color(Color::Muted),
),
),
)
}),
)
.when(configuration.waiting_for_context_server, |this| {
this.child(
h_flex()
.gap_1p5()
.child(
Icon::new(IconName::ArrowCircle)
.size(IconSize::XSmall)
.color(Color::Info)
.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| {
icon.transform(Transformation::rotate(
percentage(delta),
))
},
)
.into_any_element(),
)
.child(
Label::new("Waiting for Context Server")
.size(LabelSize::Small)
.color(Color::Muted),
),
)
}),
)
.footer(
ModalFooter::new().end_slot(
h_flex()
.gap_1()
.child(
Button::new("cancel", "Cancel")
.key_binding(
KeyBinding::for_action_in(
&menu::Cancel,
&focus_handle,
window,
cx,
)
.map(|kb| kb.size(rems_from_px(12.))),
)
.on_click(cx.listener(|this, _event, _window, cx| {
this.dismiss(cx)
})),
)
.child(
Button::new("configure-server", "Configure MCP")
.disabled(configuration.waiting_for_context_server)
.key_binding(
KeyBinding::for_action_in(
&menu::Confirm,
&focus_handle,
window,
cx,
)
.map(|kb| kb.size(rems_from_px(12.))),
)
.on_click(cx.listener(|this, _event, _window, cx| {
this.confirm(cx)
})),
),
),
),
)
}
}
pub(crate) fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
let theme_settings = ThemeSettings::get_global(cx);
let colors = cx.theme().colors();
let mut text_style = window.text_style();
text_style.refine(&TextStyleRefinement {
font_family: Some(theme_settings.ui_font.family.clone()),
font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
font_features: Some(theme_settings.ui_font.features.clone()),
font_size: Some(TextSize::XSmall.rems(cx).into()),
color: Some(colors.text_muted),
..Default::default()
});
MarkdownStyle {
base_text_style: text_style.clone(),
selection_background_color: cx.theme().players().local().selection,
link: TextStyleRefinement {
background_color: Some(colors.editor_foreground.opacity(0.025)),
underline: Some(UnderlineStyle {
color: Some(colors.text_accent.opacity(0.5)),
thickness: px(1.),
..Default::default()
}),
..Default::default()
},
..Default::default()
}
}
impl ModalView for ConfigureContextServerModal {}
impl EventEmitter<DismissEvent> for ConfigureContextServerModal {}
impl Focusable for ConfigureContextServerModal {
fn focus_handle(&self, cx: &App) -> FocusHandle {
if let Some(current) = self.context_servers_to_setup.first() {
current.settings_editor.read(cx).focus_handle(cx)
} else {
cx.focus_handle()
}
}
}

View File

@@ -401,11 +401,12 @@ impl AssistantPanel {
}
});
let thread_id = thread.read(cx).id().clone();
let history_store = cx.new(|cx| {
HistoryStore::new(
thread_store.clone(),
context_store.clone(),
[RecentEntry::Thread(thread.clone())],
[RecentEntry::Thread(thread_id, thread.clone())],
cx,
)
});
@@ -423,6 +424,7 @@ impl AssistantPanel {
ActiveThread::new(
thread.clone(),
thread_store.clone(),
message_editor_context_store.clone(),
language_registry.clone(),
workspace.clone(),
window,
@@ -456,7 +458,8 @@ impl AssistantPanel {
for entry in recently_opened.iter() {
let summary = entry.summary(cx);
menu = menu.entry_with_end_slot(
menu = menu.entry_with_end_slot_on_hover(
summary,
None,
{
@@ -467,7 +470,7 @@ impl AssistantPanel {
.update(cx, {
let entry = entry.clone();
move |this, cx| match entry {
RecentEntry::Thread(thread) => {
RecentEntry::Thread(_, thread) => {
this.open_thread(thread, window, cx)
}
RecentEntry::Context(context) => {
@@ -625,7 +628,7 @@ impl AssistantPanel {
let thread_view = ActiveView::thread(thread.clone(), window, cx);
self.set_active_view(thread_view, window, cx);
let message_editor_context_store = cx.new(|_cx| {
let context_store = cx.new(|_cx| {
crate::context_store::ContextStore::new(
self.project.downgrade(),
Some(self.thread_store.downgrade()),
@@ -638,7 +641,7 @@ impl AssistantPanel {
.update(cx, |this, cx| this.open_thread(&other_thread_id, cx));
cx.spawn({
let context_store = message_editor_context_store.clone();
let context_store = context_store.clone();
async move |_panel, cx| {
let other_thread = other_thread_task.await?;
@@ -663,6 +666,7 @@ impl AssistantPanel {
ActiveThread::new(
thread.clone(),
self.thread_store.clone(),
context_store.clone(),
self.language_registry.clone(),
self.workspace.clone(),
window,
@@ -681,7 +685,7 @@ impl AssistantPanel {
MessageEditor::new(
self.fs.clone(),
self.workspace.clone(),
message_editor_context_store,
context_store,
self.prompt_store.clone(),
self.thread_store.downgrade(),
thread,
@@ -842,7 +846,7 @@ impl AssistantPanel {
) {
let thread_view = ActiveView::thread(thread.clone(), window, cx);
self.set_active_view(thread_view, window, cx);
let message_editor_context_store = cx.new(|_cx| {
let context_store = cx.new(|_cx| {
crate::context_store::ContextStore::new(
self.project.downgrade(),
Some(self.thread_store.downgrade()),
@@ -859,6 +863,7 @@ impl AssistantPanel {
ActiveThread::new(
thread.clone(),
self.thread_store.clone(),
context_store.clone(),
self.language_registry.clone(),
self.workspace.clone(),
window,
@@ -877,7 +882,7 @@ impl AssistantPanel {
MessageEditor::new(
self.fs.clone(),
self.workspace.clone(),
message_editor_context_store,
context_store,
self.prompt_store.clone(),
self.thread_store.downgrade(),
thread,
@@ -1099,7 +1104,8 @@ impl AssistantPanel {
ActiveView::Thread { thread, .. } => self.history_store.update(cx, |store, cx| {
if let Some(thread) = thread.upgrade() {
if thread.read(cx).is_empty() {
store.remove_recently_opened_entry(&RecentEntry::Thread(thread), cx);
let id = thread.read(cx).id().clone();
store.remove_recently_opened_thread(id, cx);
}
}
}),
@@ -1109,7 +1115,8 @@ impl AssistantPanel {
match &new_view {
ActiveView::Thread { thread, .. } => self.history_store.update(cx, |store, cx| {
if let Some(thread) = thread.upgrade() {
store.push_recently_opened_entry(RecentEntry::Thread(thread), cx);
let id = thread.read(cx).id().clone();
store.push_recently_opened_entry(RecentEntry::Thread(id, thread), cx);
}
}),
ActiveView::PromptEditor { context_editor, .. } => {

View File

@@ -3,11 +3,12 @@ use std::hash::{Hash, Hasher};
use std::path::PathBuf;
use std::{ops::Range, path::Path, sync::Arc};
use assistant_tool::outline;
use collections::HashSet;
use futures::future;
use futures::{FutureExt, future::Shared};
use gpui::{App, AppContext as _, Entity, SharedString, Task};
use language::Buffer;
use language::{Buffer, ParseStatus};
use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
use project::{Project, ProjectEntryId, ProjectPath, Worktree};
use prompt_store::{PromptStore, UserPromptId};
@@ -152,6 +153,7 @@ pub struct FileContext {
pub handle: FileContextHandle,
pub full_path: Arc<Path>,
pub text: SharedString,
pub is_outline: bool,
}
impl FileContextHandle {
@@ -177,14 +179,51 @@ impl FileContextHandle {
log::error!("file context missing path");
return Task::ready(None);
};
let full_path = file.full_path(cx);
let full_path: Arc<Path> = file.full_path(cx).into();
let rope = buffer_ref.as_rope().clone();
let buffer = self.buffer.clone();
cx.background_spawn(async move {
cx.spawn(async move |cx| {
// For large files, use outline instead of full content
if rope.len() > outline::AUTO_OUTLINE_SIZE {
// Wait until the buffer has been fully parsed, so we can read its outline
if let Ok(mut parse_status) =
buffer.read_with(cx, |buffer, _| buffer.parse_status())
{
while *parse_status.borrow() != ParseStatus::Idle {
parse_status.changed().await.log_err();
}
if let Ok(snapshot) = buffer.read_with(cx, |buffer, _| buffer.snapshot()) {
if let Some(outline) = snapshot.outline(None) {
let items = outline
.items
.into_iter()
.map(|item| item.to_point(&snapshot));
if let Ok(outline_text) =
outline::render_outline(items, None, 0, usize::MAX).await
{
let context = AgentContext::File(FileContext {
handle: self,
full_path,
text: outline_text.into(),
is_outline: true,
});
return Some((context, vec![buffer]));
}
}
}
}
}
// Fallback to full content if we couldn't build an outline
// (or didn't need to because the file was small enough)
let context = AgentContext::File(FileContext {
handle: self,
full_path: full_path.into(),
full_path,
text: rope.to_string().into(),
is_outline: false,
});
Some((context, vec![buffer]))
})
@@ -804,7 +843,7 @@ pub fn load_context(
text.push_str(
"\n<context>\n\
The following items were attached by the user. \
You don't need to use other tools to read them.\n\n",
They are up-to-date and don't need to be re-read.\n\n",
);
if !file_context.is_empty() {
@@ -996,3 +1035,115 @@ impl Hash for AgentContextKey {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use util::path;
fn init_test_settings(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
}
// Helper to create a test project with test files
async fn create_test_project(
cx: &mut TestAppContext,
files: serde_json::Value,
) -> Entity<Project> {
let fs = FakeFs::new(cx.background_executor.clone());
fs.insert_tree(path!("/test"), files).await;
Project::test(fs, [path!("/test").as_ref()], cx).await
}
#[gpui::test]
async fn test_large_file_uses_outline(cx: &mut TestAppContext) {
init_test_settings(cx);
// Create a large file that exceeds AUTO_OUTLINE_SIZE
const LINE: &str = "Line with some text\n";
let large_content = LINE.repeat(2 * (outline::AUTO_OUTLINE_SIZE / LINE.len()));
let content_len = large_content.len();
assert!(content_len > outline::AUTO_OUTLINE_SIZE);
let file_context = file_context_for(large_content, cx).await;
assert!(
file_context.is_outline,
"Large file should use outline format"
);
assert!(
file_context.text.len() < content_len,
"Outline should be smaller than original content"
);
}
#[gpui::test]
async fn test_small_file_uses_full_content(cx: &mut TestAppContext) {
init_test_settings(cx);
let small_content = "This is a small file.\n";
let content_len = small_content.len();
assert!(content_len < outline::AUTO_OUTLINE_SIZE);
let file_context = file_context_for(small_content.to_string(), cx).await;
assert!(
!file_context.is_outline,
"Small files should not get an outline"
);
assert_eq!(file_context.text, small_content);
}
async fn file_context_for(content: String, cx: &mut TestAppContext) -> FileContext {
// Create a test project with the file
let project = create_test_project(
cx,
json!({
"file.txt": content,
}),
)
.await;
// Open the buffer
let buffer_path = project
.read_with(cx, |project, cx| project.find_project_path("file.txt", cx))
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(buffer_path, cx))
.await
.unwrap();
let context_handle = AgentContextHandle::File(FileContextHandle {
buffer: buffer.clone(),
context_id: ContextId::zero(),
});
cx.update(|cx| load_context(vec![context_handle], &project, &None, cx))
.await
.loaded_context
.contexts
.into_iter()
.find_map(|ctx| {
if let AgentContext::File(file_ctx) = ctx {
Some(file_ctx)
} else {
None
}
})
.expect("Should have found a file context")
}
}

View File

@@ -0,0 +1,120 @@
use std::sync::Arc;
use anyhow::Context as _;
use context_server::ContextServerDescriptorRegistry;
use extension::ExtensionManifest;
use language::LanguageRegistry;
use ui::prelude::*;
use util::ResultExt;
use workspace::Workspace;
use crate::{AssistantPanel, assistant_configuration::ConfigureContextServerModal};
pub(crate) fn init(language_registry: Arc<LanguageRegistry>, cx: &mut App) {
cx.observe_new(move |_: &mut Workspace, window, cx| {
let Some(window) = window else {
return;
};
if let Some(extension_events) = extension::ExtensionEvents::try_global(cx).as_ref() {
cx.subscribe_in(extension_events, window, {
let language_registry = language_registry.clone();
move |workspace, _, event, window, cx| match event {
extension::Event::ExtensionInstalled(manifest) => {
show_configure_mcp_modal(
language_registry.clone(),
manifest,
workspace,
window,
cx,
);
}
extension::Event::ConfigureExtensionRequested(manifest) => {
if !manifest.context_servers.is_empty() {
show_configure_mcp_modal(
language_registry.clone(),
manifest,
workspace,
window,
cx,
);
}
}
_ => {}
}
})
.detach();
} else {
log::info!(
"No extension events global found. Skipping context server configuration wizard"
);
}
})
.detach();
}
fn show_configure_mcp_modal(
language_registry: Arc<LanguageRegistry>,
manifest: &Arc<ExtensionManifest>,
workspace: &mut Workspace,
window: &mut Window,
cx: &mut Context<'_, Workspace>,
) {
let Some(context_server_manager) = workspace.panel::<AssistantPanel>(cx).map(|panel| {
panel
.read(cx)
.thread_store()
.read(cx)
.context_server_manager()
}) else {
return;
};
let registry = ContextServerDescriptorRegistry::global(cx).read(cx);
let project = workspace.project().clone();
let configuration_tasks = manifest
.context_servers
.keys()
.cloned()
.filter_map({
|key| {
let descriptor = registry.context_server_descriptor(&key)?;
Some(cx.spawn({
let project = project.clone();
async move |_, cx| {
descriptor
.configuration(project, &cx)
.await
.context("Failed to resolve context server configuration")
.log_err()
.flatten()
.map(|config| (key, config))
}
}))
}
})
.collect::<Vec<_>>();
let jsonc_language = language_registry.language_for_name("jsonc");
cx.spawn_in(window, async move |this, cx| {
let descriptors = futures::future::join_all(configuration_tasks).await;
let jsonc_language = jsonc_language.await.ok();
this.update_in(cx, |this, window, cx| {
let modal = ConfigureContextServerModal::new(
descriptors.into_iter().flatten(),
jsonc_language,
context_server_manager,
language_registry,
cx.entity().downgrade(),
window,
cx,
);
if let Some(modal) = modal {
this.toggle_modal(window, cx, |_, _| modal);
}
})
})
.detach();
}

View File

@@ -21,7 +21,7 @@ use crate::context::{
SymbolContextHandle, ThreadContextHandle,
};
use crate::context_strip::SuggestedContext;
use crate::thread::{Thread, ThreadId};
use crate::thread::{MessageId, Thread, ThreadId};
pub struct ContextStore {
project: WeakEntity<Project>,
@@ -54,9 +54,14 @@ impl ContextStore {
self.context_thread_ids.clear();
}
pub fn new_context_for_thread(&self, thread: &Thread) -> Vec<AgentContextHandle> {
pub fn new_context_for_thread(
&self,
thread: &Thread,
exclude_messages_from_id: Option<MessageId>,
) -> Vec<AgentContextHandle> {
let existing_context = thread
.messages()
.take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
.flat_map(|message| {
message
.loaded_context

View File

@@ -36,16 +36,28 @@ impl HistoryEntry {
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug)]
pub(crate) enum RecentEntry {
Thread(Entity<Thread>),
Thread(ThreadId, Entity<Thread>),
Context(Entity<AssistantContext>),
}
impl PartialEq for RecentEntry {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Thread(l0, _), Self::Thread(r0, _)) => l0 == r0,
(Self::Context(l0), Self::Context(r0)) => l0 == r0,
_ => false,
}
}
}
impl Eq for RecentEntry {}
impl RecentEntry {
pub(crate) fn summary(&self, cx: &App) -> SharedString {
match self {
RecentEntry::Thread(thread) => thread.read(cx).summary_or_default(),
RecentEntry::Thread(_, thread) => thread.read(cx).summary_or_default(),
RecentEntry::Context(context) => context.read(cx).summary_or_default(),
}
}
@@ -93,9 +105,10 @@ impl HistoryStore {
.map(|serialized| match serialized {
SerializedRecentEntry::Thread(id) => thread_store
.update(cx, |thread_store, cx| {
let thread_id = ThreadId::from(id.as_str());
thread_store
.open_thread(&ThreadId::from(id.as_str()), cx)
.map_ok(RecentEntry::Thread)
.open_thread(&thread_id, cx)
.map_ok(|thread| RecentEntry::Thread(thread_id, thread))
.boxed()
})
.unwrap_or_else(|_| async { Err(anyhow!("no thread store")) }.boxed()),
@@ -170,9 +183,7 @@ impl HistoryStore {
RecentEntry::Context(context) => Some(SerializedRecentEntry::Context(
context.read(cx).path()?.to_str()?.to_owned(),
)),
RecentEntry::Thread(thread) => Some(SerializedRecentEntry::Thread(
thread.read(cx).id().to_string(),
)),
RecentEntry::Thread(id, _) => Some(SerializedRecentEntry::Thread(id.to_string())),
})
.collect::<Vec<_>>();
@@ -200,6 +211,14 @@ impl HistoryStore {
self.save_recently_opened_entries(cx);
}
pub fn remove_recently_opened_thread(&mut self, id: ThreadId, cx: &mut Context<Self>) {
self.recently_opened_entries.retain(|entry| match entry {
RecentEntry::Thread(thread_id, _) if thread_id == &id => false,
_ => true,
});
self.save_recently_opened_entries(cx);
}
pub fn remove_recently_opened_entry(&mut self, entry: &RecentEntry, cx: &mut Context<Self>) {
self.recently_opened_entries
.retain(|old_entry| old_entry != entry);

View File

@@ -4,7 +4,7 @@ use std::sync::Arc;
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
use crate::context::{ContextLoadResult, load_context};
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
use crate::ui::AnimatedLabel;
use crate::ui::{AgentPreview, AnimatedLabel};
use buffer_diff::BufferDiff;
use collections::HashSet;
use editor::actions::{MoveUp, Paste};
@@ -42,10 +42,11 @@ use crate::profile_selector::ProfileSelector;
use crate::thread::{Thread, TokenUsageRatio};
use crate::thread_store::ThreadStore;
use crate::{
AgentDiff, Chat, ChatMode, ExpandMessageEditor, NewThread, OpenAgentDiff, RemoveAllContext,
ToggleContextPicker, ToggleProfileSelector,
ActiveThread, AgentDiff, Chat, ExpandMessageEditor, NewThread, OpenAgentDiff, RemoveAllContext,
ToggleContextPicker, ToggleProfileSelector, register_agent_preview,
};
#[derive(RegisterComponent)]
pub struct MessageEditor {
thread: Entity<Thread>,
incompatible_tools_state: Entity<IncompatibleToolsState>,
@@ -69,6 +70,56 @@ pub struct MessageEditor {
const MAX_EDITOR_LINES: usize = 8;
pub(crate) fn create_editor(
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut App,
) -> Entity<Editor> {
let language = Language::new(
language::LanguageConfig {
completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']),
..Default::default()
},
None,
);
let editor = cx.new(|cx| {
let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let mut editor = Editor::new(
editor::EditorMode::AutoHeight {
max_lines: MAX_EDITOR_LINES,
},
buffer,
None,
window,
cx,
);
editor.set_placeholder_text("Ask anything, @ to mention, ↑ to select", cx);
editor.set_show_indent_guides(false, cx);
editor.set_soft_wrap();
editor.set_context_menu_options(ContextMenuOptions {
min_entries_visible: 12,
max_entries_visible: 12,
placement: Some(ContextMenuPlacement::Above),
});
editor
});
let editor_entity = editor.downgrade();
editor.update(cx, |editor, _| {
editor.set_completion_provider(Some(Box::new(ContextPickerCompletionProvider::new(
workspace,
context_store,
Some(thread_store),
editor_entity,
))));
});
editor
}
impl MessageEditor {
pub fn new(
fs: Arc<dyn Fs>,
@@ -83,47 +134,14 @@ impl MessageEditor {
let context_picker_menu_handle = PopoverMenuHandle::default();
let model_selector_menu_handle = PopoverMenuHandle::default();
let language = Language::new(
language::LanguageConfig {
completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']),
..Default::default()
},
None,
let editor = create_editor(
workspace.clone(),
context_store.downgrade(),
thread_store.clone(),
window,
cx,
);
let editor = cx.new(|cx| {
let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let mut editor = Editor::new(
editor::EditorMode::AutoHeight {
max_lines: MAX_EDITOR_LINES,
},
buffer,
None,
window,
cx,
);
editor.set_placeholder_text("Ask anything, @ to mention, ↑ to select", cx);
editor.set_show_indent_guides(false, cx);
editor.set_soft_wrap();
editor.set_context_menu_options(ContextMenuOptions {
min_entries_visible: 12,
max_entries_visible: 12,
placement: Some(ContextMenuPlacement::Above),
});
editor
});
let editor_entity = editor.downgrade();
editor.update(cx, |editor, _| {
editor.set_completion_provider(Some(Box::new(ContextPickerCompletionProvider::new(
workspace.clone(),
context_store.downgrade(),
Some(thread_store.clone()),
editor_entity,
))));
});
let context_strip = cx.new(|cx| {
ContextStrip::new(
context_store.clone(),
@@ -189,10 +207,6 @@ impl MessageEditor {
&self.context_store
}
fn toggle_chat_mode(&mut self, _: &ChatMode, _window: &mut Window, cx: &mut Context<Self>) {
cx.notify();
}
pub fn expand_message_editor(
&mut self,
_: &ExpandMessageEditor,
@@ -415,12 +429,13 @@ impl MessageEditor {
Some(
IconButton::new("max-mode", IconName::ZedMaxMode)
.icon_size(IconSize::Small)
.toggle_state(active_completion_mode == Some(CompletionMode::Max))
.icon_color(Color::Muted)
.toggle_state(active_completion_mode == CompletionMode::Max)
.on_click(cx.listener(move |this, _event, _window, cx| {
this.thread.update(cx, |thread, _cx| {
thread.set_completion_mode(match active_completion_mode {
Some(CompletionMode::Max) => Some(CompletionMode::Normal),
Some(CompletionMode::Normal) | None => Some(CompletionMode::Max),
CompletionMode::Max => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Max,
});
});
}))
@@ -482,7 +497,6 @@ impl MessageEditor {
.on_action(cx.listener(Self::toggle_context_picker))
.on_action(cx.listener(Self::remove_all_context))
.on_action(cx.listener(Self::move_up))
.on_action(cx.listener(Self::toggle_chat_mode))
.on_action(cx.listener(Self::expand_message_editor))
.capture_action(cx.listener(Self::paste))
.gap_2()
@@ -1041,7 +1055,7 @@ impl MessageEditor {
let load_task = cx.spawn(async move |this, cx| {
let Ok(load_task) = this.update(cx, |this, cx| {
let new_context = this.context_store.read_with(cx, |context_store, cx| {
context_store.new_context_for_thread(this.thread.read(cx))
context_store.new_context_for_thread(this.thread.read(cx), None)
});
load_context(new_context, &this.project, &this.prompt_store, cx)
}) else {
@@ -1189,3 +1203,53 @@ impl Render for MessageEditor {
})
}
}
impl Component for MessageEditor {
fn scope() -> ComponentScope {
ComponentScope::Agent
}
}
impl AgentPreview for MessageEditor {
fn create_preview(
workspace: WeakEntity<Workspace>,
active_thread: Entity<ActiveThread>,
thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyElement> {
if let Some(workspace_entity) = workspace.upgrade() {
let fs = workspace_entity.read(cx).app_state().fs.clone();
let weak_project = workspace_entity.read(cx).project().clone().downgrade();
let context_store = cx.new(|_cx| ContextStore::new(weak_project, None));
let thread = active_thread.read(cx).thread().clone();
let example_message_editor = cx.new(|cx| {
MessageEditor::new(
fs,
workspace,
context_store,
None,
thread_store,
thread,
window,
cx,
)
});
Some(
v_flex()
.gap_4()
.children(vec![single_example(
"Default",
example_message_editor.clone().into_any_element(),
)])
.into_any_element(),
)
} else {
None
}
}
}
register_agent_preview!(MessageEditor);

View File

@@ -301,6 +301,14 @@ pub enum TokenUsageRatio {
Exceeded,
}
fn default_completion_mode(cx: &App) -> CompletionMode {
if cx.is_staff() {
CompletionMode::Max
} else {
CompletionMode::Normal
}
}
/// A thread of conversation with the LLM.
pub struct Thread {
id: ThreadId,
@@ -310,7 +318,7 @@ pub struct Thread {
detailed_summary_task: Task<Option<()>>,
detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
completion_mode: Option<CompletionMode>,
completion_mode: CompletionMode,
messages: Vec<Message>,
next_message_id: MessageId,
last_prompt_id: PromptId,
@@ -366,7 +374,7 @@ impl Thread {
detailed_summary_task: Task::ready(None),
detailed_summary_tx,
detailed_summary_rx,
completion_mode: None,
completion_mode: default_completion_mode(cx),
messages: Vec::new(),
next_message_id: MessageId(0),
last_prompt_id: PromptId::new(),
@@ -440,7 +448,7 @@ impl Thread {
detailed_summary_task: Task::ready(None),
detailed_summary_tx,
detailed_summary_rx,
completion_mode: None,
completion_mode: default_completion_mode(cx),
messages: serialized
.messages
.into_iter()
@@ -569,11 +577,11 @@ impl Thread {
}
}
pub fn completion_mode(&self) -> Option<CompletionMode> {
pub fn completion_mode(&self) -> CompletionMode {
self.completion_mode
}
pub fn set_completion_mode(&mut self, mode: Option<CompletionMode>) {
pub fn set_completion_mode(&mut self, mode: CompletionMode) {
self.completion_mode = mode;
}
@@ -879,6 +887,7 @@ impl Thread {
id: MessageId,
new_role: Role,
new_segments: Vec<MessageSegment>,
loaded_context: Option<LoadedContext>,
cx: &mut Context<Self>,
) -> bool {
let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
@@ -886,6 +895,9 @@ impl Thread {
};
message.role = new_role;
message.segments = new_segments;
if let Some(context) = loaded_context {
message.loaded_context = context;
}
self.touch_updated_at();
cx.emit(ThreadEvent::MessageEdited(id));
true
@@ -1148,9 +1160,9 @@ impl Thread {
request.tools = available_tools;
request.mode = if model.supports_max_mode() {
self.completion_mode
Some(self.completion_mode)
} else {
None
Some(CompletionMode::Normal)
};
request
@@ -2106,7 +2118,7 @@ impl Thread {
.map(|repo| {
repo.update(cx, |repo, _| {
let current_branch =
repo.branch.as_ref().map(|branch| branch.name.to_string());
repo.branch.as_ref().map(|branch| branch.name().to_owned());
repo.send_job(None, |state, _| async move {
let RepositoryState::Local { backend, .. } = state else {
return GitState {
@@ -2505,7 +2517,7 @@ mod tests {
let expected_context = format!(
r#"
<context>
The following items were attached by the user. You don't need to use other tools to read them.
The following items were attached by the user. They are up-to-date and don't need to be re-read.
<files>
```rs {path_part}
@@ -2546,6 +2558,7 @@ fn main() {{
"file1.rs": "fn function1() {}\n",
"file2.rs": "fn function2() {}\n",
"file3.rs": "fn function3() {}\n",
"file4.rs": "fn function4() {}\n",
}),
)
.await;
@@ -2558,7 +2571,7 @@ fn main() {{
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx))
store.new_context_for_thread(thread.read(cx), None)
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
@@ -2573,7 +2586,7 @@ fn main() {{
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx))
store.new_context_for_thread(thread.read(cx), None)
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
@@ -2589,7 +2602,7 @@ fn main() {{
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx))
store.new_context_for_thread(thread.read(cx), None)
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
@@ -2640,6 +2653,55 @@ fn main() {{
assert!(!request.messages[3].string_contents().contains("file1.rs"));
assert!(!request.messages[3].string_contents().contains("file2.rs"));
assert!(request.messages[3].string_contents().contains("file3.rs"));
add_file_to_context(&project, &context_store, "test/file4.rs", cx)
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx), Some(message2_id))
});
assert_eq!(new_contexts.len(), 3);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await
.loaded_context;
assert!(!loaded_context.text.contains("file1.rs"));
assert!(loaded_context.text.contains("file2.rs"));
assert!(loaded_context.text.contains("file3.rs"));
assert!(loaded_context.text.contains("file4.rs"));
let new_contexts = context_store.update(cx, |store, cx| {
// Remove file4.rs
store.remove_context(&loaded_context.contexts[2].handle(), cx);
store.new_context_for_thread(thread.read(cx), Some(message2_id))
});
assert_eq!(new_contexts.len(), 2);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await
.loaded_context;
assert!(!loaded_context.text.contains("file1.rs"));
assert!(loaded_context.text.contains("file2.rs"));
assert!(loaded_context.text.contains("file3.rs"));
assert!(!loaded_context.text.contains("file4.rs"));
let new_contexts = context_store.update(cx, |store, cx| {
// Remove file3.rs
store.remove_context(&loaded_context.contexts[1].handle(), cx);
store.new_context_for_thread(thread.read(cx), Some(message2_id))
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await
.loaded_context;
assert!(!loaded_context.text.contains("file1.rs"));
assert!(loaded_context.text.contains("file2.rs"));
assert!(!loaded_context.text.contains("file3.rs"));
assert!(!loaded_context.text.contains("file4.rs"));
}
#[gpui::test]

View File

@@ -9,8 +9,8 @@ use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
use context_server::manager::ContextServerManager;
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
use context_server::manager::{ContextServerManager, ContextServerStatus};
use context_server::{ContextServerDescriptorRegistry, ContextServerTool};
use futures::channel::{mpsc, oneshot};
use futures::future::{self, BoxFuture, Shared};
use futures::{FutureExt as _, StreamExt as _};
@@ -108,7 +108,7 @@ impl ThreadStore {
prompt_store: Option<Entity<PromptStore>>,
cx: &mut Context<Self>,
) -> (Self, oneshot::Receiver<()>) {
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
let context_server_factory_registry = ContextServerDescriptorRegistry::default_global(cx);
let context_server_manager = cx.new(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});
@@ -555,62 +555,68 @@ impl ThreadStore {
) {
let tool_working_set = self.tools.clone();
match event {
context_server::manager::Event::ServerStarted { server_id } => {
if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
let context_server_manager = context_server_manager.clone();
cx.spawn({
let server = server.clone();
let server_id = server_id.clone();
async move |this, cx| {
let Some(protocol) = server.client() else {
return;
};
context_server::manager::Event::ServerStatusChanged { server_id, status } => {
match status {
Some(ContextServerStatus::Running) => {
if let Some(server) = context_server_manager.read(cx).get_server(server_id)
{
let context_server_manager = context_server_manager.clone();
cx.spawn({
let server = server.clone();
let server_id = server_id.clone();
async move |this, cx| {
let Some(protocol) = server.client() else {
return;
};
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() {
let tool_ids = tool_working_set
.update(cx, |tool_working_set, _| {
tools
.tools
.into_iter()
.map(|tool| {
log::info!(
"registering context server tool: {:?}",
tool.name
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
),
))
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() {
let tool_ids = tool_working_set
.update(cx, |tool_working_set, _| {
tools
.tools
.into_iter()
.map(|tool| {
log::info!(
"registering context server tool: {:?}",
tool.name
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
),
))
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
})
.log_err();
.log_err();
if let Some(tool_ids) = tool_ids {
this.update(cx, |this, cx| {
this.context_server_tool_ids
.insert(server_id, tool_ids);
this.load_default_profile(cx);
})
.log_err();
if let Some(tool_ids) = tool_ids {
this.update(cx, |this, cx| {
this.context_server_tool_ids
.insert(server_id, tool_ids);
this.load_default_profile(cx);
})
.log_err();
}
}
}
}
}
})
.detach();
}
})
.detach();
}
}
context_server::manager::Event::ServerStopped { server_id } => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.update(cx, |tool_working_set, _| {
tool_working_set.remove(&tool_ids);
});
self.load_default_profile(cx);
}
None => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.update(cx, |tool_working_set, _| {
tool_working_set.remove(&tool_ids);
});
self.load_default_profile(cx);
}
}
_ => {}
}
}
}

View File

@@ -1,9 +1,12 @@
mod agent_notification;
pub mod agent_preview;
mod animated_label;
mod context_pill;
mod upsell;
mod usage_banner;
pub use agent_notification::*;
pub use agent_preview::*;
pub use animated_label::*;
pub use context_pill::*;
pub use usage_banner::*;

View File

@@ -0,0 +1,99 @@
use collections::HashMap;
use component::ComponentId;
use gpui::{App, Entity, WeakEntity};
use linkme::distributed_slice;
use std::sync::OnceLock;
use ui::{AnyElement, Component, Window};
use workspace::Workspace;
use crate::{ActiveThread, ThreadStore};
/// Function type for creating agent component previews
pub type PreviewFn = fn(
WeakEntity<Workspace>,
Entity<ActiveThread>,
WeakEntity<ThreadStore>,
&mut Window,
&mut App,
) -> Option<AnyElement>;
/// Distributed slice for preview registration functions
#[distributed_slice]
pub static __ALL_AGENT_PREVIEWS: [fn() -> (ComponentId, PreviewFn)] = [..];
/// Trait that must be implemented by components that provide agent previews.
pub trait AgentPreview: Component {
/// Get the ID for this component
///
/// Eventually this will move to the component trait.
fn id() -> ComponentId
where
Self: Sized,
{
ComponentId(Self::name())
}
/// Static method to create a preview for this component type
fn create_preview(
workspace: WeakEntity<Workspace>,
active_thread: Entity<ActiveThread>,
thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyElement>
where
Self: Sized;
}
/// Register an agent preview for the given component type
#[macro_export]
macro_rules! register_agent_preview {
($type:ty) => {
#[linkme::distributed_slice($crate::ui::agent_preview::__ALL_AGENT_PREVIEWS)]
static __REGISTER_AGENT_PREVIEW: fn() -> (
component::ComponentId,
$crate::ui::agent_preview::PreviewFn,
) = || {
(
<$type as $crate::ui::agent_preview::AgentPreview>::id(),
<$type as $crate::ui::agent_preview::AgentPreview>::create_preview,
)
};
};
}
/// Lazy initialized registry of preview functions
static AGENT_PREVIEW_REGISTRY: OnceLock<HashMap<ComponentId, PreviewFn>> = OnceLock::new();
/// Initialize the agent preview registry if needed
fn get_or_init_registry() -> &'static HashMap<ComponentId, PreviewFn> {
AGENT_PREVIEW_REGISTRY.get_or_init(|| {
let mut map = HashMap::default();
for register_fn in __ALL_AGENT_PREVIEWS.iter() {
let (id, preview_fn) = register_fn();
map.insert(id, preview_fn);
}
map
})
}
/// Get a specific agent preview by component ID.
pub fn get_agent_preview(
id: &ComponentId,
workspace: WeakEntity<Workspace>,
active_thread: Entity<ActiveThread>,
thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyElement> {
let registry = get_or_init_registry();
registry
.get(id)
.and_then(|preview_fn| preview_fn(workspace, active_thread, thread_store, window, cx))
}
/// Get all registered agent previews.
pub fn all_agent_previews() -> Vec<ComponentId> {
let registry = get_or_init_registry();
registry.keys().cloned().collect()
}

View File

@@ -216,9 +216,10 @@ impl RenderOnce for ContextPill {
})
.when_some(on_click.as_ref(), |element, on_click| {
let on_click = on_click.clone();
element
.cursor_pointer()
.on_click(move |event, window, cx| on_click(event, window, cx))
element.cursor_pointer().on_click(move |event, window, cx| {
on_click(event, window, cx);
cx.stop_propagation();
})
})
.into_any_element()
}
@@ -254,7 +255,10 @@ impl RenderOnce for ContextPill {
})
.when_some(on_click.as_ref(), |element, on_click| {
let on_click = on_click.clone();
element.on_click(move |event, window, cx| on_click(event, window, cx))
element.on_click(move |event, window, cx| {
on_click(event, window, cx);
cx.stop_propagation();
})
})
.into_any(),
}

View File

@@ -0,0 +1,163 @@
use component::{Component, ComponentScope, single_example};
use gpui::{
AnyElement, App, ClickEvent, IntoElement, ParentElement, RenderOnce, SharedString, Styled,
Window,
};
use theme::ActiveTheme;
use ui::{
Button, ButtonCommon, ButtonStyle, Checkbox, Clickable, Color, Label, LabelCommon,
RegisterComponent, ToggleState, h_flex, v_flex,
};
/// A component that displays an upsell message with a call-to-action button
///
/// # Example
/// ```
/// let upsell = Upsell::new(
/// "Upgrade to Zed Pro",
/// "Get unlimited access to AI features and more",
/// "Upgrade Now",
/// Box::new(|_, _window, cx| {
/// cx.open_url("https://zed.dev/pricing");
/// }),
/// Box::new(|_, _window, cx| {
/// // Handle dismiss
/// }),
/// Box::new(|checked, window, cx| {
/// // Handle don't show again
/// }),
/// );
/// ```
#[derive(IntoElement, RegisterComponent)]
pub struct Upsell {
title: SharedString,
message: SharedString,
cta_text: SharedString,
on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App)>,
on_dismiss: Box<dyn Fn(&ClickEvent, &mut Window, &mut App)>,
on_dont_show_again: Box<dyn Fn(bool, &mut Window, &mut App)>,
}
impl Upsell {
/// Create a new upsell component
pub fn new(
title: impl Into<SharedString>,
message: impl Into<SharedString>,
cta_text: impl Into<SharedString>,
on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App)>,
on_dismiss: Box<dyn Fn(&ClickEvent, &mut Window, &mut App)>,
on_dont_show_again: Box<dyn Fn(bool, &mut Window, &mut App)>,
) -> Self {
Self {
title: title.into(),
message: message.into(),
cta_text: cta_text.into(),
on_click,
on_dismiss,
on_dont_show_again,
}
}
}
impl RenderOnce for Upsell {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
v_flex()
.w_full()
.p_4()
.gap_3()
.bg(cx.theme().colors().surface_background)
.rounded_md()
.border_1()
.border_color(cx.theme().colors().border)
.child(
v_flex()
.gap_1()
.child(
Label::new(self.title)
.size(ui::LabelSize::Large)
.weight(gpui::FontWeight::BOLD),
)
.child(Label::new(self.message).color(Color::Muted)),
)
.child(
h_flex()
.w_full()
.justify_between()
.items_center()
.child(
h_flex()
.items_center()
.gap_1()
.child(
Checkbox::new("dont-show-again", ToggleState::Unselected).on_click(
move |_, window, cx| {
(self.on_dont_show_again)(true, window, cx);
},
),
)
.child(
Label::new("Don't show again")
.color(Color::Muted)
.size(ui::LabelSize::Small),
),
)
.child(
h_flex()
.gap_2()
.child(
Button::new("dismiss-button", "Dismiss")
.style(ButtonStyle::Subtle)
.on_click(self.on_dismiss),
)
.child(
Button::new("cta-button", self.cta_text)
.style(ButtonStyle::Filled)
.on_click(self.on_click),
),
),
)
}
}
impl Component for Upsell {
fn scope() -> ComponentScope {
ComponentScope::Agent
}
fn name() -> &'static str {
"Upsell"
}
fn description() -> Option<&'static str> {
Some("A promotional component that displays a message with a call-to-action.")
}
fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
let examples = vec![
single_example(
"Default",
Upsell::new(
"Upgrade to Zed Pro",
"Get unlimited access to AI features and more with Zed Pro. Unlock advanced AI capabilities and other premium features.",
"Upgrade Now",
Box::new(|_, _, _| {}),
Box::new(|_, _, _| {}),
Box::new(|_, _, _| {}),
).render(window, cx).into_any_element(),
),
single_example(
"Short Message",
Upsell::new(
"Try Zed Pro for free",
"Start your 7-day trial today.",
"Start Trial",
Box::new(|_, _, _| {}),
Box::new(|_, _, _| {}),
Box::new(|_, _, _| {}),
).render(window, cx).into_any_element(),
),
];
Some(v_flex().gap_4().children(examples).into_any_element())
}
}

View File

@@ -98,6 +98,10 @@ impl RenderOnce for UsageBanner {
}
impl Component for UsageBanner {
fn scope() -> ComponentScope {
ComponentScope::Agent
}
fn sort_name() -> &'static str {
"AgentUsageBanner"
}

View File

@@ -7,8 +7,8 @@ use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
use client::{Client, TypedEnvelope, proto, telemetry::Telemetry};
use clock::ReplicaId;
use collections::HashMap;
use context_server::ContextServerFactoryRegistry;
use context_server::manager::ContextServerManager;
use context_server::ContextServerDescriptorRegistry;
use context_server::manager::{ContextServerManager, ContextServerStatus};
use fs::{Fs, RemoveOptions};
use futures::StreamExt;
use fuzzy::StringMatchCandidate;
@@ -99,7 +99,7 @@ impl ContextStore {
let this = cx.new(|cx: &mut Context<Self>| {
let context_server_factory_registry =
ContextServerFactoryRegistry::default_global(cx);
ContextServerDescriptorRegistry::default_global(cx);
let context_server_manager = cx.new(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});
@@ -831,54 +831,60 @@ impl ContextStore {
) {
let slash_command_working_set = self.slash_commands.clone();
match event {
context_server::manager::Event::ServerStarted { server_id } => {
if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
let context_server_manager = context_server_manager.clone();
cx.spawn({
let server = server.clone();
let server_id = server_id.clone();
async move |this, cx| {
let Some(protocol) = server.client() else {
return;
};
context_server::manager::Event::ServerStatusChanged { server_id, status } => {
match status {
Some(ContextServerStatus::Running) => {
if let Some(server) = context_server_manager.read(cx).get_server(server_id)
{
let context_server_manager = context_server_manager.clone();
cx.spawn({
let server = server.clone();
let server_id = server_id.clone();
async move |this, cx| {
let Some(protocol) = server.client() else {
return;
};
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
if let Some(prompts) = protocol.list_prompts().await.log_err() {
let slash_command_ids = prompts
.into_iter()
.filter(assistant_slash_commands::acceptable_prompt)
.map(|prompt| {
log::info!(
"registering context server command: {:?}",
prompt.name
);
slash_command_working_set.insert(Arc::new(
assistant_slash_commands::ContextServerSlashCommand::new(
context_server_manager.clone(),
&server,
prompt,
),
))
})
.collect::<Vec<_>>();
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
if let Some(prompts) = protocol.list_prompts().await.log_err() {
let slash_command_ids = prompts
.into_iter()
.filter(assistant_slash_commands::acceptable_prompt)
.map(|prompt| {
log::info!(
"registering context server command: {:?}",
prompt.name
);
slash_command_working_set.insert(Arc::new(
assistant_slash_commands::ContextServerSlashCommand::new(
context_server_manager.clone(),
&server,
prompt,
),
))
})
.collect::<Vec<_>>();
this.update( cx, |this, _cx| {
this.context_server_slash_command_ids
.insert(server_id.clone(), slash_command_ids);
})
.log_err();
this.update( cx, |this, _cx| {
this.context_server_slash_command_ids
.insert(server_id.clone(), slash_command_ids);
})
.log_err();
}
}
}
}
})
.detach();
}
})
.detach();
}
}
context_server::manager::Event::ServerStopped { server_id } => {
if let Some(slash_command_ids) =
self.context_server_slash_command_ids.remove(server_id)
{
slash_command_working_set.remove(&slash_command_ids);
}
None => {
if let Some(slash_command_ids) =
self.context_server_slash_command_ids.remove(server_id)
{
slash_command_working_set.remove(&slash_command_ids);
}
}
_ => {}
}
}
}

View File

@@ -6,7 +6,7 @@ use ::open_ai::Model as OpenAiModel;
use anthropic::Model as AnthropicModel;
use anyhow::{Result, bail};
use deepseek::Model as DeepseekModel;
use feature_flags::{Assistant2FeatureFlag, FeatureFlagAppExt};
use feature_flags::{AgentStreamEditsFeatureFlag, Assistant2FeatureFlag, FeatureFlagAppExt};
use gpui::{App, Pixels};
use indexmap::IndexMap;
use language_model::{CloudModel, LanguageModel};
@@ -87,9 +87,14 @@ pub struct AssistantSettings {
pub profiles: IndexMap<AgentProfileId, AgentProfile>,
pub always_allow_tool_actions: bool,
pub notify_when_agent_waiting: NotifyWhenAgentWaiting,
pub stream_edits: bool,
}
impl AssistantSettings {
pub fn stream_edits(&self, cx: &App) -> bool {
cx.has_flag::<AgentStreamEditsFeatureFlag>() || self.stream_edits
}
pub fn are_live_diffs_enabled(&self, cx: &App) -> bool {
if cx.has_flag::<Assistant2FeatureFlag>() {
return false;
@@ -218,6 +223,7 @@ impl AssistantSettingsContent {
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
},
VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(),
},
@@ -245,6 +251,7 @@ impl AssistantSettingsContent {
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
},
None => AssistantSettingsContentV2::default(),
}
@@ -495,6 +502,7 @@ impl Default for VersionedAssistantSettingsContent {
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
})
}
}
@@ -550,6 +558,10 @@ pub struct AssistantSettingsContentV2 {
///
/// Default: "primary_screen"
notify_when_agent_waiting: Option<NotifyWhenAgentWaiting>,
/// Whether to stream edits from the agent as they are received.
///
/// Default: false
stream_edits: Option<bool>,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
@@ -712,6 +724,7 @@ impl Settings for AssistantSettings {
&mut settings.notify_when_agent_waiting,
value.notify_when_agent_waiting,
);
merge(&mut settings.stream_edits, value.stream_edits);
merge(&mut settings.default_profile, value.default_profile);
if let Some(profiles) = value.profiles {
@@ -843,6 +856,7 @@ mod tests {
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
},
)),
}

View File

@@ -24,6 +24,7 @@ language.workspace = true
language_model.workspace = true
parking_lot.workspace = true
project.workspace = true
regex.workspace = true
serde.workspace = true
serde_json.workspace = true
text.workspace = true

View File

@@ -1,4 +1,5 @@
mod action_log;
pub mod outline;
mod tool_registry;
mod tool_schema;
mod tool_working_set;

View File

@@ -0,0 +1,132 @@
use crate::ActionLog;
use anyhow::{Result, anyhow};
use gpui::{AsyncApp, Entity};
use language::{OutlineItem, ParseStatus};
use project::Project;
use regex::Regex;
use std::fmt::Write;
use text::Point;
/// For files over this size, instead of reading them (or including them in context),
/// we automatically provide the file's symbol outline instead, with line numbers.
pub const AUTO_OUTLINE_SIZE: usize = 16384;
pub async fn file_outline(
project: Entity<Project>,
path: String,
action_log: Entity<ActionLog>,
regex: Option<Regex>,
cx: &mut AsyncApp,
) -> anyhow::Result<String> {
let buffer = {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&path, cx)
.ok_or_else(|| anyhow!("Path {path} not found in project"))
})??;
project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?
};
action_log.update(cx, |action_log, cx| {
action_log.track_buffer(buffer.clone(), cx);
})?;
// Wait until the buffer has been fully parsed, so that we can read its outline.
let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
while *parse_status.borrow() != ParseStatus::Idle {
parse_status.changed().await?;
}
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
let Some(outline) = snapshot.outline(None) else {
return Err(anyhow!("No outline information available for this file."));
};
render_outline(
outline
.items
.into_iter()
.map(|item| item.to_point(&snapshot)),
regex,
0,
usize::MAX,
)
.await
}
pub async fn render_outline(
items: impl IntoIterator<Item = OutlineItem<Point>>,
regex: Option<Regex>,
offset: usize,
results_per_page: usize,
) -> Result<String> {
let mut items = items.into_iter().skip(offset);
let entries = items
.by_ref()
.filter(|item| {
regex
.as_ref()
.is_none_or(|regex| regex.is_match(&item.text))
})
.take(results_per_page)
.collect::<Vec<_>>();
let has_more = items.next().is_some();
let mut output = String::new();
let entries_rendered = render_entries(&mut output, entries);
// Calculate pagination information
let page_start = offset + 1;
let page_end = offset + entries_rendered;
let total_symbols = if has_more {
format!("more than {}", page_end)
} else {
page_end.to_string()
};
// Add pagination information
if has_more {
writeln!(&mut output, "\nShowing symbols {page_start}-{page_end} (there were more symbols found; use offset: {page_end} to see next page)",
)
} else {
writeln!(
&mut output,
"\nShowing symbols {page_start}-{page_end} (total symbols: {total_symbols})",
)
}
.ok();
Ok(output)
}
fn render_entries(
output: &mut String,
items: impl IntoIterator<Item = OutlineItem<Point>>,
) -> usize {
let mut entries_rendered = 0;
for item in items {
// Indent based on depth ("" for level 0, " " for level 1, etc.)
for _ in 0..item.depth {
output.push(' ');
}
output.push_str(&item.text);
// Add position information - convert to 1-based line numbers for display
let start_line = item.range.start.row + 1;
let end_line = item.range.end.row + 1;
if start_line == end_line {
writeln!(output, " [L{}]", start_line).ok();
} else {
writeln!(output, " [L{}-{}]", start_line, end_line).ok();
}
entries_rendered += 1;
}
entries_rendered
}

View File

@@ -11,16 +11,24 @@ workspace = true
[lib]
path = "src/assistant_tools.rs"
[features]
eval = []
[dependencies]
aho-corasick.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
assistant_settings.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
collections.workspace = true
component.workspace = true
editor.workspace = true
derive_more.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
handlebars = { workspace = true, features = ["rust-embed"] }
html_to_markdown.workspace = true
http_client.workspace = true
indoc.workspace = true
@@ -31,9 +39,14 @@ linkme.workspace = true
open.workspace = true
project.workspace = true
regex.workspace = true
rust-embed.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smallvec.workspace = true
streaming_diff.workspace = true
strsim.workspace = true
task.workspace = true
terminal.workspace = true
terminal_view.workspace = true
@@ -49,12 +62,18 @@ client = { workspace = true, features = ["test-support"] }
clock = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
gpui_tokio.workspace = true
fs = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
language_models.workspace = true
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
pretty_assertions.workspace = true
reqwest_client.workspace = true
settings = { workspace = true, features = ["test-support"] }
task = { workspace = true, features = ["test-support"]}
tempfile.workspace = true
tree-sitter-rust.workspace = true
workspace = { workspace = true, features = ["test-support"] }
unindent.workspace = true

View File

@@ -7,6 +7,7 @@ mod create_directory_tool;
mod create_file_tool;
mod delete_path_tool;
mod diagnostics_tool;
mod edit_agent;
mod edit_file_tool;
mod fetch_tool;
mod find_path_tool;
@@ -19,7 +20,9 @@ mod read_file_tool;
mod rename_tool;
mod replace;
mod schema;
mod streaming_edit_file_tool;
mod symbol_info_tool;
mod templates;
mod terminal_tool;
mod thinking_tool;
mod ui;
@@ -27,14 +30,19 @@ mod web_search_tool;
use std::sync::Arc;
use assistant_settings::AssistantSettings;
use assistant_tool::ToolRegistry;
use copy_path_tool::CopyPathTool;
use feature_flags::{AgentStreamEditsFeatureFlag, FeatureFlagAppExt};
use gpui::App;
use http_client::HttpClientWithUrl;
use language_model::LanguageModelRegistry;
use move_path_tool::MovePathTool;
use settings::{Settings, SettingsStore};
use web_search_tool::WebSearchTool;
pub(crate) use templates::*;
use crate::batch_tool::BatchTool;
use crate::code_action_tool::CodeActionTool;
use crate::code_symbols_tool::CodeSymbolsTool;
@@ -52,6 +60,7 @@ use crate::now_tool::NowTool;
use crate::open_tool::OpenTool;
use crate::read_file_tool::ReadFileTool;
use crate::rename_tool::RenameTool;
use crate::streaming_edit_file_tool::StreamingEditFileTool;
use crate::symbol_info_tool::SymbolInfoTool;
use crate::terminal_tool::TerminalTool;
use crate::thinking_tool::ThinkingTool;
@@ -68,10 +77,8 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(TerminalTool);
registry.register_tool(BatchTool);
registry.register_tool(CreateDirectoryTool);
registry.register_tool(CreateFileTool);
registry.register_tool(CopyPathTool);
registry.register_tool(DeletePathTool);
registry.register_tool(EditFileTool);
registry.register_tool(SymbolInfoTool);
registry.register_tool(CodeActionTool);
registry.register_tool(MovePathTool);
@@ -88,6 +95,12 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(ThinkingTool);
registry.register_tool(FetchTool::new(http_client));
register_edit_file_tool(cx);
cx.observe_flag::<AgentStreamEditsFeatureFlag, _>(|_, cx| register_edit_file_tool(cx))
.detach();
cx.observe_global::<SettingsStore>(register_edit_file_tool)
.detach();
cx.subscribe(
&LanguageModelRegistry::global(cx),
move |registry, event, cx| match event {
@@ -108,6 +121,21 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
.detach();
}
fn register_edit_file_tool(cx: &mut App) {
let registry = ToolRegistry::global(cx);
registry.unregister_tool(CreateFileTool);
registry.unregister_tool(EditFileTool);
registry.unregister_tool(StreamingEditFileTool);
if AssistantSettings::get_global(cx).stream_edits(cx) {
registry.register_tool(StreamingEditFileTool);
} else {
registry.register_tool(CreateFileTool);
registry.register_tool(EditFileTool);
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -146,6 +174,7 @@ mod tests {
#[gpui::test]
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
settings::init(cx);
AssistantSettings::register(cx);
let client = Client::new(
Arc::new(FakeSystemClock::new()),

View File

@@ -4,10 +4,10 @@ use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::outline;
use assistant_tool::{ActionLog, Tool, ToolResult};
use collections::IndexMap;
use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
use language::{OutlineItem, ParseStatus, Point};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{Project, Symbol};
use regex::{Regex, RegexBuilder};
@@ -148,59 +148,13 @@ impl Tool for CodeSymbolsTool {
};
cx.spawn(async move |cx| match input.path {
Some(path) => file_outline(project, path, action_log, regex, cx).await,
Some(path) => outline::file_outline(project, path, action_log, regex, cx).await,
None => project_symbols(project, regex, input.offset, cx).await,
})
.into()
}
}
pub async fn file_outline(
project: Entity<Project>,
path: String,
action_log: Entity<ActionLog>,
regex: Option<Regex>,
cx: &mut AsyncApp,
) -> anyhow::Result<String> {
let buffer = {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&path, cx)
.ok_or_else(|| anyhow!("Path {path} not found in project"))
})??;
project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?
};
action_log.update(cx, |action_log, cx| {
action_log.track_buffer(buffer.clone(), cx);
})?;
// Wait until the buffer has been fully parsed, so that we can read its outline.
let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
while *parse_status.borrow() != ParseStatus::Idle {
parse_status.changed().await?;
}
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
let Some(outline) = snapshot.outline(None) else {
return Err(anyhow!("No outline information available for this file."));
};
render_outline(
outline
.items
.into_iter()
.map(|item| item.to_point(&snapshot)),
regex,
0,
usize::MAX,
)
.await
}
async fn project_symbols(
project: Entity<Project>,
regex: Option<Regex>,
@@ -291,77 +245,3 @@ async fn project_symbols(
output
})
}
async fn render_outline(
items: impl IntoIterator<Item = OutlineItem<Point>>,
regex: Option<Regex>,
offset: usize,
results_per_page: usize,
) -> Result<String> {
let mut items = items.into_iter().skip(offset);
let entries = items
.by_ref()
.filter(|item| {
regex
.as_ref()
.is_none_or(|regex| regex.is_match(&item.text))
})
.take(results_per_page)
.collect::<Vec<_>>();
let has_more = items.next().is_some();
let mut output = String::new();
let entries_rendered = render_entries(&mut output, entries);
// Calculate pagination information
let page_start = offset + 1;
let page_end = offset + entries_rendered;
let total_symbols = if has_more {
format!("more than {}", page_end)
} else {
page_end.to_string()
};
// Add pagination information
if has_more {
writeln!(&mut output, "\nShowing symbols {page_start}-{page_end} (there were more symbols found; use offset: {page_end} to see next page)",
)
} else {
writeln!(
&mut output,
"\nShowing symbols {page_start}-{page_end} (total symbols: {total_symbols})",
)
}
.ok();
Ok(output)
}
fn render_entries(
output: &mut String,
items: impl IntoIterator<Item = OutlineItem<Point>>,
) -> usize {
let mut entries_rendered = 0;
for item in items {
// Indent based on depth ("" for level 0, " " for level 1, etc.)
for _ in 0..item.depth {
output.push(' ');
}
output.push_str(&item.text);
// Add position information - convert to 1-based line numbers for display
let start_line = item.range.start.row + 1;
let end_line = item.range.end.row + 1;
if start_line == end_line {
writeln!(output, " [L{}]", start_line).ok();
} else {
writeln!(output, " [L{}-{}]", start_line, end_line).ok();
}
entries_rendered += 1;
}
entries_rendered
}

View File

@@ -1,8 +1,8 @@
use std::sync::Arc;
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool, ToolResult, outline};
use gpui::{AnyWindowHandle, App, Entity, Task};
use itertools::Itertools;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -14,10 +14,6 @@ use ui::IconName;
use util::markdown::MarkdownInlineCode;
/// If the model requests to read a file whose size exceeds this, then
/// the tool will return the file's symbol outline instead of its contents,
/// and suggest trying again using line ranges from the outline.
const MAX_FILE_SIZE_TO_READ: usize = 16384;
/// If the model requests to list the entries in a directory with more
/// entries than this, then the tool will return a subset of the entries
/// and suggest trying again.
@@ -218,7 +214,7 @@ impl Tool for ContentsTool {
// No line ranges specified, so check file size to see if it's too big.
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
if file_size <= MAX_FILE_SIZE_TO_READ {
if file_size <= outline::AUTO_OUTLINE_SIZE {
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
action_log.update(cx, |log, cx| {
@@ -229,7 +225,7 @@ impl Tool for ContentsTool {
} else {
// File is too big, so return its outline and a suggestion to
// read again with a line number range specified.
let outline = file_outline(project, file_path, action_log, None, cx).await?;
let outline = outline::file_outline(project, file_path, action_log, None, cx).await?;
Ok(format!("This file was too big to read all at once. Here is an outline of its symbols:\n\n{outline}\n\nUsing the line numbers in this outline, you can call this tool again while specifying the start and end fields to see the implementations of symbols in the outline."))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,408 @@
use derive_more::{Add, AddAssign};
use smallvec::SmallVec;
use std::{cmp, mem, ops::Range};
const OLD_TEXT_END_TAG: &str = "</old_text>";
const NEW_TEXT_END_TAG: &str = "</new_text>";
const END_TAG_LEN: usize = OLD_TEXT_END_TAG.len();
const _: () = debug_assert!(OLD_TEXT_END_TAG.len() == NEW_TEXT_END_TAG.len());
#[derive(Debug)]
pub enum EditParserEvent {
OldText(String),
NewTextChunk { chunk: String, done: bool },
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Add, AddAssign)]
pub struct EditParserMetrics {
pub tags: usize,
pub mismatched_tags: usize,
}
#[derive(Debug)]
pub struct EditParser {
state: EditParserState,
buffer: String,
metrics: EditParserMetrics,
}
#[derive(Debug, PartialEq)]
enum EditParserState {
Pending,
WithinOldText,
AfterOldText,
WithinNewText { start: bool },
}
impl EditParser {
pub fn new() -> Self {
EditParser {
state: EditParserState::Pending,
buffer: String::new(),
metrics: EditParserMetrics::default(),
}
}
pub fn push(&mut self, chunk: &str) -> SmallVec<[EditParserEvent; 1]> {
self.buffer.push_str(chunk);
let mut edit_events = SmallVec::new();
loop {
match &mut self.state {
EditParserState::Pending => {
if let Some(start) = self.buffer.find("<old_text>") {
self.buffer.drain(..start + "<old_text>".len());
self.state = EditParserState::WithinOldText;
} else {
break;
}
}
EditParserState::WithinOldText => {
if let Some(tag_range) = self.find_end_tag() {
let mut start = 0;
if self.buffer.starts_with('\n') {
start = 1;
}
let mut old_text = self.buffer[start..tag_range.start].to_string();
if old_text.ends_with('\n') {
old_text.pop();
}
self.metrics.tags += 1;
if &self.buffer[tag_range.clone()] != OLD_TEXT_END_TAG {
self.metrics.mismatched_tags += 1;
}
self.buffer.drain(..tag_range.end);
self.state = EditParserState::AfterOldText;
edit_events.push(EditParserEvent::OldText(old_text));
} else {
break;
}
}
EditParserState::AfterOldText => {
if let Some(start) = self.buffer.find("<new_text>") {
self.buffer.drain(..start + "<new_text>".len());
self.state = EditParserState::WithinNewText { start: true };
} else {
break;
}
}
EditParserState::WithinNewText { start } => {
if !self.buffer.is_empty() {
if *start && self.buffer.starts_with('\n') {
self.buffer.remove(0);
}
*start = false;
}
if let Some(tag_range) = self.find_end_tag() {
let mut chunk = self.buffer[..tag_range.start].to_string();
if chunk.ends_with('\n') {
chunk.pop();
}
self.metrics.tags += 1;
if &self.buffer[tag_range.clone()] != NEW_TEXT_END_TAG {
self.metrics.mismatched_tags += 1;
}
self.buffer.drain(..tag_range.end);
self.state = EditParserState::Pending;
edit_events.push(EditParserEvent::NewTextChunk { chunk, done: true });
} else {
let mut end_prefixes = (1..END_TAG_LEN)
.flat_map(|i| [&NEW_TEXT_END_TAG[..i], &OLD_TEXT_END_TAG[..i]])
.chain(["\n"]);
if end_prefixes.all(|prefix| !self.buffer.ends_with(&prefix)) {
edit_events.push(EditParserEvent::NewTextChunk {
chunk: mem::take(&mut self.buffer),
done: false,
});
}
break;
}
}
}
}
edit_events
}
fn find_end_tag(&self) -> Option<Range<usize>> {
let old_text_end_tag_ix = self.buffer.find(OLD_TEXT_END_TAG);
let new_text_end_tag_ix = self.buffer.find(NEW_TEXT_END_TAG);
let start_ix = if let Some((old_text_ix, new_text_ix)) =
old_text_end_tag_ix.zip(new_text_end_tag_ix)
{
cmp::min(old_text_ix, new_text_ix)
} else {
old_text_end_tag_ix.or(new_text_end_tag_ix)?
};
Some(start_ix..start_ix + END_TAG_LEN)
}
pub fn finish(self) -> EditParserMetrics {
self.metrics
}
}
#[cfg(test)]
mod tests {
use super::*;
use indoc::indoc;
use rand::prelude::*;
use std::cmp;
#[gpui::test(iterations = 1000)]
fn test_single_edit(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text>original</old_text><new_text>updated</new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "original".to_string(),
new_text: "updated".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_multiple_edits(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
indoc! {"
<old_text>
first old
</old_text><new_text>first new</new_text>
<old_text>second old</old_text><new_text>
second new
</new_text>
"},
&mut parser,
&mut rng
),
vec![
Edit {
old_text: "first old".to_string(),
new_text: "first new".to_string(),
},
Edit {
old_text: "second old".to_string(),
new_text: "second new".to_string(),
},
]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 4,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_edits_with_extra_text(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
indoc! {"
ignore this <old_text>
content</old_text>extra stuff<new_text>updated content</new_text>trailing data
more text <old_text>second item
</old_text>middle text<new_text>modified second item</new_text>end
<old_text>third case</old_text><new_text>improved third case</new_text> with trailing text
"},
&mut parser,
&mut rng
),
vec![
Edit {
old_text: "content".to_string(),
new_text: "updated content".to_string(),
},
Edit {
old_text: "second item".to_string(),
new_text: "modified second item".to_string(),
},
Edit {
old_text: "third case".to_string(),
new_text: "improved third case".to_string(),
},
]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 6,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_nested_tags(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text>code with <tag>nested</tag> elements</old_text><new_text>new <code>content</code></new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "code with <tag>nested</tag> elements".to_string(),
new_text: "new <code>content</code>".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_empty_old_and_new_text(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text></old_text><new_text></new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "".to_string(),
new_text: "".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 100)]
fn test_multiline_content(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text>line1\nline2\nline3</old_text><new_text>line1\nmodified line2\nline3</new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "line1\nline2\nline3".to_string(),
new_text: "line1\nmodified line2\nline3".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_mismatched_tags(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
// Reduced from an actual Sonnet 3.7 output
indoc! {"
<old_text>
a
b
c
</new_text>
<new_text>
a
B
c
</old_text>
<old_text>
d
e
f
</new_text>
<new_text>
D
e
F
</old_text>
"},
&mut parser,
&mut rng
),
vec![
Edit {
old_text: "a\nb\nc".to_string(),
new_text: "a\nB\nc".to_string(),
},
Edit {
old_text: "d\ne\nf".to_string(),
new_text: "D\ne\nF".to_string(),
}
]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 4,
mismatched_tags: 4
}
);
}
#[derive(Default, Debug, PartialEq, Eq)]
struct Edit {
old_text: String,
new_text: String,
}
fn parse_random_chunks(input: &str, parser: &mut EditParser, rng: &mut StdRng) -> Vec<Edit> {
let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
chunk_indices.sort();
chunk_indices.push(input.len());
let mut pending_edit = Edit::default();
let mut edits = Vec::new();
let mut last_ix = 0;
for chunk_ix in chunk_indices {
for event in parser.push(&input[last_ix..chunk_ix]) {
match event {
EditParserEvent::OldText(old_text) => {
pending_edit.old_text = old_text;
}
EditParserEvent::NewTextChunk { chunk, done } => {
pending_edit.new_text.push_str(&chunk);
if done {
edits.push(pending_edit);
pending_edit = Edit::default();
}
}
}
}
last_ix = chunk_ix;
}
edits
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,328 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View File

@@ -0,0 +1,374 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
async fn run_git_blame(
git_binary: &Path,
working_directory: &Path,
path: &Path,
contents: &Rope,
) -> Result<String> {
let mut child = util::command::new_smol_command(git_binary)
.current_dir(working_directory)
.arg("blame")
.arg("--incremental")
.arg("--contents")
.arg("-")
.arg(path.as_os_str())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
let stdin = child
.stdin
.as_mut()
.context("failed to get pipe to stdin of git blame command")?;
for chunk in contents.chunks() {
stdin.write_all(chunk.as_bytes()).await?;
}
stdin.flush().await?;
let output = child
.output()
.await
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let trimmed = stderr.trim();
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
return Ok(String::new());
}
return Err(anyhow!("git blame process failed: {}", stderr));
}
Ok(String::from_utf8(output.stdout)?)
}
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View File

@@ -0,0 +1,378 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
async fn run_git_blame(
git_binary: &Path,
working_directory: &Path,
path: &Path,
contents: &Rope,
) -> Result<String> {
let mut child = util::command::new_smol_command(git_binary)
.current_dir(working_directory)
.arg("blame")
.arg("--incremental")
.arg("--contents")
.arg("-")
.arg(path.as_os_str())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
let stdin = child
.stdin
.as_mut()
.context("failed to get pipe to stdin of git blame command")?;
for chunk in contents.chunks() {
stdin.write_all(chunk.as_bytes()).await?;
}
stdin.flush().await?;
let output = child
.output()
.await
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
handle_command_output(output)
}
fn handle_command_output(output: std::process::Output) -> Result<String> {
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let trimmed = stderr.trim();
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
return Ok(String::new());
}
return Err(anyhow!("git blame process failed: {}", stderr));
}
Ok(String::from_utf8(output.stdout)?)
}
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View File

@@ -0,0 +1,374 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
async fn run_git_blame(
git_binary: &Path,
working_directory: &Path,
path: &Path,
contents: &Rope,
) -> Result<String> {
let mut child = util::command::new_smol_command(git_binary)
.current_dir(working_directory)
.arg("blame")
.arg("--incremental")
.arg("--contents")
.arg("-")
.arg(path.as_os_str())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
let stdin = child
.stdin
.as_mut()
.context("failed to get pipe to stdin of git blame command")?;
for chunk in contents.chunks() {
stdin.write_all(chunk.as_bytes()).await?;
}
stdin.flush().await?;
let output = child
.output()
.await
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let trimmed = stderr.trim();
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
return Ok(String::new());
}
return Err(anyhow!("git blame process failed: {}", stderr));
}
Ok(String::from_utf8(output.stdout)?)
}
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View File

@@ -0,0 +1,339 @@
// font-kit/src/canvas.rs
//
// Copyright © 2018 The Pathfinder Project Developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//! An in-memory bitmap surface for glyph rasterization.
use lazy_static::lazy_static;
use pathfinder_geometry::rect::RectI;
use pathfinder_geometry::vector::Vector2I;
use std::cmp;
use std::fmt;
use crate::utils;
lazy_static! {
static ref BITMAP_1BPP_TO_8BPP_LUT: [[u8; 8]; 256] = {
let mut lut = [[0; 8]; 256];
for byte in 0..0x100 {
let mut value = [0; 8];
for bit in 0..8 {
if (byte & (0x80 >> bit)) != 0 {
value[bit] = 0xff;
}
}
lut[byte] = value
}
lut
};
}
/// An in-memory bitmap surface for glyph rasterization.
pub struct Canvas {
/// The raw pixel data.
pub pixels: Vec<u8>,
/// The size of the buffer, in pixels.
pub size: Vector2I,
/// The number of *bytes* between successive rows.
pub stride: usize,
/// The image format of the canvas.
pub format: Format,
}
impl Canvas {
/// Creates a new blank canvas with the given pixel size and format.
///
/// Stride is automatically calculated from width.
///
/// The canvas is initialized with transparent black (all values 0).
#[inline]
pub fn new(size: Vector2I, format: Format) -> Canvas {
Canvas::with_stride(
size,
size.x() as usize * format.bytes_per_pixel() as usize,
format,
)
}
/// Creates a new blank canvas with the given pixel size, stride (number of bytes between
/// successive rows), and format.
///
/// The canvas is initialized with transparent black (all values 0).
pub fn with_stride(size: Vector2I, stride: usize, format: Format) -> Canvas {
Canvas {
pixels: vec![0; stride * size.y() as usize],
size,
stride,
format,
}
}
#[allow(dead_code)]
pub(crate) fn blit_from_canvas(&mut self, src: &Canvas) {
self.blit_from(
Vector2I::default(),
&src.pixels,
src.size,
src.stride,
src.format,
)
}
/// Blits to a rectangle with origin at `dst_point` and size according to `src_size`.
/// If the target area overlaps the boundaries of the canvas, only the drawable region is blitted.
/// `dst_point` and `src_size` are specified in pixels. `src_stride` is specified in bytes.
/// `src_stride` must be equal or larger than the actual data length.
#[allow(dead_code)]
pub(crate) fn blit_from(
&mut self,
dst_point: Vector2I,
src_bytes: &[u8],
src_size: Vector2I,
src_stride: usize,
src_format: Format,
) {
assert_eq!(
src_stride * src_size.y() as usize,
src_bytes.len(),
"Number of pixels in src_bytes does not match stride and size."
);
assert!(
src_stride >= src_size.x() as usize * src_format.bytes_per_pixel() as usize,
"src_stride must be >= than src_size.x()"
);
let dst_rect = RectI::new(dst_point, src_size);
let dst_rect = dst_rect.intersection(RectI::new(Vector2I::default(), self.size));
let dst_rect = match dst_rect {
Some(dst_rect) => dst_rect,
None => return,
};
match (self.format, src_format) {
(Format::A8, Format::A8)
| (Format::Rgb24, Format::Rgb24)
| (Format::Rgba32, Format::Rgba32) => {
self.blit_from_with::<BlitMemcpy>(dst_rect, src_bytes, src_stride, src_format)
}
(Format::A8, Format::Rgb24) => {
self.blit_from_with::<BlitRgb24ToA8>(dst_rect, src_bytes, src_stride, src_format)
}
(Format::Rgb24, Format::A8) => {
self.blit_from_with::<BlitA8ToRgb24>(dst_rect, src_bytes, src_stride, src_format)
}
(Format::Rgb24, Format::Rgba32) => self
.blit_from_with::<BlitRgba32ToRgb24>(dst_rect, src_bytes, src_stride, src_format),
(Format::Rgba32, Format::Rgb24) => self
.blit_from_with::<BlitRgb24ToRgba32>(dst_rect, src_bytes, src_stride, src_format),
(Format::Rgba32, Format::A8) | (Format::A8, Format::Rgba32) => unimplemented!(),
}
}
#[allow(dead_code)]
pub(crate) fn blit_from_bitmap_1bpp(
&mut self,
dst_point: Vector2I,
src_bytes: &[u8],
src_size: Vector2I,
src_stride: usize,
) {
if self.format != Format::A8 {
unimplemented!()
}
let dst_rect = RectI::new(dst_point, src_size);
let dst_rect = dst_rect.intersection(RectI::new(Vector2I::default(), self.size));
let dst_rect = match dst_rect {
Some(dst_rect) => dst_rect,
None => return,
};
let size = dst_rect.size();
let dest_bytes_per_pixel = self.format.bytes_per_pixel() as usize;
let dest_row_stride = size.x() as usize * dest_bytes_per_pixel;
let src_row_stride = utils::div_round_up(size.x() as usize, 8);
for y in 0..size.y() {
let (dest_row_start, src_row_start) = (
(y + dst_rect.origin_y()) as usize * self.stride
+ dst_rect.origin_x() as usize * dest_bytes_per_pixel,
y as usize * src_stride,
);
let dest_row_end = dest_row_start + dest_row_stride;
let src_row_end = src_row_start + src_row_stride;
let dest_row_pixels = &mut self.pixels[dest_row_start..dest_row_end];
let src_row_pixels = &src_bytes[src_row_start..src_row_end];
for x in 0..src_row_stride {
let pattern = &BITMAP_1BPP_TO_8BPP_LUT[src_row_pixels[x] as usize];
let dest_start = x * 8;
let dest_end = cmp::min(dest_start + 8, dest_row_stride);
let src = &pattern[0..(dest_end - dest_start)];
dest_row_pixels[dest_start..dest_end].clone_from_slice(src);
}
}
}
/// Blits to area `rect` using the data given in the buffer `src_bytes`.
/// `src_stride` must be specified in bytes.
/// The dimensions of `rect` must be in pixels.
fn blit_from_with<B: Blit>(
&mut self,
rect: RectI,
src_bytes: &[u8],
src_stride: usize,
src_format: Format,
) {
let src_bytes_per_pixel = src_format.bytes_per_pixel() as usize;
let dest_bytes_per_pixel = self.format.bytes_per_pixel() as usize;
for y in 0..rect.height() {
let (dest_row_start, src_row_start) = (
(y + rect.origin_y()) as usize * self.stride
+ rect.origin_x() as usize * dest_bytes_per_pixel,
y as usize * src_stride,
);
let dest_row_end = dest_row_start + rect.width() as usize * dest_bytes_per_pixel;
let src_row_end = src_row_start + rect.width() as usize * src_bytes_per_pixel;
let dest_row_pixels = &mut self.pixels[dest_row_start..dest_row_end];
let src_row_pixels = &src_bytes[src_row_start..src_row_end];
B::blit(dest_row_pixels, src_row_pixels)
}
}
}
impl fmt::Debug for Canvas {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Canvas")
.field("pixels", &self.pixels.len()) // Do not dump a vector content.
.field("size", &self.size)
.field("stride", &self.stride)
.field("format", &self.format)
.finish()
}
}
/// The image format for the canvas.
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Format {
/// Premultiplied R8G8B8A8, little-endian.
Rgba32,
/// R8G8B8, little-endian.
Rgb24,
/// A8.
A8,
}
impl Format {
/// Returns the number of bits per pixel that this image format corresponds to.
#[inline]
pub fn bits_per_pixel(self) -> u8 {
match self {
Format::Rgba32 => 32,
Format::Rgb24 => 24,
Format::A8 => 8,
}
}
/// Returns the number of color channels per pixel that this image format corresponds to.
#[inline]
pub fn components_per_pixel(self) -> u8 {
match self {
Format::Rgba32 => 4,
Format::Rgb24 => 3,
Format::A8 => 1,
}
}
/// Returns the number of bits per color channel that this image format contains.
#[inline]
pub fn bits_per_component(self) -> u8 {
self.bits_per_pixel() / self.components_per_pixel()
}
/// Returns the number of bytes per pixel that this image format corresponds to.
#[inline]
pub fn bytes_per_pixel(self) -> u8 {
self.bits_per_pixel() / 8
}
}
/// The antialiasing strategy that should be used when rasterizing glyphs.
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum RasterizationOptions {
/// "Black-and-white" rendering. Each pixel is either entirely on or off.
Bilevel,
/// Grayscale antialiasing. Only one channel is used.
GrayscaleAa,
/// Subpixel RGB antialiasing, for LCD screens.
SubpixelAa,
}
trait Blit {
fn blit(dest: &mut [u8], src: &[u8]);
}
struct BlitMemcpy;
impl Blit for BlitMemcpy {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
dest.clone_from_slice(src)
}
}
struct BlitRgb24ToA8;
impl Blit for BlitRgb24ToA8 {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
// TODO(pcwalton): SIMD.
for (dest, src) in dest.iter_mut().zip(src.chunks(3)) {
*dest = src[1]
}
}
}
struct BlitA8ToRgb24;
impl Blit for BlitA8ToRgb24 {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
for (dest, src) in dest.chunks_mut(3).zip(src.iter()) {
dest[0] = *src;
dest[1] = *src;
dest[2] = *src;
}
}
}
struct BlitRgba32ToRgb24;
impl Blit for BlitRgba32ToRgb24 {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
// TODO(pcwalton): SIMD.
for (dest, src) in dest.chunks_mut(3).zip(src.chunks(4)) {
dest.copy_from_slice(&src[0..3])
}
}
}
struct BlitRgb24ToRgba32;
impl Blit for BlitRgb24ToRgba32 {
fn blit(dest: &mut [u8], src: &[u8]) {
for (dest, src) in dest.chunks_mut(4).zip(src.chunks(3)) {
dest[0] = src[0];
dest[1] = src[1];
dest[2] = src[2];
dest[3] = 255;
}
}
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,14 @@
class InputCell:
def __init__(self, initial_value):
self.value = None
class ComputeCell:
def __init__(self, inputs, compute_function):
self.value = None
def add_callback(self, callback):
pass
def remove_callback(self, callback):
pass

View File

@@ -0,0 +1,271 @@
# These tests are auto-generated with test data from:
# https://github.com/exercism/problem-specifications/tree/main/exercises/react/canonical-data.json
# File last updated on 2023-07-19
from functools import partial
import unittest
from react import (
InputCell,
ComputeCell,
)
class ReactTest(unittest.TestCase):
def test_input_cells_have_a_value(self):
input = InputCell(10)
self.assertEqual(input.value, 10)
def test_an_input_cell_s_value_can_be_set(self):
input = InputCell(4)
input.value = 20
self.assertEqual(input.value, 20)
def test_compute_cells_calculate_initial_value(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
self.assertEqual(output.value, 2)
def test_compute_cells_take_inputs_in_the_right_order(self):
one = InputCell(1)
two = InputCell(2)
output = ComputeCell(
[
one,
two,
],
lambda inputs: inputs[0] + inputs[1] * 10,
)
self.assertEqual(output.value, 21)
def test_compute_cells_update_value_when_dependencies_are_changed(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
input.value = 3
self.assertEqual(output.value, 4)
def test_compute_cells_can_depend_on_other_compute_cells(self):
input = InputCell(1)
times_two = ComputeCell(
[
input,
],
lambda inputs: inputs[0] * 2,
)
times_thirty = ComputeCell(
[
input,
],
lambda inputs: inputs[0] * 30,
)
output = ComputeCell(
[
times_two,
times_thirty,
],
lambda inputs: inputs[0] + inputs[1],
)
self.assertEqual(output.value, 32)
input.value = 3
self.assertEqual(output.value, 96)
def test_compute_cells_fire_callbacks(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 3
self.assertEqual(cb1_observer[-1], 4)
def test_callback_cells_only_fire_on_change(self):
input = InputCell(1)
output = ComputeCell([input], lambda inputs: 111 if inputs[0] < 3 else 222)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer, [])
input.value = 4
self.assertEqual(cb1_observer[-1], 222)
def test_callbacks_do_not_report_already_reported_values(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer[-1], 3)
input.value = 3
self.assertEqual(cb1_observer[-1], 4)
def test_callbacks_can_fire_from_multiple_cells(self):
input = InputCell(1)
plus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
minus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] - 1,
)
cb1_observer = []
cb2_observer = []
callback1 = self.callback_factory(cb1_observer)
callback2 = self.callback_factory(cb2_observer)
plus_one.add_callback(callback1)
minus_one.add_callback(callback2)
input.value = 10
self.assertEqual(cb1_observer[-1], 11)
self.assertEqual(cb2_observer[-1], 9)
def test_callbacks_can_be_added_and_removed(self):
input = InputCell(11)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
cb2_observer = []
cb3_observer = []
callback1 = self.callback_factory(cb1_observer)
callback2 = self.callback_factory(cb2_observer)
callback3 = self.callback_factory(cb3_observer)
output.add_callback(callback1)
output.add_callback(callback2)
input.value = 31
self.assertEqual(cb1_observer[-1], 32)
self.assertEqual(cb2_observer[-1], 32)
output.remove_callback(callback1)
output.add_callback(callback3)
input.value = 41
self.assertEqual(len(cb1_observer), 1)
self.assertEqual(cb2_observer[-1], 42)
self.assertEqual(cb3_observer[-1], 42)
def test_removing_a_callback_multiple_times_doesn_t_interfere_with_other_callbacks(
self,
):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
cb2_observer = []
callback1 = self.callback_factory(cb1_observer)
callback2 = self.callback_factory(cb2_observer)
output.add_callback(callback1)
output.add_callback(callback2)
output.remove_callback(callback1)
output.remove_callback(callback1)
output.remove_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer, [])
self.assertEqual(cb2_observer[-1], 3)
def test_callbacks_should_only_be_called_once_even_if_multiple_dependencies_change(
self,
):
input = InputCell(1)
plus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
minus_one1 = ComputeCell(
[
input,
],
lambda inputs: inputs[0] - 1,
)
minus_one2 = ComputeCell(
[
minus_one1,
],
lambda inputs: inputs[0] - 1,
)
output = ComputeCell(
[
plus_one,
minus_one2,
],
lambda inputs: inputs[0] * inputs[1],
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 4
self.assertEqual(cb1_observer[-1], 10)
def test_callbacks_should_not_be_called_if_dependencies_change_but_output_value_doesn_t_change(
self,
):
input = InputCell(1)
plus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
minus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] - 1,
)
always_two = ComputeCell(
[
plus_one,
minus_one,
],
lambda inputs: inputs[0] - inputs[1],
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
always_two.add_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer, [])
input.value = 3
self.assertEqual(cb1_observer, [])
input.value = 4
self.assertEqual(cb1_observer, [])
input.value = 5
self.assertEqual(cb1_observer, [])
# Utility functions.
def callback_factory(self, observer):
def callback(observer, value):
observer.append(value)
return partial(callback, observer)

View File

@@ -282,7 +282,7 @@ pub struct EditFileToolCard {
}
impl EditFileToolCard {
fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
pub fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
let multibuffer = cx.new(|_| MultiBuffer::without_headers(Capability::ReadOnly));
let editor = cx.new(|cx| {
let mut editor = Editor::new(
@@ -323,7 +323,7 @@ impl EditFileToolCard {
}
}
fn set_diff(
pub fn set_diff(
&mut self,
path: Arc<Path>,
old_text: String,
@@ -343,6 +343,7 @@ impl EditFileToolCard {
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
.collect::<Vec<_>>();
multibuffer.clear(cx);
let (_, is_newly_added) = multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&buffer, cx),
buffer,
@@ -576,13 +577,10 @@ impl ToolCard for EditFileToolCard {
card.child(
v_flex()
.relative()
.map(|editor_container| {
if self.full_height_expanded {
editor_container.h_full()
} else {
editor_container
.h(DEFAULT_COLLAPSED_LINES as f32 * editor_line_height)
}
.h_full()
.when(!self.full_height_expanded, |editor_container| {
editor_container
.max_h(DEFAULT_COLLAPSED_LINES as f32 * editor_line_height)
})
.overflow_hidden()
.border_t_1()

View File

@@ -6,7 +6,7 @@ use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::{path::PathBuf, sync::Arc};
use ui::IconName;
use util::markdown::MarkdownEscaped;
@@ -50,7 +50,7 @@ impl Tool for OpenTool {
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
@@ -60,11 +60,107 @@ impl Tool for OpenTool {
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
// If path_or_url turns out to be a path in the project, make it absolute.
let abs_path = to_absolute_path(&input.path_or_url, project, cx);
cx.background_spawn(async move {
open::that(&input.path_or_url).context("Failed to open URL or file path")?;
match abs_path {
Some(path) => open::that(path),
None => open::that(&input.path_or_url),
}
.context("Failed to open URL or file path")?;
Ok(format!("Successfully opened {}", input.path_or_url))
})
.into()
}
}
fn to_absolute_path(
potential_path: &str,
project: Entity<Project>,
cx: &mut App,
) -> Option<PathBuf> {
let project = project.read(cx);
project
.find_project_path(PathBuf::from(potential_path), cx)
.and_then(|project_path| project.absolute_path(&project_path, cx))
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
use project::{FakeFs, Project};
use settings::SettingsStore;
use std::path::Path;
use tempfile::TempDir;
#[gpui::test]
async fn test_to_absolute_path(cx: &mut TestAppContext) {
init_test(cx);
let temp_dir = TempDir::new().expect("Failed to create temp directory");
let temp_path = temp_dir.path().to_string_lossy().to_string();
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
&temp_path,
serde_json::json!({
"src": {
"main.rs": "fn main() {}",
"lib.rs": "pub fn lib_fn() {}"
},
"docs": {
"readme.md": "# Project Documentation"
}
}),
)
.await;
// Use the temp_path as the root directory, not just its filename
let project = Project::test(fs.clone(), [temp_dir.path()], cx).await;
// Test cases where the function should return Some
cx.update(|cx| {
// Project-relative paths should return Some
// Create paths using the last segment of the temp path to simulate a project-relative path
let root_dir_name = Path::new(&temp_path)
.file_name()
.unwrap_or_else(|| std::ffi::OsStr::new("temp"))
.to_string_lossy();
assert!(
to_absolute_path(&format!("{root_dir_name}/src/main.rs"), project.clone(), cx)
.is_some(),
"Failed to resolve main.rs path"
);
assert!(
to_absolute_path(
&format!("{root_dir_name}/docs/readme.md",),
project.clone(),
cx,
)
.is_some(),
"Failed to resolve readme.md path"
);
// External URL should return None
let result = to_absolute_path("https://example.com", project.clone(), cx);
assert_eq!(result, None, "External URLs should return None");
// Path outside project
let result = to_absolute_path("../invalid/path", project.clone(), cx);
assert_eq!(result, None, "Paths outside the project should return None");
});
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
}
}

View File

@@ -1,5 +1,6 @@
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::outline;
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task};
@@ -14,10 +15,6 @@ use ui::IconName;
use util::markdown::MarkdownInlineCode;
/// If the model requests to read a file whose size exceeds this, then
/// the tool will return an error along with the model's symbol outline,
/// and suggest trying again using line ranges from the outline.
const MAX_FILE_SIZE_TO_READ: usize = 16384;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ReadFileToolInput {
/// The relative path of the file to read.
@@ -125,7 +122,8 @@ impl Tool for ReadFileTool {
if input.start_line.is_some() || input.end_line.is_some() {
let result = buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
let start = input.start_line.unwrap_or(1);
// .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0.
let start = input.start_line.unwrap_or(1).max(1);
let lines = text.split('\n').skip(start - 1);
if let Some(end) = input.end_line {
let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
@@ -144,7 +142,7 @@ impl Tool for ReadFileTool {
// No line ranges specified, so check file size to see if it's too big.
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
if file_size <= MAX_FILE_SIZE_TO_READ {
if file_size <= outline::AUTO_OUTLINE_SIZE {
// File is small enough, so return its contents.
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
@@ -154,9 +152,9 @@ impl Tool for ReadFileTool {
Ok(result)
} else {
// File is too big, so return an error with the outline
// File is too big, so return the outline
// and a suggestion to read again with line numbers.
let outline = file_outline(project, file_path, action_log, None, cx).await?;
let outline = outline::file_outline(project, file_path, action_log, None, cx).await?;
Ok(formatdoc! {"
This file was too big to read all at once. Here is an outline of its symbols:
@@ -332,6 +330,67 @@ mod test {
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
}
#[gpui::test]
async fn test_read_file_line_range_edge_cases(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
// start_line of 0 should be treated as 1
let result = cx
.update(|cx| {
let input = json!({
"path": "root/multiline.txt",
"start_line": 0,
"end_line": 2
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log.clone(), None, cx)
.output
})
.await;
assert_eq!(result.unwrap(), "Line 1\nLine 2");
// end_line of 0 should result in at least 1 line
let result = cx
.update(|cx| {
let input = json!({
"path": "root/multiline.txt",
"start_line": 1,
"end_line": 0
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log.clone(), None, cx)
.output
})
.await;
assert_eq!(result.unwrap(), "Line 1");
// when start_line > end_line, should still return at least 1 line
let result = cx
.update(|cx| {
let input = json!({
"path": "root/multiline.txt",
"start_line": 3,
"end_line": 2
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, None, cx)
.output
})
.await;
assert_eq!(result.unwrap(), "Line 3");
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);

View File

@@ -0,0 +1,352 @@
use crate::{
Templates,
edit_agent::{EditAgent, EditAgentOutputEvent},
edit_file_tool::EditFileToolCard,
schema::json_schema_for,
};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult};
use futures::StreamExt;
use gpui::{AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task};
use indoc::formatdoc;
use language_model::{
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolSchemaFormat,
};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc};
use ui::prelude::*;
use util::ResultExt;
pub struct StreamingEditFileTool;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct StreamingEditFileToolInput {
/// A one-line, user-friendly markdown description of the edit. This will be
/// shown in the UI and also passed to another model to perform the edit.
///
/// Be terse, but also descriptive in what you want to achieve with this
/// edit. Avoid generic instructions.
///
/// NEVER mention the file path in this description.
///
/// <example>Fix API endpoint URLs</example>
/// <example>Update copyright year in `page_footer`</example>
///
/// Make sure to include this field before all the others in the input object
/// so that we can display it immediately.
pub display_description: String,
/// The full path of the file to create or modify in the project.
///
/// WARNING: When specifying which file path need changing, you MUST
/// start each path with one of the project's root directories.
///
/// The following examples assume we have two root directories in the project:
/// - backend
/// - frontend
///
/// <example>
/// `backend/src/main.rs`
///
/// Notice how the file path starts with root-1. Without that, the path
/// would be ambiguous and the call would fail!
/// </example>
///
/// <example>
/// `frontend/db.js`
/// </example>
pub path: PathBuf,
/// If true, this tool will recreate the file from scratch.
/// If false, this tool will produce granular edits to an existing file.
pub create_or_overwrite: bool,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct PartialInput {
#[serde(default)]
path: String,
#[serde(default)]
display_description: String,
}
const DEFAULT_UI_TEXT: &str = "Editing file";
impl Tool for StreamingEditFileTool {
fn name(&self) -> String {
"edit_file".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("streaming_edit_file_tool/description.md").to_string()
}
fn icon(&self) -> IconName {
IconName::Pencil
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<StreamingEditFileToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<StreamingEditFileToolInput>(input.clone()) {
Ok(input) => input.display_description,
Err(_) => "Editing file".to_string(),
}
}
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
let description = input.display_description.trim();
if !description.is_empty() {
return description.to_string();
}
let path = input.path.trim();
if !path.is_empty() {
return path.to_string();
}
}
DEFAULT_UI_TEXT.to_string()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<StreamingEditFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!(
"Path {} not found in project",
input.path.display()
)))
.into();
};
let Some(worktree) = project
.read(cx)
.worktree_for_id(project_path.worktree_id, cx)
else {
return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
};
let exists = worktree.update(cx, |worktree, cx| {
worktree.file_exists(&project_path.path, cx)
});
let card = window.and_then(|window| {
window
.update(cx, |_, window, cx| {
cx.new(|cx| {
EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
})
})
.ok()
});
let card_clone = card.clone();
let messages = messages.to_vec();
let task = cx.spawn(async move |cx: &mut AsyncApp| {
if !input.create_or_overwrite && !exists.await? {
return Err(anyhow!("{} not found", input.path.display()));
}
let model = cx
.update(|cx| LanguageModelRegistry::read_global(cx).default_model())?
.context("default model not set")?
.model;
let edit_agent = EditAgent::new(model, action_log, Templates::new());
let buffer = project
.update(cx, |project, cx| {
project.open_buffer(project_path.clone(), cx)
})?
.await?;
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let old_text = cx
.background_spawn({
let old_snapshot = old_snapshot.clone();
async move { old_snapshot.text() }
})
.await;
let (output, mut events) = if input.create_or_overwrite {
edit_agent.overwrite(
buffer.clone(),
input.display_description.clone(),
messages,
cx,
)
} else {
edit_agent.edit(
buffer.clone(),
input.display_description.clone(),
messages,
cx,
)
};
let mut hallucinated_old_text = false;
while let Some(event) = events.next().await {
match event {
EditAgentOutputEvent::Edited { position } => {
if let Some(card) = card_clone.as_ref() {
let new_snapshot =
buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let new_text = cx
.background_spawn({
let new_snapshot = new_snapshot.clone();
async move { new_snapshot.text() }
})
.await;
card.update(cx, |card, cx| {
card.set_diff(
project_path.path.clone(),
old_text.clone(),
new_text,
cx,
);
})
.log_err();
}
}
EditAgentOutputEvent::OldTextNotFound(_) => hallucinated_old_text = true,
}
}
output.await?;
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let new_text = cx.background_spawn({
let new_snapshot = new_snapshot.clone();
async move { new_snapshot.text() }
});
let diff = cx.background_spawn(async move {
language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
});
let (new_text, diff) = futures::join!(new_text, diff);
if let Some(card) = card_clone {
card.update(cx, |card, cx| {
card.set_diff(project_path.path.clone(), old_text, new_text, cx);
})
.log_err();
}
let input_path = input.path.display();
if diff.is_empty() {
if hallucinated_old_text {
Err(anyhow!(formatdoc! {"
Some edits were produced but none of them could be applied.
Read the relevant sections of {input_path} again so that
I can perform the requested edits.
"}))
} else {
Ok("No edits were made.".to_string())
}
} else {
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff))
}
});
ToolResult {
output: task,
card: card.map(AnyToolCard::from),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn still_streaming_ui_text_with_path() {
let input = json!({
"path": "src/main.rs",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
"src/main.rs"
);
}
#[test]
fn still_streaming_ui_text_with_description() {
let input = json!({
"path": "",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
"Fix error handling",
);
}
#[test]
fn still_streaming_ui_text_with_path_and_description() {
let input = json!({
"path": "src/main.rs",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
"Fix error handling",
);
}
#[test]
fn still_streaming_ui_text_no_path_or_description() {
let input = json!({
"path": "",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
DEFAULT_UI_TEXT,
);
}
#[test]
fn still_streaming_ui_text_with_null() {
let input = serde_json::Value::Null;
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
DEFAULT_UI_TEXT,
);
}
}

View File

@@ -0,0 +1,8 @@
This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
Before using this tool:
1. Use the `read_file` tool to understand the file's contents and context
2. Verify the directory path is correct (only applicable when creating new files):
- Use the `list_directory` tool to verify the parent directory exists and is the correct location

View File

@@ -0,0 +1,32 @@
use anyhow::Result;
use handlebars::Handlebars;
use rust_embed::RustEmbed;
use serde::Serialize;
use std::sync::Arc;
#[derive(RustEmbed)]
#[folder = "src/templates"]
#[include = "*.hbs"]
struct Assets;
pub struct Templates(Handlebars<'static>);
impl Templates {
pub fn new() -> Arc<Self> {
let mut handlebars = Handlebars::new();
handlebars.register_embed_templates::<Assets>().unwrap();
handlebars.register_escape_fn(|text| text.into());
Arc::new(Self(handlebars))
}
}
pub trait Template: Sized {
const TEMPLATE_NAME: &'static str;
fn render(&self, templates: &Templates) -> Result<String>
where
Self: Serialize + Sized,
{
Ok(templates.0.render(Self::TEMPLATE_NAME, self)?)
}
}

View File

@@ -0,0 +1,12 @@
You are an expert engineer and your task is to write a new file from scratch.
<file_to_edit>
{{path}}
</file_to_edit>
<edit_description>
{{edit_description}}
</edit_description>
You MUST respond directly with the file's content, without explanations, additional text or triple backticks.
The text you output will be saved verbatim as the content of the file.

View File

@@ -0,0 +1,23 @@
You are an expert coder, and have been tasked with looking at the following diff:
<diff>
{{diff}}
</diff>
Evaluate the following assertions:
<assertions>
{{assertions}}
</assertions>
You must respond with a short analysis and a score between 0 and 100, where:
- 0 means no assertions pass
- 100 means all the assertions pass perfectly
<analysis>
- Assertion 1: one line describing why the first assertion passes or fails (even partially)
- Assertion 2: one line describing why the second assertion passes or fails (even partially)
- ...
- Assertion N: one line describing why the Nth assertion passes or fails (even partially)
</analysis>
<score>YOUR FINAL SCORE HERE</score>

View File

@@ -0,0 +1,49 @@
You are an expert text editor and your task is to produce a series of edits to a file given a description of the changes you need to make.
You MUST respond with a series of edits to that one file in the following format:
```
<edits>
<old_text>
OLD TEXT 1 HERE
</old_text>
<new_text>
NEW TEXT 1 HERE
</new_text>
<old_text>
OLD TEXT 2 HERE
</old_text>
<new_text>
NEW TEXT 2 HERE
</new_text>
<old_text>
OLD TEXT 3 HERE
</old_text>
<new_text>
NEW TEXT 3 HERE
</new_text>
</edits>
```
Rules for editing:
- `old_text` represents lines in the input file that will be replaced with `new_text`. `old_text` MUST exactly match the existing file content, character for character, including indentation.
- Always include enough context around the lines you want to replace in `old_text` such that it's impossible to mistake them for other lines.
- If you want to replace many occurrences of the same text, repeat the same `old_text`/`new_text` pair multiple times and I will apply them sequentially, one occurrence at a time.
- When reporting multiple edits, each edit assumes the previous one has already been applied! Therefore, you must ensure `old_text` doesn't reference text that has already been modified by a previous edit.
- Don't explain the edits, just report them.
- Only edit the file specified in `<file_to_edit>` and NEVER include edits to other files!
- If you open an <old_text> tag, you MUST close it using </old_text>
- If you open an <new_text> tag, you MUST close it using </new_text>
<file_to_edit>
{{path}}
</file_to_edit>
<edit_description>
{{edit_description}}
</edit_description>

View File

@@ -27,7 +27,9 @@ use crate::db::billing_subscription::{
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
};
use crate::llm::db::subscription_usage_meter::CompletionMode;
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
use crate::llm::{
AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT,
};
use crate::rpc::{ResultExt as _, Server};
use crate::{AppState, Cents, Error, Result};
use crate::{db::UserId, llm::db::LlmDatabase};
@@ -54,6 +56,10 @@ pub fn router() -> Router {
"/billing/subscriptions/manage",
post(manage_billing_subscription),
)
.route(
"/billing/subscriptions/migrate",
post(migrate_to_new_billing),
)
.route("/billing/monthly_spend", get(get_monthly_spend))
.route("/billing/usage", get(get_current_usage))
}
@@ -256,6 +262,7 @@ async fn list_billing_subscriptions(
enum ProductCode {
ZedPro,
ZedProTrial,
ZedFree,
}
#[derive(Debug, Deserialize)]
@@ -362,12 +369,7 @@ async fn create_billing_subscription(
let checkout_session_url = match body.product {
Some(ProductCode::ZedPro) => {
stripe_billing
.checkout_with_price(
app.config.zed_pro_price_id()?,
customer_id,
&user.github_login,
&success_url,
)
.checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
.await?
}
Some(ProductCode::ZedProTrial) => {
@@ -384,7 +386,6 @@ async fn create_billing_subscription(
stripe_billing
.checkout_with_zed_pro_trial(
app.config.zed_pro_price_id()?,
customer_id,
&user.github_login,
feature_flags,
@@ -392,6 +393,11 @@ async fn create_billing_subscription(
)
.await?
}
Some(ProductCode::ZedFree) => {
stripe_billing
.checkout_with_zed_free(customer_id, &user.github_login, &success_url)
.await?
}
None => {
let default_model = llm_db.model(
zed_llm_client::LanguageModelProvider::Anthropic,
@@ -458,6 +464,14 @@ async fn manage_billing_subscription(
))?
};
let Some(stripe_billing) = app.stripe_billing.clone() else {
log::error!("failed to retrieve Stripe billing object");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let customer = app
.db
.get_billing_customer_by_user_id(user.id)
@@ -508,8 +522,8 @@ async fn manage_billing_subscription(
let flow = match body.intent {
ManageSubscriptionIntent::ManageSubscription => None,
ManageSubscriptionIntent::UpgradeToPro => {
let zed_pro_price_id = app.config.zed_pro_price_id()?;
let zed_free_price_id = app.config.zed_free_price_id()?;
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
let stripe_subscription =
Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
@@ -602,6 +616,85 @@ async fn manage_billing_subscription(
}))
}
#[derive(Debug, Deserialize)]
struct MigrateToNewBillingBody {
github_user_id: i32,
}
#[derive(Debug, Serialize)]
struct MigrateToNewBillingResponse {
/// The ID of the subscription that was canceled.
canceled_subscription_id: String,
}
async fn migrate_to_new_billing(
Extension(app): Extension<Arc<AppState>>,
extract::Json(body): extract::Json<MigrateToNewBillingBody>,
) -> Result<Json<MigrateToNewBillingResponse>> {
let Some(stripe_client) = app.stripe_client.clone() else {
log::error!("failed to retrieve Stripe client");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let user = app
.db
.get_user_by_github_user_id(body.github_user_id)
.await?
.ok_or_else(|| anyhow!("user not found"))?;
let old_billing_subscriptions_by_user = app
.db
.get_active_billing_subscriptions(HashSet::from_iter([user.id]))
.await?;
let Some((_billing_customer, billing_subscription)) =
old_billing_subscriptions_by_user.get(&user.id)
else {
return Err(Error::http(
StatusCode::NOT_FOUND,
"No active billing subscriptions to migrate".into(),
));
};
let stripe_subscription_id = billing_subscription
.stripe_subscription_id
.parse::<stripe::SubscriptionId>()
.context("failed to parse Stripe subscription ID from database")?;
Subscription::cancel(
&stripe_client,
&stripe_subscription_id,
stripe::CancelSubscription {
invoice_now: Some(true),
..Default::default()
},
)
.await?;
let feature_flags = app.db.list_feature_flags().await?;
for feature_flag in ["new-billing", "assistant2"] {
let already_in_feature_flag = feature_flags.iter().any(|flag| flag.flag == feature_flag);
if already_in_feature_flag {
continue;
}
let feature_flag = feature_flags
.iter()
.find(|flag| flag.flag == feature_flag)
.context("failed to find feature flag: {feature_flag:?}")?;
app.db.add_user_flag(user.id, feature_flag.id).await?;
}
Ok(Json(MigrateToNewBillingResponse {
canceled_subscription_id: stripe_subscription_id.to_string(),
}))
}
/// The amount of time we wait in between each poll of Stripe events.
///
/// This value should strike a balance between:
@@ -856,9 +949,11 @@ async fn handle_customer_subscription_event(
log::info!("handling Stripe {} event: {}", event.type_, event.id);
let subscription_kind = maybe!({
let zed_pro_price_id = app.config.zed_pro_price_id().ok()?;
let zed_free_price_id = app.config.zed_free_price_id().ok()?;
let subscription_kind = maybe!(async {
let stripe_billing = app.stripe_billing.clone()?;
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.ok()?;
let zed_free_price_id = stripe_billing.zed_free_price_id().await.ok()?;
subscription.items.data.iter().find_map(|item| {
let price = item.price.as_ref()?;
@@ -875,7 +970,8 @@ async fn handle_customer_subscription_event(
None
}
})
});
})
.await;
let billing_customer =
find_or_create_billing_customer(app, stripe_client, subscription.customer)
@@ -1090,9 +1186,17 @@ struct UsageCounts {
pub remaining: Option<i32>,
}
#[derive(Debug, Serialize)]
struct ModelRequestUsage {
pub model: String,
pub mode: CompletionMode,
pub requests: i32,
}
#[derive(Debug, Serialize)]
struct GetCurrentUsageResponse {
pub model_requests: UsageCounts,
pub model_request_usage: Vec<ModelRequestUsage>,
pub edit_predictions: UsageCounts,
}
@@ -1119,6 +1223,7 @@ async fn get_current_usage(
limit: Some(0),
remaining: Some(0),
},
model_request_usage: Vec::new(),
edit_predictions: UsageCounts {
used: 0,
limit: Some(0),
@@ -1154,8 +1259,21 @@ async fn get_current_usage(
SubscriptionKind::ZedFree => zed_llm_client::Plan::Free,
};
let feature_flags = app.db.get_user_flags(user.id).await?;
let has_extended_trial = feature_flags
.iter()
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
let model_requests_limit = match plan.model_requests_limit() {
zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
zed_llm_client::UsageLimit::Limited(limit) => {
let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
1_000
} else {
limit
};
Some(limit)
}
zed_llm_client::UsageLimit::Unlimited => None,
};
let edit_prediction_limit = match plan.edit_predictions_limit() {
@@ -1163,12 +1281,30 @@ async fn get_current_usage(
zed_llm_client::UsageLimit::Unlimited => None,
};
let subscription_usage_meters = llm_db
.get_current_subscription_usage_meters_for_user(user.id, Utc::now())
.await?;
let model_request_usage = subscription_usage_meters
.into_iter()
.filter_map(|(usage_meter, _usage)| {
let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
Some(ModelRequestUsage {
model: model.name.clone(),
mode: usage_meter.mode,
requests: usage_meter.requests,
})
})
.collect::<Vec<_>>();
Ok(Json(GetCurrentUsageResponse {
model_requests: UsageCounts {
used: usage.model_requests,
limit: model_requests_limit,
remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
},
model_request_usage,
edit_predictions: UsageCounts {
used: usage.edit_predictions,
limit: edit_prediction_limit,
@@ -1402,12 +1538,12 @@ async fn sync_model_request_usage_with_stripe(
let model = llm_db.model_by_id(usage_meter.model_id)?;
let (price_id, meter_event_name) = match model.name.as_str() {
"claude-3-5-sonnet" => (&claude_3_5_sonnet.id, "claude_3_5_sonnet/requests"),
let (price, meter_event_name) = match model.name.as_str() {
"claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
"claude-3-7-sonnet" => match usage_meter.mode {
CompletionMode::Normal => (&claude_3_7_sonnet.id, "claude_3_7_sonnet/requests"),
CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
CompletionMode::Max => {
(&claude_3_7_sonnet_max.id, "claude_3_7_sonnet/requests/max")
(&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
}
},
model_name => {
@@ -1416,7 +1552,7 @@ async fn sync_model_request_usage_with_stripe(
};
stripe_billing
.subscribe_to_price(&stripe_subscription_id, price_id)
.subscribe_to_price(&stripe_subscription_id, price)
.await?;
stripe_billing
.bill_model_request_usage(

View File

@@ -180,9 +180,6 @@ pub struct Config {
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
pub stripe_api_key: Option<String>,
pub stripe_zed_pro_price_id: Option<String>,
pub stripe_zed_pro_trial_price_id: Option<String>,
pub stripe_zed_free_price_id: Option<String>,
pub supermaven_admin_api_key: Option<Arc<str>>,
pub user_backfiller_github_access_token: Option<Arc<str>>,
}
@@ -201,22 +198,6 @@ impl Config {
}
}
pub fn zed_pro_price_id(&self) -> anyhow::Result<stripe::PriceId> {
Self::parse_stripe_price_id("Zed Pro", self.stripe_zed_pro_price_id.as_deref())
}
pub fn zed_free_price_id(&self) -> anyhow::Result<stripe::PriceId> {
Self::parse_stripe_price_id("Zed Free", self.stripe_zed_pro_price_id.as_deref())
}
fn parse_stripe_price_id(name: &str, value: Option<&str>) -> anyhow::Result<stripe::PriceId> {
use std::str::FromStr as _;
let price_id = value.ok_or_else(|| anyhow!("{name} price ID not set"))?;
Ok(stripe::PriceId::from_str(price_id)?)
}
#[cfg(test)]
pub fn test() -> Self {
Self {
@@ -254,9 +235,6 @@ impl Config {
migrations_path: None,
seed_path: None,
stripe_api_key: None,
stripe_zed_pro_price_id: None,
stripe_zed_pro_trial_price_id: None,
stripe_zed_free_price_id: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
kinesis_region: None,

View File

@@ -1,3 +1,4 @@
use crate::db::UserId;
use crate::llm::db::queries::subscription_usages::convert_chrono_to_time;
use super::*;
@@ -34,4 +35,38 @@ impl LlmDatabase {
})
.await
}
/// Returns all current subscription usage meters for the given user as of the given timestamp.
pub async fn get_current_subscription_usage_meters_for_user(
&self,
user_id: UserId,
now: DateTimeUtc,
) -> Result<Vec<(subscription_usage_meter::Model, subscription_usage::Model)>> {
let now = convert_chrono_to_time(now)?;
self.transaction(|tx| async move {
let result = subscription_usage_meter::Entity::find()
.inner_join(subscription_usage::Entity)
.filter(subscription_usage::Column::UserId.eq(user_id))
.filter(
subscription_usage::Column::PeriodStartAt
.lte(now)
.and(subscription_usage::Column::PeriodEndAt.gte(now)),
)
.select_also(subscription_usage::Entity)
.all(&*tx)
.await?;
let result = result
.into_iter()
.filter_map(|(meter, usage)| {
let usage = usage?;
Some((meter, usage))
})
.collect();
Ok(result)
})
.await
}
}

View File

@@ -32,6 +32,8 @@ pub struct LlmTokenClaims {
pub has_llm_subscription: bool,
pub max_monthly_spend_in_cents: u32,
pub custom_llm_monthly_allowance_in_cents: Option<u32>,
#[serde(default)]
pub use_new_billing: bool,
pub plan: Plan,
#[serde(default)]
pub has_extended_trial: bool,
@@ -90,6 +92,7 @@ impl LlmTokenClaims {
custom_llm_monthly_allowance_in_cents: user
.custom_llm_monthly_allowance_in_cents
.map(|allowance| allowance as u32),
use_new_billing: feature_flags.iter().any(|flag| flag == "new-billing"),
plan: subscription
.as_ref()
.and_then(|subscription| subscription.kind)

View File

@@ -327,6 +327,10 @@ impl Server {
.add_request_handler(
forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
)
.add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
.add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
.add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
.add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
.add_request_handler(
forward_read_only_project_request::<proto::LanguageServerIdForName>,
)

View File

@@ -81,6 +81,24 @@ impl StripeBilling {
Ok(())
}
pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
self.find_price_id_by_lookup_key("zed-pro").await
}
pub async fn zed_free_price_id(&self) -> Result<PriceId> {
self.find_price_id_by_lookup_key("zed-free").await
}
pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
self.state
.read()
.await
.prices_by_lookup_key
.get(lookup_key)
.map(|price| price.id.clone())
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
}
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
self.state
.read()
@@ -88,7 +106,7 @@ impl StripeBilling {
.prices_by_lookup_key
.get(lookup_key)
.cloned()
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
.ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
}
pub async fn register_model_for_token_based_usage(
@@ -230,21 +248,29 @@ impl StripeBilling {
pub async fn subscribe_to_price(
&self,
subscription_id: &stripe::SubscriptionId,
price_id: &stripe::PriceId,
price: &stripe::Price,
) -> Result<()> {
let subscription =
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
if subscription_contains_price(&subscription, price_id) {
if subscription_contains_price(&subscription, &price.id) {
return Ok(());
}
const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
let price_per_unit = price.unit_amount.unwrap_or_default();
let units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
stripe::Subscription::update(
&self.client,
subscription_id,
stripe::UpdateSubscription {
items: Some(vec![stripe::UpdateSubscriptionItems {
price: Some(price_id.to_string()),
price: Some(price.id.to_string()),
billing_thresholds: Some(stripe::SubscriptionItemBillingThresholds {
usage_gte: Some(units_for_billing_threshold),
}),
..Default::default()
}]),
trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
@@ -463,19 +489,20 @@ impl StripeBilling {
Ok(session.url.context("no checkout session URL")?)
}
pub async fn checkout_with_price(
pub async fn checkout_with_zed_pro(
&self,
price_id: PriceId,
customer_id: stripe::CustomerId,
github_login: &str,
success_url: &str,
) -> Result<String> {
let zed_pro_price_id = self.zed_pro_price_id().await?;
let mut params = stripe::CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
price: Some(price_id.to_string()),
price: Some(zed_pro_price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
@@ -487,12 +514,13 @@ impl StripeBilling {
pub async fn checkout_with_zed_pro_trial(
&self,
zed_pro_price_id: PriceId,
customer_id: stripe::CustomerId,
github_login: &str,
feature_flags: Vec<String>,
success_url: &str,
) -> Result<String> {
let zed_pro_price_id = self.zed_pro_price_id().await?;
let eligible_for_extended_trial = feature_flags
.iter()
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
@@ -537,6 +565,29 @@ impl StripeBilling {
let session = stripe::CheckoutSession::create(&self.client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
pub async fn checkout_with_zed_free(
&self,
customer_id: stripe::CustomerId,
github_login: &str,
success_url: &str,
) -> Result<String> {
let zed_free_price_id = self.zed_free_price_id().await?;
let mut params = stripe::CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
price: Some(zed_free_price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
params.success_url = Some(success_url);
let session = stripe::CheckoutSession::create(&self.client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
}
#[derive(Serialize)]

View File

@@ -25,7 +25,7 @@ use language::{
use project::{
ProjectPath, SERVER_PROGRESS_THROTTLE_TIMEOUT,
lsp_store::{
lsp_ext_command::{ExpandedMacro, LspExpandMacro},
lsp_ext_command::{ExpandedMacro, LspExtExpandMacro},
rust_analyzer_ext::RUST_ANALYZER_NAME,
},
project_settings::{InlineBlameSettings, ProjectSettings},
@@ -2704,8 +2704,8 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes
let fake_language_server = fake_language_servers.next().await.unwrap();
// host
let mut expand_request_a =
fake_language_server.set_request_handler::<LspExpandMacro, _, _>(|params, _| async move {
let mut expand_request_a = fake_language_server.set_request_handler::<LspExtExpandMacro, _, _>(
|params, _| async move {
assert_eq!(
params.text_document.uri,
lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(),
@@ -2715,7 +2715,8 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes
name: "test_macro_name".to_string(),
expansion: "test_macro_expansion on the host".to_string(),
}))
});
},
);
editor_a.update_in(cx_a, |editor, window, cx| {
expand_macro_recursively(editor, &ExpandMacroRecursively, window, cx)
@@ -2738,8 +2739,8 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes
});
// client
let mut expand_request_b =
fake_language_server.set_request_handler::<LspExpandMacro, _, _>(|params, _| async move {
let mut expand_request_b = fake_language_server.set_request_handler::<LspExtExpandMacro, _, _>(
|params, _| async move {
assert_eq!(
params.text_document.uri,
lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(),
@@ -2749,7 +2750,8 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes
name: "test_macro_name".to_string(),
expansion: "test_macro_expansion on the client".to_string(),
}))
});
},
);
editor_b.update_in(cx_b, |editor, window, cx| {
expand_macro_recursively(editor, &ExpandMacroRecursively, window, cx)

View File

@@ -2902,7 +2902,7 @@ async fn test_git_branch_name(
.read(cx)
.branch
.as_ref()
.map(|branch| branch.name.to_string()),
.map(|branch| branch.name().to_owned()),
branch_name
)
}
@@ -6864,7 +6864,7 @@ async fn test_remote_git_branches(
let branches_b = branches_b
.into_iter()
.map(|branch| branch.name.to_string())
.map(|branch| branch.name().to_string())
.collect::<HashSet<_>>();
assert_eq!(branches_b, branches_set);
@@ -6895,7 +6895,7 @@ async fn test_remote_git_branches(
})
});
assert_eq!(host_branch.name, branches[2]);
assert_eq!(host_branch.name(), branches[2]);
// Also try creating a new branch
cx_b.update(|cx| {
@@ -6933,5 +6933,5 @@ async fn test_remote_git_branches(
})
});
assert_eq!(host_branch.name, "totally-new-branch");
assert_eq!(host_branch.name(), "totally-new-branch");
}

View File

@@ -293,7 +293,7 @@ async fn test_ssh_collaboration_git_branches(
let branches_b = branches_b
.into_iter()
.map(|branch| branch.name.to_string())
.map(|branch| branch.name().to_string())
.collect::<HashSet<_>>();
assert_eq!(&branches_b, &branches_set);
@@ -326,7 +326,7 @@ async fn test_ssh_collaboration_git_branches(
})
});
assert_eq!(server_branch.name, branches[2]);
assert_eq!(server_branch.name(), branches[2]);
// Also try creating a new branch
cx_b.update(|cx| {
@@ -366,7 +366,7 @@ async fn test_ssh_collaboration_git_branches(
})
});
assert_eq!(server_branch.name, "totally-new-branch");
assert_eq!(server_branch.name(), "totally-new-branch");
// Remove the git repository and check that all participants get the update.
remote_fs

View File

@@ -554,9 +554,6 @@ impl TestServer {
migrations_path: None,
seed_path: None,
stripe_api_key: None,
stripe_zed_pro_price_id: None,
stripe_zed_pro_trial_price_id: None,
stripe_zed_free_price_id: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
kinesis_region: None,

View File

@@ -15,18 +15,21 @@ path = "src/component_preview.rs"
default = []
[dependencies]
agent.workspace = true
anyhow.workspace = true
client.workspace = true
collections.workspace = true
component.workspace = true
db.workspace = true
gpui.workspace = true
languages.workspace = true
notifications.workspace = true
log.workspace = true
notifications.workspace = true
project.workspace = true
prompt_store.workspace = true
serde.workspace = true
ui.workspace = true
ui_input.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
db.workspace = true
anyhow.workspace = true
serde.workspace = true
assistant_tool.workspace = true

View File

@@ -3,10 +3,12 @@
//! A view for exploring Zed components.
mod persistence;
mod preview_support;
use std::iter::Iterator;
use std::sync::Arc;
use agent::{ActiveThread, ThreadStore};
use client::UserStore;
use component::{ComponentId, ComponentMetadata, components};
use gpui::{
@@ -19,6 +21,7 @@ use gpui::{ListState, ScrollHandle, ScrollStrategy, UniformListScrollHandle};
use languages::LanguageRegistry;
use notifications::status_toast::{StatusToast, ToastIcon};
use persistence::COMPONENT_PREVIEW_DB;
use preview_support::active_thread::{load_preview_thread_store, static_active_thread};
use project::Project;
use ui::{Divider, HighlightedLabel, ListItem, ListSubHeader, prelude::*};
@@ -33,6 +36,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
cx.observe_new(move |workspace: &mut Workspace, _window, cx| {
let app_state = app_state.clone();
let project = workspace.project().clone();
let weak_workspace = cx.entity().downgrade();
workspace.register_action(
@@ -45,6 +49,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
let component_preview = cx.new(|cx| {
ComponentPreview::new(
weak_workspace.clone(),
project.clone(),
language_registry,
user_store,
None,
@@ -52,6 +57,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
window,
cx,
)
.expect("Failed to create component preview")
});
workspace.add_item_to_active_pane(
@@ -69,6 +75,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
enum PreviewEntry {
AllComponents,
ActiveThread,
Separator,
Component(ComponentMetadata, Option<Vec<usize>>),
SectionHeader(SharedString),
@@ -91,6 +98,7 @@ enum PreviewPage {
#[default]
AllComponents,
Component(ComponentId),
ActiveThread,
}
struct ComponentPreview {
@@ -102,24 +110,63 @@ struct ComponentPreview {
active_page: PreviewPage,
components: Vec<ComponentMetadata>,
component_list: ListState,
agent_previews: Vec<
Box<
dyn Fn(
&Self,
WeakEntity<Workspace>,
Entity<ActiveThread>,
WeakEntity<ThreadStore>,
&mut Window,
&mut App,
) -> Option<AnyElement>,
>,
>,
cursor_index: usize,
language_registry: Arc<LanguageRegistry>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
user_store: Entity<UserStore>,
filter_editor: Entity<SingleLineInput>,
filter_text: String,
// preview support
thread_store: Option<Entity<ThreadStore>>,
active_thread: Option<Entity<ActiveThread>>,
}
impl ComponentPreview {
pub fn new(
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
language_registry: Arc<LanguageRegistry>,
user_store: Entity<UserStore>,
selected_index: impl Into<Option<usize>>,
active_page: Option<PreviewPage>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
) -> anyhow::Result<Self> {
let workspace_clone = workspace.clone();
let project_clone = project.clone();
let entity = cx.weak_entity();
window
.spawn(cx, async move |cx| {
let thread_store_task =
load_preview_thread_store(workspace_clone.clone(), project_clone.clone(), cx)
.await;
if let Ok(thread_store) = thread_store_task.await {
entity
.update_in(cx, |this, window, cx| {
this.thread_store = Some(thread_store.clone());
this.create_active_thread(window, cx);
})
.ok();
}
})
.detach();
let sorted_components = components().all_sorted();
let selected_index = selected_index.into().unwrap_or(0);
let active_page = active_page.unwrap_or(PreviewPage::AllComponents);
@@ -143,6 +190,40 @@ impl ComponentPreview {
},
);
// Initialize agent previews
let agent_previews = agent::all_agent_previews()
.into_iter()
.map(|id| {
Box::new(
move |_self: &ComponentPreview,
workspace: WeakEntity<Workspace>,
active_thread: Entity<ActiveThread>,
thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut App| {
agent::get_agent_preview(
&id,
workspace,
active_thread,
thread_store,
window,
cx,
)
},
)
as Box<
dyn Fn(
&ComponentPreview,
WeakEntity<Workspace>,
Entity<ActiveThread>,
WeakEntity<ThreadStore>,
&mut Window,
&mut App,
) -> Option<AnyElement>,
>
})
.collect::<Vec<_>>();
let mut component_preview = Self {
workspace_id: None,
focus_handle: cx.focus_handle(),
@@ -151,13 +232,17 @@ impl ComponentPreview {
language_registry,
user_store,
workspace,
project,
active_page,
component_map: components().0,
components: sorted_components,
component_list,
agent_previews,
cursor_index: selected_index,
filter_editor,
filter_text: String::new(),
thread_store: None,
active_thread: None,
};
if component_preview.cursor_index > 0 {
@@ -169,13 +254,41 @@ impl ComponentPreview {
let focus_handle = component_preview.filter_editor.read(cx).focus_handle(cx);
window.focus(&focus_handle);
component_preview
Ok(component_preview)
}
pub fn create_active_thread(
&mut self,
window: &mut Window,
cx: &mut Context<Self>,
) -> &mut Self {
let workspace = self.workspace.clone();
let language_registry = self.language_registry.clone();
let weak_handle = self.workspace.clone();
if let Some(workspace) = workspace.upgrade() {
let project = workspace.read(cx).project().clone();
if let Some(thread_store) = self.thread_store.clone() {
let active_thread = static_active_thread(
weak_handle,
project,
language_registry,
thread_store,
window,
cx,
);
self.active_thread = Some(active_thread);
cx.notify();
}
}
self
}
pub fn active_page_id(&self, _cx: &App) -> ActivePageId {
match &self.active_page {
PreviewPage::AllComponents => ActivePageId::default(),
PreviewPage::Component(component_id) => ActivePageId(component_id.0.to_string()),
PreviewPage::ActiveThread => ActivePageId("active_thread".to_string()),
}
}
@@ -289,6 +402,7 @@ impl ComponentPreview {
// Always show all components first
entries.push(PreviewEntry::AllComponents);
entries.push(PreviewEntry::ActiveThread);
entries.push(PreviewEntry::Separator);
let mut scopes: Vec<_> = scope_groups
@@ -389,6 +503,19 @@ impl ComponentPreview {
}))
.into_any_element()
}
PreviewEntry::ActiveThread => {
let selected = self.active_page == PreviewPage::ActiveThread;
ListItem::new(ix)
.child(Label::new("Active Thread").color(Color::Default))
.selectable(true)
.toggle_state(selected)
.inset(true)
.on_click(cx.listener(move |this, _, _, cx| {
this.set_active_page(PreviewPage::ActiveThread, cx);
}))
.into_any_element()
}
PreviewEntry::Separator => ListItem::new(ix)
.child(
h_flex()
@@ -471,6 +598,7 @@ impl ComponentPreview {
.render_scope_header(ix, shared_string.clone(), window, cx)
.into_any_element(),
PreviewEntry::AllComponents => div().w_full().h_0().into_any_element(),
PreviewEntry::ActiveThread => div().w_full().h_0().into_any_element(),
PreviewEntry::Separator => div().w_full().h_0().into_any_element(),
})
.unwrap()
@@ -595,6 +723,41 @@ impl ComponentPreview {
}
}
fn render_active_thread(
&self,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
v_flex()
.id("render-active-thread")
.size_full()
.child(
v_flex().children(self.agent_previews.iter().filter_map(|preview_fn| {
if let (Some(thread_store), Some(active_thread)) = (
self.thread_store.as_ref().map(|ts| ts.downgrade()),
self.active_thread.clone(),
) {
preview_fn(
self,
self.workspace.clone(),
active_thread,
thread_store,
window,
cx,
)
.map(|element| div().child(element))
} else {
None
}
})),
)
.children(self.active_thread.clone().map(|thread| thread.clone()))
.when_none(&self.active_thread.clone(), |this| {
this.child("No active thread")
})
.into_any_element()
}
fn test_status_toast(&self, cx: &mut Context<Self>) {
if let Some(workspace) = self.workspace.upgrade() {
workspace.update(cx, |workspace, cx| {
@@ -704,6 +867,9 @@ impl Render for ComponentPreview {
PreviewPage::Component(id) => self
.render_component_page(&id, window, cx)
.into_any_element(),
PreviewPage::ActiveThread => {
self.render_active_thread(window, cx).into_any_element()
}
}),
)
}
@@ -759,20 +925,28 @@ impl Item for ComponentPreview {
let language_registry = self.language_registry.clone();
let user_store = self.user_store.clone();
let weak_workspace = self.workspace.clone();
let project = self.project.clone();
let selected_index = self.cursor_index;
let active_page = self.active_page.clone();
Some(cx.new(|cx| {
Self::new(
weak_workspace,
language_registry,
user_store,
selected_index,
Some(active_page),
window,
cx,
)
}))
let self_result = Self::new(
weak_workspace,
project,
language_registry,
user_store,
selected_index,
Some(active_page),
window,
cx,
);
match self_result {
Ok(preview) => Some(cx.new(|_cx| preview)),
Err(e) => {
log::error!("Failed to clone component preview: {}", e);
None
}
}
}
fn to_item_events(event: &Self::Event, mut f: impl FnMut(workspace::item::ItemEvent)) {
@@ -838,10 +1012,12 @@ impl SerializableItem for ComponentPreview {
let user_store = user_store.clone();
let language_registry = language_registry.clone();
let weak_workspace = workspace.clone();
let project = project.clone();
cx.update(move |window, cx| {
Ok(cx.new(|cx| {
ComponentPreview::new(
weak_workspace,
project,
language_registry,
user_store,
None,
@@ -849,6 +1025,7 @@ impl SerializableItem for ComponentPreview {
window,
cx,
)
.expect("Failed to create component preview")
}))
})?
})

View File

@@ -0,0 +1 @@
pub mod active_thread;

View File

@@ -0,0 +1,69 @@
use languages::LanguageRegistry;
use project::Project;
use std::sync::Arc;
use agent::{ActiveThread, ContextStore, MessageSegment, ThreadStore};
use assistant_tool::ToolWorkingSet;
use gpui::{AppContext, AsyncApp, Entity, Task, WeakEntity};
use prompt_store::PromptBuilder;
use ui::{App, Window};
use workspace::Workspace;
pub async fn load_preview_thread_store(
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
cx: &mut AsyncApp,
) -> Task<anyhow::Result<Entity<ThreadStore>>> {
cx.spawn(async move |cx| {
workspace
.update(cx, |_, cx| {
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
None,
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
})?
.await
})
}
pub fn static_active_thread(
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>,
window: &mut Window,
cx: &mut App,
) -> Entity<ActiveThread> {
let context_store =
cx.new(|_| ContextStore::new(project.downgrade(), Some(thread_store.downgrade())));
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
thread.update(cx, |thread, cx| {
thread.insert_assistant_message(vec![
MessageSegment::Text("I'll help you fix the lifetime error in your `cx.spawn` call. When working with async operations in GPUI, there are specific patterns to follow for proper lifetime management.".to_string()),
MessageSegment::Text("\n\nLet's look at what's happening in your code:".to_string()),
MessageSegment::Text("\n\n---\n\nLet's check the current state of the active_thread.rs file to understand what might have changed:".to_string()),
MessageSegment::Text("\n\n---\n\nLooking at the implementation of `load_preview_thread_store` and understanding GPUI's async patterns, here's the issue:".to_string()),
MessageSegment::Text("\n\n1. `load_preview_thread_store` returns a `Task<anyhow::Result<Entity<ThreadStore>>>`, which means it's already a task".to_string()),
MessageSegment::Text("\n2. When you call this function inside another `spawn` call, you're nesting tasks incorrectly".to_string()),
MessageSegment::Text("\n3. The `this` parameter you're trying to use in your closure has the wrong context".to_string()),
MessageSegment::Text("\n\nHere's the correct way to implement this:".to_string()),
MessageSegment::Text("\n\n---\n\nThe problem is in how you're setting up the async closure and trying to reference variables like `window` and `language_registry` that aren't accessible in that scope.".to_string()),
MessageSegment::Text("\n\nHere's how to fix it:".to_string()),
], cx);
});
cx.new(|cx| {
ActiveThread::new(
thread,
thread_store,
context_store,
language_registry,
workspace.clone(),
window,
cx,
)
})
}

View File

@@ -34,3 +34,7 @@ smol.workspace = true
url = { workspace = true, features = ["serde"] }
util.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }

View File

@@ -140,7 +140,7 @@ impl Client {
/// This function initializes a new Client by spawning a child process for the context server,
/// setting up communication channels, and initializing handlers for input/output operations.
/// It takes a server ID, binary information, and an async app context as input.
pub fn new(
pub fn stdio(
server_id: ContextServerId,
binary: ModelContextServerBinary,
cx: AsyncApp,
@@ -158,7 +158,16 @@ impl Client {
.unwrap_or_else(String::new);
let transport = Arc::new(StdioTransport::new(binary, &cx)?);
Self::new(server_id, server_name.into(), transport, cx)
}
/// Creates a new Client instance for a context server.
pub fn new(
server_id: ContextServerId,
server_name: Arc<str>,
transport: Arc<dyn Transport>,
cx: AsyncApp,
) -> Result<Self> {
let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
let (output_done_tx, output_done_rx) = barrier::channel();
@@ -167,7 +176,7 @@ impl Client {
let response_handlers =
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
let stdout_input_task = cx.spawn({
let receive_input_task = cx.spawn({
let notification_handlers = notification_handlers.clone();
let response_handlers = response_handlers.clone();
let transport = transport.clone();
@@ -177,13 +186,13 @@ impl Client {
.await
}
});
let stderr_input_task = cx.spawn({
let receive_err_task = cx.spawn({
let transport = transport.clone();
async move |_| Self::handle_stderr(transport).log_err().await
async move |_| Self::handle_err(transport).log_err().await
});
let input_task = cx.spawn(async move |_| {
let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
stdout.or(stderr)
let (input, err) = futures::join!(receive_input_task, receive_err_task);
input.or(err)
});
let output_task = cx.background_spawn({
@@ -201,7 +210,7 @@ impl Client {
server_id,
notification_handlers,
response_handlers,
name: server_name.into(),
name: server_name,
next_id: Default::default(),
outbound_tx,
executor: cx.background_executor().clone(),
@@ -247,7 +256,7 @@ impl Client {
/// Handles the stderr output from the context server.
/// Continuously reads and logs any error messages from the server.
async fn handle_stderr(transport: Arc<dyn Transport>) -> anyhow::Result<()> {
async fn handle_err(transport: Arc<dyn Transport>) -> anyhow::Result<()> {
while let Some(err) = transport.receive_err().next().await {
log::warn!("context server stderr: {}", err.trim());
}

View File

@@ -12,7 +12,7 @@ pub use context_server_settings::{ContextServerSettings, ServerCommand, ServerCo
use gpui::{App, actions};
pub use crate::context_server_tool::ContextServerTool;
pub use crate::registry::ContextServerFactoryRegistry;
pub use crate::registry::ContextServerDescriptorRegistry;
actions!(context_servers, [Restart]);
@@ -21,7 +21,7 @@ pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers";
pub fn init(cx: &mut App) {
context_server_settings::init(cx);
ContextServerFactoryRegistry::default_global(cx);
ContextServerDescriptorRegistry::default_global(cx);
extension_context_server::init(cx);
CommandPaletteFilter::update_global(cx, |filter, _cx| {

View File

@@ -1,9 +1,21 @@
use std::sync::Arc;
use extension::{Extension, ExtensionContextServerProxy, ExtensionHostProxy, ProjectDelegate};
use gpui::{App, Entity};
use anyhow::Result;
use extension::{
ContextServerConfiguration, Extension, ExtensionContextServerProxy, ExtensionHostProxy,
ProjectDelegate,
};
use gpui::{App, AsyncApp, Entity, Task};
use project::Project;
use crate::{ContextServerFactoryRegistry, ServerCommand};
use crate::{ContextServerDescriptorRegistry, ServerCommand, registry};
pub fn init(cx: &mut App) {
let proxy = ExtensionHostProxy::default_global(cx);
proxy.register_context_server_proxy(ContextServerDescriptorRegistryProxy {
context_server_factory_registry: ContextServerDescriptorRegistry::global(cx),
});
}
struct ExtensionProject {
worktree_ids: Vec<u64>,
@@ -15,60 +27,78 @@ impl ProjectDelegate for ExtensionProject {
}
}
pub fn init(cx: &mut App) {
let proxy = ExtensionHostProxy::default_global(cx);
proxy.register_context_server_proxy(ContextServerFactoryRegistryProxy {
context_server_factory_registry: ContextServerFactoryRegistry::global(cx),
});
struct ContextServerDescriptor {
id: Arc<str>,
extension: Arc<dyn Extension>,
}
struct ContextServerFactoryRegistryProxy {
context_server_factory_registry: Entity<ContextServerFactoryRegistry>,
fn extension_project(project: Entity<Project>, cx: &mut AsyncApp) -> Result<Arc<ExtensionProject>> {
project.update(cx, |project, cx| {
Arc::new(ExtensionProject {
worktree_ids: project
.visible_worktrees(cx)
.map(|worktree| worktree.read(cx).id().to_proto())
.collect(),
})
})
}
impl ExtensionContextServerProxy for ContextServerFactoryRegistryProxy {
impl registry::ContextServerDescriptor for ContextServerDescriptor {
fn command(&self, project: Entity<Project>, cx: &AsyncApp) -> Task<Result<ServerCommand>> {
let id = self.id.clone();
let extension = self.extension.clone();
cx.spawn(async move |cx| {
let extension_project = extension_project(project, cx)?;
let mut command = extension
.context_server_command(id.clone(), extension_project.clone())
.await?;
command.command = extension
.path_from_extension(command.command.as_ref())
.to_string_lossy()
.to_string();
log::info!("loaded command for context server {id}: {command:?}");
Ok(ServerCommand {
path: command.command,
args: command.args,
env: Some(command.env.into_iter().collect()),
})
})
}
fn configuration(
&self,
project: Entity<Project>,
cx: &AsyncApp,
) -> Task<Result<Option<ContextServerConfiguration>>> {
let id = self.id.clone();
let extension = self.extension.clone();
cx.spawn(async move |cx| {
let extension_project = extension_project(project, cx)?;
let configuration = extension
.context_server_configuration(id.clone(), extension_project)
.await?;
log::debug!("loaded configuration for context server {id}: {configuration:?}");
Ok(configuration)
})
}
}
struct ContextServerDescriptorRegistryProxy {
context_server_factory_registry: Entity<ContextServerDescriptorRegistry>,
}
impl ExtensionContextServerProxy for ContextServerDescriptorRegistryProxy {
fn register_context_server(&self, extension: Arc<dyn Extension>, id: Arc<str>, cx: &mut App) {
self.context_server_factory_registry
.update(cx, |registry, _| {
registry.register_server_factory(
registry.register_context_server_descriptor(
id.clone(),
Arc::new({
move |project, cx| {
log::info!(
"loading command for context server {id} from extension {}",
extension.manifest().id
);
let id = id.clone();
let extension = extension.clone();
cx.spawn(async move |cx| {
let extension_project = project.update(cx, |project, cx| {
Arc::new(ExtensionProject {
worktree_ids: project
.visible_worktrees(cx)
.map(|worktree| worktree.read(cx).id().to_proto())
.collect(),
})
})?;
let mut command = extension
.context_server_command(id.clone(), extension_project)
.await?;
command.command = extension
.path_from_extension(command.command.as_ref())
.to_string_lossy()
.to_string();
log::info!("loaded command for context server {id}: {command:?}");
Ok(ServerCommand {
path: command.command,
args: command.args,
env: Some(command.env.into_iter().collect()),
})
})
}
}),
Arc::new(ContextServerDescriptor { id, extension })
as Arc<dyn registry::ContextServerDescriptor>,
)
});
}

View File

@@ -27,18 +27,27 @@ use project::Project;
use settings::{Settings, SettingsStore};
use util::ResultExt as _;
use crate::transport::Transport;
use crate::{ContextServerSettings, ServerConfig};
use crate::{
CONTEXT_SERVERS_NAMESPACE, ContextServerFactoryRegistry,
CONTEXT_SERVERS_NAMESPACE, ContextServerDescriptorRegistry,
client::{self, Client},
types,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ContextServerStatus {
Starting,
Running,
Error(Arc<str>),
}
pub struct ContextServer {
pub id: Arc<str>,
pub config: Arc<ServerConfig>,
pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
transport: Option<Arc<dyn Transport>>,
}
impl ContextServer {
@@ -47,9 +56,20 @@ impl ContextServer {
id,
config,
client: RwLock::new(None),
transport: None,
}
}
#[cfg(any(test, feature = "test-support"))]
pub fn test(id: Arc<str>, transport: Arc<dyn crate::transport::Transport>) -> Arc<Self> {
Arc::new(Self {
id,
client: RwLock::new(None),
config: Arc::new(ServerConfig::default()),
transport: Some(transport),
})
}
pub fn id(&self) -> Arc<str> {
self.id.clone()
}
@@ -63,20 +83,32 @@ impl ContextServer {
}
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
log::info!("starting context server {}", self.id);
let Some(command) = &self.config.command else {
bail!("no command specified for server {}", self.id);
let client = if let Some(transport) = self.transport.clone() {
Client::new(
client::ContextServerId(self.id.clone()),
self.id(),
transport,
cx.clone(),
)?
} else {
let Some(command) = &self.config.command else {
bail!("no command specified for server {}", self.id);
};
Client::stdio(
client::ContextServerId(self.id.clone()),
client::ModelContextServerBinary {
executable: Path::new(&command.path).to_path_buf(),
args: command.args.clone(),
env: command.env.clone(),
},
cx.clone(),
)?
};
let client = Client::new(
client::ContextServerId(self.id.clone()),
client::ModelContextServerBinary {
executable: Path::new(&command.path).to_path_buf(),
args: command.args.clone(),
env: command.env.clone(),
},
cx.clone(),
)?;
self.initialize(client).await
}
async fn initialize(&self, client: Client) -> Result<()> {
log::info!("starting context server {}", self.id);
let protocol = crate::protocol::ModelContextProtocol::new(client);
let client_info = types::Implementation {
name: "Zed".to_string(),
@@ -105,23 +137,26 @@ impl ContextServer {
pub struct ContextServerManager {
servers: HashMap<Arc<str>, Arc<ContextServer>>,
server_status: HashMap<Arc<str>, ContextServerStatus>,
project: Entity<Project>,
registry: Entity<ContextServerFactoryRegistry>,
registry: Entity<ContextServerDescriptorRegistry>,
update_servers_task: Option<Task<Result<()>>>,
needs_server_update: bool,
_subscriptions: Vec<Subscription>,
}
pub enum Event {
ServerStarted { server_id: Arc<str> },
ServerStopped { server_id: Arc<str> },
ServerStatusChanged {
server_id: Arc<str>,
status: Option<ContextServerStatus>,
},
}
impl EventEmitter<Event> for ContextServerManager {}
impl ContextServerManager {
pub fn new(
registry: Entity<ContextServerFactoryRegistry>,
registry: Entity<ContextServerDescriptorRegistry>,
project: Entity<Project>,
cx: &mut Context<Self>,
) -> Self {
@@ -138,6 +173,7 @@ impl ContextServerManager {
registry,
needs_server_update: false,
servers: HashMap::default(),
server_status: HashMap::default(),
update_servers_task: None,
};
this.available_context_servers_changed(cx);
@@ -153,7 +189,9 @@ impl ContextServerManager {
this.needs_server_update = false;
})?;
Self::maintain_servers(this.clone(), cx).await?;
if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
log::error!("Error maintaining context servers: {}", err);
}
this.update(cx, |this, cx| {
let has_any_context_servers = !this.running_servers().is_empty();
@@ -181,52 +219,37 @@ impl ContextServerManager {
.cloned()
}
pub fn status_for_server(&self, id: &str) -> Option<ContextServerStatus> {
self.server_status.get(id).cloned()
}
pub fn start_server(
&self,
server: Arc<ContextServer>,
cx: &mut Context<Self>,
) -> Task<anyhow::Result<()>> {
cx.spawn(async move |this, cx| {
let id = server.id.clone();
server.start(&cx).await?;
this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?;
Ok(())
})
) -> Task<Result<()>> {
cx.spawn(async move |this, cx| Self::run_server(this, server, cx).await)
}
pub fn stop_server(
&self,
&mut self,
server: Arc<ContextServer>,
cx: &mut Context<Self>,
) -> anyhow::Result<()> {
server.stop()?;
cx.emit(Event::ServerStopped {
server_id: server.id(),
});
) -> Result<()> {
server.stop().log_err();
self.update_server_status(server.id().clone(), None, cx);
Ok(())
}
pub fn restart_server(
&mut self,
id: &Arc<str>,
cx: &mut Context<Self>,
) -> Task<anyhow::Result<()>> {
pub fn restart_server(&mut self, id: &Arc<str>, cx: &mut Context<Self>) -> Task<Result<()>> {
let id = id.clone();
cx.spawn(async move |this, cx| {
if let Some(server) = this.update(cx, |this, _cx| this.servers.remove(&id))? {
server.stop()?;
let config = server.config();
this.update(cx, |this, cx| this.stop_server(server, cx))??;
let new_server = Arc::new(ContextServer::new(id.clone(), config));
new_server.clone().start(&cx).await?;
this.update(cx, |this, cx| {
this.servers.insert(id.clone(), new_server);
cx.emit(Event::ServerStopped {
server_id: id.clone(),
});
cx.emit(Event::ServerStarted {
server_id: id.clone(),
});
})?;
Self::run_server(this, new_server, cx).await?;
}
Ok(())
})
@@ -263,12 +286,14 @@ impl ContextServerManager {
(this.registry.clone(), this.project.clone())
})?;
for (id, factory) in
registry.read_with(cx, |registry, _| registry.context_server_factories())?
for (id, descriptor) in
registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
{
let config = desired_servers.entry(id).or_default();
if config.command.is_none() {
if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
if let Some(extension_command) =
descriptor.command(project.clone(), &cx).await.log_err()
{
config.command = Some(extension_command);
}
}
@@ -290,28 +315,270 @@ impl ContextServerManager {
for (id, config) in desired_servers {
let existing_config = this.servers.get(&id).map(|server| server.config());
if existing_config.as_deref() != Some(&config) {
let config = Arc::new(config);
let server = Arc::new(ContextServer::new(id.clone(), config));
let server = Arc::new(ContextServer::new(id.clone(), Arc::new(config)));
servers_to_start.insert(id.clone(), server.clone());
let old_server = this.servers.insert(id.clone(), server);
if let Some(old_server) = old_server {
if let Some(old_server) = this.servers.remove(&id) {
servers_to_stop.insert(id, old_server);
}
}
}
})?;
for (id, server) in servers_to_stop {
server.stop().log_err();
this.update(cx, |_, cx| cx.emit(Event::ServerStopped { server_id: id }))?;
for (_, server) in servers_to_stop {
this.update(cx, |this, cx| this.stop_server(server, cx).ok())?;
}
for (id, server) in servers_to_start {
if server.start(&cx).await.log_err().is_some() {
this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?;
}
for (_, server) in servers_to_start {
Self::run_server(this.clone(), server, cx).await.ok();
}
Ok(())
}
async fn run_server(
this: WeakEntity<Self>,
server: Arc<ContextServer>,
cx: &mut AsyncApp,
) -> Result<()> {
let id = server.id();
this.update(cx, |this, cx| {
this.update_server_status(id.clone(), Some(ContextServerStatus::Starting), cx);
this.servers.insert(id.clone(), server.clone());
})?;
match server.start(&cx).await {
Ok(_) => {
log::debug!("`{}` context server started", id);
this.update(cx, |this, cx| {
this.update_server_status(id.clone(), Some(ContextServerStatus::Running), cx)
})?;
Ok(())
}
Err(err) => {
log::error!("`{}` context server failed to start\n{}", id, err);
this.update(cx, |this, cx| {
this.update_server_status(
id.clone(),
Some(ContextServerStatus::Error(err.to_string().into())),
cx,
)
})?;
Err(err)
}
}
}
fn update_server_status(
&mut self,
id: Arc<str>,
status: Option<ContextServerStatus>,
cx: &mut Context<Self>,
) {
if let Some(status) = status.clone() {
self.server_status.insert(id.clone(), status);
} else {
self.server_status.remove(&id);
}
cx.emit(Event::ServerStatusChanged {
server_id: id,
status,
});
}
}
#[cfg(test)]
mod tests {
use std::pin::Pin;
use crate::types::{
Implementation, InitializeResponse, ProtocolVersion, RequestType, ServerCapabilities,
};
use super::*;
use futures::{Stream, StreamExt as _, lock::Mutex};
use gpui::{AppContext as _, TestAppContext};
use project::FakeFs;
use serde_json::json;
use util::path;
#[gpui::test]
async fn test_context_server_status(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({"code.rs": ""})).await;
let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
let manager = cx.new(|cx| ContextServerManager::new(registry.clone(), project, cx));
let server_1_id: Arc<str> = "mcp-1".into();
let server_2_id: Arc<str> = "mcp-2".into();
let transport_1 = Arc::new(FakeTransport::new(
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response("mcp-1".to_string()))
}
_ => None,
},
));
let transport_2 = Arc::new(FakeTransport::new(
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response("mcp-2".to_string()))
}
_ => None,
},
));
let server_1 = ContextServer::test(server_1_id.clone(), transport_1.clone());
let server_2 = ContextServer::test(server_2_id.clone(), transport_2.clone());
manager
.update(cx, |manager, cx| manager.start_server(server_1, cx))
.await
.unwrap();
cx.update(|cx| {
assert_eq!(
manager.read(cx).status_for_server(&server_1_id),
Some(ContextServerStatus::Running)
);
assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
});
manager
.update(cx, |manager, cx| manager.start_server(server_2.clone(), cx))
.await
.unwrap();
cx.update(|cx| {
assert_eq!(
manager.read(cx).status_for_server(&server_1_id),
Some(ContextServerStatus::Running)
);
assert_eq!(
manager.read(cx).status_for_server(&server_2_id),
Some(ContextServerStatus::Running)
);
});
manager
.update(cx, |manager, cx| manager.stop_server(server_2, cx))
.unwrap();
cx.update(|cx| {
assert_eq!(
manager.read(cx).status_for_server(&server_1_id),
Some(ContextServerStatus::Running)
);
assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
});
}
async fn create_test_project(
cx: &mut TestAppContext,
files: serde_json::Value,
) -> Entity<Project> {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/test"), files).await;
Project::test(fs, [path!("/test").as_ref()], cx).await
}
fn init_test_settings(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
ContextServerSettings::register(cx);
});
}
fn create_initialize_response(server_name: String) -> serde_json::Value {
serde_json::to_value(&InitializeResponse {
protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
server_info: Implementation {
name: server_name,
version: "1.0.0".to_string(),
},
capabilities: ServerCapabilities::default(),
meta: None,
})
.unwrap()
}
struct FakeTransport {
on_request: Arc<
dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
+ Send
+ Sync,
>,
tx: futures::channel::mpsc::UnboundedSender<String>,
rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
}
impl FakeTransport {
fn new(
on_request: impl Fn(
u64,
Option<RequestType>,
serde_json::Value,
) -> Option<serde_json::Value>
+ 'static
+ Send
+ Sync,
) -> Self {
let (tx, rx) = futures::channel::mpsc::unbounded();
Self {
on_request: Arc::new(on_request),
tx,
rx: Arc::new(Mutex::new(rx)),
}
}
}
#[async_trait::async_trait]
impl Transport for FakeTransport {
async fn send(&self, message: String) -> Result<()> {
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
if let Some(method) = msg.get("method") {
let request_type = method
.as_str()
.and_then(|method| types::RequestType::try_from(method).ok());
if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
let response = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": payload
});
self.tx
.unbounded_send(response.to_string())
.map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?;
}
}
}
Ok(())
}
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
let rx = self.rx.clone();
Box::pin(futures::stream::unfold(rx, |rx| async move {
let mut rx_guard = rx.lock().await;
if let Some(message) = rx_guard.next().await {
drop(rx_guard);
Some((message, rx))
} else {
None
}
}))
}
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
Box::pin(futures::stream::empty())
}
}
}

View File

@@ -2,38 +2,47 @@ use std::sync::Arc;
use anyhow::Result;
use collections::HashMap;
use extension::ContextServerConfiguration;
use gpui::{App, AppContext as _, AsyncApp, Entity, Global, ReadGlobal, Task};
use project::Project;
use crate::ServerCommand;
pub type ContextServerFactory =
Arc<dyn Fn(Entity<Project>, &AsyncApp) -> Task<Result<ServerCommand>> + Send + Sync + 'static>;
struct GlobalContextServerFactoryRegistry(Entity<ContextServerFactoryRegistry>);
impl Global for GlobalContextServerFactoryRegistry {}
#[derive(Default)]
pub struct ContextServerFactoryRegistry {
context_servers: HashMap<Arc<str>, ContextServerFactory>,
pub trait ContextServerDescriptor {
fn command(&self, project: Entity<Project>, cx: &AsyncApp) -> Task<Result<ServerCommand>>;
fn configuration(
&self,
project: Entity<Project>,
cx: &AsyncApp,
) -> Task<Result<Option<ContextServerConfiguration>>>;
}
impl ContextServerFactoryRegistry {
/// Returns the global [`ContextServerFactoryRegistry`].
struct GlobalContextServerDescriptorRegistry(Entity<ContextServerDescriptorRegistry>);
impl Global for GlobalContextServerDescriptorRegistry {}
#[derive(Default)]
pub struct ContextServerDescriptorRegistry {
context_servers: HashMap<Arc<str>, Arc<dyn ContextServerDescriptor>>,
}
impl ContextServerDescriptorRegistry {
/// Returns the global [`ContextServerDescriptorRegistry`].
pub fn global(cx: &App) -> Entity<Self> {
GlobalContextServerFactoryRegistry::global(cx).0.clone()
GlobalContextServerDescriptorRegistry::global(cx).0.clone()
}
/// Returns the global [`ContextServerFactoryRegistry`].
/// Returns the global [`ContextServerDescriptorRegistry`].
///
/// Inserts a default [`ContextServerFactoryRegistry`] if one does not yet exist.
/// Inserts a default [`ContextServerDescriptorRegistry`] if one does not yet exist.
pub fn default_global(cx: &mut App) -> Entity<Self> {
if !cx.has_global::<GlobalContextServerFactoryRegistry>() {
if !cx.has_global::<GlobalContextServerDescriptorRegistry>() {
let registry = cx.new(|_| Self::new());
cx.set_global(GlobalContextServerFactoryRegistry(registry));
cx.set_global(GlobalContextServerDescriptorRegistry(registry));
}
cx.global::<GlobalContextServerFactoryRegistry>().0.clone()
cx.global::<GlobalContextServerDescriptorRegistry>()
.0
.clone()
}
pub fn new() -> Self {
@@ -42,20 +51,28 @@ impl ContextServerFactoryRegistry {
}
}
pub fn context_server_factories(&self) -> Vec<(Arc<str>, ContextServerFactory)> {
pub fn context_server_descriptors(&self) -> Vec<(Arc<str>, Arc<dyn ContextServerDescriptor>)> {
self.context_servers
.iter()
.map(|(id, factory)| (id.clone(), factory.clone()))
.collect()
}
/// Registers the provided [`ContextServerFactory`].
pub fn register_server_factory(&mut self, id: Arc<str>, factory: ContextServerFactory) {
self.context_servers.insert(id, factory);
pub fn context_server_descriptor(&self, id: &str) -> Option<Arc<dyn ContextServerDescriptor>> {
self.context_servers.get(id).cloned()
}
/// Unregisters the [`ContextServerFactory`] for the server with the given ID.
pub fn unregister_server_factory_by_id(&mut self, server_id: &str) {
/// Registers the provided [`ContextServerDescriptor`].
pub fn register_context_server_descriptor(
&mut self,
id: Arc<str>,
descriptor: Arc<dyn ContextServerDescriptor>,
) {
self.context_servers.insert(id, descriptor);
}
/// Unregisters the [`ContextServerDescriptor`] for the server with the given ID.
pub fn unregister_context_server_descriptor_by_id(&mut self, server_id: &str) {
self.context_servers.remove(server_id);
}
}

View File

@@ -42,6 +42,30 @@ impl RequestType {
}
}
impl TryFrom<&str> for RequestType {
type Error = ();
fn try_from(s: &str) -> Result<Self, Self::Error> {
match s {
"initialize" => Ok(RequestType::Initialize),
"tools/call" => Ok(RequestType::CallTool),
"resources/unsubscribe" => Ok(RequestType::ResourcesUnsubscribe),
"resources/subscribe" => Ok(RequestType::ResourcesSubscribe),
"resources/read" => Ok(RequestType::ResourcesRead),
"resources/list" => Ok(RequestType::ResourcesList),
"logging/setLevel" => Ok(RequestType::LoggingSetLevel),
"prompts/get" => Ok(RequestType::PromptsGet),
"prompts/list" => Ok(RequestType::PromptsList),
"completion/complete" => Ok(RequestType::CompletionComplete),
"ping" => Ok(RequestType::Ping),
"tools/list" => Ok(RequestType::ListTools),
"resources/templates/list" => Ok(RequestType::ListResourceTemplates),
"roots/list" => Ok(RequestType::ListRoots),
_ => Err(()),
}
}
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ProtocolVersion(pub String);
@@ -154,7 +178,7 @@ pub struct CompletionArgument {
pub value: String,
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeResponse {
pub protocol_version: ProtocolVersion,
@@ -343,7 +367,7 @@ pub struct ClientCapabilities {
pub roots: Option<RootsCapabilities>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Default, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ServerCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]

View File

@@ -444,10 +444,7 @@ pub trait DebugAdapter: 'static + Send + Sync {
log::info!("Getting latest version of debug adapter {}", self.name());
delegate.update_status(self.name(), DapStatus::CheckingForUpdate);
if let Some(version) = self.fetch_latest_adapter_version(delegate).await.log_err() {
log::info!(
"Installiing latest version of debug adapter {}",
self.name()
);
log::info!("Installing latest version of debug adapter {}", self.name());
delegate.update_status(self.name(), DapStatus::Downloading);
match self.install_binary(version, delegate).await {
Ok(_) => {

View File

@@ -7,6 +7,7 @@ use crate::{
};
use crate::{new_session_modal::NewSessionModal, session::DebugSession};
use anyhow::{Result, anyhow};
use collections::{HashMap, HashSet};
use command_palette_hooks::CommandPaletteFilter;
use dap::DebugRequest;
use dap::{
@@ -26,6 +27,7 @@ use project::{Project, debugger::session::ThreadStatus};
use rpc::proto::{self};
use settings::Settings;
use std::any::TypeId;
use std::path::PathBuf;
use task::{DebugScenario, TaskContext};
use ui::{ContextMenu, Divider, DropdownMenu, Tooltip, prelude::*};
use workspace::SplitDirection;
@@ -403,7 +405,6 @@ impl DebugPanel {
pub fn resolve_scenario(
&self,
scenario: DebugScenario,
task_context: TaskContext,
buffer: Option<Entity<Buffer>>,
window: &Window,
@@ -424,8 +425,60 @@ impl DebugPanel {
stop_on_entry,
} = scenario;
let request = if let Some(mut request) = request {
// Resolve task variables within the request.
if let DebugRequest::Launch(_) = &mut request {}
if let DebugRequest::Launch(launch_config) = &mut request {
let mut variable_names = HashMap::default();
let mut substituted_variables = HashSet::default();
let task_variables = task_context
.task_variables
.iter()
.map(|(key, value)| {
let key_string = key.to_string();
if !variable_names.contains_key(&key_string) {
variable_names.insert(key_string.clone(), key.clone());
}
(key_string, value.as_str())
})
.collect::<HashMap<_, _>>();
let cwd = launch_config
.cwd
.as_ref()
.and_then(|cwd| cwd.to_str())
.and_then(|cwd| {
task::substitute_all_template_variables_in_str(
cwd,
&task_variables,
&variable_names,
&mut substituted_variables,
)
});
if let Some(cwd) = cwd {
launch_config.cwd = Some(PathBuf::from(cwd))
}
if let Some(program) = task::substitute_all_template_variables_in_str(
&launch_config.program,
&task_variables,
&variable_names,
&mut substituted_variables,
) {
launch_config.program = program;
}
for arg in launch_config.args.iter_mut() {
if let Some(substituted_arg) =
task::substitute_all_template_variables_in_str(
&arg,
&task_variables,
&variable_names,
&mut substituted_variables,
)
{
*arg = substituted_arg;
}
}
}
request
} else if let Some(build) = build {
@@ -944,6 +997,7 @@ impl DebugPanel {
past_debug_definition,
weak_panel,
workspace,
None,
window,
cx,
)

View File

@@ -158,6 +158,7 @@ pub fn init(cx: &mut App) {
debug_panel.read(cx).past_debug_definition.clone(),
weak_panel,
weak_workspace,
None,
window,
cx,
)
@@ -166,14 +167,22 @@ pub fn init(cx: &mut App) {
},
)
.register_action(|workspace: &mut Workspace, _: &Start, window, cx| {
tasks_ui::toggle_modal(
workspace,
None,
task::TaskModal::DebugModal,
window,
cx,
)
.detach();
if let Some(debug_panel) = workspace.panel::<DebugPanel>(cx) {
let weak_panel = debug_panel.downgrade();
let weak_workspace = cx.weak_entity();
let task_store = workspace.project().read(cx).task_store().clone();
workspace.toggle_modal(window, cx, |window, cx| {
NewSessionModal::new(
debug_panel.read(cx).past_debug_definition.clone(),
weak_panel,
weak_workspace,
Some(task_store),
window,
cx,
)
});
}
});
})
})

View File

@@ -6,19 +6,25 @@ use std::{
use dap::{DapRegistry, DebugRequest, adapters::DebugTaskDefinition};
use editor::{Editor, EditorElement, EditorStyle};
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
App, AppContext, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, TextStyle,
WeakEntity,
App, AppContext, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render,
Subscription, TextStyle, WeakEntity,
};
use picker::{Picker, PickerDelegate, highlighted_match_with_paths::HighlightedMatch};
use project::{TaskSourceKind, task_store::TaskStore};
use session_modes::{AttachMode, DebugScenarioDelegate, LaunchMode};
use settings::Settings;
use task::{DebugScenario, LaunchRequest, TaskContext};
use task::{DebugScenario, LaunchRequest};
use theme::ThemeSettings;
use ui::{
ActiveTheme, Button, ButtonCommon, ButtonSize, CheckboxWithLabel, Clickable, Color, Context,
ContextMenu, Disableable, DropdownMenu, FluentBuilder, InteractiveElement, IntoElement, Label,
LabelCommon as _, ParentElement, RenderOnce, SharedString, Styled, StyledExt, ToggleButton,
ToggleState, Toggleable, Window, div, h_flex, relative, rems, v_flex,
ContextMenu, Disableable, DropdownMenu, FluentBuilder, Icon, IconName, InteractiveElement,
IntoElement, Label, LabelCommon as _, ListItem, ListItemSpacing, ParentElement, RenderOnce,
SharedString, Styled, StyledExt, ToggleButton, ToggleState, Toggleable, Window, div, h_flex,
relative, rems, v_flex,
};
use util::ResultExt;
use workspace::{ModalView, Workspace};
use crate::{attach_modal::AttachModal, debugger_panel::DebugPanel};
@@ -57,6 +63,7 @@ impl NewSessionModal {
past_debug_definition: Option<DebugTaskDefinition>,
debug_panel: WeakEntity<DebugPanel>,
workspace: WeakEntity<Workspace>,
task_store: Option<Entity<TaskStore>>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -73,6 +80,18 @@ impl NewSessionModal {
_ => None,
};
if let Some(task_store) = task_store {
cx.defer_in(window, |this, window, cx| {
this.mode = NewSessionMode::scenario(
this.debug_panel.clone(),
this.workspace.clone(),
task_store,
window,
cx,
);
});
};
Self {
workspace: workspace.clone(),
debugger,
@@ -86,10 +105,10 @@ impl NewSessionModal {
}
}
fn debug_config(&self, cx: &App, debugger: &str) -> DebugScenario {
let request = self.mode.debug_task(cx);
fn debug_config(&self, cx: &App, debugger: &str) -> Option<DebugScenario> {
let request = self.mode.debug_task(cx)?;
let label = suggested_label(&request, debugger);
DebugScenario {
Some(DebugScenario {
adapter: debugger.to_owned().into(),
label,
request: Some(request),
@@ -100,21 +119,42 @@ impl NewSessionModal {
_ => None,
},
build: None,
}
})
}
fn start_new_session(&self, window: &mut Window, cx: &mut Context<Self>) {
let Some(debugger) = self.debugger.as_ref() else {
// todo: show in UI.
// todo(debugger): show in UI.
log::error!("No debugger selected");
return;
};
let config = self.debug_config(cx, debugger);
if let NewSessionMode::Scenario(picker) = &self.mode {
picker.update(cx, |picker, cx| {
picker.delegate.confirm(false, window, cx);
});
return;
}
let Some(config) = self.debug_config(cx, debugger) else {
log::error!("debug config not found in mode: {}", self.mode);
return;
};
let debug_panel = self.debug_panel.clone();
let workspace = self.workspace.clone();
cx.spawn_in(window, async move |this, cx| {
let task_contexts = workspace
.update_in(cx, |workspace, window, cx| {
tasks_ui::task_contexts(workspace, window, cx)
})?
.await;
let task_context = task_contexts.active_context().cloned().unwrap_or_default();
debug_panel.update_in(cx, |debug_panel, window, cx| {
debug_panel.start_session(config, TaskContext::default(), None, window, cx)
debug_panel.start_session(config, task_context, None, window, cx)
})?;
this.update(cx, |_, cx| {
cx.emit(DismissEvent);
@@ -256,9 +296,14 @@ impl NewSessionModal {
.iter()
.flat_map(|task_inventory| {
task_inventory.read(cx).list_debug_scenarios(
worktree.as_ref().map(|worktree| worktree.read(cx).id()),
worktree
.as_ref()
.map(|worktree| worktree.read(cx).id())
.iter()
.copied(),
)
})
.map(|(_source_kind, scenario)| scenario)
.collect()
})
.ok()
@@ -277,102 +322,22 @@ impl NewSessionModal {
}
}
#[derive(Clone)]
struct LaunchMode {
program: Entity<Editor>,
cwd: Entity<Editor>,
}
impl LaunchMode {
fn new(
past_launch_config: Option<LaunchRequest>,
window: &mut Window,
cx: &mut App,
) -> Entity<Self> {
let (past_program, past_cwd) = past_launch_config
.map(|config| (Some(config.program), config.cwd))
.unwrap_or_else(|| (None, None));
let program = cx.new(|cx| Editor::single_line(window, cx));
program.update(cx, |this, cx| {
this.set_placeholder_text("Program path", cx);
if let Some(past_program) = past_program {
this.set_text(past_program, window, cx);
};
});
let cwd = cx.new(|cx| Editor::single_line(window, cx));
cwd.update(cx, |this, cx| {
this.set_placeholder_text("Working Directory", cx);
if let Some(past_cwd) = past_cwd {
this.set_text(past_cwd.to_string_lossy(), window, cx);
};
});
cx.new(|_| Self { program, cwd })
}
fn debug_task(&self, cx: &App) -> task::LaunchRequest {
let path = self.cwd.read(cx).text(cx);
task::LaunchRequest {
program: self.program.read(cx).text(cx),
cwd: path.is_empty().not().then(|| PathBuf::from(path)),
args: Default::default(),
env: Default::default(),
}
}
}
#[derive(Clone)]
struct AttachMode {
definition: DebugTaskDefinition,
attach_picker: Entity<AttachModal>,
}
impl AttachMode {
fn new(
debugger: Option<SharedString>,
workspace: Entity<Workspace>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> Entity<Self> {
let definition = DebugTaskDefinition {
adapter: debugger.clone().unwrap_or_default(),
label: "Attach New Session Setup".into(),
request: dap::DebugRequest::Attach(task::AttachRequest { process_id: None }),
initialize_args: None,
tcp_connection: None,
stop_on_entry: Some(false),
};
let attach_picker = cx.new(|cx| {
let modal = AttachModal::new(definition.clone(), workspace, false, window, cx);
window.focus(&modal.focus_handle(cx));
modal
});
cx.new(|_| Self {
definition,
attach_picker,
})
}
fn debug_task(&self) -> task::AttachRequest {
task::AttachRequest { process_id: None }
}
}
static SELECT_DEBUGGER_LABEL: SharedString = SharedString::new_static("Select Debugger");
static SELECT_SCENARIO_LABEL: SharedString = SharedString::new_static("Select Profile");
#[derive(Clone)]
enum NewSessionMode {
Launch(Entity<LaunchMode>),
Scenario(Entity<Picker<DebugScenarioDelegate>>),
Attach(Entity<AttachMode>),
}
impl NewSessionMode {
fn debug_task(&self, cx: &App) -> DebugRequest {
fn debug_task(&self, cx: &App) -> Option<DebugRequest> {
match self {
NewSessionMode::Launch(entity) => entity.read(cx).debug_task(cx).into(),
NewSessionMode::Attach(entity) => entity.read(cx).debug_task().into(),
NewSessionMode::Launch(entity) => Some(entity.read(cx).debug_task(cx).into()),
NewSessionMode::Attach(entity) => Some(entity.read(cx).debug_task().into()),
NewSessionMode::Scenario(_) => None,
}
}
fn as_attach(&self) -> Option<&Entity<AttachMode>> {
@@ -382,6 +347,78 @@ impl NewSessionMode {
None
}
}
fn scenario(
debug_panel: WeakEntity<DebugPanel>,
workspace: WeakEntity<Workspace>,
task_store: Entity<TaskStore>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> NewSessionMode {
let picker = cx.new(|cx| {
Picker::uniform_list(
DebugScenarioDelegate::new(debug_panel, workspace, task_store),
window,
cx,
)
.modal(false)
});
cx.subscribe(&picker, |_, _, _, cx| {
cx.emit(DismissEvent);
})
.detach();
picker.focus_handle(cx).focus(window);
NewSessionMode::Scenario(picker)
}
fn attach(
debugger: Option<SharedString>,
workspace: Entity<Workspace>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> Self {
Self::Attach(AttachMode::new(debugger, workspace, window, cx))
}
fn launch(
past_launch_config: Option<LaunchRequest>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> Self {
Self::Launch(LaunchMode::new(past_launch_config, window, cx))
}
fn has_match(&self, cx: &App) -> bool {
match self {
NewSessionMode::Scenario(picker) => picker.read(cx).delegate.match_count() > 0,
NewSessionMode::Attach(picker) => {
picker
.read(cx)
.attach_picker
.read(cx)
.picker
.read(cx)
.delegate
.match_count()
> 0
}
_ => false,
}
}
}
impl std::fmt::Display for NewSessionMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mode = match self {
NewSessionMode::Launch(_) => "launch".to_owned(),
NewSessionMode::Attach(_) => "attach".to_owned(),
NewSessionMode::Scenario(_) => "scenario picker".to_owned(),
};
write!(f, "{}", mode)
}
}
impl Focusable for NewSessionMode {
@@ -389,6 +426,7 @@ impl Focusable for NewSessionMode {
match &self {
NewSessionMode::Launch(entity) => entity.read(cx).program.focus_handle(cx),
NewSessionMode::Attach(entity) => entity.read(cx).attach_picker.focus_handle(cx),
NewSessionMode::Scenario(entity) => entity.read(cx).focus_handle(cx),
}
}
}
@@ -437,27 +475,14 @@ impl RenderOnce for NewSessionMode {
NewSessionMode::Attach(entity) => entity.update(cx, |this, cx| {
this.clone().render(window, cx).into_any_element()
}),
NewSessionMode::Scenario(entity) => v_flex()
.w(rems(34.))
.child(entity.clone())
.into_any_element(),
}
}
}
impl NewSessionMode {
fn attach(
debugger: Option<SharedString>,
workspace: Entity<Workspace>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> Self {
Self::Attach(AttachMode::new(debugger, workspace, window, cx))
}
fn launch(
past_launch_config: Option<LaunchRequest>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> Self {
Self::Launch(LaunchMode::new(past_launch_config, window, cx))
}
}
fn render_editor(editor: &Entity<Editor>, window: &mut Window, cx: &App) -> impl IntoElement {
let settings = ThemeSettings::get_global(cx);
let theme = cx.theme();
@@ -519,6 +544,34 @@ impl Render for NewSessionModal {
h_flex()
.justify_start()
.w_full()
.child(
ToggleButton::new("debugger-session-ui-picker-button", "Scenarios")
.size(ButtonSize::Default)
.style(ui::ButtonStyle::Subtle)
.toggle_state(matches!(self.mode, NewSessionMode::Scenario(_)))
.on_click(cx.listener(|this, _, window, cx| {
let Some(task_store) = this
.workspace
.update(cx, |workspace, cx| {
workspace.project().read(cx).task_store().clone()
})
.ok()
else {
return;
};
this.mode = NewSessionMode::scenario(
this.debug_panel.clone(),
this.workspace.clone(),
task_store,
window,
cx,
);
cx.notify();
}))
.first(),
)
.child(
ToggleButton::new(
"debugger-session-ui-launch-button",
@@ -532,7 +585,7 @@ impl Render for NewSessionModal {
this.mode.focus_handle(cx).focus(window);
cx.notify();
}))
.first(),
.middle(),
)
.child(
ToggleButton::new(
@@ -601,10 +654,21 @@ impl Render for NewSessionModal {
})
.child(
Button::new("debugger-spawn", "Start")
.on_click(cx.listener(|this, _, window, cx| {
this.start_new_session(window, cx);
.on_click(cx.listener(|this, _, window, cx| match &this.mode {
NewSessionMode::Scenario(picker) => {
picker.update(cx, |picker, cx| {
picker.delegate.confirm(true, window, cx)
})
}
_ => this.start_new_session(window, cx),
}))
.disabled(self.debugger.is_none()),
.disabled(match self.mode {
NewSessionMode::Scenario(_) => !self.mode.has_match(cx),
NewSessionMode::Attach(_) => {
self.debugger.is_none() || !self.mode.has_match(cx)
}
NewSessionMode::Launch(_) => self.debugger.is_none(),
}),
),
),
)
@@ -619,3 +683,319 @@ impl Focusable for NewSessionModal {
}
impl ModalView for NewSessionModal {}
// This module makes sure that the modes setup the correct subscriptions whenever they're created
mod session_modes {
use std::rc::Rc;
use super::*;
#[derive(Clone)]
#[non_exhaustive]
pub(super) struct LaunchMode {
pub(super) program: Entity<Editor>,
pub(super) cwd: Entity<Editor>,
}
impl LaunchMode {
pub(super) fn new(
past_launch_config: Option<LaunchRequest>,
window: &mut Window,
cx: &mut App,
) -> Entity<Self> {
let (past_program, past_cwd) = past_launch_config
.map(|config| (Some(config.program), config.cwd))
.unwrap_or_else(|| (None, None));
let program = cx.new(|cx| Editor::single_line(window, cx));
program.update(cx, |this, cx| {
this.set_placeholder_text("Program path", cx);
if let Some(past_program) = past_program {
this.set_text(past_program, window, cx);
};
});
let cwd = cx.new(|cx| Editor::single_line(window, cx));
cwd.update(cx, |this, cx| {
this.set_placeholder_text("Working Directory", cx);
if let Some(past_cwd) = past_cwd {
this.set_text(past_cwd.to_string_lossy(), window, cx);
};
});
cx.new(|_| Self { program, cwd })
}
pub(super) fn debug_task(&self, cx: &App) -> task::LaunchRequest {
let path = self.cwd.read(cx).text(cx);
task::LaunchRequest {
program: self.program.read(cx).text(cx),
cwd: path.is_empty().not().then(|| PathBuf::from(path)),
args: Default::default(),
env: Default::default(),
}
}
}
#[derive(Clone)]
pub(super) struct AttachMode {
pub(super) definition: DebugTaskDefinition,
pub(super) attach_picker: Entity<AttachModal>,
_subscription: Rc<Subscription>,
}
impl AttachMode {
pub(super) fn new(
debugger: Option<SharedString>,
workspace: Entity<Workspace>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> Entity<Self> {
let definition = DebugTaskDefinition {
adapter: debugger.clone().unwrap_or_default(),
label: "Attach New Session Setup".into(),
request: dap::DebugRequest::Attach(task::AttachRequest { process_id: None }),
initialize_args: None,
tcp_connection: None,
stop_on_entry: Some(false),
};
let attach_picker = cx.new(|cx| {
let modal = AttachModal::new(definition.clone(), workspace, false, window, cx);
window.focus(&modal.focus_handle(cx));
modal
});
let subscription = cx.subscribe(&attach_picker, |_, _, _, cx| {
cx.emit(DismissEvent);
});
cx.new(|_| Self {
definition,
attach_picker,
_subscription: Rc::new(subscription),
})
}
pub(super) fn debug_task(&self) -> task::AttachRequest {
task::AttachRequest { process_id: None }
}
}
pub(super) struct DebugScenarioDelegate {
task_store: Entity<TaskStore>,
candidates: Option<Vec<(TaskSourceKind, DebugScenario)>>,
selected_index: usize,
matches: Vec<StringMatch>,
prompt: String,
debug_panel: WeakEntity<DebugPanel>,
workspace: WeakEntity<Workspace>,
}
impl DebugScenarioDelegate {
pub(super) fn new(
debug_panel: WeakEntity<DebugPanel>,
workspace: WeakEntity<Workspace>,
task_store: Entity<TaskStore>,
) -> Self {
Self {
task_store,
candidates: None,
selected_index: 0,
matches: Vec::new(),
prompt: String::new(),
debug_panel,
workspace,
}
}
}
impl PickerDelegate for DebugScenarioDelegate {
type ListItem = ui::ListItem;
fn match_count(&self) -> usize {
self.matches.len()
}
fn selected_index(&self) -> usize {
self.selected_index
}
fn set_selected_index(
&mut self,
ix: usize,
_window: &mut Window,
_cx: &mut Context<picker::Picker<Self>>,
) {
self.selected_index = ix;
}
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> std::sync::Arc<str> {
"".into()
}
fn update_matches(
&mut self,
query: String,
window: &mut Window,
cx: &mut Context<picker::Picker<Self>>,
) -> gpui::Task<()> {
let candidates: Vec<_> = match &self.candidates {
Some(candidates) => candidates
.into_iter()
.enumerate()
.map(|(index, (_, candidate))| {
StringMatchCandidate::new(index, candidate.label.as_ref())
})
.collect(),
None => {
let worktree_ids: Vec<_> = self
.workspace
.update(cx, |this, cx| {
this.visible_worktrees(cx)
.map(|tree| tree.read(cx).id())
.collect()
})
.ok()
.unwrap_or_default();
let scenarios: Vec<_> = self
.task_store
.read(cx)
.task_inventory()
.map(|item| item.read(cx).list_debug_scenarios(worktree_ids.into_iter()))
.unwrap_or_default();
self.candidates = Some(scenarios.clone());
scenarios
.into_iter()
.enumerate()
.map(|(index, (_, candidate))| {
StringMatchCandidate::new(index, candidate.label.as_ref())
})
.collect()
}
};
cx.spawn_in(window, async move |picker, cx| {
let matches = fuzzy::match_strings(
&candidates,
&query,
true,
1000,
&Default::default(),
cx.background_executor().clone(),
)
.await;
picker
.update(cx, |picker, _| {
let delegate = &mut picker.delegate;
delegate.matches = matches;
delegate.prompt = query;
if delegate.matches.is_empty() {
delegate.selected_index = 0;
} else {
delegate.selected_index =
delegate.selected_index.min(delegate.matches.len() - 1);
}
})
.log_err();
})
}
fn confirm(
&mut self,
_: bool,
window: &mut Window,
cx: &mut Context<picker::Picker<Self>>,
) {
let debug_scenario =
self.matches
.get(self.selected_index())
.and_then(|match_candidate| {
self.candidates
.as_ref()
.map(|candidates| candidates[match_candidate.candidate_id].clone())
});
let Some((task_source_kind, debug_scenario)) = debug_scenario else {
return;
};
let task_context = if let TaskSourceKind::Worktree {
id: worktree_id,
directory_in_worktree: _,
id_base: _,
} = task_source_kind
{
let workspace = self.workspace.clone();
cx.spawn_in(window, async move |_, cx| {
workspace
.update_in(cx, |workspace, window, cx| {
tasks_ui::task_contexts(workspace, window, cx)
})
.ok()?
.await
.task_context_for_worktree_id(worktree_id)
.cloned()
})
} else {
gpui::Task::ready(None)
};
cx.spawn_in(window, async move |this, cx| {
let task_context = task_context.await.unwrap_or_default();
this.update_in(cx, |this, window, cx| {
this.delegate
.debug_panel
.update(cx, |panel, cx| {
panel.start_session(debug_scenario, task_context, None, window, cx);
})
.ok();
cx.emit(DismissEvent);
})
.ok();
})
.detach();
}
fn dismissed(&mut self, _: &mut Window, cx: &mut Context<picker::Picker<Self>>) {
cx.emit(DismissEvent);
}
fn render_match(
&self,
ix: usize,
selected: bool,
window: &mut Window,
cx: &mut Context<picker::Picker<Self>>,
) -> Option<Self::ListItem> {
let hit = &self.matches[ix];
let highlighted_location = HighlightedMatch {
text: hit.string.clone(),
highlight_positions: hit.positions.clone(),
char_count: hit.string.chars().count(),
color: Color::Default,
};
let icon = Icon::new(IconName::FileTree)
.color(Color::Muted)
.size(ui::IconSize::Small);
Some(
ListItem::new(SharedString::from(format!("debug-scenario-selection-{ix}")))
.inset(true)
.start_slot::<Icon>(icon)
.spacing(ListItemSpacing::Sparse)
.toggle_state(selected)
.child(highlighted_location.render(window, cx)),
)
}
}
}

View File

@@ -188,7 +188,7 @@ impl Render for SubView {
cx.notify();
}))
.size_full()
// Add border uncoditionally to prevent layout shifts on focus changes.
// Add border unconditionally to prevent layout shifts on focus changes.
.border_1()
.when(self.pane_focus_handle.contains_focused(window, cx), |el| {
el.border_color(cx.theme().colors().pane_focused_border)

View File

@@ -19,6 +19,7 @@ component.workspace = true
ctor.workspace = true
editor.workspace = true
env_logger.workspace = true
futures.workspace = true
gpui.workspace = true
indoc.workspace = true
language.workspace = true
@@ -29,6 +30,7 @@ markdown.workspace = true
project.workspace = true
rand.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
text.workspace = true
theme.workspace = true

View File

@@ -14,6 +14,7 @@ use editor::{
display_map::{BlockPlacement, BlockProperties, BlockStyle, CustomBlockId},
scroll::Autoscroll,
};
use futures::future::join_all;
use gpui::{
AnyElement, AnyView, App, AsyncApp, Context, Entity, EventEmitter, FocusHandle, Focusable,
Global, InteractiveElement, IntoElement, ParentElement, Render, SharedString, Styled,
@@ -23,13 +24,15 @@ use language::{
Bias, Buffer, BufferRow, BufferSnapshot, DiagnosticEntry, Point, ToTreeSitterPoint,
};
use lsp::DiagnosticSeverity;
use project::{DiagnosticSummary, Project, ProjectPath, project_settings::ProjectSettings};
use project::{
DiagnosticSummary, Project, ProjectPath,
lsp_store::rust_analyzer_ext::{cancel_flycheck, run_flycheck},
project_settings::ProjectSettings,
};
use settings::Settings;
use std::{
any::{Any, TypeId},
cmp,
cmp::Ordering,
cmp::{self, Ordering},
ops::{Range, RangeInclusive},
sync::Arc,
time::Duration,
@@ -45,7 +48,10 @@ use workspace::{
searchable::SearchableItemHandle,
};
actions!(diagnostics, [Deploy, ToggleWarnings]);
actions!(
diagnostics,
[Deploy, ToggleWarnings, ToggleDiagnosticsRefresh]
);
#[derive(Default)]
pub(crate) struct IncludeWarnings(bool);
@@ -68,9 +74,16 @@ pub(crate) struct ProjectDiagnosticsEditor {
paths_to_update: BTreeSet<ProjectPath>,
include_warnings: bool,
update_excerpts_task: Option<Task<Result<()>>>,
cargo_diagnostics_fetch: CargoDiagnosticsFetchState,
_subscription: Subscription,
}
struct CargoDiagnosticsFetchState {
fetch_task: Option<Task<()>>,
cancel_task: Option<Task<()>>,
diagnostic_sources: Arc<Vec<ProjectPath>>,
}
impl EventEmitter<EditorEvent> for ProjectDiagnosticsEditor {}
const DIAGNOSTICS_UPDATE_DELAY: Duration = Duration::from_millis(50);
@@ -126,6 +139,7 @@ impl Render for ProjectDiagnosticsEditor {
.track_focus(&self.focus_handle(cx))
.size_full()
.on_action(cx.listener(Self::toggle_warnings))
.on_action(cx.listener(Self::toggle_diagnostics_refresh))
.child(child)
}
}
@@ -212,7 +226,11 @@ impl ProjectDiagnosticsEditor {
cx.observe_global_in::<IncludeWarnings>(window, |this, window, cx| {
this.include_warnings = cx.global::<IncludeWarnings>().0;
this.diagnostics.clear();
this.update_all_excerpts(window, cx);
this.update_all_diagnostics(window, cx);
})
.detach();
cx.observe_release(&cx.entity(), |editor, _, cx| {
editor.stop_cargo_diagnostics_fetch(cx);
})
.detach();
@@ -229,9 +247,14 @@ impl ProjectDiagnosticsEditor {
editor,
paths_to_update: Default::default(),
update_excerpts_task: None,
cargo_diagnostics_fetch: CargoDiagnosticsFetchState {
fetch_task: None,
cancel_task: None,
diagnostic_sources: Arc::new(Vec::new()),
},
_subscription: project_event_subscription,
};
this.update_all_excerpts(window, cx);
this.update_all_diagnostics(window, cx);
this
}
@@ -239,15 +262,17 @@ impl ProjectDiagnosticsEditor {
if self.update_excerpts_task.is_some() {
return;
}
let project_handle = self.project.clone();
self.update_excerpts_task = Some(cx.spawn_in(window, async move |this, cx| {
cx.background_executor()
.timer(DIAGNOSTICS_UPDATE_DELAY)
.await;
loop {
let Some(path) = this.update(cx, |this, _| {
let Some(path) = this.update(cx, |this, cx| {
let Some(path) = this.paths_to_update.pop_first() else {
this.update_excerpts_task.take();
this.update_excerpts_task = None;
cx.notify();
return None;
};
Some(path)
@@ -307,6 +332,32 @@ impl ProjectDiagnosticsEditor {
cx.set_global(IncludeWarnings(!self.include_warnings));
}
fn toggle_diagnostics_refresh(
&mut self,
_: &ToggleDiagnosticsRefresh,
window: &mut Window,
cx: &mut Context<Self>,
) {
let fetch_cargo_diagnostics = ProjectSettings::get_global(cx)
.diagnostics
.fetch_cargo_diagnostics();
if fetch_cargo_diagnostics {
if self.cargo_diagnostics_fetch.fetch_task.is_some() {
self.stop_cargo_diagnostics_fetch(cx);
} else {
self.update_all_diagnostics(window, cx);
}
} else {
if self.update_excerpts_task.is_some() {
self.update_excerpts_task = None;
} else {
self.update_all_diagnostics(window, cx);
}
}
cx.notify();
}
fn focus_in(&mut self, window: &mut Window, cx: &mut Context<Self>) {
if self.focus_handle.is_focused(window) && !self.multibuffer.read(cx).is_empty() {
self.editor.focus_handle(cx).focus(window)
@@ -320,6 +371,66 @@ impl ProjectDiagnosticsEditor {
}
}
fn update_all_diagnostics(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let cargo_diagnostics_sources = self.cargo_diagnostics_sources(cx);
if cargo_diagnostics_sources.is_empty() {
self.update_all_excerpts(window, cx);
} else {
self.fetch_cargo_diagnostics(Arc::new(cargo_diagnostics_sources), cx);
}
}
fn fetch_cargo_diagnostics(
&mut self,
diagnostics_sources: Arc<Vec<ProjectPath>>,
cx: &mut Context<Self>,
) {
let project = self.project.clone();
self.cargo_diagnostics_fetch.cancel_task = None;
self.cargo_diagnostics_fetch.fetch_task = None;
self.cargo_diagnostics_fetch.diagnostic_sources = diagnostics_sources.clone();
if self.cargo_diagnostics_fetch.diagnostic_sources.is_empty() {
return;
}
self.cargo_diagnostics_fetch.fetch_task = Some(cx.spawn(async move |editor, cx| {
let mut fetch_tasks = Vec::new();
for buffer_path in diagnostics_sources.iter().cloned() {
if cx
.update(|cx| {
fetch_tasks.push(run_flycheck(project.clone(), buffer_path, cx));
})
.is_err()
{
break;
}
}
let _ = join_all(fetch_tasks).await;
editor
.update(cx, |editor, _| {
editor.cargo_diagnostics_fetch.fetch_task = None;
})
.ok();
}));
}
fn stop_cargo_diagnostics_fetch(&mut self, cx: &mut App) {
self.cargo_diagnostics_fetch.fetch_task = None;
let mut cancel_gasks = Vec::new();
for buffer_path in std::mem::take(&mut self.cargo_diagnostics_fetch.diagnostic_sources)
.iter()
.cloned()
{
cancel_gasks.push(cancel_flycheck(self.project.clone(), buffer_path, cx));
}
self.cargo_diagnostics_fetch.cancel_task = Some(cx.background_spawn(async move {
let _ = join_all(cancel_gasks).await;
log::info!("Finished fetching cargo diagnostics");
}));
}
/// Enqueue an update of all excerpts. Updates all paths that either
/// currently have diagnostics or are currently present in this view.
fn update_all_excerpts(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -422,20 +533,17 @@ impl ProjectDiagnosticsEditor {
})?;
for item in more {
let insert_pos = blocks
.binary_search_by(|existing| {
match existing.initial_range.start.cmp(&item.initial_range.start) {
Ordering::Equal => item
.initial_range
.end
.cmp(&existing.initial_range.end)
.reverse(),
other => other,
}
let i = blocks
.binary_search_by(|probe| {
probe
.initial_range
.start
.cmp(&item.initial_range.start)
.then(probe.initial_range.end.cmp(&item.initial_range.end))
.then(Ordering::Greater)
})
.unwrap_or_else(|pos| pos);
blocks.insert(insert_pos, item);
.unwrap_or_else(|i| i);
blocks.insert(i, item);
}
}
@@ -448,10 +556,25 @@ impl ProjectDiagnosticsEditor {
&mut cx,
)
.await;
excerpt_ranges.push(ExcerptRange {
context: excerpt_range,
primary: b.initial_range.clone(),
})
let i = excerpt_ranges
.binary_search_by(|probe| {
probe
.context
.start
.cmp(&excerpt_range.start)
.then(probe.context.end.cmp(&excerpt_range.end))
.then(probe.primary.start.cmp(&b.initial_range.start))
.then(probe.primary.end.cmp(&b.initial_range.end))
.then(cmp::Ordering::Greater)
})
.unwrap_or_else(|i| i);
excerpt_ranges.insert(
i,
ExcerptRange {
context: excerpt_range,
primary: b.initial_range.clone(),
},
)
}
this.update_in(cx, |this, window, cx| {
@@ -534,6 +657,30 @@ impl ProjectDiagnosticsEditor {
})
})
}
pub fn cargo_diagnostics_sources(&self, cx: &App) -> Vec<ProjectPath> {
let fetch_cargo_diagnostics = ProjectSettings::get_global(cx)
.diagnostics
.fetch_cargo_diagnostics();
if !fetch_cargo_diagnostics {
return Vec::new();
}
self.project
.read(cx)
.worktrees(cx)
.filter_map(|worktree| {
let _cargo_toml_entry = worktree.read(cx).entry_for_path("Cargo.toml")?;
let rust_file_entry = worktree.read(cx).entries(false, 0).find(|entry| {
entry
.path
.extension()
.and_then(|extension| extension.to_str())
== Some("rs")
})?;
self.project.read(cx).path_for_entry(rust_file_entry.id, cx)
})
.collect()
}
}
impl Focusable for ProjectDiagnosticsEditor {

View File

@@ -1,4 +1,6 @@
use crate::ProjectDiagnosticsEditor;
use std::sync::Arc;
use crate::{ProjectDiagnosticsEditor, ToggleDiagnosticsRefresh};
use gpui::{Context, Entity, EventEmitter, ParentElement, Render, WeakEntity, Window};
use ui::prelude::*;
use ui::{IconButton, IconButtonShape, IconName, Tooltip};
@@ -13,18 +15,26 @@ impl Render for ToolbarControls {
let mut include_warnings = false;
let mut has_stale_excerpts = false;
let mut is_updating = false;
let cargo_diagnostics_sources = Arc::new(self.diagnostics().map_or(Vec::new(), |editor| {
editor.read(cx).cargo_diagnostics_sources(cx)
}));
let fetch_cargo_diagnostics = !cargo_diagnostics_sources.is_empty();
if let Some(editor) = self.diagnostics() {
let diagnostics = editor.read(cx);
include_warnings = diagnostics.include_warnings;
has_stale_excerpts = !diagnostics.paths_to_update.is_empty();
is_updating = diagnostics.update_excerpts_task.is_some()
|| diagnostics
.project
.read(cx)
.language_servers_running_disk_based_diagnostics(cx)
.next()
.is_some();
is_updating = if fetch_cargo_diagnostics {
diagnostics.cargo_diagnostics_fetch.fetch_task.is_some()
} else {
diagnostics.update_excerpts_task.is_some()
|| diagnostics
.project
.read(cx)
.language_servers_running_disk_based_diagnostics(cx)
.next()
.is_some()
};
}
let tooltip = if include_warnings {
@@ -41,21 +51,56 @@ impl Render for ToolbarControls {
h_flex()
.gap_1()
.when(has_stale_excerpts, |div| {
div.child(
IconButton::new("update-excerpts", IconName::Update)
.icon_color(Color::Info)
.shape(IconButtonShape::Square)
.disabled(is_updating)
.tooltip(Tooltip::text("Update excerpts"))
.on_click(cx.listener(|this, _, window, cx| {
if let Some(diagnostics) = this.diagnostics() {
diagnostics.update(cx, |diagnostics, cx| {
diagnostics.update_all_excerpts(window, cx);
});
}
})),
)
.map(|div| {
if is_updating {
div.child(
IconButton::new("stop-updating", IconName::StopFilled)
.icon_color(Color::Info)
.shape(IconButtonShape::Square)
.tooltip(Tooltip::for_action_title(
"Stop diagnostics update",
&ToggleDiagnosticsRefresh,
))
.on_click(cx.listener(move |toolbar_controls, _, _, cx| {
if let Some(diagnostics) = toolbar_controls.diagnostics() {
diagnostics.update(cx, |diagnostics, cx| {
diagnostics.stop_cargo_diagnostics_fetch(cx);
diagnostics.update_excerpts_task = None;
cx.notify();
});
}
})),
)
} else {
div.child(
IconButton::new("refresh-diagnostics", IconName::Update)
.icon_color(Color::Info)
.shape(IconButtonShape::Square)
.disabled(!has_stale_excerpts && !fetch_cargo_diagnostics)
.tooltip(Tooltip::for_action_title(
"Refresh diagnostics",
&ToggleDiagnosticsRefresh,
))
.on_click(cx.listener({
move |toolbar_controls, _, window, cx| {
if let Some(diagnostics) = toolbar_controls.diagnostics() {
let cargo_diagnostics_sources =
Arc::clone(&cargo_diagnostics_sources);
diagnostics.update(cx, move |diagnostics, cx| {
if fetch_cargo_diagnostics {
diagnostics.fetch_cargo_diagnostics(
cargo_diagnostics_sources,
cx,
);
} else {
diagnostics.update_all_excerpts(window, cx);
}
});
}
}
})),
)
}
})
.child(
IconButton::new("toggle-warnings", IconName::Warning)

View File

@@ -249,7 +249,9 @@ actions!(
ApplyDiffHunk,
Backspace,
Cancel,
CancelFlycheck,
CancelLanguageServerWork,
ClearFlycheck,
ConfirmRename,
ConfirmCompletionInsert,
ConfirmCompletionReplace,
@@ -308,6 +310,7 @@ actions!(
GoToImplementation,
GoToImplementationSplit,
GoToNextChange,
GoToParentModule,
GoToPreviousChange,
GoToPreviousDiagnostic,
GoToTypeDefinition,
@@ -371,6 +374,7 @@ actions!(
RevertFile,
ReloadFile,
Rewrap,
RunFlycheck,
ScrollCursorBottom,
ScrollCursorCenter,
ScrollCursorCenterTopBottom,

View File

@@ -5005,11 +5005,11 @@ impl Editor {
range
};
ranges.push(range);
ranges.push(range.clone());
if !self.linked_edit_ranges.is_empty() {
let start_anchor = snapshot.anchor_before(selection.head());
let end_anchor = snapshot.anchor_after(selection.tail());
let start_anchor = snapshot.anchor_before(range.start);
let end_anchor = snapshot.anchor_after(range.end);
if let Some(ranges) = self
.linked_editing_ranges_for(start_anchor.text_anchor..end_anchor.text_anchor, cx)
{

View File

@@ -1,6 +1,7 @@
use super::*;
use crate::{
JoinLines,
linked_editing_ranges::LinkedEditingRanges,
scroll::scroll_amount::ScrollAmount,
test::{
assert_text_with_selections, build_editor,
@@ -19559,6 +19560,146 @@ async fn test_hide_mouse_context_menu_on_modal_opened(cx: &mut TestAppContext) {
});
}
#[gpui::test]
async fn test_html_linked_edits_on_completion(cx: &mut TestAppContext) {
init_test(cx, |_| {});
let fs = FakeFs::new(cx.executor());
fs.insert_file(path!("/file.html"), Default::default())
.await;
let project = Project::test(fs, [path!("/").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
let html_language = Arc::new(Language::new(
LanguageConfig {
name: "HTML".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["html".to_string()],
..LanguageMatcher::default()
},
brackets: BracketPairConfig {
pairs: vec![BracketPair {
start: "<".into(),
end: ">".into(),
close: true,
..Default::default()
}],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_html::LANGUAGE.into()),
));
language_registry.add(html_language);
let mut fake_servers = language_registry.register_fake_lsp(
"HTML",
FakeLspAdapter {
capabilities: lsp::ServerCapabilities {
completion_provider: Some(lsp::CompletionOptions {
resolve_provider: Some(true),
..Default::default()
}),
..Default::default()
},
..Default::default()
},
);
let workspace = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let worktree_id = workspace
.update(cx, |workspace, _window, cx| {
workspace.project().update(cx, |project, cx| {
project.worktrees(cx).next().unwrap().read(cx).id()
})
})
.unwrap();
project
.update(cx, |project, cx| {
project.open_local_buffer_with_lsp(path!("/file.html"), cx)
})
.await
.unwrap();
let editor = workspace
.update(cx, |workspace, window, cx| {
workspace.open_path((worktree_id, "file.html"), None, true, window, cx)
})
.unwrap()
.await
.unwrap()
.downcast::<Editor>()
.unwrap();
let fake_server = fake_servers.next().await.unwrap();
editor.update_in(cx, |editor, window, cx| {
editor.set_text("<ad></ad>", window, cx);
editor.change_selections(None, window, cx, |selections| {
selections.select_ranges([Point::new(0, 3)..Point::new(0, 3)]);
});
let Some((buffer, _)) = editor
.buffer
.read(cx)
.text_anchor_for_position(editor.selections.newest_anchor().start, cx)
else {
panic!("Failed to get buffer for selection position");
};
let buffer = buffer.read(cx);
let buffer_id = buffer.remote_id();
let opening_range =
buffer.anchor_before(Point::new(0, 1))..buffer.anchor_after(Point::new(0, 3));
let closing_range =
buffer.anchor_before(Point::new(0, 6))..buffer.anchor_after(Point::new(0, 8));
let mut linked_ranges = HashMap::default();
linked_ranges.insert(
buffer_id,
vec![(opening_range.clone(), vec![closing_range.clone()])],
);
editor.linked_edit_ranges = LinkedEditingRanges(linked_ranges);
});
let mut completion_handle =
fake_server.set_request_handler::<lsp::request::Completion, _, _>(move |_, _| async move {
Ok(Some(lsp::CompletionResponse::Array(vec![
lsp::CompletionItem {
label: "head".to_string(),
text_edit: Some(lsp::CompletionTextEdit::InsertAndReplace(
lsp::InsertReplaceEdit {
new_text: "head".to_string(),
insert: lsp::Range::new(
lsp::Position::new(0, 1),
lsp::Position::new(0, 3),
),
replace: lsp::Range::new(
lsp::Position::new(0, 1),
lsp::Position::new(0, 3),
),
},
)),
..Default::default()
},
])))
});
editor.update_in(cx, |editor, window, cx| {
editor.show_completions(&ShowCompletions { trigger: None }, window, cx);
});
cx.run_until_parked();
completion_handle.next().await.unwrap();
editor.update(cx, |editor, _| {
assert!(
editor.context_menu_visible(),
"Completion menu should be visible"
);
});
editor.update_in(cx, |editor, window, cx| {
editor.confirm_completion(&ConfirmCompletion::default(), window, cx)
});
cx.executor().run_until_parked();
editor.update(cx, |editor, cx| {
assert_eq!(editor.text(cx), "<head></head>");
});
}
fn empty_range(row: usize, column: usize) -> Range<DisplayPoint> {
let point = DisplayPoint::new(DisplayRow(row as u32), column as u32);
point..point

View File

@@ -2276,6 +2276,9 @@ impl EditorElement {
}
let display_row = multibuffer_point.to_display_point(snapshot).row();
if !range.contains(&display_row) {
return None;
}
if row_infos
.get((display_row - range.start).0 as usize)
.is_some_and(|row_info| row_info.expand_info.is_some())
@@ -6871,7 +6874,12 @@ impl Element for EditorElement {
// The max scroll position for the top of the window
let max_scroll_top = if matches!(
snapshot.mode,
EditorMode::AutoHeight { .. } | EditorMode::SingleLine { .. }
EditorMode::SingleLine { .. }
| EditorMode::AutoHeight { .. }
| EditorMode::Full {
sized_by_content: true,
..
}
) {
(max_row - height_in_lines + 1.).max(0.)
} else {

View File

@@ -233,7 +233,7 @@ pub fn deploy_context_menu(
.separator()
.action("Cut", Box::new(Cut))
.action("Copy", Box::new(Copy))
.action("Copy and trim", Box::new(CopyAndTrim))
.action("Copy and Trim", Box::new(CopyAndTrim))
.action("Paste", Box::new(Paste))
.separator()
.map(|builder| {

View File

@@ -4,15 +4,20 @@ use anyhow::Context as _;
use gpui::{App, AppContext as _, Context, Entity, Window};
use language::{Capability, Language, proto::serialize_anchor};
use multi_buffer::MultiBuffer;
use project::lsp_store::{
lsp_ext_command::{DocsUrls, ExpandMacro, ExpandedMacro},
rust_analyzer_ext::RUST_ANALYZER_NAME,
use project::{
ProjectItem,
lsp_command::location_link_from_proto,
lsp_store::{
lsp_ext_command::{DocsUrls, ExpandMacro, ExpandedMacro},
rust_analyzer_ext::{RUST_ANALYZER_NAME, cancel_flycheck, clear_flycheck, run_flycheck},
},
};
use rpc::proto;
use text::ToPointUtf16;
use crate::{
Editor, ExpandMacroRecursively, OpenDocs, element::register_action,
CancelFlycheck, ClearFlycheck, Editor, ExpandMacroRecursively, GoToParentModule,
GotoDefinitionKind, OpenDocs, RunFlycheck, element::register_action, hover_links::HoverLink,
lsp_ext::find_specific_language_server_in_selection,
};
@@ -30,11 +35,97 @@ pub fn apply_related_actions(editor: &Entity<Editor>, window: &mut Window, cx: &
.filter_map(|buffer| buffer.read(cx).language())
.any(|language| is_rust_language(language))
{
register_action(&editor, window, go_to_parent_module);
register_action(&editor, window, expand_macro_recursively);
register_action(&editor, window, open_docs);
register_action(&editor, window, cancel_flycheck_action);
register_action(&editor, window, run_flycheck_action);
register_action(&editor, window, clear_flycheck_action);
}
}
pub fn go_to_parent_module(
editor: &mut Editor,
_: &GoToParentModule,
window: &mut Window,
cx: &mut Context<Editor>,
) {
if editor.selections.count() == 0 {
return;
}
let Some(project) = &editor.project else {
return;
};
let server_lookup = find_specific_language_server_in_selection(
editor,
cx,
is_rust_language,
RUST_ANALYZER_NAME,
);
let project = project.clone();
let lsp_store = project.read(cx).lsp_store();
let upstream_client = lsp_store.read(cx).upstream_client();
cx.spawn_in(window, async move |editor, cx| {
let Some((trigger_anchor, _, server_to_query, buffer)) = server_lookup.await else {
return anyhow::Ok(());
};
let location_links = if let Some((client, project_id)) = upstream_client {
let buffer_id = buffer.update(cx, |buffer, _| buffer.remote_id())?;
let request = proto::LspExtGoToParentModule {
project_id,
buffer_id: buffer_id.to_proto(),
position: Some(serialize_anchor(&trigger_anchor.text_anchor)),
};
let response = client
.request(request)
.await
.context("lsp ext go to parent module proto request")?;
futures::future::join_all(
response
.links
.into_iter()
.map(|link| location_link_from_proto(link, lsp_store.clone(), cx)),
)
.await
.into_iter()
.collect::<anyhow::Result<_>>()
.context("go to parent module via collab")?
} else {
let buffer_snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
let position = trigger_anchor.text_anchor.to_point_utf16(&buffer_snapshot);
project
.update(cx, |project, cx| {
project.request_lsp(
buffer,
project::LanguageServerToQuery::Other(server_to_query),
project::lsp_store::lsp_ext_command::GoToParentModule { position },
cx,
)
})?
.await
.context("go to parent module")?
};
editor
.update_in(cx, |editor, window, cx| {
editor.navigate_to_hover_links(
Some(GotoDefinitionKind::Declaration),
location_links.into_iter().map(HoverLink::Text).collect(),
false,
window,
cx,
)
})?
.await?;
Ok(())
})
.detach_and_log_err(cx);
}
pub fn expand_macro_recursively(
editor: &mut Editor,
_: &ExpandMacroRecursively,
@@ -213,3 +304,87 @@ pub fn open_docs(editor: &mut Editor, _: &OpenDocs, window: &mut Window, cx: &mu
})
.detach_and_log_err(cx);
}
fn cancel_flycheck_action(
editor: &mut Editor,
_: &CancelFlycheck,
_: &mut Window,
cx: &mut Context<Editor>,
) {
let Some(project) = &editor.project else {
return;
};
let Some(buffer_id) = editor
.selections
.disjoint_anchors()
.iter()
.find_map(|selection| {
let buffer_id = selection.start.buffer_id.or(selection.end.buffer_id)?;
let project = project.read(cx);
let entry_id = project
.buffer_for_id(buffer_id, cx)?
.read(cx)
.entry_id(cx)?;
project.path_for_entry(entry_id, cx)
})
else {
return;
};
cancel_flycheck(project.clone(), buffer_id, cx).detach_and_log_err(cx);
}
fn run_flycheck_action(
editor: &mut Editor,
_: &RunFlycheck,
_: &mut Window,
cx: &mut Context<Editor>,
) {
let Some(project) = &editor.project else {
return;
};
let Some(buffer_id) = editor
.selections
.disjoint_anchors()
.iter()
.find_map(|selection| {
let buffer_id = selection.start.buffer_id.or(selection.end.buffer_id)?;
let project = project.read(cx);
let entry_id = project
.buffer_for_id(buffer_id, cx)?
.read(cx)
.entry_id(cx)?;
project.path_for_entry(entry_id, cx)
})
else {
return;
};
run_flycheck(project.clone(), buffer_id, cx).detach_and_log_err(cx);
}
fn clear_flycheck_action(
editor: &mut Editor,
_: &ClearFlycheck,
_: &mut Window,
cx: &mut Context<Editor>,
) {
let Some(project) = &editor.project else {
return;
};
let Some(buffer_id) = editor
.selections
.disjoint_anchors()
.iter()
.find_map(|selection| {
let buffer_id = selection.start.buffer_id.or(selection.end.buffer_id)?;
let project = project.read(cx);
let entry_id = project
.buffer_for_id(buffer_id, cx)?
.read(cx)
.entry_id(cx)?;
project.path_for_entry(entry_id, cx)
})
else {
return;
};
clear_flycheck(project.clone(), buffer_id, cx).detach_and_log_err(cx);
}

View File

@@ -48,6 +48,7 @@ markdown.workspace = true
node_runtime.workspace = true
pathdiff.workspace = true
paths.workspace = true
pretty_assertions.workspace = true
project.workspace = true
prompt_store.workspace = true
regex.workspace = true

View File

@@ -1,6 +1,7 @@
{
"assistant": {
"always_allow_tool_actions": true,
"stream_edits": true,
"version": "2"
}
}

View File

@@ -52,10 +52,10 @@ struct Args {
#[arg(long, value_delimiter = ',', default_value = "rs,ts")]
languages: Vec<String>,
/// How many times to run each example.
#[arg(long, default_value = "1")]
#[arg(long, default_value = "8")]
repetitions: usize,
/// Maximum number of examples to run concurrently.
#[arg(long, default_value = "10")]
#[arg(long, default_value = "4")]
concurrency: usize,
}
@@ -420,12 +420,18 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
languages::init(languages.clone(), node_runtime.clone(), cx);
assistant_tools::init(client.http_client(), cx);
context_server::init(cx);
prompt_store::init(cx);
let stdout_is_a_pty = false;
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
agent::init(
fs.clone(),
client.clone(),
prompt_builder.clone(),
languages.clone(),
cx,
);
assistant_tools::init(client.http_client(), cx);
SettingsStore::update_global(cx, |store, cx| {
store.set_user_settings(include_str!("../runner_settings.json"), cx)

View File

@@ -160,7 +160,11 @@ impl ExampleContext {
if left == right {
Ok(())
} else {
println!("{}{:#?} != {:#?}", self.log_prefix, left, right);
println!(
"{}{}",
self.log_prefix,
pretty_assertions::Comparison::new(&left, &right)
);
Err(anyhow::Error::from(FailedAssertion(message.clone())))
},
message,
@@ -334,8 +338,8 @@ impl ExampleContext {
}
pub fn edits(&self) -> HashMap<Arc<Path>, FileEdits> {
self.app
.read_entity(&self.agent_thread, |thread, cx| {
self.agent_thread
.read_with(&self.app, |thread, cx| {
let action_log = thread.action_log().read(cx);
HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
|(buffer, diff)| {
@@ -503,16 +507,16 @@ impl ToolUse {
}
}
#[derive(Debug)]
#[derive(Debug, Eq, PartialEq)]
pub struct FileEdits {
hunks: Vec<FileEditHunk>,
pub hunks: Vec<FileEditHunk>,
}
#[derive(Debug)]
struct FileEditHunk {
base_text: String,
text: String,
status: DiffHunkStatus,
#[derive(Debug, Eq, PartialEq)]
pub struct FileEditHunk {
pub base_text: String,
pub text: String,
pub status: DiffHunkStatus,
}
impl FileEdits {

Some files were not shown because too many files have changed in this diff Show More