Compare commits

...

100 Commits

Author SHA1 Message Date
Oleksiy Syvokon
9ca4d3c33d Add reminders to sync prompts with the backend 2025-05-30 11:07:35 +03:00
Michael Sloan
d7f0241d7b editor: Defer the effects of change_selections to end of transact (#31731)
In quite a few places the selection is changed multiple times in a
transaction. For example, `backspace` might do it 3 times:

* `select_autoclose_pair`
* selection of the ranges to delete
* `insert` of empty string also updates selection

Before this change, each of these selection changes appended to
selection history and did a bunch of work that's only relevant to
selections the user actually sees. So for each backspace,
`editor::UndoSelection` would need to be invoked 3-4 times before the
cursor actually moves. It still needs to be run twice after this change,
but that is a separate issue.

Signature help even had a `backspace_pressed: bool` as an incomplete
workaround, to avoid it flickering due to the selection switching
between being a range and being cursor-like.

The original motivation for this change is work I'm doing on not
re-querying completions when the language server provides a response
that has `is_incomplete: false`. Whether the menu is still visible is
determined by the cursor position, and this was complicated by it seeing
`backspace` temporarily moving the head of the selection 1 character to
the left.

This change also removes some redundant uses of
`push_to_selection_history`.

Not super stoked with the name `DeferredSelectionEffectsState`. Naming
is hard.

Release Notes:

- N/A
2025-05-30 01:53:02 -06:00
Cole Miller
1445af559b Unify the tasks modal and the new session modal (#31646)
Release Notes:

- Debugger Beta: added a button to the quick action bar to start a debug
session or spawn a task, depending on which of these actions was taken
most recently.
- Debugger Beta: incorporated the tasks modal into the new session modal
as an additional tab.

---------

Co-authored-by: Julia Ryan <juliaryan3.14@gmail.com>
Co-authored-by: Julia Ryan <p1n3appl3@users.noreply.github.com>
Co-authored-by: Anthony Eid <hello@anthonyeid.me>
Co-authored-by: Mikayla <mikayla@zed.dev>
2025-05-29 21:33:52 -04:00
Anthony Eid
804de3316e debugger: Update docs with more examples (#31597)
This PR also shows more completion items when defining a debug config in
a `debug.json` file. Mainly when using a pre build task argument.

### Follow ups
- Add docs for Go, JS, PHP
- Add attach docs

Release Notes:

- debugger beta: Show build task completions when editing a debug.json
configuration with a pre build task
- debugger beta: Add Python and Native Code debug config
[examples](https://zed.dev/docs/debugger)
2025-05-30 04:22:16 +03:00
Smit Barmase
a387bf5f54 zed: Fix migration banner not hiding after migration has been carried out (#31723)
- https://github.com/zed-industries/zed/pull/30444

This PR broke migration notification to only emit event when content is
migrated. This resulted in the migration banner not going away after
clicking "Backup and Migrate". It should also emit event when it's not
migrated which removes the banner.

Future: I think we should have better tests in place for banner
visibility.

Release Notes:

- Fixed an issue where migration banner wouldn't go away after clicking
"Backup and Migrate".
2025-05-30 06:00:37 +05:30
Marshall Bowers
c7047d5f0a collab: Fully move StripeBilling over to using StripeClient (#31722)
This PR moves over the last method on `StripeBilling` to use the
`StripeClient` trait, allowing us to fully mock out Stripe behaviors for
`StripeBilling` in tests.

Release Notes:

- N/A
2025-05-29 23:49:14 +00:00
Kirill Bulatov
406d975f39 Cleanup corresponding task history on task file update (#31720)
Closes https://github.com/zed-industries/zed/issues/31715

Release Notes:

- Fixed old task history not erased after task file update
2025-05-29 23:02:59 +00:00
Finn Evers
cbed580db0 workspace: Ensure pane handle hitbox blocks mouse events (#31719)
Follow-up to #31712

Pane handle hitboxes were opaque prior to the linked PR. This was the
case because pane handles have an intentionally larger hitbox than the
pane dividers size to allow for easier dragging. The cursor style is
also updated for that hitbox to indicate that resizing is possible:


9086784038/crates/workspace/src/pane_group.rs (L1297-L1301)

Not blocking the mouse events here causes mouse events to bleed through
this hitbox whilst actually any clicks will only cause a pane resize to
happen. Hence, this hitbox should continue to block mouse events to
avoid any confusion when resizing panes.

I considered using `HitboxBehavior::BlockMouseExceptScroll` here,
however, due to the reasons mentioned above, I decided against it. The
cursor will not indicate that scrolling should be possible. Since all
other mouse events on underlying elements (like hovers) are blocked, it
felt more reasonable to just go with `HitboxBehavior::BlockMouse`.

Release Notes:

- N/A
2025-05-29 16:35:22 -06:00
Michael Sloan
8aef64bbfa Remove block_mouse_down in favor of stop_mouse_events_except_scroll (#30401)
This method was added in #20649 to be an alternative of `occlude` which
allows scroll events. It seems a bit arbitrary to only stop left mouse
downs, so this seems like it's probably an improvement.

Release Notes:

- N/A
2025-05-29 22:07:34 +00:00
Michael Sloan
9086784038 gpui: Support hitbox blocking mouse interaction except scrolling (#31712)
tl;dr: This adds `.block_mouse_except_scroll()` which should typically
be used instead of `.occlude()` for cases when the mouse shouldn't
interact with elements drawn below an element. The rationale for
treating scroll events differently:

* Mouse move / click / styles / tooltips are for elements the user is
interacting with directly.
* Mouse scroll events are about finding the current outer scroll
container.

Most use of `occlude` should probably be switched to this, but I figured
I'd derisk this change by minimizing behavior changes to just the 3 uses
of `block_mouse_except_scroll`.

GPUI changes:

* Added `InteractiveElement::block_mouse_except_scroll()`, and removes
`stop_mouse_events_except_scroll()`

* Added `Hitbox::should_handle_scroll()` to be used when handling scroll
wheel events.

* `Window::insert_hitbox` now takes `HitboxBehavior` instead of
`occlude: bool`.

    - `false` for that bool is now `HitboxBehavior::Normal`.

    - `true` for that bool is now `HitboxBehavior::BlockMouse`.
    
    - The new mode is `HitboxBehavior::BlockMouseExceptScroll`.

* Removes `Default` impl for `HitboxId` since applications should not
manually create `HitboxId(0)`.

Release Notes:

- N/A
2025-05-29 21:41:15 +00:00
Kirill Bulatov
2abc5893c1 Improve TypeScript task detection (#31711)
Parses project's package.json to better detect Jasmine, Jest, Vitest and
Mocha and `test`, `build` scripts presence.
Also tries to detect `pnpm` and `npx` as test runners, falls back to
`npm`.


https://github.com/user-attachments/assets/112d3d8b-8daa-4ba5-8cb5-2f483036bd98

Release Notes:

- Improved TypeScript task detection
2025-05-29 20:51:20 +00:00
Marshall Bowers
a23ee61a4b Pass up intent with completion requests (#31710)
This PR adds a new `intent` field to completion requests to assist in
categorizing them correctly.

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-05-29 20:43:12 +00:00
Simon Pham
38e45e828b Add View Release Notes to Help menu (#31704)
<img width="891" alt="image"
src="https://github.com/user-attachments/assets/59e98fdb-c1b5-4948-8d69-661561d838f1"
/>



Release Notes:

- Added `View Release Notes` to `Help` menu
2025-05-29 19:39:54 +00:00
Danilo Leal
181bf78b7d agent: Change the navigation menu keybinding (#31709)
As much as I enjoyed the previous keybinding, it was causing a conflict
with the editor where it wouldn't open on text threads. To not get into
a rabbit hole and complicate the fix too much, I figured simply changing
it to something non-conflictual would be a good move.

Release Notes:

- agent: Fixed a bug where the panel navigation menu wouldn't open with
the keybinding on text threads.
2025-05-29 19:31:57 +00:00
Piotr Osiewicz
c42d060509 Update debug.json in Zed repo to run the build on session startup (#31707)
Closes #ISSUE

Release Notes:

- N/A
2025-05-29 21:29:18 +02:00
Peter Tripp
6ea9abdc1b Cursor keymap (#31702)
To use this, spawn `weclome: toggle base keymap selector` from the
command palette.

<img width="589" alt="Screenshot 2025-05-29 at 14 07 35"
src="https://github.com/user-attachments/assets/0d4c4eff-6a3b-40f4-9032-5d8ca7664d20"
/>

MacOS is well tested to match Cursor. The [curors keymap
documentation](https://docs.cursor.com/kbd) is does not explicitly state
windows/linux keymap entries only "All Cmd keys can be replaced with
Ctrl on Windows." so that is what we've done. We welcome feedback /
refinements.

Note, because this provides a mapping for `cmd-k` (macos) and `ctrl-k`
(linux/windows) using this keymap will disable all of the default
chorded keymap entries which have `cmd-k` / `ctrl-k` as a prefix. For
example `cmd-k cmd-s` for open keymap will no longer function.

Release Notes:

- Added Cursor compatibility keymap

---------

Co-authored-by: Joseph Lyons <joseph@zed.dev>
2025-05-29 15:20:58 -04:00
Piotr Osiewicz
070eac28e3 go: Use delve-dap-shim for spawning delve (#31700)
This allows us to support terminal with go sessions

Closes #ISSUE

Release Notes:

- debugger: Add support for terminal when debugging Go programs
2025-05-29 21:19:56 +02:00
Danilo Leal
05692e298a agent: Fix panel "go back" button (#31706)
Closes https://github.com/zed-industries/zed/issues/31652.

Release Notes:

- agent: Fixed a bug where the "go back" button wouldn't go back to the
Text Thread after visiting another view from it.
2025-05-29 16:00:37 -03:00
Richard Feldman
ccb049bd97 Format streamed edits on save (#31623)
Re-enables format on save for agent changes (when the user has that
enabled in settings), except differently from before:
- Now we do the format-on-save in the separate buffer the edit tool
uses, *before* the diff
- This means it never triggers separate staleness
- It has the downside that edits are now blocked on the formatter
completing, but that's true of saving in general.

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <hi@aguz.me>
2025-05-29 14:33:41 -04:00
Danilo Leal
fe57eedb44 agent: Rename PromptEditor to TextThread in the panel's ActiveView (#31705)
Was touching this part of the Agent Panel and thought it could be a
quick name consistency win here, so it is aligned with the terminology
we're currently actively using in the product/docs.

Release Notes:

- N/A
2025-05-29 15:31:35 -03:00
5brian
c57e6bc784 tab_switcher: Add placeholder text (#31697)
| Before | After |
|---|---|
|<img width="478" alt="image"
src="https://github.com/user-attachments/assets/5baba783-ee31-42cd-9760-7ee19edb1123"
/>|<img width="478" alt="image"
src="https://github.com/user-attachments/assets/1b149500-4a97-4085-80e5-fd628c92471a"
/>|

Release Notes:

- N/A
2025-05-29 16:09:07 +00:00
Piotr Osiewicz
83135e98e6 Introduce $ZED_CUSTOM_PYTHON_ACTIVE_ZED_TOOLCHAIN_RAW to work around (#31685)
Follow up to #31674 

Release Notes:

- N/A

Co-authored-by: Kirill Bulatov <mail4score@gmail.com>
2025-05-29 13:44:55 +00:00
Umesh Yadav
703ee29658 Rename Max Mode to Burn Mode throughout code and docs (#31668)
Follow up to https://github.com/zed-industries/zed/pull/31470.

I started looking at config and changed preferred_completion_mode to
burn to only find its max so made changes to align it better with
rebrand. As this is in preview build now.

This doesn't touch zed_llm_client. Only the Zed changes the code and doc
to match the new UI of burn mode. There are still more things to be
renamed, though.

Release Notes:

- N/A

---------

Signed-off-by: Umesh Yadav <git@umesh.dev>
Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
2025-05-29 13:12:42 +00:00
Kirill Bulatov
f792827a01 Allow to reuse PickerPopoverMenu outside of the model selector (#31684)
LSP button preparation step: move out the component that will be used to
build the button's context menu.

Release Notes:

- N/A
2025-05-29 12:55:47 +00:00
Danilo Leal
45f9edcbb9 docs: Add small refinements to CSS adjacent pages (#31683)
Follow up to https://github.com/zed-industries/zed/pull/31681. Was
visiting some of these pages and noticed these somewhat small formatting
and copywriting improvement opportunities. The docs for Svelte in
particular felt somewhat unorganized.

Release Notes:

- N/A
2025-05-29 08:43:54 -03:00
Danilo Leal
e3354543c0 docs: Improve the Tailwind CSS page (#31681)
Namely, ensuring we mention the support for their Prettier plugins.

Release Notes:

- N/A
2025-05-29 08:15:59 -03:00
Oleksiy Syvokon
cb187b0b4d evals: Configurable number of max dialog turns (#31680)
Release Notes:

- N/A
2025-05-29 10:35:29 +00:00
Kirill Bulatov
d989b2260b Do not react on settings change for disabled minimaps (#31677)
Turning minimap on during debug sessions would cause the console editor
to gain the minimap, despite it being explicitly disabled in the code.

Release Notes:

- N/A
2025-05-29 10:04:27 +00:00
Dhruvin Gandhi
ae076fa415 task: Add ZED_RELATIVE_DIR task variable (#31657)
This is my first contribution to zed, let me know if I missed anything.

There is no corresponding issue/discussion.

`$ZED_RELATIVE_DIR` can be used in cases where a task's command's
filesystem namespace (e.g. inside a container) is different than the
host, where absolute paths cannot work.

I modified `relative_path` to `relative_file` after the addition of
`relative_dir`.

For top-level files, where `relative_file.parent() == Some("")`, I use
`"."` for `$ZED_RELATIVE_DIR`, which is a valid relative path in both
*nix and windows.

Thank you for building zed, and open-sourcing it. I hope to contribute
more as I use it as my primary editor.

Release Notes:

- Added ZED_RELATIVE_DIR (path to current file's directory relative to
worktree root) task variable.
2025-05-29 11:50:36 +02:00
Kirill Bulatov
b4af61edfe Revert "task: Wrap programs in ""s (#31537)" (#31674)
That commit broke a lot, as our one-off tasks (alt-enter in the tasks
modal), npm, jest tasks are all not real commands, but a composition of
commands and arguments.

This reverts commit 5db14d315b.

Closes https://github.com/zed-industries/zed/issues/31554

Release Notes:

- N/A

Co-authored-by: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com>
2025-05-29 09:19:23 +00:00
Smit Barmase
ea8a3be91b recent_projects: Move SSH server entry to initialize once instead of every render (#31650)
Currently, `RemoteEntry::SshConfig` for `ssh_config_servers` initializes
on every render. This leads to side effects like a new focus handle
being created on every render, which leads to breaking navigating
up/down for `ssh_config_servers` items.

This PR fixes it by moving the logic of remote entry
for`ssh_config_servers` into `default_mode`, and only rebuilding it when
`ssh_config_servers` actually changes.

Before:


https://github.com/user-attachments/assets/8c7187d3-16b5-4f96-aa73-fe4f8227b7d0

After:


https://github.com/user-attachments/assets/21588628-8b1c-43fb-bcb8-0b93c70a1e2b

Release Notes:

- Fixed issue navigating SSH config servers in Remote Projects with
keyboard.
2025-05-29 09:24:39 +05:30
Smit Barmase
5173a1a968 recent_projects: Fix remote projects not regaining focus after SSH server connect (#31651)
Closes #28071

Release Notes:

- Fixed issue preventing remote projects modal from regaining focus
after a successful SSH server connection.
2025-05-29 08:55:29 +05:30
Smit Barmase
87f097a0ab terminal_view: Fix terminal stealing focus on editor selection (#31639)
Closes #28234

Release Notes:

- Fixed the issue where the terminal focused when the mouse hovered over
it after selecting text in the editor.
2025-05-29 08:55:12 +05:30
Cole Miller
f9407db7d6 debugger: Add spinners while session is starting up (#31548)
Release Notes:

- Debugger Beta: Added a spinner to the debug panel when a session is
starting up.

---------

Co-authored-by: Remco Smits <djsmits12@gmail.com>
Co-authored-by: Julia <julia@zed.dev>
2025-05-29 01:58:40 +00:00
Cole Miller
384b11392a debugger: Disambiguate child session labels (#31526)
Add `(child)` instead of using the same label.

Release Notes:

- Debugger Beta: Made child sessions appear distinct from their parents
in the session selector.
2025-05-28 21:44:15 -04:00
Cole Miller
f20596c33b debugger: Don't open non-absolute paths from stack frame list (#31534)
Follow-up to #31524 with a more general fix

Release Notes:

- N/A

---------

Co-authored-by: Piotr <piotr@zed.dev>
2025-05-28 21:44:00 -04:00
Marshall Bowers
eb863f8fd6 collab: Use StripeClient when creating Stripe Checkout sessions (#31644)
This PR updates the `StripeBilling::checkout_with_zed_pro` and
`StripeBilling::checkout_with_zed_pro_trial` methods to use the
`StripeClient` trait instead of using `stripe::Client` directly.

Release Notes:

- N/A
2025-05-29 00:57:04 +00:00
Max Brunsfeld
97579662e6 Fix editor rendering slowness with large folds (#31569)
Closes https://github.com/zed-industries/zed/issues/31565

* Looking up settings on every row was very slow in the case of large
folds, especially if there was an `.editorconfig` file with numerous
glob patterns
* Checking whether each indent guide was within a fold was very slow,
when a fold spanned many indent guides.

Release Notes:

- Fixed slowness that could happen when editing in the presence of large
folds.
2025-05-28 23:05:06 +00:00
Marshall Bowers
53849cf983 collab: Remove Zed Free as an option when initiating a checkout session (#31638)
This PR removes Zed Free as an option when initiating a checkout
session, as we manage this plan automatically now.

Release Notes:

- N/A
2025-05-28 23:00:54 +00:00
Danilo Leal
1e25249055 docs: Adjust the channels page a bit (#31636)
All the docs related to collaboration could use some deep revamp, but
this PR is just formatting tweaks so it doesn't look broken. The images
weren't showing at all!

Release Notes:

- N/A
2025-05-28 19:27:47 -03:00
Marshall Bowers
469824c350 collab: Use StripeClient for creating model usage meter events (#31633)
This PR updates the `StripeBilling::bill_model_request_usage` method to
use the `StripeClient` trait.

Release Notes:

- N/A
2025-05-28 22:19:43 +00:00
Danilo Leal
a1c645e57e docs: Improve footer button design (#31634)
Just touching up these a little bit. I think including the page
destination as a label here is a good move!

Release Notes:

- N/A
2025-05-28 19:16:40 -03:00
Danilo Leal
0791596cda docs: Hide "on this page" element when there are no headings (#31635)
We were still showing the "On this page" element even when the page
didn't contain any h2s or h3s.

Release Notes:

- N/A
2025-05-28 19:16:32 -03:00
Finn Evers
9cc1851be7 python: Improve docstring highlighting (#31628)
This PR broadens the highlighting for docstrings in Python. 

Previously, only the first docstring for e.g. type aliases was
highlighted in Python files. This happened as only the first occurrence
in the module was considered a docstring. With this change, now all
existing docstrings are actually highlighted as such.

| `main` | This PR | 
| --- | --- |
|
![main](https://github.com/user-attachments/assets/facc96a9-4e98-4063-8b93-d6e9884221ff)
|
![PR](https://github.com/user-attachments/assets/9da557a1-b327-466a-be87-65d6a811e24c)
|

Release Notes:

- Added more docstring highlights for Python.
2025-05-29 00:02:40 +02:00
Finn Evers
50bd8770bd file_finder: Reduce vertical padding in footer (#31632)
Follow-up to #31542

This PR reduces the vertical padding in the file finders footer. We can
remove this padding as we already apply it just above


a5a116439e/crates/file_finder/src/file_finder.rs (L1500)

This also ensures that the items on the right side have the same padding
to the border as the icon on the left side. Currently, due to the
padding being applied twice, the items on the right side have `pr_4` as
well as `py_4` in practice, which seems a little excessive.

| `main` | This PR |
| --- | --- |
|
![file_finder_main](https://github.com/user-attachments/assets/352d2ac9-04a9-487d-96ca-b009b797809b)
|
![file_finder_pr](https://github.com/user-attachments/assets/c0b44beb-ff2c-4e93-a5b1-2393652a2a58)
|


Release Notes:

- N/A
2025-05-28 21:29:51 +00:00
Marshall Bowers
00bdebc89d collab: Use StripeClient in StripeBilling::subscribe_to_price (#31631)
This PR updates the `StripeBilling::subscribe_to_price` method to use
the `StripeClient` trait.

Release Notes:

- N/A
2025-05-28 21:17:11 +00:00
Danilo Leal
d5134062ac agent: Add keybinding to toggle Burn Mode (#31630)
One caveat with this PR is that the keybinding still doesn't work for text threads. Will do that in a follow-up.

Release Notes:

- agent: Added a keybinding to toggle Burn Mode on and off.
2025-05-28 18:08:58 -03:00
Julia Ryan
0e9f6986cf nix: Add job names and garnix substitutor (#31625)
This should result in some additional cache hits as I personally use
garnix.

Also added `-v` cachix arg to try to figure out why CI jobs aren't
pushing any paths. Right now they just show ["Pushing is
disabled."](https://github.com/zed-industries/zed/actions/runs/15293723678/job/43018512167#step:13:3)
but I'm not sure if that's due to the `pushFilter` or misconfigured
secrets.

Release Notes:

- N/A
2025-05-28 13:32:12 -07:00
Finn Evers
1035c6aab5 editor: Fix horizontal scrollbar alignment if indent guides are disabled (#31621)
Follow-up to #24887
Follow-up to #31510

This PR ensures that [this misalignment of the horizontal
scrollbar](https://github.com/zed-industries/zed/pull/31510#issuecomment-2912842457)
does not occur. See the entire discussion in the first linked PR as to
why this gap is there in the first place.

I am also aware of the general stance towards comments. Yet, I felt for
this case it is better to just straight up explain how these two things
are connected, as I do believe this is not intuitively clear after all.

Might also be a good time to bring
https://github.com/zed-industries/zed/issues/25519 up again. The
horizontal scrollbar seems huge for the edit file tool card.
Furthermore, since we do not reserve space for the horizontal scrollbar
(yet), this will lead to the last line being not clickable.

Release Notes:

- N/A
2025-05-28 22:59:51 +03:00
Marshall Bowers
75e69a5ae9 collab: Use StripeClient to retrieve prices and meters from Stripe (#31624)
This PR updates `StripeBilling` to use the `StripeClient` trait to
retrieve prices and meters from Stripe instead of using the
`stripe::Client` directly.

Release Notes:

- N/A
2025-05-28 19:51:06 +00:00
Oleksiy Syvokon
05afe95539 agent: Fix bug in creating empty files (#31626)
Release Notes:

- NA
2025-05-28 19:31:54 +00:00
Oleksiy Syvokon
a5a116439e agent: Rejecting agent changes shouldn't discard user edits (#31617)
The fix prevents data loss, but it also results in a somewhat confusing
UX. Specifically, after the user has made changes to an AI-created file,
selecting "Reject" will leave AI changes in place.

This is because there's no trivial way to disentangle user edits from
the edits made by the AI.

A better solution might exist. In the meantime, this change should do.
    
Closes
* #30527 

Release Notes:

- Prevent data loss when reverting changes in an agent-created file
2025-05-28 18:44:49 +00:00
Marshall Bowers
361ceee72b collab: Introduce StripeClient trait to abstract over Stripe interactions (#31615)
This PR introduces a new `StripeClient` trait to abstract over
interacting with the Stripe API.

This will allow us to more easily test our billing code.

This initial cut is small and focuses just on making
`StripeBilling::find_or_create_customer_by_email` testable. I'll follow
up with using the `StripeClient` in more places.

Release Notes:

- N/A
2025-05-28 18:34:44 +00:00
Danilo Leal
68724ea99e agent: Make clicking on the backdrop to dismiss message editing more reliable (#31614)
Previously, the click on the backdrop to dismiss the message editing was
unreliable. You would click on it and sometimes it would work and others
it wouldn't. This PR fixes that now.

Release Notes:

- agent: Fixes the previous message dismissal by clicking on the
backdrop
2025-05-28 15:29:52 -03:00
Danilo Leal
e12106e025 agent: Move focus to the panel after dismissing a user message edit (#31611)
Previously, when you clicked on a previous message to edit it and then
dismissed it, your focus would jump to the buffer. This caught me
several times as the most obvious place to return to for me was the
agent panel main message editor, so I can continue prompting something
else. And this is what this PR changes.

Release Notes:

- agent: Improved previous message editing UX by returning focus to the
main panel's text area after dismissing it.
2025-05-28 15:24:58 -03:00
Umesh Yadav
77aa667bf3 docs: Update LM Studio docs to show tool use is supported (#31610)
As the lmstudio tool call support was added recently:
https://github.com/zed-industries/zed/pull/30589. This updates the doc
to reflect it.

Release Notes:

- N/A
2025-05-28 20:09:20 +02:00
Peter Tripp
8b47b40dc0 Improve AI GitHub Issue template (#31598)
Release Notes:

- N/A
2025-05-28 13:54:07 -04:00
Max Brunsfeld
01990c8375 Bump Tree-sitter to 0.25.5 for YAML-editing crash fix (#31603)
Closes https://github.com/zed-industries/zed/issues/31380

See https://github.com/tree-sitter/tree-sitter/pull/4472 for the fix

Release Notes:

- Fixed a crash that could occur when editing YAML files.
2025-05-28 10:12:27 -07:00
Umesh Yadav
4e7dc37f01 language_models: Remove handling of WrappedTextContent in tool result content (#31605)
Fixes ci pipeline

Release Notes:

- N/A
2025-05-28 16:43:08 +00:00
Richard Feldman
00fd045844 Make language model deserialization more resilient (#31311)
This expands our deserialization of JSON from models to be more tolerant
of different variations that the model may send, including
capitalization, wrapping things in objects vs. being plain strings, etc.

Also when deserialization fails, it reports the entire error in the JSON
so we can see what failed to deserialize. (Previously these errors were
very unhelpful at diagnosing the problem.)

Finally, also removes the `WrappedText` variant since the custom
deserializer just turns that style of JSON into a normal `Text` variant.

Release Notes:

- N/A
2025-05-28 12:06:07 -04:00
Joseph T. Lyons
7443fde4e9 Show version info when downloading and installing updates (#31568)
Follow up to #31179 

In addition to seeing the version when in the `Click to restart and
update Zed` status, this PR allows us to see the version when in
`Downloading Zed update…` or `Installing Zed update…` status, in a
tooltip, when hovering on the activity indicator.

Will merge after tomorrow's release.

Release Notes:

- Added version information, in a tooltip, when hovering on the activity
indicator for both the download and install status.
2025-05-28 11:51:21 -04:00
Joseph T. Lyons
d5ab42aeb8 Clean up some auto updater code (#31543)
This PR simply does a tiny bit of cleanup on some code, where I wasn't
quite happy with the naming and ordering of parameters of the now
`check_if_fetched_version_is_newer` function. There should be no
functional changes here, but I will wait until after tomorrow's release
to merge.

Release Notes:

- N/A
2025-05-28 11:46:41 -04:00
Kirill Bulatov
07403f0b08 Improve LSP tasks ergonomics (#31551)
* stopped fetching LSP tasks for too long (but still use the hardcoded
value for the time being — the LSP tasks settings part is a simple bool
key and it's not very simple to fit in another value there)

* introduced `prefer_lsp` language task settings value, to control
whether in the gutter/modal/both/none LSP tasks are shown exclusively,
if possible

Release Notes:

- Added a way to prefer LSP tasks over Zed tasks
2025-05-28 18:36:25 +03:00
Remco Smits
00bc154c46 debugger: Fix invalid schema for pathMappings (#31595)
See
https://github.com/xdebug/vscode-php-debug?tab=readme-ov-file#remote-host-debugging

Release Notes:

- Debugger Beta: Fixed invalid schema for `pathMappings`
2025-05-28 15:16:12 +00:00
Joseph T. Lyons
f627ac92ee Bump Zed to v0.190 (#31592)
Release Notes:

-N/A
2025-05-28 14:36:50 +00:00
Cole Miller
218e8d09c5 Revert "Fix text wrapping in commit message editors (#31030)" (#31587)
This reverts commit f2601ce52c.

Release Notes:

- N/A
2025-05-28 10:16:34 -04:00
Peter Tripp
2c4b75ab30 Remove agent label for github issues (#31591)
Release Notes:

- N/A
2025-05-28 14:09:35 +00:00
Anthony Eid
aab76208b5 debugger beta: Fix bug where debug Rust main running action failed (#31291)
@osiewicz @SomeoneToIgnore If you guys have time to look this over it
would be greatly appreciated. I wanted to move the bug fix into the task
resolution code but wasn't sure if there was a reason that we didn't
already.

The bug is caused by an env variable being empty when we send it as a
terminal command. When the shell resolves all the env variables there's
an extra space that gets added due to the empty env variable being
placed between two other variables.

Closes #31240

Release Notes:

- debugger beta: Fix a bug where debug main Rust runner action wouldn't
work
2025-05-28 13:59:48 +00:00
Umesh Yadav
f3f0766242 assistant_tools: Remove description.md files of removed tools (#31586)
This pull request removes orphaned description.md files for tools that
were deleted in [PR
#29808](https://github.com/zed-industries/zed/pull/29808). These
descriptions are no longer needed as their corresponding tools no longer
exist
Closes #ISSUE

Release Notes:

- N/A
2025-05-28 09:55:38 -04:00
Ben Brandt
148e9adec2 Revert "agent: Namespace MCP server tools" (#31588)
Reverts zed-industries/zed#30600
2025-05-28 13:25:53 +00:00
Alvaro Parker
e314963f5b agent: Add max mode on text threads (#31361)
Related discussions #30240 #30596

Release Notes:

- Added the ability to use max mode on text threads.

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
2025-05-28 10:12:38 -03:00
Antonio Scandurra
957e4adc3f Fix lag when interacting with MarkdownElement (#31585)
Previously, we forgot to associate the `Markdown` entity to
`MarkdownElement` during `prepaint`. This caused calls to
`Context<Markdown>::notify` to not invalidate the view cache, which
meant we would have to wait for some other invalidation before seeing
the results of that initial notify.

Release Notes:

- Improved responsiveness of mouse interactions with the agent panel.

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-05-28 12:50:45 +00:00
Anthony Eid
fee6f13887 debugger: Fix go locator creating false scenarios (#31583)
This caused other locators to fail because go would accept build tasks
that it couldn't actually resolve

Release Notes:

- N/A
2025-05-28 12:34:14 +00:00
Antonio Scandurra
4f78165ee8 Show progress as the agent locates which range it needs to edit (#31582)
Release Notes:

- Improved latency when the agent starts streaming edits.

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-05-28 12:32:54 +00:00
Alex
94a5fe265d debugger: Improve Go support (#31559)
Supersedes https://github.com/zed-industries/zed/pull/31345 
This PR does not have any terminal/console related stuff so that it can
be solved separately.

Introduces inline hints in debugger:
<img width="1141" alt="image"
src="https://github.com/user-attachments/assets/b0575f1e-ddf8-41fe-8958-2da6d4974912"
/>
Adds locators for go, so that you can your app in debug mode:
<img width="706" alt="image"
src="https://github.com/user-attachments/assets/df29bba5-8264-4bea-976f-686c32a5605b"
/>
As well is allows you to specify an existing compiled binary:
<img width="604" alt="image"
src="https://github.com/user-attachments/assets/548f2ab5-88c1-41fb-af84-115a19e685ea"
/>

Release Notes:

- Added inline value hints for Go debugging, displaying variable values
directly in the editor during debug sessions
- Added Go debug locator support, enabling debugging of Go applications
through task templates
- Improved Go debug adapter to support both source debugging (mode:
"debug") and binary execution (mode: "exec") based on program path

cc @osiewicz, @Anthony-Eid
2025-05-28 12:59:05 +02:00
Piotr Osiewicz
c0a5ace8b8 debugger: Add locator for Python tasks (#31533)
Closes #ISSUE

Release Notes:

- debugger: Python tests/main functions can now we debugged from the
gutter.

---------

Co-authored-by: Kirill Bulatov <kirill@zed.dev>
2025-05-28 12:27:12 +02:00
Smit Barmase
15d59fcda9 vim: Fix crash when using ‘ge’ motion on multibyte character (#31566)
Closes #30919

- [x] Test

Release Notes:

- Fixed the issue where using the Vim motion `ge` on multibyte character
would cause Zed to crash.
2025-05-28 06:30:51 +05:30
Smit Barmase
6545c5ebe0 linux: Fix crash when switching repository via git panel (#31556)
Closes #30409

Handles edge case where `f32` turns into `Nan` and causes panic down the
code.

Release Notes:

- Fixed issue where Zed crashes on switching repository via git panel on
Linux.
2025-05-28 05:26:00 +05:30
Michael Sloan
506beafe10 Add caching of parsed completion documentation markdown to reduce flicker when selecting (#31546)
Related to #31460 and #28635.

Release Notes:

- Fixed redraw delay of documentation from language server completions
and added caching to reduce flicker when using arrow keys to change
selection.
2025-05-27 23:12:38 +00:00
tongjicoder
31d908fc74 Remove redundant words in comments (#31512)
remove redundant word in comment


Release Notes:

- N/A

Signed-off-by: tongjicoder <tongjicoder@icloud.com>
2025-05-27 23:01:31 +00:00
Danilo Leal
0731097ee5 agent: Improve consecutive tool call UX and rebrand Max Mode (#31470)
This PR improves the consecutive tool call UX by allowing users to
quickly continue an interrupted with one-click. What we do here is
insert a hidden "Continue" message that will just nudge the LLM to keep
going. We're also using the opportunity to upsell the previously called
"Max Mode", now rebranded as "Burn Mode", which allows users to don't be
interrupted anymore if they ever have 25 consecutive tool calls again.

Release Notes:

- agent: Improve consecutive tool call UX by allowing users to quickly
continue an interrupted thread with one click.

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Agus Zubiaga <hi@aguz.me>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-05-27 19:44:10 -03:00
Finn Evers
233b73b385 ui: Implement hover color for scrollbar component (#25525)
This PR implements color changing for the scrollbar component based upon
user mouse interaction.


https://github.com/user-attachments/assets/2fd14e2d-cc5c-4272-906e-bd39bfb007e4


This PR also already adds the state for a scrollbar being actively
dragged. However, as themes currently do not provide a color for this
scenario, this implementation re-uses the hover color as a placeholder
instead. If this feature is at all wanted, I can quickly open up a
follow-up PR which adds support for that property to themes as well as
this component.

Release Notes:

- Added hover state to scrollbars outside of the editor.
2025-05-27 18:16:04 -04:00
Marshall Bowers
0145e2c101 inline_completion_button: Fix links to account page (#31558)
This PR fixes an issue where the various links to the account page from
the Edit Prediction menu were not working.

The `OpenZedUrl` action is opening URLs that deep-link _into_ Zed.

Fixes https://github.com/zed-industries/zed/issues/31060.

Release Notes:

- Fixed an issue with opening links to the Zed account page from the
Edit Prediction menu.
2025-05-27 21:52:42 +00:00
Marshall Bowers
09fc64e0c5 collab: Downgrade non-collab queries to READ COMMITTED isolation level (#31552)
This PR downgrades a number of database queries that aren't part of the
actual collaboration from `SERIALIZABLE` to `READ COMMITTED`.

The serializable isolation level is overkill for these queries.

Release Notes:

- N/A
2025-05-27 17:02:27 -04:00
Marshall Bowers
fc803ce9d4 collab: Increase max database connections to 250 (#31553)
This PR increases the number of max database connections to 250.

Release Notes:

- N/A
2025-05-27 16:48:50 -04:00
Max Brunsfeld
697c2ba71f Enable merge conflict parsing for currently-unmerged files (#31549)
Previously, we only enabled merge conflict parsing for files that were
unmerged at the last time a change was detected to the repo's merge
heads. Now we enable the parsing for these files *and* any files that
are currently unmerged.

The old strategy meant that conflicts produced via `git stash pop` would
not be parsed.

Release Notes:

- Fixed parsing of merge conflicts when the conflict was produced by a
`git stash pop`
2025-05-27 13:34:39 -07:00
Richard Feldman
f54c057001 Add warning message when editing a message in a thread (#31508)
<img width="479" alt="Screenshot 2025-05-27 at 9 42 44 AM"
src="https://github.com/user-attachments/assets/7bd9e1b9-26b4-4396-9f93-e92a5f4ac2e1"
/>

Release Notes:

- Added notice that editing a message in the agent panel will restart
the thread from that point.

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
2025-05-27 16:26:47 -04:00
Marshall Bowers
32848e9c8a collab: Add support for overage billing for Claude Opus 4 (#31544)
This PR adds support for billing for overages for Claude Opus 4.

Release Notes:

- N/A
2025-05-27 18:37:57 +00:00
Anthony Eid
86b75759d1 debugger beta: Autoscroll to recently saved debug scenario when saving a scenario (#31528)
I added a test to this too as one of my first steps of improving
`NewSessionModal`'s test coverage.


Release Notes:

- debugger beta: Select saved debug config when opening debug.json from
`NewSessionModal`
2025-05-27 21:35:17 +03:00
Kirill Bulatov
94c006236e Properly handle ignored files in the file finder (#31542)
Follow-up of https://github.com/zed-industries/zed/pull/31457

Add a button and also allows to use `search::ToggleIncludeIgnored`
action in the file finder to toggle whether to show gitignored files or
not.
By default, returns back to the gitignored treatment before the PR
above.


![image](https://github.com/user-attachments/assets/c3117488-9c51-4b34-b630-42098fe14b4d)


Release Notes:

- Improved file finder to include indexed gitignored files in its search
results
2025-05-27 18:34:28 +00:00
Julia Ryan
5b6b911946 nix: Refactor gh-actions and re-enable nightly builds (#31489)
Now that the nix build is working again, re-enable nightly builds and
refactor the workflow for re-use between nightly releases and CI jobs.

Release Notes:

- N/A

---------

Co-authored-by: Rahul Butani <rrbutani@users.noreply.github.com>
2025-05-27 11:34:15 -07:00
Ben Kunkle
b9a5d437db Cursor settings import (#31424)
Closes #ISSUE

Release Notes:

- Added support for importing settings from cursor. Cursor settings can
be imported using the `zed: import cursor settings` command from the
command palette
2025-05-27 14:14:25 -04:00
Bennet Bo Fenner
21bd91a773 agent: Namespace MCP server tools (#30600)
This fixes an issue where requests were failing when MCP servers were
registering tools with the same name.
We now prefix the tool names with the context server name, in the UI we
still show the name that the MCP server gives us

Release Notes:

- agent: Fix an error were requests would fail if two MCP servers were
using an identical tool name
2025-05-27 17:47:44 +00:00
Piotr Osiewicz
5db14d315b task: Wrap programs in ""s (#31537)
This commit effectively re-implements #21981 in task system. commands
with spaces cannot be spawned currently, and we don't want to have to
deal with shell variables wrapped in "" in DAP locators.

Closes #ISSUE

Release Notes:

- Fixed an issue where tasks with spaces in `command` field could not be
spawned.
2025-05-27 19:33:16 +02:00
Anthony Eid
b63cea1f17 debugger beta: Fix gdb/delve JSON data conversion from New Session Modal (#31501)
test that check's that each conversion works properly based on the
adapter's config validation function. 

Co-authored-by: Zed AI \<ai@zed.dev\>

Release Notes:

- debugger beta: Fix bug where Go/GDB configuration's wouldn't work from
NewSessionModal
2025-05-27 17:28:41 +00:00
5brian
b7c5540075 git_ui: Replace spaces with hyphens in new branch names (#27873)
This PR improves UX by converting spaces to hyphens, following branch
naming conventions and allowing users to create branches without
worrying about naming restrictions.

I think a few other git tools do this, which was nice.



![image](https://github.com/user-attachments/assets/db40ec31-e461-4ab3-a3de-e249559994fc)

Release Notes:

- Updated the branch picker to convert spaces to hyphens when creating
new branch names.

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
2025-05-27 17:10:20 +00:00
Cole Miller
b01f7c848b Make it possible to use cargo-zigbuild for ZED_BUILD_REMOTE_SERVER (#31467)
This is significantly faster for me than using Cross.

Release Notes:

- N/A
2025-05-27 16:56:27 +00:00
Finn Evers
3476705bbb docs_preprocessor: Ensure keybind is found for actions with arguments (#27224)
Tried fixing a keybind in
https://github.com/zed-industries/zed/pull/27217 just to find out it
[still doesnt render
afterwards](https://zed.dev/docs/extensions/languages#language-metadata)
😅 This PR is a quick follow-up to fix this issue.

Issue here is (as seen in the code comment) that the
`editor::ToggleComments` command has additional arguments which caused
the match to fail. However, simply adding the missing arguments does not
work, since the regex only matches the first closing brace and fails to
match multiple closing braces. I decided against changing the matching
since it additionally looked confusing and unintuitive to use.

To not be too intrusive with this change, I just decided to add some
processing for the action string (the `KeymapAction` is not exported
from the settings and the `Value` it holds is also private). The
processing basically reverts the conversion done in `keymap_file.rs`
4b5df2189b/crates/settings/src/keymap_file.rs (L102-L115)
and extracts just the action name. It changes nothing for existing
keybinds and fixes the aforementioned issue.

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
2025-05-27 16:56:03 +00:00
张小白
ba6b5a59f9 windows: Fix title bar not responsing (#31532)
Closes #31431

Release Notes:

- N/A
2025-05-27 16:43:48 +00:00
Antonio Scandurra
28d6362964 Revert "Highlight file finder entries according to their git status" (#31529)
Reverts zed-industries/zed#31469

This isn't looking great, so reverting for now.

/cc @SomeoneToIgnore
2025-05-27 16:10:49 +00:00
233 changed files with 10088 additions and 3478 deletions

View File

@@ -1,8 +1,8 @@
name: Bug Report (Agent Panel)
name: Bug Report (AI Related)
description: Zed Agent Panel Bugs
type: "Bug"
labels: ["agent", "ai"]
title: "Agent Panel: <a short description of the Agent Panel bug>"
labels: ["ai"]
title: "AI: <a short description of the AI Related bug>"
body:
- type: textarea
attributes:
@@ -14,7 +14,6 @@ body:
### Description
<!-- Describe with sufficient detail to reproduce from a clean Zed install. -->
<!-- Please include the LLM provider and model name you are using -->
Steps to trigger the problem:
1.
2.
@@ -22,6 +21,13 @@ body:
Actual Behavior:
Expected Behavior:
### Model Provider Details
- Provider: (Anthropic via ZedPro, Anthropic via API key, Copilot Chat, Mistral, OpenAI, etc)
- Model Name:
- Mode: (Agent Panel, Inline Assistant, Terminal Assistant or Text Threads)
- MCP Servers in-use:
- Other Details:
validations:
required: true

View File

@@ -714,48 +714,13 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
nix-build:
timeout-minutes: 60
name: Nix Build
continue-on-error: true
name: Build with Nix
uses: ./.github/workflows/nix.yml
if: github.repository_owner == 'zed-industries' && contains(github.event.pull_request.labels.*.name, 'run-nix')
strategy:
fail-fast: false
matrix:
system:
- os: x86 Linux
runner: buildjet-16vcpu-ubuntu-2204
install_nix: true
- os: arm Mac
runner: [macOS, ARM64, test]
install_nix: false
runs-on: ${{ matrix.system.runner }}
env:
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
GIT_LFS_SKIP_SMUDGE: 1 # breaks the livekit rust sdk examples which we don't actually depend on
steps:
- name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
with:
clean: false
- name: Set path
if: ${{ ! matrix.system.install_nix }}
run: |
echo "/nix/var/nix/profiles/default/bin" >> $GITHUB_PATH
echo "/Users/administrator/.nix-profile/bin" >> $GITHUB_PATH
- uses: cachix/install-nix-action@d1ca217b388ee87b2507a9a93bf01368bde7cec2 # v31
if: ${{ matrix.system.install_nix }}
with:
github_access_token: ${{ secrets.GITHUB_TOKEN }}
- uses: cachix/cachix-action@0fc020193b5a1fa3ac4575aa3a7d3aa6a35435ad # v16
with:
name: zed-industries
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
skipPush: true
- run: nix build .#debug
- name: Limit /nix/store to 50GB
run: "[ $(du -sm /nix/store | cut -f1) -gt 50000 ] && nix-collect-garbage -d"
with:
flake-output: debug
# excludes the final package to only cache dependencies
cachix-filter: "-zed-editor-[0-9.]*-nightly"
auto-release-preview:
name: Auto release preview

66
.github/workflows/nix.yml vendored Normal file
View File

@@ -0,0 +1,66 @@
name: "Nix build"
on:
workflow_call:
inputs:
flake-output:
type: string
default: "default"
cachix-filter:
type: string
default: ""
jobs:
nix-build:
timeout-minutes: 60
name: (${{ matrix.system.os }}) Nix Build
continue-on-error: true # TODO: remove when we want this to start blocking CI
strategy:
fail-fast: false
matrix:
system:
- os: x86 Linux
runner: buildjet-16vcpu-ubuntu-2204
install_nix: true
- os: arm Mac
runner: [macOS, ARM64, test]
install_nix: false
if: github.repository_owner == 'zed-industries'
runs-on: ${{ matrix.system.runner }}
env:
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
GIT_LFS_SKIP_SMUDGE: 1 # breaks the livekit rust sdk examples which we don't actually depend on
steps:
- name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
with:
clean: false
# on our macs we manually install nix. for some reason the cachix action is running
# under a non-login /bin/bash shell which doesn't source the proper script to add the
# nix profile to PATH, so we manually add them here
- name: Set path
if: ${{ ! matrix.system.install_nix }}
run: |
echo "/nix/var/nix/profiles/default/bin" >> $GITHUB_PATH
echo "/Users/administrator/.nix-profile/bin" >> $GITHUB_PATH
- uses: cachix/install-nix-action@02a151ada4993995686f9ed4f1be7cfbb229e56f # v31
if: ${{ matrix.system.install_nix }}
with:
github_access_token: ${{ secrets.GITHUB_TOKEN }}
- uses: cachix/cachix-action@0fc020193b5a1fa3ac4575aa3a7d3aa6a35435ad # v16
with:
name: zed
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
pushFilter: "${{ inputs.cachix-filter }}"
cachixArgs: '-v'
- run: nix build .#${{ inputs.flake-output }} -L --accept-flake-config
- name: Limit /nix/store to 50GB on macs
if: ${{ ! matrix.system.install_nix }}
run: |
[ $(du -sm /nix/store | cut -f1) -gt 50000 ] && nix-collect-garbage -d || :

View File

@@ -167,6 +167,11 @@ jobs:
- name: Upload Zed Nightly
run: script/upload-nightly linux-targz
bundle-nix:
name: Build and cache Nix package
needs: tests
uses: ./.github/workflows/nix.yml
update-nightly-tag:
name: Update nightly tag
if: github.repository_owner == 'zed-industries'

View File

@@ -2,16 +2,11 @@
{
"label": "Debug Zed (CodeLLDB)",
"adapter": "CodeLLDB",
"program": "$ZED_WORKTREE_ROOT/target/debug/zed",
"request": "launch"
"build": { "label": "Build Zed", "command": "cargo", "args": ["build"] }
},
{
"label": "Debug Zed (GDB)",
"adapter": "GDB",
"program": "$ZED_WORKTREE_ROOT/target/debug/zed",
"request": "launch",
"initialize_args": {
"stopAtBeginningOfMainSubprogram": true
}
"build": { "label": "Build Zed", "command": "cargo", "args": ["build"] }
}
]

27
Cargo.lock generated
View File

@@ -559,6 +559,7 @@ dependencies = [
"workspace",
"workspace-hack",
"zed_actions",
"zed_llm_client",
]
[[package]]
@@ -658,9 +659,9 @@ name = "assistant_tools"
version = "0.1.0"
dependencies = [
"agent_settings",
"aho-corasick",
"anyhow",
"assistant_tool",
"async-watch",
"buffer_diff",
"chrono",
"client",
@@ -683,6 +684,7 @@ dependencies = [
"language_model",
"language_models",
"log",
"lsp",
"markdown",
"open",
"paths",
@@ -4031,6 +4033,8 @@ dependencies = [
"smol",
"task",
"telemetry",
"tree-sitter",
"tree-sitter-go",
"util",
"workspace-hack",
"zlog",
@@ -4730,6 +4734,7 @@ dependencies = [
"tree-sitter-rust",
"tree-sitter-typescript",
"ui",
"unicode-script",
"unicode-segmentation",
"unindent",
"url",
@@ -5043,6 +5048,7 @@ dependencies = [
"util",
"uuid",
"workspace-hack",
"zed_llm_client",
]
[[package]]
@@ -5381,8 +5387,10 @@ dependencies = [
"language",
"menu",
"picker",
"pretty_assertions",
"project",
"schemars",
"search",
"serde",
"serde_derive",
"serde_json",
@@ -6143,6 +6151,7 @@ dependencies = [
"workspace",
"workspace-hack",
"zed_actions",
"zed_llm_client",
"zlog",
]
@@ -8925,6 +8934,7 @@ dependencies = [
"async-compression",
"async-tar",
"async-trait",
"chrono",
"collections",
"dap",
"futures 0.3.31",
@@ -8978,6 +8988,7 @@ dependencies = [
"tree-sitter-yaml",
"unindent",
"util",
"which 6.0.3",
"workspace",
"workspace-hack",
]
@@ -9565,6 +9576,7 @@ dependencies = [
"assets",
"base64 0.22.1",
"env_logger 0.11.8",
"futures 0.3.31",
"gpui",
"language",
"languages",
@@ -15576,6 +15588,7 @@ dependencies = [
"futures 0.3.31",
"gpui",
"hex",
"log",
"parking_lot",
"pretty_assertions",
"proto",
@@ -16475,9 +16488,9 @@ dependencies = [
[[package]]
name = "tree-sitter"
version = "0.25.3"
version = "0.25.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9ac5ea5e7f2f1700842ec071401010b9c59bf735295f6e9fa079c3dc035b167"
checksum = "ac5fff5c47490dfdf473b5228039bfacad9d765d9b6939d26bf7cc064c1c7822"
dependencies = [
"cc",
"regex",
@@ -17110,8 +17123,6 @@ dependencies = [
"tempfile",
"tendril",
"unicase",
"unicode-script",
"unicode-segmentation",
"util_macros",
"walkdir",
"workspace-hack",
@@ -19675,7 +19686,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.189.0"
version = "0.190.0"
dependencies = [
"activity_indicator",
"agent",
@@ -19871,9 +19882,9 @@ dependencies = [
[[package]]
name = "zed_llm_client"
version = "0.8.3"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22a8b9575b215536ed8ad254ba07171e4e13bd029eda3b54cca4b184d2768050"
checksum = "de7d9523255f4e00ee3d0918e5407bd252d798a4a8e71f6d37f23317a1588203"
dependencies = [
"anyhow",
"serde",

View File

@@ -572,7 +572,7 @@ tokio = { version = "1" }
tokio-tungstenite = { version = "0.26", features = ["__rustls-tls"] }
toml = "0.8"
tower-http = "0.4.4"
tree-sitter = { version = "0.25.3", features = ["wasm"] }
tree-sitter = { version = "0.25.5", features = ["wasm"] }
tree-sitter-bash = "0.23"
tree-sitter-c = "0.23"
tree-sitter-cpp = "0.23"
@@ -617,7 +617,7 @@ wasmtime = { version = "29", default-features = false, features = [
wasmtime-wasi = "29"
which = "6.0.0"
workspace-hack = "0.1.0"
zed_llm_client = "0.8.3"
zed_llm_client = "0.8.4"
zstd = "0.11"
[workspace.dependencies.async-stripe]

View File

@@ -0,0 +1,3 @@
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M4.99207 8.14741C5.37246 8.14741 5.73726 7.9963 6.00623 7.72733C6.27521 7.45836 6.42631 7.09355 6.42631 6.71317C6.42631 5.92147 6.13946 5.56578 5.85262 4.99208C5.23761 3.76265 5.72411 2.66631 7.00001 1.5499C7.28686 2.98414 8.1474 4.36101 9.2948 5.27893C10.4422 6.19684 11.0159 7.28687 11.0159 8.43426C11.0159 8.96163 10.912 9.48384 10.7102 9.97107C10.5084 10.4583 10.2126 10.901 9.83967 11.2739C9.46676 11.6468 9.02405 11.9426 8.53682 12.1444C8.04959 12.3463 7.52738 12.4501 7.00001 12.4501C6.47264 12.4501 5.95043 12.3463 5.4632 12.1444C4.97597 11.9426 4.53326 11.6468 4.16035 11.2739C3.78745 10.901 3.49164 10.4583 3.28982 9.97107C3.088 9.48384 2.98413 8.96163 2.98413 8.43426C2.98413 7.77279 3.23254 7.1182 3.55783 6.71317C3.55783 7.09355 3.70894 7.45836 3.97791 7.72733C4.24688 7.9963 4.61169 8.14741 4.99207 8.14741Z" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

After

Width:  |  Height:  |  Size: 1018 B

View File

@@ -0,0 +1,13 @@
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_2595_5640)">
<path d="M4.99207 8.14741C5.37246 8.14741 5.73726 7.9963 6.00623 7.72733C6.27521 7.45836 6.42631 7.09355 6.42631 6.71317C6.42631 5.92147 6.13946 5.56578 5.85262 4.99208C5.23761 3.76265 5.72411 2.66631 7.00001 1.5499C7.28686 2.98414 8.1474 4.36101 9.2948 5.27893C10.4422 6.19684 11.0159 7.28687 11.0159 8.43426C11.0159 8.96163 10.912 9.48384 10.7102 9.97107C10.5084 10.4583 10.2126 10.901 9.83967 11.2739C9.46676 11.6468 9.02405 11.9426 8.53682 12.1444C8.04959 12.3463 7.52738 12.4501 7.00001 12.4501C6.47264 12.4501 5.95043 12.3463 5.4632 12.1444C4.97597 11.9426 4.53326 11.6468 4.16035 11.2739C3.78745 10.901 3.49164 10.4583 3.28982 9.97107C3.088 9.48384 2.98413 8.96163 2.98413 8.43426C2.98413 7.77279 3.23254 7.1182 3.55783 6.71317C3.55783 7.09355 3.70894 7.45836 3.97791 7.72733C4.24688 7.9963 4.61169 8.14741 4.99207 8.14741Z" fill="black" fill-opacity="0.5" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M2 4C2.55228 4 3 3.55228 3 3C3 2.44772 2.55228 2 2 2C1.44772 2 1 2.44772 1 3C1 3.55228 1.44772 4 2 4Z" fill="black"/>
<path d="M10 2C10.5523 2 11 1.55228 11 1C11 0.44772 10.5523 0 10 0C9.44772 0 9 0.44772 9 1C9 1.55228 9.44772 2 10 2Z" fill="black"/>
<path d="M13 5C13.5522 5 14 4.55228 14 4C14 3.44772 13.5522 3 13 3C12.4478 3 12 3.44772 12 4C12 4.55228 12.4478 5 13 5Z" fill="black"/>
</g>
<defs>
<clipPath id="clip0_2595_5640">
<rect width="14" height="14" fill="white"/>
</clipPath>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 1.6 KiB

View File

@@ -1,14 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_2489_484)">
<path d="M11 8.9V11C8.51716 11 7.48284 11 5 11V10.4L11 5.6V5H5V7.1" stroke="black" stroke-width="1.5"/>
<path d="M1.5 5.5V1.5H5" stroke="black" stroke-opacity="0.5" stroke-width="1.5"/>
<path d="M14.5 5.5V1.5H11" stroke="black" stroke-opacity="0.5" stroke-width="1.5"/>
<path d="M1.5 10.5V14.5H5" stroke="black" stroke-opacity="0.5" stroke-width="1.5"/>
<path d="M14.5 10.5V14.5H11" stroke="black" stroke-opacity="0.5" stroke-width="1.5"/>
</g>
<defs>
<clipPath id="clip0_2489_484">
<rect width="16" height="16" fill="white"/>
</clipPath>
</defs>
</svg>

Before

Width:  |  Height:  |  Size: 687 B

View File

@@ -127,9 +127,7 @@
"shift-f10": "editor::OpenContextMenu",
"ctrl-shift-e": "editor::ToggleEditPrediction",
"f9": "editor::ToggleBreakpoint",
"shift-f9": "editor::EditLogBreakpoint",
"ctrl-shift-backspace": "editor::GoToPreviousChange",
"ctrl-shift-alt-backspace": "editor::GoToNextChange"
"shift-f9": "editor::EditLogBreakpoint"
}
},
{
@@ -148,6 +146,8 @@
"ctrl->": "assistant::QuoteSelection",
"ctrl-<": "assistant::InsertIntoEditor",
"ctrl-alt-e": "editor::SelectEnclosingSymbol",
"ctrl-shift-backspace": "editor::GoToPreviousChange",
"ctrl-shift-alt-backspace": "editor::GoToNextChange",
"alt-enter": "editor::OpenSelectionsInMultibuffer"
}
},
@@ -244,11 +244,14 @@
"ctrl-i": "agent::ToggleProfileSelector",
"ctrl-alt-/": "agent::ToggleModelSelector",
"ctrl-shift-a": "agent::ToggleContextPicker",
"ctrl-shift-o": "agent::ToggleNavigationMenu",
"ctrl-shift-j": "agent::ToggleNavigationMenu",
"ctrl-shift-i": "agent::ToggleOptionsMenu",
"shift-alt-escape": "agent::ExpandMessageEditor",
"ctrl-alt-e": "agent::RemoveAllContext",
"ctrl-shift-e": "project_panel::ToggleFocus"
"ctrl-shift-e": "project_panel::ToggleFocus",
"ctrl-shift-enter": "agent::ContinueThread",
"alt-enter": "agent::ContinueWithBurnMode",
"ctrl-alt-b": "agent::ToggleBurnMode"
}
},
{
@@ -1016,5 +1019,12 @@
"bindings": {
"enter": "menu::Confirm"
}
},
{
"context": "RunModal",
"bindings": {
"ctrl-tab": "pane::ActivateNextItem",
"ctrl-shift-tab": "pane::ActivatePreviousItem"
}
}
]

View File

@@ -279,11 +279,14 @@
"cmd-i": "agent::ToggleProfileSelector",
"cmd-alt-/": "agent::ToggleModelSelector",
"cmd-shift-a": "agent::ToggleContextPicker",
"cmd-shift-o": "agent::ToggleNavigationMenu",
"cmd-shift-j": "agent::ToggleNavigationMenu",
"cmd-shift-i": "agent::ToggleOptionsMenu",
"shift-alt-escape": "agent::ExpandMessageEditor",
"cmd-alt-e": "agent::RemoveAllContext",
"cmd-shift-e": "project_panel::ToggleFocus"
"cmd-shift-e": "project_panel::ToggleFocus",
"cmd-shift-enter": "agent::ContinueThread",
"alt-enter": "agent::ContinueWithBurnMode",
"cmd-alt-b": "agent::ToggleBurnMode"
}
},
{
@@ -543,9 +546,7 @@
"cmd-\\": "pane::SplitRight",
"cmd-k v": "markdown::OpenPreviewToTheSide",
"cmd-shift-v": "markdown::OpenPreview",
"ctrl-cmd-c": "editor::DisplayCursorNames",
"cmd-shift-backspace": "editor::GoToPreviousChange",
"cmd-shift-alt-backspace": "editor::GoToNextChange"
"ctrl-cmd-c": "editor::DisplayCursorNames"
}
},
{
@@ -553,7 +554,9 @@
"use_key_equivalents": true,
"bindings": {
"cmd-shift-o": "outline::Toggle",
"ctrl-g": "go_to_line::Toggle"
"ctrl-g": "go_to_line::Toggle",
"cmd-shift-backspace": "editor::GoToPreviousChange",
"cmd-shift-alt-backspace": "editor::GoToNextChange"
}
},
{
@@ -1106,5 +1109,13 @@
"bindings": {
"enter": "menu::Confirm"
}
},
{
"context": "RunModal",
"use_key_equivalents": true,
"bindings": {
"ctrl-tab": "pane::ActivateNextItem",
"ctrl-shift-tab": "pane::ActivatePreviousItem"
}
}
]

View File

@@ -0,0 +1,85 @@
[
// Cursor for MacOS. See: https://docs.cursor.com/kbd
{
"context": "Workspace",
"use_key_equivalents": true,
"bindings": {
"ctrl-i": "agent::ToggleFocus",
"ctrl-shift-i": "agent::ToggleFocus",
"ctrl-l": "agent::ToggleFocus",
"ctrl-shift-l": "agent::ToggleFocus",
"ctrl-alt-b": "agent::ToggleFocus",
"ctrl-shift-j": "agent::OpenConfiguration"
}
},
{
"context": "Editor && mode == full",
"use_key_equivalents": true,
"bindings": {
"ctrl-i": "agent::ToggleFocus",
"ctrl-shift-i": "agent::ToggleFocus",
"ctrl-shift-l": "assistant::QuoteSelection", // In cursor uses "Ask" mode
"ctrl-l": "assistant::QuoteSelection", // In cursor uses "Agent" mode
"ctrl-k": "assistant::InlineAssist",
"ctrl-shift-k": "assistant::InsertIntoEditor"
}
},
{
"context": "InlineAssistEditor",
"use_key_equivalents": true,
"bindings": {
"ctrl-shift-backspace": "editor::Cancel"
// "alt-enter": // Quick Question
// "ctrl-shift-enter": // Full File Context
// "ctrl-shift-k": // Toggle input focus (editor <> inline assist)
}
},
{
"context": "AgentPanel || ContextEditor || (MessageEditor > Editor)",
"use_key_equivalents": true,
"bindings": {
"ctrl-i": "workspace::ToggleRightDock",
"ctrl-shift-i": "workspace::ToggleRightDock",
"ctrl-l": "workspace::ToggleRightDock",
"ctrl-shift-l": "workspace::ToggleRightDock",
"ctrl-alt-b": "workspace::ToggleRightDock",
"ctrl-w": "workspace::ToggleRightDock", // technically should close chat
"ctrl-.": "agent::ToggleProfileSelector",
"ctrl-/": "agent::ToggleModelSelector",
"ctrl-shift-backspace": "editor::Cancel",
"ctrl-r": "agent::NewThread",
"ctrl-shift-v": "editor::Paste",
"ctrl-shift-k": "assistant::InsertIntoEditor"
// "escape": "agent::ToggleFocus"
///// Enable when Zed supports multiple thread tabs
// "ctrl-t": // new thread tab
// "ctrl-[": // next thread tab
// "ctrl-]": // next thread tab
///// Enable if Zed adds support for keyboard navigation of thread elements
// "tab": // cycle to next message
// "shift-tab": // cycle to previous message
}
},
{
"context": "Editor && editor_agent_diff",
"use_key_equivalents": true,
"bindings": {
"ctrl-enter": "agent::KeepAll",
"ctrl-backspace": "agent::RejectAll"
}
},
{
"context": "Editor && mode == full && edit_prediction",
"use_key_equivalents": true,
"bindings": {
"ctrl-right": "editor::AcceptPartialEditPrediction"
}
},
{
"context": "Terminal",
"use_key_equivalents": true,
"bindings": {
"ctrl-k": "assistant::InlineAssist"
}
}
]

View File

@@ -0,0 +1,85 @@
[
// Cursor for MacOS. See: https://docs.cursor.com/kbd
{
"context": "Workspace",
"use_key_equivalents": true,
"bindings": {
"cmd-i": "agent::ToggleFocus",
"cmd-shift-i": "agent::ToggleFocus",
"cmd-l": "agent::ToggleFocus",
"cmd-shift-l": "agent::ToggleFocus",
"cmd-alt-b": "agent::ToggleFocus",
"cmd-shift-j": "agent::OpenConfiguration"
}
},
{
"context": "Editor && mode == full",
"use_key_equivalents": true,
"bindings": {
"cmd-i": "agent::ToggleFocus",
"cmd-shift-i": "agent::ToggleFocus",
"cmd-shift-l": "assistant::QuoteSelection", // In cursor uses "Ask" mode
"cmd-l": "assistant::QuoteSelection", // In cursor uses "Agent" mode
"cmd-k": "assistant::InlineAssist",
"cmd-shift-k": "assistant::InsertIntoEditor"
}
},
{
"context": "InlineAssistEditor",
"use_key_equivalents": true,
"bindings": {
"cmd-shift-backspace": "editor::Cancel"
// "alt-enter": // Quick Question
// "cmd-shift-enter": // Full File Context
// "cmd-shift-k": // Toggle input focus (editor <> inline assist)
}
},
{
"context": "AgentPanel || ContextEditor || (MessageEditor > Editor)",
"use_key_equivalents": true,
"bindings": {
"cmd-i": "workspace::ToggleRightDock",
"cmd-shift-i": "workspace::ToggleRightDock",
"cmd-l": "workspace::ToggleRightDock",
"cmd-shift-l": "workspace::ToggleRightDock",
"cmd-alt-b": "workspace::ToggleRightDock",
"cmd-w": "workspace::ToggleRightDock", // technically should close chat
"cmd-.": "agent::ToggleProfileSelector",
"cmd-/": "agent::ToggleModelSelector",
"cmd-shift-backspace": "editor::Cancel",
"cmd-r": "agent::NewThread",
"cmd-shift-v": "editor::Paste",
"cmd-shift-k": "assistant::InsertIntoEditor"
// "escape": "agent::ToggleFocus"
///// Enable when Zed supports multiple thread tabs
// "cmd-t": // new thread tab
// "cmd-[": // next thread tab
// "cmd-]": // next thread tab
///// Enable if Zed adds support for keyboard navigation of thread elements
// "tab": // cycle to next message
// "shift-tab": // cycle to previous message
}
},
{
"context": "Editor && editor_agent_diff",
"use_key_equivalents": true,
"bindings": {
"cmd-enter": "agent::KeepAll",
"cmd-backspace": "agent::RejectAll"
}
},
{
"context": "Editor && mode == full && edit_prediction",
"use_key_equivalents": true,
"bindings": {
"cmd-right": "editor::AcceptPartialEditPrediction"
}
},
{
"context": "Terminal",
"use_key_equivalents": true,
"bindings": {
"cmd-k": "assistant::InlineAssist"
}
}
]

View File

@@ -1,3 +1,6 @@
{{!----------------------------------------------------------------------------------
NOTE: Changes to this prompt require a symmetric update in monorepo/request_kind.rs
------------------------------------------------------------------------------------}}
{{#if language_name}}
Here's a file of {{language_name}} that I'm going to ask you to make an edit to.
{{else}}

View File

@@ -1,3 +1,6 @@
{{!----------------------------------------------------------------------------------
NOTE: Changes to this prompt require a symmetric update in monorepo/request_kind.rs
------------------------------------------------------------------------------------}}
You are an expert terminal user.
You will be given a description of a command and you need to respond with a command that matches the description.
Do not include markdown blocks or any other text formatting in your response, always respond with a single command that can be executed in the given shell.

View File

@@ -714,7 +714,7 @@
"version": "2",
// Whether the agent is enabled.
"enabled": true,
/// What completion mode to start new threads in, if available. Can be 'normal' or 'max'.
/// What completion mode to start new threads in, if available. Can be 'normal' or 'burn'.
"preferred_completion_mode": "normal",
// Whether to show the agent panel button in the status bar.
"button": true,
@@ -961,7 +961,15 @@
// Default: true
"skip_focus_for_active_in_search": true,
// Whether to show the git status in the file finder.
"git_status": true
"git_status": true,
// Whether to use gitignored files when searching.
// Only the file Zed had indexed will be used, not necessary all the gitignored files.
//
// Can accept 3 values:
// * `true`: Use all gitignored files
// * `false`: Use only the files Zed had indexed
// * `null`: Be smart and search for ignored when called from a gitignored worktree
"include_ignored": null
},
// Whether or not to remove any trailing whitespace from lines of a buffer
// before saving it.
@@ -1306,7 +1314,17 @@
// Settings related to running tasks.
"tasks": {
"variables": {},
"enabled": true
"enabled": true,
// Use LSP tasks over Zed language extension ones.
// If no LSP tasks are returned due to error/timeout or regular execution,
// Zed language extension tasks will be used instead.
//
// Other Zed tasks will still be shown:
// * Zed task from either of the task config file
// * Zed task from history (e.g. one-off task was spawned before)
//
// Default: true
"prefer_lsp": true
},
// An object whose keys are language names, and whose values
// are arrays of filenames or extensions of files that should
@@ -1444,9 +1462,7 @@
"language_servers": ["erlang-ls", "!elp", "..."]
},
"Git Commit": {
"allow_rewrap": "anywhere",
"preferred_line_length": 72,
"soft_wrap": "bounded"
"allow_rewrap": "anywhere"
},
"Go": {
"code_actions_on_format": {

View File

@@ -311,6 +311,31 @@ impl ActivityIndicator {
});
}
if let Some(session) = self
.project
.read(cx)
.dap_store()
.read(cx)
.sessions()
.find(|s| !s.read(cx).is_started())
{
return Some(Content {
icon: Some(
Icon::new(IconName::ArrowCircle)
.size(IconSize::Small)
.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
)
.into_any_element(),
),
message: format!("Debug: {}", session.read(cx).adapter()),
tooltip_message: Some(session.read(cx).label().to_string()),
on_click: None,
});
}
let current_job = self
.project
.read(cx)
@@ -472,7 +497,7 @@ impl ActivityIndicator {
})),
tooltip_message: None,
}),
AutoUpdateStatus::Downloading => Some(Content {
AutoUpdateStatus::Downloading { version } => Some(Content {
icon: Some(
Icon::new(IconName::Download)
.size(IconSize::Small)
@@ -482,9 +507,9 @@ impl ActivityIndicator {
on_click: Some(Arc::new(|this, window, cx| {
this.dismiss_error_message(&DismissErrorMessage, window, cx)
})),
tooltip_message: None,
tooltip_message: Some(Self::version_tooltip_message(&version)),
}),
AutoUpdateStatus::Installing => Some(Content {
AutoUpdateStatus::Installing { version } => Some(Content {
icon: Some(
Icon::new(IconName::Download)
.size(IconSize::Small)
@@ -494,7 +519,7 @@ impl ActivityIndicator {
on_click: Some(Arc::new(|this, window, cx| {
this.dismiss_error_message(&DismissErrorMessage, window, cx)
})),
tooltip_message: None,
tooltip_message: Some(Self::version_tooltip_message(&version)),
}),
AutoUpdateStatus::Updated {
binary_path,
@@ -508,7 +533,7 @@ impl ActivityIndicator {
};
move |_, _, cx| workspace::reload(&reload, cx)
})),
tooltip_message: Some(Self::install_version_tooltip_message(&version)),
tooltip_message: Some(Self::version_tooltip_message(&version)),
}),
AutoUpdateStatus::Errored => Some(Content {
icon: Some(
@@ -548,8 +573,8 @@ impl ActivityIndicator {
None
}
fn install_version_tooltip_message(version: &VersionCheckType) -> String {
format!("Install version: {}", {
fn version_tooltip_message(version: &VersionCheckType) -> String {
format!("Version: {}", {
match version {
auto_update::VersionCheckType::Sha(sha) => format!("{}", sha.short()),
auto_update::VersionCheckType::Semantic(semantic_version) => {
@@ -699,17 +724,17 @@ mod tests {
use super::*;
#[test]
fn test_install_version_tooltip_message() {
let message = ActivityIndicator::install_version_tooltip_message(
&VersionCheckType::Semantic(SemanticVersion::new(1, 0, 0)),
);
fn test_version_tooltip_message() {
let message = ActivityIndicator::version_tooltip_message(&VersionCheckType::Semantic(
SemanticVersion::new(1, 0, 0),
));
assert_eq!(message, "Install version: 1.0.0");
assert_eq!(message, "Version: 1.0.0");
let message = ActivityIndicator::install_version_tooltip_message(&VersionCheckType::Sha(
let message = ActivityIndicator::version_tooltip_message(&VersionCheckType::Sha(
AppCommitSha::new("14d9a4189f058d8736339b06ff2340101eaea5af".to_string()),
));
assert_eq!(message, "Install version: 14d9a41…");
assert_eq!(message, "Version: 14d9a41…");
}
}

View File

@@ -55,6 +55,7 @@ use util::ResultExt as _;
use util::markdown::MarkdownCodeBlock;
use workspace::{CollaboratorId, Workspace};
use zed_actions::assistant::OpenRulesLibrary;
use zed_llm_client::CompletionIntent;
pub struct ActiveThread {
context_store: Entity<ContextStore>,
@@ -1436,6 +1437,7 @@ impl ActiveThread {
let request = language_model::LanguageModelRequest {
thread_id: None,
prompt_id: None,
intent: None,
mode: None,
messages: vec![request_message],
tools: vec![],
@@ -1533,9 +1535,22 @@ impl ActiveThread {
});
}
fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
fn cancel_editing_message(
&mut self,
_: &menu::Cancel,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.editing_message.take();
cx.notify();
if let Some(workspace) = self.workspace.upgrade() {
workspace.update(cx, |workspace, cx| {
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
panel.focus_handle(cx).focus(window);
}
});
}
}
fn confirm_editing_message(
@@ -1597,7 +1612,12 @@ impl ActiveThread {
this.thread.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.send_to_model(model.model, Some(window.window_handle()), cx);
thread.send_to_model(
model.model,
CompletionIntent::UserPrompt,
Some(window.window_handle()),
cx,
);
});
this._load_edited_message_context_task = None;
cx.notify();
@@ -1778,6 +1798,11 @@ impl ActiveThread {
let Some(message) = self.thread.read(cx).message(message_id) else {
return Empty.into_any();
};
if message.is_hidden {
return Empty.into_any();
}
let message_creases = message.creases.clone();
let Some(rendered_message) = self.rendered_messages_by_id.get(&message_id) else {
@@ -1813,6 +1838,7 @@ impl ActiveThread {
let colors = cx.theme().colors();
let editor_bg_color = colors.editor_background;
let panel_bg = colors.panel_background;
let open_as_markdown = IconButton::new(("open-as-markdown", ix), IconName::DocumentText)
.icon_size(IconSize::XSmall)
@@ -1833,7 +1859,6 @@ impl ActiveThread {
const RESPONSE_PADDING_X: Pixels = px(19.);
let show_feedback = thread.is_turn_end(ix);
let feedback_container = h_flex()
.group("feedback_container")
.mt_1()
@@ -2006,65 +2031,89 @@ impl ActiveThread {
.border_1()
.border_color(colors.border)
.hover(|hover| hover.border_color(colors.text_accent.opacity(0.5)))
.cursor_pointer()
.child(
h_flex()
v_flex()
.p_2p5()
.gap_1()
.items_end()
.children(message_content)
.when_some(editing_message_state, |this, state| {
let focus_handle = state.editor.focus_handle(cx).clone();
this.w_full().justify_between().child(
this.child(
h_flex()
.gap_0p5()
.w_full()
.gap_1()
.justify_between()
.flex_wrap()
.child(
IconButton::new(
"cancel-edit-message",
IconName::Close,
)
.shape(ui::IconButtonShape::Square)
.icon_color(Color::Error)
.icon_size(IconSize::Small)
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
"Cancel Edit",
&menu::Cancel,
&focus_handle,
window,
cx,
)
}
})
.on_click(cx.listener(Self::handle_cancel_click)),
h_flex()
.gap_1p5()
.child(
div()
.opacity(0.8)
.child(
Icon::new(IconName::Warning)
.size(IconSize::Indicator)
.color(Color::Warning)
),
)
.child(
Label::new("Editing will restart the thread from this point.")
.color(Color::Muted)
.size(LabelSize::XSmall),
),
)
.child(
IconButton::new(
"confirm-edit-message",
IconName::Return,
)
.disabled(state.editor.read(cx).is_empty(cx))
.shape(ui::IconButtonShape::Square)
.icon_color(Color::Muted)
.icon_size(IconSize::Small)
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
"Regenerate",
&menu::Confirm,
&focus_handle,
window,
cx,
h_flex()
.gap_0p5()
.child(
IconButton::new(
"cancel-edit-message",
IconName::Close,
)
}
})
.on_click(
cx.listener(Self::handle_regenerate_click),
),
),
.shape(ui::IconButtonShape::Square)
.icon_color(Color::Error)
.icon_size(IconSize::Small)
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
"Cancel Edit",
&menu::Cancel,
&focus_handle,
window,
cx,
)
}
})
.on_click(cx.listener(Self::handle_cancel_click)),
)
.child(
IconButton::new(
"confirm-edit-message",
IconName::Return,
)
.disabled(state.editor.read(cx).is_empty(cx))
.shape(ui::IconButtonShape::Square)
.icon_color(Color::Muted)
.icon_size(IconSize::Small)
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
"Regenerate",
&menu::Confirm,
&focus_handle,
window,
cx,
)
}
})
.on_click(
cx.listener(Self::handle_regenerate_click),
),
),
)
)
}),
)
@@ -2106,16 +2155,14 @@ impl ActiveThread {
message_id > *editing_message_id
});
let panel_background = cx.theme().colors().panel_background;
let backdrop = div()
.id("backdrop")
.stop_mouse_events_except_scroll()
.id(("backdrop", ix))
.size_full()
.absolute()
.inset_0()
.size_full()
.bg(panel_background)
.bg(panel_bg)
.opacity(0.8)
.block_mouse_except_scroll()
.on_click(cx.listener(Self::handle_cancel_click));
v_flex()
@@ -3662,7 +3709,8 @@ mod tests {
// Stream response to user message
thread.update(cx, |thread, cx| {
let request = thread.to_completion_request(model.clone(), cx);
let request =
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx);
thread.stream_completion(request, model, cx.active_window(), cx)
});
// Follow the agent

View File

@@ -87,6 +87,9 @@ actions!(
Follow,
ResetTrialUpsell,
ResetTrialEndUpsell,
ContinueThread,
ContinueWithBurnMode,
ToggleBurnMode,
]
);

View File

@@ -699,7 +699,7 @@ fn render_diff_hunk_controls(
.rounded_b_md()
.bg(cx.theme().colors().editor_background)
.gap_1()
.stop_mouse_events_except_scroll()
.block_mouse_except_scroll()
.shadow_md()
.children(vec![
Button::new(("reject", row as u64), "Reject")

View File

@@ -1,10 +1,11 @@
use agent_settings::AgentSettings;
use fs::Fs;
use gpui::{Entity, FocusHandle, SharedString};
use picker::popover_menu::PickerPopoverMenu;
use crate::Thread;
use assistant_context_editor::language_model_selector::{
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
LanguageModelSelector, ToggleModelSelector, language_model_selector,
};
use language_model::{ConfiguredModel, LanguageModelRegistry};
use settings::update_settings_file;
@@ -35,7 +36,7 @@ impl AgentModelSelector {
Self {
selector: cx.new(move |cx| {
let fs = fs.clone();
LanguageModelSelector::new(
language_model_selector(
{
let model_type = model_type.clone();
move |cx| match &model_type {
@@ -100,15 +101,14 @@ impl AgentModelSelector {
}
impl Render for AgentModelSelector {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let focus_handle = self.focus_handle.clone();
let model = self.selector.read(cx).active_model(cx);
let model = self.selector.read(cx).delegate.active_model(cx);
let model_name = model
.map(|model| model.model.name().0)
.unwrap_or_else(|| SharedString::from("No model selected"));
LanguageModelSelectorPopoverMenu::new(
PickerPopoverMenu::new(
self.selector.clone(),
Button::new("active-model", model_name)
.label_size(LabelSize::Small)
@@ -127,7 +127,9 @@ impl Render for AgentModelSelector {
)
},
gpui::Corner::BottomRight,
cx,
)
.with_handle(self.menu_handle.clone())
.render(window, cx)
}
}

View File

@@ -7,7 +7,7 @@ use std::time::Duration;
use db::kvp::{Dismissable, KEY_VALUE_STORE};
use serde::{Deserialize, Serialize};
use agent_settings::{AgentDockPosition, AgentSettings, DefaultView};
use agent_settings::{AgentDockPosition, AgentSettings, CompletionMode, DefaultView};
use anyhow::{Result, anyhow};
use assistant_context_editor::{
AgentPanelDelegate, AssistantContext, ConfigurationError, ContextEditor, ContextEvent,
@@ -41,8 +41,8 @@ use theme::ThemeSettings;
use time::UtcOffset;
use ui::utils::WithRemSize;
use ui::{
Banner, CheckboxWithLabel, ContextMenu, KeyBinding, PopoverMenu, PopoverMenuHandle,
ProgressBar, Tab, Tooltip, Vector, VectorName, prelude::*,
Banner, CheckboxWithLabel, ContextMenu, ElevationIndex, KeyBinding, PopoverMenu,
PopoverMenuHandle, ProgressBar, Tab, Tooltip, Vector, VectorName, prelude::*,
};
use util::{ResultExt as _, maybe};
use workspace::dock::{DockPosition, Panel, PanelEvent};
@@ -52,7 +52,7 @@ use workspace::{
use zed_actions::agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding};
use zed_actions::assistant::{OpenRulesLibrary, ToggleFocus};
use zed_actions::{DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize};
use zed_llm_client::UsageLimit;
use zed_llm_client::{CompletionIntent, UsageLimit};
use crate::active_thread::{self, ActiveThread, ActiveThreadEvent};
use crate::agent_configuration::{AgentConfiguration, AssistantConfigurationEvent};
@@ -64,10 +64,11 @@ use crate::thread_history::{HistoryEntryElement, ThreadHistory};
use crate::thread_store::ThreadStore;
use crate::ui::AgentOnboardingModal;
use crate::{
AddContextServer, AgentDiffPane, ContextStore, DeleteRecentlyOpenThread, ExpandMessageEditor,
Follow, InlineAssistant, NewTextThread, NewThread, OpenActiveThreadAsMarkdown, OpenAgentDiff,
OpenHistory, ResetTrialEndUpsell, ResetTrialUpsell, TextThreadStore, ThreadEvent,
ToggleContextPicker, ToggleNavigationMenu, ToggleOptionsMenu,
AddContextServer, AgentDiffPane, ContextStore, ContinueThread, ContinueWithBurnMode,
DeleteRecentlyOpenThread, ExpandMessageEditor, Follow, InlineAssistant, NewTextThread,
NewThread, OpenActiveThreadAsMarkdown, OpenAgentDiff, OpenHistory, ResetTrialEndUpsell,
ResetTrialUpsell, TextThreadStore, ThreadEvent, ToggleBurnMode, ToggleContextPicker,
ToggleNavigationMenu, ToggleOptionsMenu,
};
const AGENT_PANEL_KEY: &str = "agent_panel";
@@ -173,7 +174,7 @@ enum ActiveView {
thread: WeakEntity<Thread>,
_subscriptions: Vec<gpui::Subscription>,
},
PromptEditor {
TextThread {
context_editor: Entity<ContextEditor>,
title_editor: Entity<Editor>,
buffer_search_bar: Entity<BufferSearchBar>,
@@ -193,7 +194,7 @@ impl ActiveView {
pub fn which_font_size_used(&self) -> WhichFontSize {
match self {
ActiveView::Thread { .. } | ActiveView::History => WhichFontSize::AgentFont,
ActiveView::PromptEditor { .. } => WhichFontSize::BufferFont,
ActiveView::TextThread { .. } => WhichFontSize::BufferFont,
ActiveView::Configuration => WhichFontSize::None,
}
}
@@ -332,7 +333,7 @@ impl ActiveView {
buffer_search_bar.set_active_pane_item(Some(&context_editor), window, cx)
});
Self::PromptEditor {
Self::TextThread {
context_editor,
title_editor: editor,
buffer_search_bar,
@@ -1083,9 +1084,23 @@ impl AgentPanel {
pub fn go_back(&mut self, _: &workspace::GoBack, window: &mut Window, cx: &mut Context<Self>) {
match self.active_view {
ActiveView::Configuration | ActiveView::History => {
self.active_view =
ActiveView::thread(self.thread.read(cx).thread().clone(), window, cx);
self.message_editor.focus_handle(cx).focus(window);
if let Some(previous_view) = self.previous_view.take() {
self.active_view = previous_view;
match &self.active_view {
ActiveView::Thread { .. } => {
self.message_editor.focus_handle(cx).focus(window);
}
ActiveView::TextThread { context_editor, .. } => {
context_editor.focus_handle(cx).focus(window);
}
_ => {}
}
} else {
self.active_view =
ActiveView::thread(self.thread.read(cx).thread().clone(), window, cx);
self.message_editor.focus_handle(cx).focus(window);
}
cx.notify();
}
_ => {}
@@ -1283,9 +1298,52 @@ impl AgentPanel {
matches!(self.active_view, ActiveView::Thread { .. })
}
fn continue_conversation(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let thread_state = self.thread.read(cx).thread().read(cx);
if !thread_state.tool_use_limit_reached() {
return;
}
let model = thread_state.configured_model().map(|cm| cm.model.clone());
if let Some(model) = model {
self.thread.update(cx, |active_thread, cx| {
active_thread.thread().update(cx, |thread, cx| {
thread.insert_invisible_continue_message(cx);
thread.advance_prompt_id();
thread.send_to_model(
model,
CompletionIntent::UserPrompt,
Some(window.window_handle()),
cx,
);
});
});
} else {
log::warn!("No configured model available for continuation");
}
}
fn toggle_burn_mode(
&mut self,
_: &ToggleBurnMode,
_window: &mut Window,
cx: &mut Context<Self>,
) {
self.thread.update(cx, |active_thread, cx| {
active_thread.thread().update(cx, |thread, _cx| {
let current_mode = thread.completion_mode();
thread.set_completion_mode(match current_mode {
CompletionMode::Burn => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Burn,
});
});
});
}
pub(crate) fn active_context_editor(&self) -> Option<Entity<ContextEditor>> {
match &self.active_view {
ActiveView::PromptEditor { context_editor, .. } => Some(context_editor.clone()),
ActiveView::TextThread { context_editor, .. } => Some(context_editor.clone()),
_ => None,
}
}
@@ -1308,6 +1366,12 @@ impl AgentPanel {
let current_is_history = matches!(self.active_view, ActiveView::History);
let new_is_history = matches!(new_view, ActiveView::History);
let current_is_config = matches!(self.active_view, ActiveView::Configuration);
let new_is_config = matches!(new_view, ActiveView::Configuration);
let current_is_special = current_is_history || current_is_config;
let new_is_special = new_is_history || new_is_config;
match &self.active_view {
ActiveView::Thread { thread, .. } => {
if let Some(thread) = thread.upgrade() {
@@ -1319,7 +1383,7 @@ impl AgentPanel {
}
}
}
ActiveView::PromptEditor { context_editor, .. } => {
ActiveView::TextThread { context_editor, .. } => {
let context = context_editor.read(cx).context();
// When switching away from an unsaved text thread, delete its entry.
if context.read(cx).path().is_none() {
@@ -1339,7 +1403,7 @@ impl AgentPanel {
store.push_recently_opened_entry(RecentEntry::Thread(id, thread), cx);
}
}),
ActiveView::PromptEditor { context_editor, .. } => {
ActiveView::TextThread { context_editor, .. } => {
self.history_store.update(cx, |store, cx| {
let context = context_editor.read(cx).context().clone();
store.push_recently_opened_entry(RecentEntry::Context(context), cx)
@@ -1348,12 +1412,12 @@ impl AgentPanel {
_ => {}
}
if current_is_history && !new_is_history {
if current_is_special && !new_is_special {
self.active_view = new_view;
} else if !current_is_history && new_is_history {
} else if !current_is_special && new_is_special {
self.previous_view = Some(std::mem::replace(&mut self.active_view, new_view));
} else {
if !new_is_history {
if !new_is_special {
self.previous_view = None;
}
self.active_view = new_view;
@@ -1368,7 +1432,7 @@ impl Focusable for AgentPanel {
match &self.active_view {
ActiveView::Thread { .. } => self.message_editor.focus_handle(cx),
ActiveView::History => self.history.focus_handle(cx),
ActiveView::PromptEditor { context_editor, .. } => context_editor.focus_handle(cx),
ActiveView::TextThread { context_editor, .. } => context_editor.focus_handle(cx),
ActiveView::Configuration => {
if let Some(configuration) = self.configuration.as_ref() {
configuration.focus_handle(cx)
@@ -1520,7 +1584,7 @@ impl AgentPanel {
.into_any_element(),
}
}
ActiveView::PromptEditor {
ActiveView::TextThread {
title_editor,
context_editor,
..
@@ -1612,7 +1676,7 @@ impl AgentPanel {
let show_token_count = match &self.active_view {
ActiveView::Thread { .. } => !is_empty || !editor_empty,
ActiveView::PromptEditor { .. } => true,
ActiveView::TextThread { .. } => true,
_ => false,
};
@@ -1928,7 +1992,7 @@ impl AgentPanel {
Some(token_count)
}
ActiveView::PromptEditor { context_editor, .. } => {
ActiveView::TextThread { context_editor, .. } => {
let element = render_remaining_tokens(context_editor, cx)?;
Some(element.into_any_element())
@@ -2574,7 +2638,11 @@ impl AgentPanel {
})
}
fn render_tool_use_limit_reached(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
fn render_tool_use_limit_reached(
&self,
window: &mut Window,
cx: &mut Context<Self>,
) -> Option<AnyElement> {
let tool_use_limit_reached = self
.thread
.read(cx)
@@ -2593,17 +2661,59 @@ impl AgentPanel {
.configured_model()?
.model;
let max_mode_upsell = if model.supports_max_mode() {
" Enable max mode for unlimited tool use."
} else {
""
};
let focus_handle = self.focus_handle(cx);
let banner = Banner::new()
.severity(ui::Severity::Info)
.child(h_flex().child(Label::new(format!(
"Consecutive tool use limit reached.{max_mode_upsell}"
))));
.child(Label::new("Consecutive tool use limit reached.").size(LabelSize::Small))
.action_slot(
h_flex()
.gap_1()
.child(
Button::new("continue-conversation", "Continue")
.layer(ElevationIndex::ModalSurface)
.label_size(LabelSize::Small)
.key_binding(
KeyBinding::for_action_in(
&ContinueThread,
&focus_handle,
window,
cx,
)
.map(|kb| kb.size(rems_from_px(10.))),
)
.on_click(cx.listener(|this, _, window, cx| {
this.continue_conversation(window, cx);
})),
)
.when(model.supports_max_mode(), |this| {
this.child(
Button::new("continue-burn-mode", "Continue with Burn Mode")
.style(ButtonStyle::Filled)
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.layer(ElevationIndex::ModalSurface)
.label_size(LabelSize::Small)
.key_binding(
KeyBinding::for_action_in(
&ContinueWithBurnMode,
&focus_handle,
window,
cx,
)
.map(|kb| kb.size(rems_from_px(10.))),
)
.tooltip(Tooltip::text("Enable Burn Mode for unlimited tool use."))
.on_click(cx.listener(|this, _, window, cx| {
this.thread.update(cx, |active_thread, cx| {
active_thread.thread().update(cx, |thread, _cx| {
thread.set_completion_mode(CompletionMode::Burn);
});
});
this.continue_conversation(window, cx);
})),
)
}),
);
Some(div().px_2().pb_2().child(banner).into_any_element())
}
@@ -2800,7 +2910,7 @@ impl AgentPanel {
) -> Div {
let mut registrar = buffer_search::DivRegistrar::new(
|this, _, _cx| match &this.active_view {
ActiveView::PromptEditor {
ActiveView::TextThread {
buffer_search_bar, ..
} => Some(buffer_search_bar.clone()),
_ => None,
@@ -2918,7 +3028,7 @@ impl AgentPanel {
.detach();
});
}
ActiveView::PromptEditor { context_editor, .. } => {
ActiveView::TextThread { context_editor, .. } => {
context_editor.update(cx, |context_editor, cx| {
ContextEditor::insert_dragged_files(
context_editor,
@@ -2945,7 +3055,7 @@ impl AgentPanel {
fn key_context(&self) -> KeyContext {
let mut key_context = KeyContext::new_with_defaults();
key_context.add("AgentPanel");
if matches!(self.active_view, ActiveView::PromptEditor { .. }) {
if matches!(self.active_view, ActiveView::TextThread { .. }) {
key_context.add("prompt_editor");
}
key_context
@@ -2958,9 +3068,9 @@ impl Render for AgentPanel {
// non-obvious implications to the layout of children.
//
// If you need to change it, please confirm:
// - The message editor expands (esc) correctly
// - The message editor expands (cmd-option-esc) correctly
// - When expanded, the buttons at the bottom of the panel are displayed correctly
// - Font size works as expected and can be changed with ⌘+/⌘-
// - Font size works as expected and can be changed with cmd-+/cmd-
// - Scrolling in all views works as expected
// - Files can be dropped into the panel
let content = v_flex()
@@ -2987,6 +3097,18 @@ impl Render for AgentPanel {
.on_action(cx.listener(Self::decrease_font_size))
.on_action(cx.listener(Self::reset_font_size))
.on_action(cx.listener(Self::toggle_zoom))
.on_action(cx.listener(|this, _: &ContinueThread, window, cx| {
this.continue_conversation(window, cx);
}))
.on_action(cx.listener(|this, _: &ContinueWithBurnMode, window, cx| {
this.thread.update(cx, |active_thread, cx| {
active_thread.thread().update(cx, |thread, _cx| {
thread.set_completion_mode(CompletionMode::Burn);
});
});
this.continue_conversation(window, cx);
}))
.on_action(cx.listener(Self::toggle_burn_mode))
.child(self.render_toolbar(window, cx))
.children(self.render_upsell(window, cx))
.children(self.render_trial_end_upsell(window, cx))
@@ -2994,12 +3116,12 @@ impl Render for AgentPanel {
ActiveView::Thread { .. } => parent
.relative()
.child(self.render_active_thread_or_empty_state(window, cx))
.children(self.render_tool_use_limit_reached(cx))
.children(self.render_tool_use_limit_reached(window, cx))
.child(h_flex().child(self.message_editor.clone()))
.children(self.render_last_error(cx))
.child(self.render_drag_target(cx)),
ActiveView::History => parent.child(self.history.clone()),
ActiveView::PromptEditor {
ActiveView::TextThread {
context_editor,
buffer_search_bar,
..

View File

@@ -34,6 +34,7 @@ use std::{
};
use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
use zed_llm_client::CompletionIntent;
pub struct BufferCodegen {
alternatives: Vec<Entity<CodegenAlternative>>,
@@ -464,6 +465,7 @@ impl CodegenAlternative {
LanguageModelRequest {
thread_id: None,
prompt_id: None,
intent: Some(CompletionIntent::InlineAssist),
mode: None,
tools: Vec::new(),
tool_choice: None,

View File

@@ -1445,7 +1445,7 @@ impl InlineAssistant {
style: BlockStyle::Flex,
render: Arc::new(move |cx| {
div()
.block_mouse_down()
.block_mouse_except_scroll()
.bg(cx.theme().status().deleted_background)
.size_full()
.h(height as f32 * cx.window.line_height())

View File

@@ -100,7 +100,7 @@ impl<T: 'static> Render for PromptEditor<T> {
v_flex()
.key_context("PromptEditor")
.bg(cx.theme().colors().editor_background)
.block_mouse_down()
.block_mouse_except_scroll()
.gap_0p5()
.border_y_1()
.border_color(cx.theme().status().info_border)

View File

@@ -42,6 +42,7 @@ use theme::ThemeSettings;
use ui::{Disclosure, KeyBinding, PopoverMenuHandle, Tooltip, prelude::*};
use util::{ResultExt as _, maybe};
use workspace::{CollaboratorId, Workspace};
use zed_llm_client::CompletionIntent;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention};
use crate::context_store::ContextStore;
@@ -51,7 +52,7 @@ use crate::thread::{MessageCrease, Thread, TokenUsageRatio};
use crate::thread_store::{TextThreadStore, ThreadStore};
use crate::{
ActiveThread, AgentDiffPane, Chat, ChatWithFollow, ExpandMessageEditor, Follow, NewThread,
OpenAgentDiff, RemoveAllContext, ToggleContextPicker, ToggleProfileSelector,
OpenAgentDiff, RemoveAllContext, ToggleBurnMode, ToggleContextPicker, ToggleProfileSelector,
register_agent_preview,
};
@@ -375,7 +376,12 @@ impl MessageEditor {
thread
.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.send_to_model(model, Some(window_handle), cx);
thread.send_to_model(
model,
CompletionIntent::UserPrompt,
Some(window_handle),
cx,
);
})
.log_err();
})
@@ -471,6 +477,22 @@ impl MessageEditor {
}
}
pub fn toggle_burn_mode(
&mut self,
_: &ToggleBurnMode,
_window: &mut Window,
cx: &mut Context<Self>,
) {
self.thread.update(cx, |thread, _cx| {
let active_completion_mode = thread.completion_mode();
thread.set_completion_mode(match active_completion_mode {
CompletionMode::Burn => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Burn,
});
});
}
fn render_max_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let thread = self.thread.read(cx);
let model = thread.configured_model();
@@ -479,27 +501,24 @@ impl MessageEditor {
}
let active_completion_mode = thread.completion_mode();
let max_mode_enabled = active_completion_mode == CompletionMode::Max;
let burn_mode_enabled = active_completion_mode == CompletionMode::Burn;
let icon = if burn_mode_enabled {
IconName::ZedBurnModeOn
} else {
IconName::ZedBurnMode
};
Some(
Button::new("max-mode", "Max Mode")
.label_size(LabelSize::Small)
.color(Color::Muted)
.icon(IconName::ZedMaxMode)
IconButton::new("burn-mode", icon)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.icon_position(IconPosition::Start)
.toggle_state(max_mode_enabled)
.on_click(cx.listener(move |this, _event, _window, cx| {
this.thread.update(cx, |thread, _cx| {
thread.set_completion_mode(match active_completion_mode {
CompletionMode::Max => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Max,
});
});
.toggle_state(burn_mode_enabled)
.selected_icon_color(Color::Error)
.on_click(cx.listener(|this, _event, window, cx| {
this.toggle_burn_mode(&ToggleBurnMode, window, cx);
}))
.tooltip(move |_window, cx| {
cx.new(|_| MaxModeTooltip::new().selected(max_mode_enabled))
cx.new(|_| MaxModeTooltip::new().selected(burn_mode_enabled))
.into()
})
.into_any_element(),
@@ -594,6 +613,7 @@ impl MessageEditor {
.on_action(cx.listener(Self::remove_all_context))
.on_action(cx.listener(Self::move_up))
.on_action(cx.listener(Self::expand_message_editor))
.on_action(cx.listener(Self::toggle_burn_mode))
.capture_action(cx.listener(Self::paste))
.gap_2()
.p_2()
@@ -686,7 +706,6 @@ impl MessageEditor {
.justify_between()
.child(
h_flex()
.gap_1()
.child(self.render_follow_toggle(cx))
.children(self.render_max_mode_toggle(cx)),
)
@@ -1267,6 +1286,7 @@ impl MessageEditor {
let request = language_model::LanguageModelRequest {
thread_id: None,
prompt_id: None,
intent: None,
mode: None,
messages: vec![request_message],
tools: vec![],

View File

@@ -25,6 +25,7 @@ use terminal_view::TerminalView;
use ui::prelude::*;
use util::ResultExt;
use workspace::{Toast, Workspace, notifications::NotificationId};
use zed_llm_client::CompletionIntent;
pub fn init(
fs: Arc<dyn Fs>,
@@ -291,6 +292,7 @@ impl TerminalInlineAssistant {
thread_id: None,
prompt_id: None,
mode: None,
intent: Some(CompletionIntent::TerminalInlineAssist),
messages: vec![request_message],
tools: Vec::new(),
tool_choice: None,

View File

@@ -24,7 +24,7 @@ use language_model::{
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
StopReason, TokenUsage, WrappedTextContent,
StopReason, TokenUsage,
};
use postage::stream::Stream as _;
use project::Project;
@@ -38,7 +38,7 @@ use thiserror::Error;
use ui::Window;
use util::{ResultExt as _, post_inc};
use uuid::Uuid;
use zed_llm_client::CompletionRequestStatus;
use zed_llm_client::{CompletionIntent, CompletionRequestStatus};
use crate::ThreadStore;
use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
@@ -115,6 +115,7 @@ pub struct Message {
pub segments: Vec<MessageSegment>,
pub loaded_context: LoadedContext,
pub creases: Vec<MessageCrease>,
pub is_hidden: bool,
}
impl Message {
@@ -540,6 +541,7 @@ impl Thread {
context: None,
})
.collect(),
is_hidden: message.is_hidden,
})
.collect(),
next_message_id,
@@ -560,7 +562,7 @@ impl Thread {
cumulative_token_usage: serialized.cumulative_token_usage,
exceeded_window_error: None,
last_usage: None,
tool_use_limit_reached: false,
tool_use_limit_reached: serialized.tool_use_limit_reached,
feedback: None,
message_feedback: HashMap::default(),
last_auto_capture_at: None,
@@ -849,7 +851,7 @@ impl Thread {
.get(ix + 1)
.and_then(|message| {
self.message(message.id)
.map(|next_message| next_message.role == Role::User)
.map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
})
.unwrap_or(false)
}
@@ -889,10 +891,7 @@ impl Thread {
pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
match &self.tool_use.tool_result(id)?.content {
LanguageModelToolResultContent::Text(text)
| LanguageModelToolResultContent::WrappedText(WrappedTextContent { text, .. }) => {
Some(text)
}
LanguageModelToolResultContent::Text(text) => Some(text),
LanguageModelToolResultContent::Image(_) => {
// TODO: We should display image
None
@@ -951,6 +950,7 @@ impl Thread {
vec![MessageSegment::Text(text.into())],
loaded_context.loaded_context,
creases,
false,
cx,
);
@@ -966,6 +966,20 @@ impl Thread {
message_id
}
pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
let id = self.insert_message(
Role::User,
vec![MessageSegment::Text("Continue where you left off".into())],
LoadedContext::default(),
vec![],
true,
cx,
);
self.pending_checkpoint = None;
id
}
pub fn insert_assistant_message(
&mut self,
segments: Vec<MessageSegment>,
@@ -976,6 +990,7 @@ impl Thread {
segments,
LoadedContext::default(),
Vec::new(),
false,
cx,
)
}
@@ -986,6 +1001,7 @@ impl Thread {
segments: Vec<MessageSegment>,
loaded_context: LoadedContext,
creases: Vec<MessageCrease>,
is_hidden: bool,
cx: &mut Context<Self>,
) -> MessageId {
let id = self.next_message_id.post_inc();
@@ -995,6 +1011,7 @@ impl Thread {
segments,
loaded_context,
creases,
is_hidden,
});
self.touch_updated_at();
cx.emit(ThreadEvent::MessageAdded(id));
@@ -1135,6 +1152,7 @@ impl Thread {
label: crease.metadata.label.clone(),
})
.collect(),
is_hidden: message.is_hidden,
})
.collect(),
initial_project_snapshot,
@@ -1150,6 +1168,7 @@ impl Thread {
model: model.model.id().0.to_string(),
}),
completion_mode: Some(this.completion_mode),
tool_use_limit_reached: this.tool_use_limit_reached,
})
})
}
@@ -1165,6 +1184,7 @@ impl Thread {
pub fn send_to_model(
&mut self,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
@@ -1174,7 +1194,7 @@ impl Thread {
self.remaining_turns -= 1;
let request = self.to_completion_request(model.clone(), cx);
let request = self.to_completion_request(model.clone(), intent, cx);
self.stream_completion(request, model, window, cx);
}
@@ -1194,11 +1214,13 @@ impl Thread {
pub fn to_completion_request(
&self,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
cx: &mut Context<Self>,
) -> LanguageModelRequest {
let mut request = LanguageModelRequest {
thread_id: Some(self.id.to_string()),
prompt_id: Some(self.last_prompt_id.to_string()),
intent: Some(intent),
mode: None,
messages: vec![],
tools: Vec::new(),
@@ -1352,12 +1374,14 @@ impl Thread {
fn to_summarize_request(
&self,
model: &Arc<dyn LanguageModel>,
intent: CompletionIntent,
added_user_message: String,
cx: &App,
) -> LanguageModelRequest {
let mut request = LanguageModelRequest {
thread_id: None,
prompt_id: None,
intent: Some(intent),
mode: None,
messages: vec![],
tools: Vec::new(),
@@ -1404,6 +1428,7 @@ impl Thread {
messages: &mut Vec<LanguageModelRequestMessage>,
cx: &App,
) {
// NOTE: Changes to this prompt require a symmetric update in monorepo/request_kind.rs
const STALE_FILES_HEADER: &str = "These files changed since last read:";
let mut stale_message = String::new();
@@ -1781,6 +1806,7 @@ impl Thread {
thread.cancel_last_completion(window, cx);
}
}
cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
if let Some((request_callback, (request, response_events))) = thread
@@ -1829,12 +1855,18 @@ impl Thread {
return;
}
// NOTE: Changes to this prompt require a symmetric update in monorepo/request_kind.rs
let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
If the conversation is about a specific subject, include it in the title. \
Be descriptive. DO NOT speak in the first person.";
let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
let request = self.to_summarize_request(
&model.model,
CompletionIntent::ThreadSummarization,
added_user_message.into(),
cx,
);
self.summary = ThreadSummary::Generating;
@@ -1928,6 +1960,7 @@ impl Thread {
return;
}
// NOTE: Changes to this prompt require a symmetric update in monorepo/request_kind.rs
let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1. A brief overview of what was discussed\n\
2. Key facts or information discovered\n\
@@ -1935,7 +1968,12 @@ impl Thread {
4. Any action items or next steps if any\n\
Format it in Markdown with headings and bullet points.";
let request = self.to_summarize_request(&model, added_user_message.into(), cx);
let request = self.to_summarize_request(
&model,
CompletionIntent::ThreadContextSummarization,
added_user_message.into(),
cx,
);
*self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
message_id: last_message_id,
@@ -2027,7 +2065,8 @@ impl Thread {
model: Arc<dyn LanguageModel>,
) -> Vec<PendingToolUse> {
self.auto_capture_telemetry(cx);
let request = Arc::new(self.to_completion_request(model.clone(), cx));
let request =
Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
let pending_tool_uses = self
.tool_use
.pending_tool_uses()
@@ -2223,7 +2262,7 @@ impl Thread {
if self.all_tools_finished() {
if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
if !canceled {
self.send_to_model(model.clone(), window, cx);
self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
}
self.auto_capture_telemetry(cx);
}
@@ -2570,11 +2609,7 @@ impl Thread {
writeln!(markdown, "**\n")?;
match &tool_result.content {
LanguageModelToolResultContent::Text(text)
| LanguageModelToolResultContent::WrappedText(WrappedTextContent {
text,
..
}) => {
LanguageModelToolResultContent::Text(text) => {
writeln!(markdown, "{text}")?;
}
LanguageModelToolResultContent::Image(image) => {
@@ -2918,7 +2953,7 @@ fn main() {{
// Check message in request
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
assert_eq!(request.messages.len(), 2);
@@ -3013,7 +3048,7 @@ fn main() {{
// Check entire request to make sure all contexts are properly included
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
// The request should contain all 3 messages
@@ -3120,7 +3155,7 @@ fn main() {{
// Check message in request
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
assert_eq!(request.messages.len(), 2);
@@ -3146,7 +3181,7 @@ fn main() {{
// Check that both messages appear in the request
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
assert_eq!(request.messages.len(), 3);
@@ -3191,7 +3226,7 @@ fn main() {{
// Create a request and check that it doesn't have a stale buffer warning yet
let initial_request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
// Make sure we don't have a stale file warning yet
@@ -3227,7 +3262,7 @@ fn main() {{
// Create a new request and check for the stale buffer warning
let new_request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
// We should have a stale file warning as the last message
@@ -3277,7 +3312,7 @@ fn main() {{
});
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
assert_eq!(request.temperature, Some(0.66));
@@ -3297,7 +3332,7 @@ fn main() {{
});
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
assert_eq!(request.temperature, Some(0.66));
@@ -3317,7 +3352,7 @@ fn main() {{
});
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
assert_eq!(request.temperature, Some(0.66));
@@ -3337,7 +3372,7 @@ fn main() {{
});
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
assert_eq!(request.temperature, None);
}
@@ -3369,7 +3404,12 @@ fn main() {{
// Send a message
thread.update(cx, |thread, cx| {
thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
thread.send_to_model(model.clone(), None, cx);
thread.send_to_model(
model.clone(),
CompletionIntent::ThreadSummarization,
None,
cx,
);
});
let fake_model = model.as_fake();
@@ -3391,8 +3431,8 @@ fn main() {{
});
cx.run_until_parked();
fake_model.stream_last_completion_response("Brief".into());
fake_model.stream_last_completion_response(" Introduction".into());
fake_model.stream_last_completion_response("Brief");
fake_model.stream_last_completion_response(" Introduction");
fake_model.end_last_completion_stream();
cx.run_until_parked();
@@ -3464,7 +3504,7 @@ fn main() {{
vec![],
cx,
);
thread.send_to_model(model.clone(), None, cx);
thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
});
let fake_model = model.as_fake();
@@ -3485,7 +3525,7 @@ fn main() {{
});
cx.run_until_parked();
fake_model.stream_last_completion_response("A successful summary".into());
fake_model.stream_last_completion_response("A successful summary");
fake_model.end_last_completion_stream();
cx.run_until_parked();
@@ -3502,7 +3542,12 @@ fn main() {{
) {
thread.update(cx, |thread, cx| {
thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
thread.send_to_model(model.clone(), None, cx);
thread.send_to_model(
model.clone(),
CompletionIntent::ThreadSummarization,
None,
cx,
);
});
let fake_model = model.as_fake();
@@ -3527,7 +3572,7 @@ fn main() {{
fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
cx.run_until_parked();
fake_model.stream_last_completion_response("Assistant response".into());
fake_model.stream_last_completion_response("Assistant response");
fake_model.end_last_completion_stream();
cx.run_until_parked();
}

View File

@@ -676,6 +676,8 @@ pub struct SerializedThread {
pub model: Option<SerializedLanguageModel>,
#[serde(default)]
pub completion_mode: Option<CompletionMode>,
#[serde(default)]
pub tool_use_limit_reached: bool,
}
#[derive(Serialize, Deserialize, Debug)]
@@ -757,6 +759,8 @@ pub struct SerializedMessage {
pub context: String,
#[serde(default)]
pub creases: Vec<SerializedCrease>,
#[serde(default)]
pub is_hidden: bool,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -815,6 +819,7 @@ impl LegacySerializedThread {
exceeded_window_error: None,
model: None,
completion_mode: None,
tool_use_limit_reached: false,
}
}
}
@@ -840,6 +845,7 @@ impl LegacySerializedMessage {
tool_results: self.tool_results,
context: String::new(),
creases: Vec::new(),
is_hidden: false,
}
}
}

View File

@@ -1,5 +1,6 @@
use gpui::{Context, IntoElement, Render, Window};
use ui::{prelude::*, tooltip_container};
use crate::ToggleBurnMode;
use gpui::{Context, FontWeight, IntoElement, Render, Window};
use ui::{KeyBinding, prelude::*, tooltip_container};
pub struct MaxModeTooltip {
selected: bool,
@@ -18,38 +19,48 @@ impl MaxModeTooltip {
impl Render for MaxModeTooltip {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let (icon, color) = if self.selected {
(IconName::ZedBurnModeOn, Color::Error)
} else {
(IconName::ZedBurnMode, Color::Default)
};
let turned_on = h_flex()
.h_4()
.px_1()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().text_accent.opacity(0.1))
.rounded_sm()
.child(
Label::new("ON")
.size(LabelSize::XSmall)
.weight(FontWeight::SEMIBOLD)
.color(Color::Accent),
);
let title = h_flex()
.gap_1p5()
.child(Icon::new(icon).size(IconSize::Small).color(color))
.child(Label::new("Burn Mode"))
.when(self.selected, |title| title.child(turned_on));
let keybinding = KeyBinding::for_action(&ToggleBurnMode, window, cx)
.map(|kb| kb.size(rems_from_px(12.)));
tooltip_container(window, cx, |this, _, _| {
this.gap_1()
.map(|header| if self.selected {
header.child(
h_flex()
.justify_between()
.child(
h_flex()
.gap_1p5()
.child(Icon::new(IconName::ZedMaxMode).size(IconSize::Small).color(Color::Accent))
.child(Label::new("Zed's Max Mode"))
)
.child(
h_flex()
.gap_0p5()
.child(Icon::new(IconName::Check).size(IconSize::XSmall).color(Color::Accent))
.child(Label::new("Turned On").size(LabelSize::XSmall).color(Color::Accent))
)
)
} else {
header.child(
h_flex()
.gap_1p5()
.child(Icon::new(IconName::ZedMaxMode).size(IconSize::Small))
.child(Label::new("Zed's Max Mode"))
)
})
this
.child(
h_flex()
.justify_between()
.child(title)
.children(keybinding)
)
.child(
div()
.max_w_72()
.max_w_64()
.child(
Label::new("This mode enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning, offering an unfettered agentic experience.")
Label::new("Enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning.")
.size(LabelSize::Small)
.color(Color::Muted)
)

View File

@@ -689,14 +689,14 @@ pub struct AgentSettingsContentV2 {
pub enum CompletionMode {
#[default]
Normal,
Max,
Burn,
}
impl From<CompletionMode> for zed_llm_client::CompletionMode {
fn from(value: CompletionMode) -> Self {
match value {
CompletionMode::Normal => zed_llm_client::CompletionMode::Normal,
CompletionMode::Max => zed_llm_client::CompletionMode::Max,
CompletionMode::Burn => zed_llm_client::CompletionMode::Max,
}
}
}

View File

@@ -57,6 +57,7 @@ uuid.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
language_model = { workspace = true, features = ["test-support"] }

View File

@@ -3,6 +3,7 @@ mod context_editor;
mod context_history;
mod context_store;
pub mod language_model_selector;
mod max_mode_tooltip;
mod slash_command;
mod slash_command_picker;

View File

@@ -29,6 +29,7 @@ use paths::contexts_dir;
use project::Project;
use prompt_store::PromptBuilder;
use serde::{Deserialize, Serialize};
use settings::Settings;
use smallvec::SmallVec;
use std::{
cmp::{Ordering, max},
@@ -44,6 +45,7 @@ use text::{BufferSnapshot, ToPoint};
use ui::IconName;
use util::{ResultExt, TryFutureExt, post_inc};
use uuid::Uuid;
use zed_llm_client::CompletionIntent;
#[derive(Clone, Debug, Eq, PartialEq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ContextId(String);
@@ -682,6 +684,7 @@ pub struct AssistantContext {
language_registry: Arc<LanguageRegistry>,
project: Option<Entity<Project>>,
prompt_builder: Arc<PromptBuilder>,
completion_mode: agent_settings::CompletionMode,
}
trait ContextAnnotation {
@@ -718,6 +721,14 @@ impl AssistantContext {
)
}
pub fn completion_mode(&self) -> agent_settings::CompletionMode {
self.completion_mode
}
pub fn set_completion_mode(&mut self, completion_mode: agent_settings::CompletionMode) {
self.completion_mode = completion_mode;
}
pub fn new(
id: ContextId,
replica_id: ReplicaId,
@@ -764,6 +775,7 @@ impl AssistantContext {
pending_cache_warming_task: Task::ready(None),
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
path: None,
buffer,
telemetry,
@@ -2261,6 +2273,7 @@ impl AssistantContext {
let mut completion_request = LanguageModelRequest {
thread_id: None,
prompt_id: None,
intent: Some(CompletionIntent::UserPrompt),
mode: None,
messages: Vec::new(),
tools: Vec::new(),
@@ -2321,7 +2334,15 @@ impl AssistantContext {
completion_request.messages.push(request_message);
}
}
let supports_max_mode = if let Some(model) = model {
model.supports_max_mode()
} else {
false
};
if supports_max_mode {
completion_request.mode = Some(self.completion_mode.into());
}
completion_request
}

View File

@@ -1210,8 +1210,8 @@ async fn test_summarization(cx: &mut TestAppContext) {
});
cx.run_until_parked();
fake_model.stream_last_completion_response("Brief".into());
fake_model.stream_last_completion_response(" Introduction".into());
fake_model.stream_last_completion_response("Brief");
fake_model.stream_last_completion_response(" Introduction");
fake_model.end_last_completion_stream();
cx.run_until_parked();
@@ -1274,7 +1274,7 @@ async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
});
cx.run_until_parked();
fake_model.stream_last_completion_response("A successful summary".into());
fake_model.stream_last_completion_response("A successful summary");
fake_model.end_last_completion_stream();
cx.run_until_parked();
@@ -1356,7 +1356,7 @@ fn setup_context_editor_with_fake_model(
fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
cx.run_until_parked();
fake_model.stream_last_completion_response("Assistant response".into());
fake_model.stream_last_completion_response("Assistant response");
fake_model.end_last_completion_stream();
cx.run_until_parked();
}

View File

@@ -1,7 +1,10 @@
use crate::language_model_selector::{
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
use crate::{
language_model_selector::{
LanguageModelSelector, ToggleModelSelector, language_model_selector,
},
max_mode_tooltip::MaxModeTooltip,
};
use agent_settings::AgentSettings;
use agent_settings::{AgentSettings, CompletionMode};
use anyhow::Result;
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection, SlashCommandWorkingSet};
use assistant_slash_commands::{
@@ -40,7 +43,7 @@ use language_model::{
Role,
};
use multi_buffer::MultiBufferRow;
use picker::Picker;
use picker::{Picker, popover_menu::PickerPopoverMenu};
use project::{Project, Worktree};
use project::{ProjectPath, lsp_store::LocalLspAdapterDelegate};
use rope::Point;
@@ -280,7 +283,7 @@ impl ContextEditor {
slash_menu_handle: Default::default(),
dragged_file_worktrees: Vec::new(),
language_model_selector: cx.new(|cx| {
LanguageModelSelector::new(
language_model_selector(
|cx| LanguageModelRegistry::read_global(cx).default_model(),
move |model, cx| {
update_settings_file::<AgentSettings>(
@@ -2008,17 +2011,17 @@ impl ContextEditor {
None => (ButtonStyle::Filled, None),
};
ButtonLike::new("send_button")
Button::new("send_button", "Send")
.label_size(LabelSize::Small)
.disabled(self.sending_disabled(cx))
.style(style)
.when_some(tooltip, |button, tooltip| {
button.tooltip(move |_, _| tooltip.clone())
})
.layer(ElevationIndex::ModalSurface)
.child(Label::new("Send"))
.children(
.key_binding(
KeyBinding::for_action_in(&Assist, &focus_handle, window, cx)
.map(|binding| binding.into_any_element()),
.map(|kb| kb.size(rems_from_px(12.))),
)
.on_click(move |_event, window, cx| {
focus_handle.dispatch_action(&Assist, window, cx);
@@ -2058,7 +2061,50 @@ impl ContextEditor {
)
}
fn render_language_model_selector(&self, cx: &mut Context<Self>) -> impl IntoElement {
fn render_max_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let context = self.context().read(cx);
let active_model = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.model)?;
if !active_model.supports_max_mode() {
return None;
}
let active_completion_mode = context.completion_mode();
let burn_mode_enabled = active_completion_mode == CompletionMode::Burn;
let icon = if burn_mode_enabled {
IconName::ZedBurnModeOn
} else {
IconName::ZedBurnMode
};
Some(
IconButton::new("burn-mode", icon)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.toggle_state(burn_mode_enabled)
.selected_icon_color(Color::Error)
.on_click(cx.listener(move |this, _event, _window, cx| {
this.context().update(cx, |context, _cx| {
context.set_completion_mode(match active_completion_mode {
CompletionMode::Burn => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Burn,
});
});
}))
.tooltip(move |_window, cx| {
cx.new(|_| MaxModeTooltip::new().selected(burn_mode_enabled))
.into()
})
.into_any_element(),
)
}
fn render_language_model_selector(
&self,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
let active_model = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.model);
@@ -2068,7 +2114,7 @@ impl ContextEditor {
None => SharedString::from("No model selected"),
};
LanguageModelSelectorPopoverMenu::new(
PickerPopoverMenu::new(
self.language_model_selector.clone(),
ButtonLike::new("active-model")
.style(ButtonStyle::Subtle)
@@ -2096,8 +2142,10 @@ impl ContextEditor {
)
},
gpui::Corner::BottomLeft,
cx,
)
.with_handle(self.language_model_selector_menu_handle.clone())
.render(window, cx)
}
fn render_last_error(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
@@ -2503,6 +2551,7 @@ impl Render for ContextEditor {
let provider = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.provider);
let accept_terms = if self.show_accept_terms {
provider.as_ref().and_then(|provider| {
provider.render_accept_terms(LanguageModelProviderTosView::PromptEditorPopup, cx)
@@ -2512,6 +2561,8 @@ impl Render for ContextEditor {
};
let language_model_selector = self.language_model_selector_menu_handle.clone();
let max_mode_toggle = self.render_max_mode_toggle(cx);
v_flex()
.key_context("ContextEditor")
.capture_action(cx.listener(ContextEditor::cancel))
@@ -2551,31 +2602,28 @@ impl Render for ContextEditor {
})
.children(self.render_last_error(cx))
.child(
h_flex().w_full().relative().child(
h_flex()
.p_2()
.w_full()
.border_t_1()
.border_color(cx.theme().colors().border_variant)
.bg(cx.theme().colors().editor_background)
.child(
h_flex()
.gap_1()
.child(self.render_inject_context_menu(cx))
.child(ui::Divider::vertical())
.child(
div()
.pl_0p5()
.child(self.render_language_model_selector(cx)),
),
)
.child(
h_flex()
.w_full()
.justify_end()
.child(self.render_send_button(window, cx)),
),
),
h_flex()
.relative()
.py_2()
.pl_1p5()
.pr_2()
.w_full()
.justify_between()
.border_t_1()
.border_color(cx.theme().colors().border_variant)
.bg(cx.theme().colors().editor_background)
.child(
h_flex()
.gap_0p5()
.child(self.render_inject_context_menu(cx))
.when_some(max_mode_toggle, |this, element| this.child(element)),
)
.child(
h_flex()
.gap_1()
.child(self.render_language_model_selector(window, cx))
.child(self.render_send_button(window, cx)),
),
)
}
}

View File

@@ -4,8 +4,7 @@ use collections::{HashSet, IndexMap};
use feature_flags::ZedProFeatureFlag;
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
use gpui::{
Action, AnyElement, AnyView, App, BackgroundExecutor, Corner, DismissEvent, Entity,
EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task,
action_with_deprecated_aliases,
};
use language_model::{
@@ -15,7 +14,7 @@ use language_model::{
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
use proto::Plan;
use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
use ui::{ListItem, ListItemSpacing, prelude::*};
action_with_deprecated_aliases!(
agent,
@@ -31,77 +30,146 @@ const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
pub struct LanguageModelSelector {
picker: Entity<Picker<LanguageModelPickerDelegate>>,
pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
pub fn language_model_selector(
get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
window: &mut Window,
cx: &mut Context<LanguageModelSelector>,
) -> LanguageModelSelector {
let delegate = LanguageModelPickerDelegate::new(get_active_model, on_model_changed, window, cx);
Picker::list(delegate, window, cx)
.show_scrollbar(true)
.width(rems(20.))
.max_height(Some(rems(20.).into()))
}
fn all_models(cx: &App) -> GroupedModels {
let mut recommended = Vec::new();
let mut recommended_set = HashSet::default();
for provider in LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.iter()
{
let models = provider.recommended_models(cx);
recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
recommended.extend(
provider
.recommended_models(cx)
.into_iter()
.map(move |model| ModelInfo {
model: model.clone(),
icon: provider.icon(),
}),
);
}
let other_models = LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.iter()
.map(|provider| {
(
provider.id(),
provider
.provided_models(cx)
.into_iter()
.filter_map(|model| {
let not_included =
!recommended_set.contains(&(model.provider_id(), model.id()));
not_included.then(|| ModelInfo {
model: model.clone(),
icon: provider.icon(),
})
})
.collect::<Vec<_>>(),
)
})
.collect::<IndexMap<_, _>>();
GroupedModels {
recommended,
other: other_models,
}
}
#[derive(Clone)]
struct ModelInfo {
model: Arc<dyn LanguageModel>,
icon: IconName,
}
pub struct LanguageModelPickerDelegate {
on_model_changed: OnModelChanged,
get_active_model: GetActiveModel,
all_models: Arc<GroupedModels>,
filtered_entries: Vec<LanguageModelPickerEntry>,
selected_index: usize,
_authenticate_all_providers_task: Task<()>,
_subscriptions: Vec<Subscription>,
}
impl LanguageModelSelector {
pub fn new(
impl LanguageModelPickerDelegate {
fn new(
get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
window: &mut Window,
cx: &mut Context<Self>,
cx: &mut Context<Picker<Self>>,
) -> Self {
let on_model_changed = Arc::new(on_model_changed);
let models = all_models(cx);
let entries = models.entries();
let all_models = Self::all_models(cx);
let entries = all_models.entries();
let delegate = LanguageModelPickerDelegate {
language_model_selector: cx.entity().downgrade(),
Self {
on_model_changed: on_model_changed.clone(),
all_models: Arc::new(all_models),
all_models: Arc::new(models),
selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
filtered_entries: entries,
get_active_model: Arc::new(get_active_model),
};
let picker = cx.new(|cx| {
Picker::list(delegate, window, cx)
.show_scrollbar(true)
.width(rems(20.))
.max_height(Some(rems(20.).into()))
});
let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
LanguageModelSelector {
picker,
_authenticate_all_providers_task: Self::authenticate_all_providers(cx),
_subscriptions: vec![
cx.subscribe_in(
&LanguageModelRegistry::global(cx),
window,
Self::handle_language_model_registry_event,
),
subscription,
],
_subscriptions: vec![cx.subscribe_in(
&LanguageModelRegistry::global(cx),
window,
|picker, _, event, window, cx| {
match event {
language_model::Event::ProviderStateChanged
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
let query = picker.query(cx);
picker.delegate.all_models = Arc::new(all_models(cx));
// Update matches will automatically drop the previous task
// if we get a provider event again
picker.update_matches(query, window, cx)
}
_ => {}
}
},
)],
}
}
fn handle_language_model_registry_event(
&mut self,
_registry: &Entity<LanguageModelRegistry>,
event: &language_model::Event,
window: &mut Window,
cx: &mut Context<Self>,
) {
match event {
language_model::Event::ProviderStateChanged
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
self.picker.update(cx, |this, cx| {
let query = this.query(cx);
this.delegate.all_models = Arc::new(Self::all_models(cx));
// Update matches will automatically drop the previous task
// if we get a provider event again
this.update_matches(query, window, cx)
});
}
_ => {}
}
fn get_active_model_index(
entries: &[LanguageModelPickerEntry],
active_model: Option<ConfiguredModel>,
) -> usize {
entries
.iter()
.position(|entry| {
if let LanguageModelPickerEntry::Model(model) = entry {
active_model
.as_ref()
.map(|active_model| {
active_model.model.id() == model.model.id()
&& active_model.provider.id() == model.model.provider_id()
})
.unwrap_or_default()
} else {
false
}
})
.unwrap_or(0)
}
/// Authenticates all providers in the [`LanguageModelRegistry`].
@@ -154,169 +222,9 @@ impl LanguageModelSelector {
})
}
fn all_models(cx: &App) -> GroupedModels {
let mut recommended = Vec::new();
let mut recommended_set = HashSet::default();
for provider in LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.iter()
{
let models = provider.recommended_models(cx);
recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
recommended.extend(
provider
.recommended_models(cx)
.into_iter()
.map(move |model| ModelInfo {
model: model.clone(),
icon: provider.icon(),
}),
);
}
let other_models = LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.iter()
.map(|provider| {
(
provider.id(),
provider
.provided_models(cx)
.into_iter()
.filter_map(|model| {
let not_included =
!recommended_set.contains(&(model.provider_id(), model.id()));
not_included.then(|| ModelInfo {
model: model.clone(),
icon: provider.icon(),
})
})
.collect::<Vec<_>>(),
)
})
.collect::<IndexMap<_, _>>();
GroupedModels {
recommended,
other: other_models,
}
}
pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
(self.picker.read(cx).delegate.get_active_model)(cx)
(self.get_active_model)(cx)
}
fn get_active_model_index(
entries: &[LanguageModelPickerEntry],
active_model: Option<ConfiguredModel>,
) -> usize {
entries
.iter()
.position(|entry| {
if let LanguageModelPickerEntry::Model(model) = entry {
active_model
.as_ref()
.map(|active_model| {
active_model.model.id() == model.model.id()
&& active_model.provider.id() == model.model.provider_id()
})
.unwrap_or_default()
} else {
false
}
})
.unwrap_or(0)
}
}
impl EventEmitter<DismissEvent> for LanguageModelSelector {}
impl Focusable for LanguageModelSelector {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.picker.focus_handle(cx)
}
}
impl Render for LanguageModelSelector {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
self.picker.clone()
}
}
#[derive(IntoElement)]
pub struct LanguageModelSelectorPopoverMenu<T, TT>
where
T: PopoverTrigger + ButtonCommon,
TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
{
language_model_selector: Entity<LanguageModelSelector>,
trigger: T,
tooltip: TT,
handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
anchor: Corner,
}
impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
where
T: PopoverTrigger + ButtonCommon,
TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
{
pub fn new(
language_model_selector: Entity<LanguageModelSelector>,
trigger: T,
tooltip: TT,
anchor: Corner,
) -> Self {
Self {
language_model_selector,
trigger,
tooltip,
handle: None,
anchor,
}
}
pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
self.handle = Some(handle);
self
}
}
impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
where
T: PopoverTrigger + ButtonCommon,
TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
{
fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
let language_model_selector = self.language_model_selector.clone();
PopoverMenu::new("model-switcher")
.menu(move |_window, _cx| Some(language_model_selector.clone()))
.trigger_with_tooltip(self.trigger, self.tooltip)
.anchor(self.anchor)
.when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
.offset(gpui::Point {
x: px(0.0),
y: px(-2.0),
})
}
}
#[derive(Clone)]
struct ModelInfo {
model: Arc<dyn LanguageModel>,
icon: IconName,
}
pub struct LanguageModelPickerDelegate {
language_model_selector: WeakEntity<LanguageModelSelector>,
on_model_changed: OnModelChanged,
get_active_model: GetActiveModel,
all_models: Arc<GroupedModels>,
filtered_entries: Vec<LanguageModelPickerEntry>,
selected_index: usize,
}
struct GroupedModels {
@@ -577,9 +485,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
}
fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
self.language_model_selector
.update(cx, |_this, cx| cx.emit(DismissEvent))
.ok();
cx.emit(DismissEvent);
}
fn render_match(

View File

@@ -0,0 +1,61 @@
use gpui::{Context, FontWeight, IntoElement, Render, Window};
use ui::{prelude::*, tooltip_container};
pub struct MaxModeTooltip {
selected: bool,
}
impl MaxModeTooltip {
pub fn new() -> Self {
Self { selected: false }
}
pub fn selected(mut self, selected: bool) -> Self {
self.selected = selected;
self
}
}
impl Render for MaxModeTooltip {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let (icon, color) = if self.selected {
(IconName::ZedBurnModeOn, Color::Error)
} else {
(IconName::ZedBurnMode, Color::Default)
};
let turned_on = h_flex()
.h_4()
.px_1()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().text_accent.opacity(0.1))
.rounded_sm()
.child(
Label::new("ON")
.size(LabelSize::XSmall)
.weight(FontWeight::SEMIBOLD)
.color(Color::Accent),
);
let title = h_flex()
.gap_1p5()
.child(Icon::new(icon).size(IconSize::Small).color(color))
.child(Label::new("Burn Mode"))
.when(self.selected, |title| title.child(turned_on));
tooltip_container(window, cx, |this, _, _| {
this
.child(title)
.child(
div()
.max_w_64()
.child(
Label::new("Enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning.")
.size(LabelSize::Small)
.color(Color::Muted)
)
)
})
}
}

View File

@@ -415,14 +415,38 @@ impl ActionLog {
self.project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
} else {
buffer
.read(cx)
.entry_id(cx)
.and_then(|entry_id| {
self.project
.update(cx, |project, cx| project.delete_entry(entry_id, false, cx))
})
.unwrap_or(Task::ready(Ok(())))
// For a file created by AI with no pre-existing content,
// only delete the file if we're certain it contains only AI content
// with no edits from the user.
let initial_version = tracked_buffer.version.clone();
let current_version = buffer.read(cx).version();
let current_content = buffer.read(cx).text();
let tracked_content = tracked_buffer.snapshot.text();
let is_ai_only_content =
initial_version == current_version && current_content == tracked_content;
if is_ai_only_content {
buffer
.read(cx)
.entry_id(cx)
.and_then(|entry_id| {
self.project.update(cx, |project, cx| {
project.delete_entry(entry_id, false, cx)
})
})
.unwrap_or(Task::ready(Ok(())))
} else {
// Not sure how to disentangle edits made by the user
// from edits made by the AI at this point.
// For now, preserve both to avoid data loss.
//
// TODO: Better solution (disable "Reject" after user makes some
// edit or find a way to differentiate between AI and user edits)
Task::ready(Ok(()))
}
};
self.tracked_buffers.remove(&buffer);
@@ -1576,7 +1600,6 @@ mod tests {
project.find_project_path("dir/new_file", cx)
})
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
@@ -1619,6 +1642,72 @@ mod tests {
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
}
#[gpui::test]
async fn test_reject_created_file_with_user_edits(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let file_path = project
.read_with(cx, |project, cx| {
project.find_project_path("dir/new_file", cx)
})
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
.unwrap();
// AI creates file with initial content
cx.update(|cx| {
action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| buffer.set_text("ai content", cx));
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
});
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
.await
.unwrap();
cx.run_until_parked();
// User makes additional edits
cx.update(|cx| {
buffer.update(cx, |buffer, cx| {
buffer.edit([(10..10, "\nuser added this line")], None, cx);
});
});
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
.await
.unwrap();
assert!(fs.is_file(path!("/dir/new_file").as_ref()).await);
// Reject all
action_log
.update(cx, |log, cx| {
log.reject_edits_in_ranges(
buffer.clone(),
vec![Point::new(0, 0)..Point::new(100, 0)],
cx,
)
})
.await
.unwrap();
cx.run_until_parked();
// File should still contain all the content
assert!(fs.is_file(path!("/dir/new_file").as_ref()).await);
let content = buffer.read_with(cx, |buffer, _| buffer.text());
assert_eq!(content, "ai content\nuser added this line");
}
#[gpui::test(iterations = 100)]
async fn test_random_diffs(mut rng: StdRng, cx: &mut TestAppContext) {
init_test(cx);

View File

@@ -16,9 +16,9 @@ eval = []
[dependencies]
agent_settings.workspace = true
aho-corasick.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
async-watch.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
collections.workspace = true
@@ -36,6 +36,7 @@ itertools.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
lsp.workspace = true
markdown.workspace = true
open.workspace = true
paths.workspace = true
@@ -64,6 +65,7 @@ workspace.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
lsp = { workspace = true, features = ["test-support"] }
client = { workspace = true, features = ["test-support"] }
clock = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }

View File

@@ -1,9 +0,0 @@
Invoke multiple other tool calls either sequentially or concurrently.
This tool is useful when you need to perform several operations at once, improving efficiency by reducing the number of back-and-forth interactions needed to complete complex tasks.
If the tool calls are set to be run sequentially, then each tool call within the batch is executed in the order provided. If it's set to run concurrently, then they may run in a different order. Regardless, all tool calls will have the same permissions and context as if they were called individually.
This tool should never be used to run a total of one tool. Instead, just run that one tool directly. You can run batches within batches if desired, which is a way you can mix concurrent and sequential tool call execution.
When it's possible to run tools in a batch, you should run as many as possible in the batch, up to a maximum of 32. For example, don't run multiple consecutive batches of 10 when you could instead run one batch of 30.

View File

@@ -1,19 +0,0 @@
A tool for applying code actions to specific sections of your code. It uses language servers to provide refactoring capabilities similar to what you'd find in an IDE.
This tool can:
- List all available code actions for a selected text range
- Execute a specific code action on that range
- Rename symbols across your codebase. This tool is the preferred way to rename things, and you should always prefer to rename code symbols using this tool rather than using textual find/replace when both are available.
Use this tool when you want to:
- Discover what code actions are available for a piece of code
- Apply automatic fixes and code transformations
- Rename variables, functions, or other symbols consistently throughout your project
- Clean up imports, implement interfaces, or perform other language-specific operations
- If unsure what actions are available, call the tool without specifying an action to get a list
- For common operations, you can directly specify actions like "quickfix.all" or "source.organizeImports"
- For renaming, use the special "textDocument/rename" action and provide the new name in the arguments field
- Be specific with your text range and context to ensure the tool identifies the correct code location
The tool will automatically save any changes it makes to your files.

View File

@@ -1,39 +0,0 @@
Returns either an outline of the public code symbols in the entire project (grouped by file) or else an outline of both the public and private code symbols within a particular file.
When a path is provided, this tool returns a hierarchical outline of code symbols for that specific file.
When no path is provided, it returns a list of all public code symbols in the project, organized by file.
You can also provide an optional regular expression which filters the output by only showing code symbols which match that regex.
Results are paginated with 2000 entries per page. Use the optional 'offset' parameter to request subsequent pages.
Markdown headings indicate the structure of the output; just like
with markdown headings, the more # symbols there are at the beginning of a line,
the deeper it is in the hierarchy.
Each code symbol entry ends with a line number or range, which tells you what portion of the
underlying source code file corresponds to that part of the outline. You can use
that line information with other tools, to strategically read portions of the source code.
For example, you can use this tool to find a relevant symbol in the project, then get the outline of the file which contains that symbol, then use the line number information from that file's outline to read different sections of that file, without having to read the entire file all at once (which can be slow, or use a lot of tokens).
<example>
# class Foo [L123-136]
## method do_something(arg1, arg2) [L124-126]
## method process_data(data) [L128-135]
# class Bar [L145-161]
## method initialize() [L146-149]
## method update_state(new_state) [L160]
## private method _validate_state(state) [L161-162]
</example>
This example shows how tree-sitter outlines the structure of source code:
1. `class Foo` is defined on lines 123-136
- It contains a method `do_something` spanning lines 124-126
- It also has a method `process_data` spanning lines 128-135
2. `class Bar` is defined on lines 145-161
- It has an `initialize` method spanning lines 146-149
- It has an `update_state` method on line 160
- It has a private method `_validate_state` spanning lines 161-162

View File

@@ -1,9 +0,0 @@
Reads the contents of a path on the filesystem.
If the path is a directory, this lists all files and directories within that path.
If the path is a file, this returns the file's contents.
When reading a file, if the file is too big and no line range is specified, an outline of the file's code symbols is listed instead, which can be used to request specific line ranges in a subsequent call.
Similarly, if a directory has too many entries to show at once, a subset of entries will be shown,
and subsequent requests can use starting and ending line numbers to get other subsets.

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@ use std::cell::LazyCell;
use util::debug_panic;
const START_MARKER: LazyCell<Regex> = LazyCell::new(|| Regex::new(r"\n?```\S*\n").unwrap());
const END_MARKER: LazyCell<Regex> = LazyCell::new(|| Regex::new(r"\n```\s*$").unwrap());
const END_MARKER: LazyCell<Regex> = LazyCell::new(|| Regex::new(r"(^|\n)```\s*$").unwrap());
#[derive(Debug)]
pub enum CreateFileParserEvent {
@@ -184,6 +184,22 @@ mod tests {
);
}
#[gpui::test(iterations = 10)]
fn test_empty_file(mut rng: StdRng) {
let mut parser = CreateFileParser::new();
assert_eq!(
parse_random_chunks(
indoc! {"
```
```
"},
&mut parser,
&mut rng
),
"".to_string()
);
}
fn parse_random_chunks(input: &str, parser: &mut CreateFileParser, rng: &mut StdRng) -> String {
let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);

View File

@@ -11,7 +11,7 @@ const END_TAGS: [&str; 3] = [OLD_TEXT_END_TAG, NEW_TEXT_END_TAG, EDITS_END_TAG];
#[derive(Debug)]
pub enum EditParserEvent {
OldText(String),
OldTextChunk { chunk: String, done: bool },
NewTextChunk { chunk: String, done: bool },
}
@@ -33,7 +33,7 @@ pub struct EditParser {
#[derive(Debug, PartialEq)]
enum EditParserState {
Pending,
WithinOldText,
WithinOldText { start: bool },
AfterOldText,
WithinNewText { start: bool },
}
@@ -56,20 +56,23 @@ impl EditParser {
EditParserState::Pending => {
if let Some(start) = self.buffer.find("<old_text>") {
self.buffer.drain(..start + "<old_text>".len());
self.state = EditParserState::WithinOldText;
self.state = EditParserState::WithinOldText { start: true };
} else {
break;
}
}
EditParserState::WithinOldText => {
if let Some(tag_range) = self.find_end_tag() {
let mut start = 0;
if self.buffer.starts_with('\n') {
start = 1;
EditParserState::WithinOldText { start } => {
if !self.buffer.is_empty() {
if *start && self.buffer.starts_with('\n') {
self.buffer.remove(0);
}
let mut old_text = self.buffer[start..tag_range.start].to_string();
if old_text.ends_with('\n') {
old_text.pop();
*start = false;
}
if let Some(tag_range) = self.find_end_tag() {
let mut chunk = self.buffer[..tag_range.start].to_string();
if chunk.ends_with('\n') {
chunk.pop();
}
self.metrics.tags += 1;
@@ -79,8 +82,14 @@ impl EditParser {
self.buffer.drain(..tag_range.end);
self.state = EditParserState::AfterOldText;
edit_events.push(EditParserEvent::OldText(old_text));
edit_events.push(EditParserEvent::OldTextChunk { chunk, done: true });
} else {
if !self.ends_with_tag_prefix() {
edit_events.push(EditParserEvent::OldTextChunk {
chunk: mem::take(&mut self.buffer),
done: false,
});
}
break;
}
}
@@ -115,11 +124,7 @@ impl EditParser {
self.state = EditParserState::Pending;
edit_events.push(EditParserEvent::NewTextChunk { chunk, done: true });
} else {
let mut end_prefixes = END_TAGS
.iter()
.flat_map(|tag| (1..tag.len()).map(move |i| &tag[..i]))
.chain(["\n"]);
if end_prefixes.all(|prefix| !self.buffer.ends_with(&prefix)) {
if !self.ends_with_tag_prefix() {
edit_events.push(EditParserEvent::NewTextChunk {
chunk: mem::take(&mut self.buffer),
done: false,
@@ -141,6 +146,14 @@ impl EditParser {
Some(start_ix..start_ix + tag.len())
}
fn ends_with_tag_prefix(&self) -> bool {
let mut end_prefixes = END_TAGS
.iter()
.flat_map(|tag| (1..tag.len()).map(move |i| &tag[..i]))
.chain(["\n"]);
end_prefixes.any(|prefix| self.buffer.ends_with(&prefix))
}
pub fn finish(self) -> EditParserMetrics {
self.metrics
}
@@ -412,20 +425,28 @@ mod tests {
chunk_indices.sort();
chunk_indices.push(input.len());
let mut old_text = Some(String::new());
let mut new_text = None;
let mut pending_edit = Edit::default();
let mut edits = Vec::new();
let mut last_ix = 0;
for chunk_ix in chunk_indices {
for event in parser.push(&input[last_ix..chunk_ix]) {
match event {
EditParserEvent::OldText(old_text) => {
pending_edit.old_text = old_text;
EditParserEvent::OldTextChunk { chunk, done } => {
old_text.as_mut().unwrap().push_str(&chunk);
if done {
pending_edit.old_text = old_text.take().unwrap();
new_text = Some(String::new());
}
}
EditParserEvent::NewTextChunk { chunk, done } => {
pending_edit.new_text.push_str(&chunk);
new_text.as_mut().unwrap().push_str(&chunk);
if done {
pending_edit.new_text = new_text.take().unwrap();
edits.push(pending_edit);
pending_edit = Edit::default();
old_text = Some(String::new());
}
}
}
@@ -433,8 +454,6 @@ mod tests {
last_ix = chunk_ix;
}
assert_eq!(pending_edit, Edit::default(), "unfinished edit");
edits
}
}

View File

@@ -0,0 +1,694 @@
use language::{Point, TextBufferSnapshot};
use std::{cmp, ops::Range};
const REPLACEMENT_COST: u32 = 1;
const INSERTION_COST: u32 = 3;
const DELETION_COST: u32 = 10;
/// A streaming fuzzy matcher that can process text chunks incrementally
/// and return the best match found so far at each step.
pub struct StreamingFuzzyMatcher {
snapshot: TextBufferSnapshot,
query_lines: Vec<String>,
incomplete_line: String,
best_match: Option<Range<usize>>,
matrix: SearchMatrix,
}
impl StreamingFuzzyMatcher {
pub fn new(snapshot: TextBufferSnapshot) -> Self {
let buffer_line_count = snapshot.max_point().row as usize + 1;
Self {
snapshot,
query_lines: Vec::new(),
incomplete_line: String::new(),
best_match: None,
matrix: SearchMatrix::new(buffer_line_count + 1),
}
}
/// Returns the query lines.
pub fn query_lines(&self) -> &[String] {
&self.query_lines
}
/// Push a new chunk of text and get the best match found so far.
///
/// This method accumulates text chunks and processes complete lines.
/// Partial lines are buffered internally until a newline is received.
///
/// # Returns
///
/// Returns `Some(range)` if a match has been found with the accumulated
/// query so far, or `None` if no suitable match exists yet.
pub fn push(&mut self, chunk: &str) -> Option<Range<usize>> {
// Add the chunk to our incomplete line buffer
self.incomplete_line.push_str(chunk);
if let Some((last_pos, _)) = self.incomplete_line.match_indices('\n').next_back() {
let complete_part = &self.incomplete_line[..=last_pos];
// Split into lines and add to query_lines
for line in complete_part.lines() {
self.query_lines.push(line.to_string());
}
self.incomplete_line.replace_range(..last_pos + 1, "");
self.best_match = self.resolve_location_fuzzy();
}
self.best_match.clone()
}
/// Finish processing and return the final best match.
///
/// This processes any remaining incomplete line before returning the final
/// match result.
pub fn finish(&mut self) -> Option<Range<usize>> {
// Process any remaining incomplete line
if !self.incomplete_line.is_empty() {
self.query_lines.push(self.incomplete_line.clone());
self.best_match = self.resolve_location_fuzzy();
}
self.best_match.clone()
}
fn resolve_location_fuzzy(&mut self) -> Option<Range<usize>> {
let new_query_line_count = self.query_lines.len();
let old_query_line_count = self.matrix.rows.saturating_sub(1);
if new_query_line_count == old_query_line_count {
return None;
}
self.matrix.resize_rows(new_query_line_count + 1);
// Process only the new query lines
for row in old_query_line_count..new_query_line_count {
let query_line = self.query_lines[row].trim();
let leading_deletion_cost = (row + 1) as u32 * DELETION_COST;
self.matrix.set(
row + 1,
0,
SearchState::new(leading_deletion_cost, SearchDirection::Up),
);
let mut buffer_lines = self.snapshot.as_rope().chunks().lines();
let mut col = 0;
while let Some(buffer_line) = buffer_lines.next() {
let buffer_line = buffer_line.trim();
let up = SearchState::new(
self.matrix
.get(row, col + 1)
.cost
.saturating_add(DELETION_COST),
SearchDirection::Up,
);
let left = SearchState::new(
self.matrix
.get(row + 1, col)
.cost
.saturating_add(INSERTION_COST),
SearchDirection::Left,
);
let diagonal = SearchState::new(
if query_line == buffer_line {
self.matrix.get(row, col).cost
} else if fuzzy_eq(query_line, buffer_line) {
self.matrix.get(row, col).cost + REPLACEMENT_COST
} else {
self.matrix
.get(row, col)
.cost
.saturating_add(DELETION_COST + INSERTION_COST)
},
SearchDirection::Diagonal,
);
self.matrix
.set(row + 1, col + 1, up.min(left).min(diagonal));
col += 1;
}
}
// Traceback to find the best match
let buffer_line_count = self.snapshot.max_point().row as usize + 1;
let mut buffer_row_end = buffer_line_count as u32;
let mut best_cost = u32::MAX;
for col in 1..=buffer_line_count {
let cost = self.matrix.get(new_query_line_count, col).cost;
if cost < best_cost {
best_cost = cost;
buffer_row_end = col as u32;
}
}
let mut matched_lines = 0;
let mut query_row = new_query_line_count;
let mut buffer_row_start = buffer_row_end;
while query_row > 0 && buffer_row_start > 0 {
let current = self.matrix.get(query_row, buffer_row_start as usize);
match current.direction {
SearchDirection::Diagonal => {
query_row -= 1;
buffer_row_start -= 1;
matched_lines += 1;
}
SearchDirection::Up => {
query_row -= 1;
}
SearchDirection::Left => {
buffer_row_start -= 1;
}
}
}
let matched_buffer_row_count = buffer_row_end - buffer_row_start;
let matched_ratio = matched_lines as f32
/ (matched_buffer_row_count as f32).max(new_query_line_count as f32);
if matched_ratio >= 0.8 {
let buffer_start_ix = self
.snapshot
.point_to_offset(Point::new(buffer_row_start, 0));
let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
buffer_row_end - 1,
self.snapshot.line_len(buffer_row_end - 1),
));
Some(buffer_start_ix..buffer_end_ix)
} else {
None
}
}
}
fn fuzzy_eq(left: &str, right: &str) -> bool {
const THRESHOLD: f64 = 0.8;
let min_levenshtein = left.len().abs_diff(right.len());
let min_normalized_levenshtein =
1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64);
if min_normalized_levenshtein < THRESHOLD {
return false;
}
strsim::normalized_levenshtein(left, right) >= THRESHOLD
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
enum SearchDirection {
Up,
Left,
Diagonal,
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct SearchState {
cost: u32,
direction: SearchDirection,
}
impl SearchState {
fn new(cost: u32, direction: SearchDirection) -> Self {
Self { cost, direction }
}
}
struct SearchMatrix {
cols: usize,
rows: usize,
data: Vec<SearchState>,
}
impl SearchMatrix {
fn new(cols: usize) -> Self {
SearchMatrix {
cols,
rows: 0,
data: Vec::new(),
}
}
fn resize_rows(&mut self, needed_rows: usize) {
debug_assert!(needed_rows > self.rows);
self.rows = needed_rows;
self.data.resize(
self.rows * self.cols,
SearchState::new(0, SearchDirection::Diagonal),
);
}
fn get(&self, row: usize, col: usize) -> SearchState {
debug_assert!(row < self.rows && col < self.cols);
self.data[row * self.cols + col]
}
fn set(&mut self, row: usize, col: usize, state: SearchState) {
debug_assert!(row < self.rows && col < self.cols);
self.data[row * self.cols + col] = state;
}
}
#[cfg(test)]
mod tests {
use super::*;
use indoc::indoc;
use language::{BufferId, TextBuffer};
use rand::prelude::*;
use util::test::{generate_marked_text, marked_text_ranges};
#[test]
fn test_empty_query() {
let buffer = TextBuffer::new(
0,
BufferId::new(1).unwrap(),
"Hello world\nThis is a test\nFoo bar baz",
);
let snapshot = buffer.snapshot();
let mut finder = StreamingFuzzyMatcher::new(snapshot.clone());
assert_eq!(push(&mut finder, ""), None);
assert_eq!(finish(finder), None);
}
#[test]
fn test_streaming_exact_match() {
let buffer = TextBuffer::new(
0,
BufferId::new(1).unwrap(),
"Hello world\nThis is a test\nFoo bar baz",
);
let snapshot = buffer.snapshot();
let mut finder = StreamingFuzzyMatcher::new(snapshot.clone());
// Push partial query
assert_eq!(push(&mut finder, "This"), None);
// Complete the line
assert_eq!(
push(&mut finder, " is a test\n"),
Some("This is a test".to_string())
);
// Finish should return the same result
assert_eq!(finish(finder), Some("This is a test".to_string()));
}
#[test]
fn test_streaming_fuzzy_match() {
let buffer = TextBuffer::new(
0,
BufferId::new(1).unwrap(),
indoc! {"
function foo(a, b) {
return a + b;
}
function bar(x, y) {
return x * y;
}
"},
);
let snapshot = buffer.snapshot();
let mut finder = StreamingFuzzyMatcher::new(snapshot.clone());
// Push a fuzzy query that should match the first function
assert_eq!(
push(&mut finder, "function foo(a, c) {\n").as_deref(),
Some("function foo(a, b) {")
);
assert_eq!(
push(&mut finder, " return a + c;\n}\n").as_deref(),
Some(concat!(
"function foo(a, b) {\n",
" return a + b;\n",
"}"
))
);
}
#[test]
fn test_incremental_improvement() {
let buffer = TextBuffer::new(
0,
BufferId::new(1).unwrap(),
"Line 1\nLine 2\nLine 3\nLine 4\nLine 5",
);
let snapshot = buffer.snapshot();
let mut finder = StreamingFuzzyMatcher::new(snapshot.clone());
// No match initially
assert_eq!(push(&mut finder, "Lin"), None);
// Get a match when we complete a line
assert_eq!(push(&mut finder, "e 3\n"), Some("Line 3".to_string()));
// The match might change if we add more specific content
assert_eq!(
push(&mut finder, "Line 4\n"),
Some("Line 3\nLine 4".to_string())
);
assert_eq!(finish(finder), Some("Line 3\nLine 4".to_string()));
}
#[test]
fn test_incomplete_lines_buffering() {
let buffer = TextBuffer::new(
0,
BufferId::new(1).unwrap(),
indoc! {"
The quick brown fox
jumps over the lazy dog
Pack my box with five dozen liquor jugs
"},
);
let snapshot = buffer.snapshot();
let mut finder = StreamingFuzzyMatcher::new(snapshot.clone());
// Push text in small chunks across line boundaries
assert_eq!(push(&mut finder, "jumps "), None); // No newline yet
assert_eq!(push(&mut finder, "over the"), None); // Still no newline
assert_eq!(push(&mut finder, " lazy"), None); // Still incomplete
// Complete the line
assert_eq!(
push(&mut finder, " dog\n"),
Some("jumps over the lazy dog".to_string())
);
}
#[test]
fn test_multiline_fuzzy_match() {
let buffer = TextBuffer::new(
0,
BufferId::new(1).unwrap(),
indoc! {r#"
impl Display for User {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "User: {} ({})", self.name, self.email)
}
}
impl Debug for User {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("User")
.field("name", &self.name)
.field("email", &self.email)
.finish()
}
}
"#},
);
let snapshot = buffer.snapshot();
let mut finder = StreamingFuzzyMatcher::new(snapshot.clone());
assert_eq!(
push(&mut finder, "impl Debug for User {\n"),
Some("impl Debug for User {".to_string())
);
assert_eq!(
push(
&mut finder,
" fn fmt(&self, f: &mut Formatter) -> Result {\n"
)
.as_deref(),
Some(concat!(
"impl Debug for User {\n",
" fn fmt(&self, f: &mut Formatter) -> fmt::Result {"
))
);
assert_eq!(
push(&mut finder, " f.debug_struct(\"User\")\n").as_deref(),
Some(concat!(
"impl Debug for User {\n",
" fn fmt(&self, f: &mut Formatter) -> fmt::Result {\n",
" f.debug_struct(\"User\")"
))
);
assert_eq!(
push(
&mut finder,
" .field(\"name\", &self.username)\n"
)
.as_deref(),
Some(concat!(
"impl Debug for User {\n",
" fn fmt(&self, f: &mut Formatter) -> fmt::Result {\n",
" f.debug_struct(\"User\")\n",
" .field(\"name\", &self.name)"
))
);
assert_eq!(
finish(finder).as_deref(),
Some(concat!(
"impl Debug for User {\n",
" fn fmt(&self, f: &mut Formatter) -> fmt::Result {\n",
" f.debug_struct(\"User\")\n",
" .field(\"name\", &self.name)"
))
);
}
#[gpui::test(iterations = 100)]
fn test_resolve_location_single_line(mut rng: StdRng) {
assert_location_resolution(
concat!(
" Lorem\n",
"« ipsum»\n",
" dolor sit amet\n",
" consecteur",
),
"ipsum",
&mut rng,
);
}
#[gpui::test(iterations = 100)]
fn test_resolve_location_multiline(mut rng: StdRng) {
assert_location_resolution(
concat!(
" Lorem\n",
"« ipsum\n",
" dolor sit amet»\n",
" consecteur",
),
"ipsum\ndolor sit amet",
&mut rng,
);
}
#[gpui::test(iterations = 100)]
fn test_resolve_location_function_with_typo(mut rng: StdRng) {
assert_location_resolution(
indoc! {"
«fn foo1(a: usize) -> usize {
40
fn foo2(b: usize) -> usize {
42
}
"},
"fn foo1(a: usize) -> u32 {\n40\n}",
&mut rng,
);
}
#[gpui::test(iterations = 100)]
fn test_resolve_location_class_methods(mut rng: StdRng) {
assert_location_resolution(
indoc! {"
class Something {
one() { return 1; }
« two() { return 2222; }
three() { return 333; }
four() { return 4444; }
five() { return 5555; }
six() { return 6666; }»
seven() { return 7; }
eight() { return 8; }
}
"},
indoc! {"
two() { return 2222; }
four() { return 4444; }
five() { return 5555; }
six() { return 6666; }
"},
&mut rng,
);
}
#[gpui::test(iterations = 100)]
fn test_resolve_location_imports_no_match(mut rng: StdRng) {
assert_location_resolution(
indoc! {"
use std::ops::Range;
use std::sync::Mutex;
use std::{
collections::HashMap,
env,
ffi::{OsStr, OsString},
fs,
io::{BufRead, BufReader},
mem,
path::{Path, PathBuf},
process::Command,
sync::LazyLock,
time::SystemTime,
};
"},
indoc! {"
use std::collections::{HashMap, HashSet};
use std::ffi::{OsStr, OsString};
use std::fmt::Write as _;
use std::fs;
use std::io::{BufReader, Read, Write};
use std::mem;
use std::path::{Path, PathBuf};
use std::process::Command;
use std::sync::Arc;
"},
&mut rng,
);
}
#[gpui::test(iterations = 100)]
fn test_resolve_location_nested_closure(mut rng: StdRng) {
assert_location_resolution(
indoc! {"
impl Foo {
fn new() -> Self {
Self {
subscriptions: vec![
cx.observe_window_activation(window, |editor, window, cx| {
let active = window.is_window_active();
editor.blink_manager.update(cx, |blink_manager, cx| {
if active {
blink_manager.enable(cx);
} else {
blink_manager.disable(cx);
}
});
}),
];
}
}
}
"},
concat!(
" editor.blink_manager.update(cx, |blink_manager, cx| {\n",
" blink_manager.enable(cx);\n",
" });",
),
&mut rng,
);
}
#[gpui::test(iterations = 100)]
fn test_resolve_location_tool_invocation(mut rng: StdRng) {
assert_location_resolution(
indoc! {r#"
let tool = cx
.update(|cx| working_set.tool(&tool_name, cx))
.map_err(|err| {
anyhow!("Failed to look up tool '{}': {}", tool_name, err)
})?;
let Some(tool) = tool else {
return Err(anyhow!("Tool '{}' not found", tool_name));
};
let project = project.clone();
let action_log = action_log.clone();
let messages = messages.clone();
let tool_result = cx
.update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
.map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
tasks.push(tool_result.output);
"#},
concat!(
"let tool_result = cx\n",
" .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))\n",
" .output;",
),
&mut rng,
);
}
#[track_caller]
fn assert_location_resolution(text_with_expected_range: &str, query: &str, rng: &mut StdRng) {
let (text, expected_ranges) = marked_text_ranges(text_with_expected_range, false);
let buffer = TextBuffer::new(0, BufferId::new(1).unwrap(), text.clone());
let snapshot = buffer.snapshot();
let mut matcher = StreamingFuzzyMatcher::new(snapshot.clone());
// Split query into random chunks
let chunks = to_random_chunks(rng, query);
// Push chunks incrementally
for chunk in &chunks {
matcher.push(chunk);
}
let result = matcher.finish();
// If no expected ranges, we expect no match
if expected_ranges.is_empty() {
assert_eq!(
result, None,
"Expected no match for query: {:?}, but found: {:?}",
query, result
);
} else {
let mut actual_ranges = Vec::new();
if let Some(range) = result {
actual_ranges.push(range);
}
let text_with_actual_range = generate_marked_text(&text, &actual_ranges, false);
pretty_assertions::assert_eq!(
text_with_actual_range,
text_with_expected_range,
"Query: {:?}, Chunks: {:?}",
query,
chunks
);
}
}
fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec<String> {
let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
chunk_indices.sort();
chunk_indices.push(input.len());
let mut chunks = Vec::new();
let mut last_ix = 0;
for chunk_ix in chunk_indices {
chunks.push(input[last_ix..chunk_ix].to_string());
last_ix = chunk_ix;
}
chunks
}
fn push(finder: &mut StreamingFuzzyMatcher, chunk: &str) -> Option<String> {
finder
.push(chunk)
.map(|range| finder.snapshot.text_for_range(range).collect::<String>())
}
fn finish(mut finder: StreamingFuzzyMatcher) -> Option<String> {
let snapshot = finder.snapshot.clone();
finder
.finish()
.map(|range| snapshot.text_for_range(range).collect::<String>())
}
}

View File

@@ -12,21 +12,28 @@ use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer, PathKey};
use futures::StreamExt;
use gpui::{
Animation, AnimationExt, AnyWindowHandle, App, AppContext, AsyncApp, Entity, EntityId, Task,
Animation, AnimationExt, AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task,
TextStyleRefinement, WeakEntity, pulsating_between,
};
use indoc::formatdoc;
use language::{
Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer,
language_settings::SoftWrap,
Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Point, Rope,
TextBuffer,
language_settings::{self, FormatOnSave, SoftWrap},
};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
use project::{Project, ProjectPath};
use project::{
Project, ProjectPath,
lsp_store::{FormatTrigger, LspFormatTarget},
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::{
cmp::Reverse,
collections::HashSet,
ops::Range,
path::{Path, PathBuf},
sync::Arc,
time::Duration,
@@ -98,7 +105,7 @@ pub enum EditFileMode {
pub struct EditFileToolOutput {
pub original_path: PathBuf,
pub new_text: String,
pub old_text: String,
pub old_text: Arc<String>,
pub raw_output: Option<EditAgentOutput>,
}
@@ -187,8 +194,10 @@ impl Tool for EditFileTool {
});
let card_clone = card.clone();
let action_log_clone = action_log.clone();
let task = cx.spawn(async move |cx: &mut AsyncApp| {
let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
let edit_agent =
EditAgent::new(model, project.clone(), action_log_clone, Templates::new());
let buffer = project
.update(cx, |project, cx| {
@@ -200,10 +209,14 @@ impl Tool for EditFileTool {
let old_text = cx
.background_spawn({
let old_snapshot = old_snapshot.clone();
async move { old_snapshot.text() }
async move { Arc::new(old_snapshot.text()) }
})
.await;
if let Some(card) = card_clone.as_ref() {
card.update(cx, |card, cx| card.initialize(buffer.clone(), cx))?;
}
let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
edit_agent.edit(
buffer.clone(),
@@ -225,54 +238,78 @@ impl Tool for EditFileTool {
match event {
EditAgentOutputEvent::Edited => {
if let Some(card) = card_clone.as_ref() {
let new_snapshot =
buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let new_text = cx
.background_spawn({
let new_snapshot = new_snapshot.clone();
async move { new_snapshot.text() }
})
.await;
card.update(cx, |card, cx| {
card.set_diff(
project_path.path.clone(),
old_text.clone(),
new_text,
cx,
);
})
.log_err();
card.update(cx, |card, cx| card.update_diff(cx))?;
}
}
EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
EditAgentOutputEvent::ResolvingEditRange(range) => {
if let Some(card) = card_clone.as_ref() {
card.update(cx, |card, cx| card.reveal_range(range, cx))?;
}
}
EditAgentOutputEvent::OldTextNotFound(_) => hallucinated_old_text = true,
}
}
let agent_output = output.await?;
// If format_on_save is enabled, format the buffer
let format_on_save_enabled = buffer
.read_with(cx, |buffer, cx| {
let settings = language_settings::language_settings(
buffer.language().map(|l| l.name()),
buffer.file(),
cx,
);
!matches!(settings.format_on_save, FormatOnSave::Off)
})
.unwrap_or(false);
if format_on_save_enabled {
let format_task = project.update(cx, |project, cx| {
project.format(
HashSet::from_iter([buffer.clone()]),
LspFormatTarget::Buffers,
false, // Don't push to history since the tool did it.
FormatTrigger::Save,
cx,
)
})?;
format_task.await.log_err();
}
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
// Notify the action log that we've edited the buffer (*after* formatting has completed).
action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx);
})?;
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let new_text = cx.background_spawn({
let new_snapshot = new_snapshot.clone();
async move { new_snapshot.text() }
});
let diff = cx.background_spawn(async move {
language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
});
let (new_text, diff) = futures::join!(new_text, diff);
let (new_text, diff) = cx
.background_spawn({
let new_snapshot = new_snapshot.clone();
let old_text = old_text.clone();
async move {
let new_text = new_snapshot.text();
let diff = language::unified_diff(&old_text, &new_text);
(new_text, diff)
}
})
.await;
let output = EditFileToolOutput {
original_path: project_path.path.to_path_buf(),
new_text: new_text.clone(),
old_text: old_text.clone(),
old_text,
raw_output: Some(agent_output),
};
if let Some(card) = card_clone {
card.update(cx, |card, cx| {
card.set_diff(project_path.path.clone(), old_text, new_text, cx);
card.update_diff(cx);
card.finalize(cx)
})
.log_err();
}
@@ -282,12 +319,15 @@ impl Tool for EditFileTool {
anyhow::ensure!(
!hallucinated_old_text,
formatdoc! {"
Some edits were produced but none of them could be applied.
Read the relevant sections of {input_path} again so that
I can perform the requested edits.
"}
Some edits were produced but none of them could be applied.
Read the relevant sections of {input_path} again so that
I can perform the requested edits.
"}
);
Ok("No edits were made.".to_string().into())
Ok(ToolResultOutput {
content: ToolResultContent::Text("No edits were made.".into()),
output: serde_json::to_value(output).ok(),
})
} else {
Ok(ToolResultOutput {
content: ToolResultContent::Text(format!(
@@ -318,16 +358,48 @@ impl Tool for EditFileTool {
};
let card = cx.new(|cx| {
let mut card = EditFileToolCard::new(output.original_path.clone(), project, window, cx);
card.set_diff(
output.original_path.into(),
output.old_text,
output.new_text,
cx,
);
card
EditFileToolCard::new(output.original_path.clone(), project.clone(), window, cx)
});
cx.spawn({
let path: Arc<Path> = output.original_path.into();
let language_registry = project.read(cx).languages().clone();
let card = card.clone();
async move |cx| {
let buffer =
build_buffer(output.new_text, path.clone(), &language_registry, cx).await?;
let buffer_diff =
build_buffer_diff(output.old_text.clone(), &buffer, &language_registry, cx)
.await?;
card.update(cx, |card, cx| {
card.multibuffer.update(cx, |multibuffer, cx| {
let snapshot = buffer.read(cx).snapshot();
let diff = buffer_diff.read(cx);
let diff_hunk_ranges = diff
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
.collect::<Vec<_>>();
multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&buffer, cx),
buffer,
diff_hunk_ranges,
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
multibuffer.add_diff(buffer_diff, cx);
let end = multibuffer.len(cx);
card.total_lines =
Some(multibuffer.snapshot(cx).offset_to_point(end).row + 1);
});
cx.notify();
})?;
anyhow::Ok(())
}
})
.detach_and_log_err(cx);
Some(card.into())
}
}
@@ -402,12 +474,15 @@ pub struct EditFileToolCard {
editor: Entity<Editor>,
multibuffer: Entity<MultiBuffer>,
project: Entity<Project>,
buffer: Option<Entity<Buffer>>,
base_text: Option<Arc<String>>,
buffer_diff: Option<Entity<BufferDiff>>,
revealed_ranges: Vec<Range<Anchor>>,
diff_task: Option<Task<Result<()>>>,
preview_expanded: bool,
error_expanded: Option<Entity<Markdown>>,
full_height_expanded: bool,
total_lines: Option<u32>,
editor_unique_id: EntityId,
}
impl EditFileToolCard {
@@ -442,11 +517,14 @@ impl EditFileToolCard {
editor
});
Self {
editor_unique_id: editor.entity_id(),
path,
project,
editor,
multibuffer,
buffer: None,
base_text: None,
buffer_diff: None,
revealed_ranges: Vec::new(),
diff_task: None,
preview_expanded: true,
error_expanded: None,
@@ -455,46 +533,184 @@ impl EditFileToolCard {
}
}
pub fn has_diff(&self) -> bool {
self.total_lines.is_some()
pub fn initialize(&mut self, buffer: Entity<Buffer>, cx: &mut App) {
let buffer_snapshot = buffer.read(cx).snapshot();
let base_text = buffer_snapshot.text();
let language_registry = buffer.read(cx).language_registry();
let text_snapshot = buffer.read(cx).text_snapshot();
// Create a buffer diff with the current text as the base
let buffer_diff = cx.new(|cx| {
let mut diff = BufferDiff::new(&text_snapshot, cx);
let _ = diff.set_base_text(
buffer_snapshot.clone(),
language_registry,
text_snapshot,
cx,
);
diff
});
self.buffer = Some(buffer.clone());
self.base_text = Some(base_text.into());
self.buffer_diff = Some(buffer_diff.clone());
// Add the diff to the multibuffer
self.multibuffer
.update(cx, |multibuffer, cx| multibuffer.add_diff(buffer_diff, cx));
}
pub fn set_diff(
&mut self,
path: Arc<Path>,
old_text: String,
new_text: String,
cx: &mut Context<Self>,
) {
let language_registry = self.project.read(cx).languages().clone();
self.diff_task = Some(cx.spawn(async move |this, cx| {
let buffer = build_buffer(new_text, path.clone(), &language_registry, cx).await?;
let buffer_diff = build_buffer_diff(old_text, &buffer, &language_registry, cx).await?;
pub fn is_loading(&self) -> bool {
self.total_lines.is_none()
}
pub fn update_diff(&mut self, cx: &mut Context<Self>) {
let Some(buffer) = self.buffer.as_ref() else {
return;
};
let Some(buffer_diff) = self.buffer_diff.as_ref() else {
return;
};
let buffer = buffer.clone();
let buffer_diff = buffer_diff.clone();
let base_text = self.base_text.clone();
self.diff_task = Some(cx.spawn(async move |this, cx| {
let text_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot())?;
let diff_snapshot = BufferDiff::update_diff(
buffer_diff.clone(),
text_snapshot.clone(),
base_text,
false,
false,
None,
None,
cx,
)
.await?;
buffer_diff.update(cx, |diff, cx| {
diff.set_snapshot(diff_snapshot, &text_snapshot, cx)
})?;
this.update(cx, |this, cx| this.update_visible_ranges(cx))
}));
}
pub fn reveal_range(&mut self, range: Range<Anchor>, cx: &mut Context<Self>) {
self.revealed_ranges.push(range);
self.update_visible_ranges(cx);
}
fn update_visible_ranges(&mut self, cx: &mut Context<Self>) {
let Some(buffer) = self.buffer.as_ref() else {
return;
};
let ranges = self.excerpt_ranges(cx);
self.total_lines = self.multibuffer.update(cx, |multibuffer, cx| {
multibuffer.set_excerpts_for_path(
PathKey::for_buffer(buffer, cx),
buffer.clone(),
ranges,
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
let end = multibuffer.len(cx);
Some(multibuffer.snapshot(cx).offset_to_point(end).row + 1)
});
cx.notify();
}
fn excerpt_ranges(&self, cx: &App) -> Vec<Range<Point>> {
let Some(buffer) = self.buffer.as_ref() else {
return Vec::new();
};
let Some(diff) = self.buffer_diff.as_ref() else {
return Vec::new();
};
let buffer = buffer.read(cx);
let diff = diff.read(cx);
let mut ranges = diff
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
.collect::<Vec<_>>();
ranges.extend(
self.revealed_ranges
.iter()
.map(|range| range.to_point(&buffer)),
);
ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
// Merge adjacent ranges
let mut ranges = ranges.into_iter().peekable();
let mut merged_ranges = Vec::new();
while let Some(mut range) = ranges.next() {
while let Some(next_range) = ranges.peek() {
if range.end >= next_range.start {
range.end = range.end.max(next_range.end);
ranges.next();
} else {
break;
}
}
merged_ranges.push(range);
}
merged_ranges
}
pub fn finalize(&mut self, cx: &mut Context<Self>) -> Result<()> {
let ranges = self.excerpt_ranges(cx);
let buffer = self.buffer.take().context("card was already finalized")?;
let base_text = self
.base_text
.take()
.context("card was already finalized")?;
let language_registry = self.project.read(cx).languages().clone();
// Replace the buffer in the multibuffer with the snapshot
let buffer = cx.new(|cx| {
let language = buffer.read(cx).language().cloned();
let buffer = TextBuffer::new_normalized(
0,
cx.entity_id().as_non_zero_u64().into(),
buffer.read(cx).line_ending(),
buffer.read(cx).as_rope().clone(),
);
let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite);
buffer.set_language(language, cx);
buffer
});
let buffer_diff = cx.spawn({
let buffer = buffer.clone();
let language_registry = language_registry.clone();
async move |_this, cx| {
build_buffer_diff(base_text, &buffer, &language_registry, cx).await
}
});
cx.spawn(async move |this, cx| {
let buffer_diff = buffer_diff.await?;
this.update(cx, |this, cx| {
this.total_lines = this.multibuffer.update(cx, |multibuffer, cx| {
let snapshot = buffer.read(cx).snapshot();
let diff = buffer_diff.read(cx);
let diff_hunk_ranges = diff
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
.collect::<Vec<_>>();
this.multibuffer.update(cx, |multibuffer, cx| {
let path_key = PathKey::for_buffer(&buffer, cx);
multibuffer.clear(cx);
multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&buffer, cx),
path_key,
buffer,
diff_hunk_ranges,
ranges,
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
multibuffer.add_diff(buffer_diff, cx);
let end = multibuffer.len(cx);
Some(multibuffer.snapshot(cx).offset_to_point(end).row + 1)
multibuffer.add_diff(buffer_diff.clone(), cx);
});
cx.notify();
})
}));
})
.detach_and_log_err(cx);
Ok(())
}
}
@@ -512,7 +728,7 @@ impl ToolCard for EditFileToolCard {
};
let path_label_button = h_flex()
.id(("edit-tool-path-label-button", self.editor_unique_id))
.id(("edit-tool-path-label-button", self.editor.entity_id()))
.w_full()
.max_w_full()
.px_1()
@@ -611,7 +827,7 @@ impl ToolCard for EditFileToolCard {
)
.child(
Disclosure::new(
("edit-file-error-disclosure", self.editor_unique_id),
("edit-file-error-disclosure", self.editor.entity_id()),
self.error_expanded.is_some(),
)
.opened_icon(IconName::ChevronUp)
@@ -633,10 +849,10 @@ impl ToolCard for EditFileToolCard {
),
)
})
.when(error_message.is_none() && self.has_diff(), |header| {
.when(error_message.is_none() && !self.is_loading(), |header| {
header.child(
Disclosure::new(
("edit-file-disclosure", self.editor_unique_id),
("edit-file-disclosure", self.editor.entity_id()),
self.preview_expanded,
)
.opened_icon(IconName::ChevronUp)
@@ -772,10 +988,10 @@ impl ToolCard for EditFileToolCard {
),
)
})
.when(!self.has_diff() && error_message.is_none(), |card| {
.when(self.is_loading() && error_message.is_none(), |card| {
card.child(waiting_for_diff)
})
.when(self.preview_expanded && self.has_diff(), |card| {
.when(self.preview_expanded && !self.is_loading(), |card| {
card.child(
v_flex()
.relative()
@@ -797,7 +1013,7 @@ impl ToolCard for EditFileToolCard {
.when(is_collapsible, |card| {
card.child(
h_flex()
.id(("expand-button", self.editor_unique_id))
.id(("expand-button", self.editor.entity_id()))
.flex_none()
.cursor_pointer()
.h_5()
@@ -871,19 +1087,23 @@ async fn build_buffer(
}
async fn build_buffer_diff(
mut old_text: String,
old_text: Arc<String>,
buffer: &Entity<Buffer>,
language_registry: &Arc<LanguageRegistry>,
cx: &mut AsyncApp,
) -> Result<Entity<BufferDiff>> {
LineEnding::normalize(&mut old_text);
let buffer = cx.update(|cx| buffer.read(cx).snapshot())?;
let old_text_rope = cx
.background_spawn({
let old_text = old_text.clone();
async move { Rope::from(old_text.as_str()) }
})
.await;
let base_buffer = cx
.update(|cx| {
Buffer::build_snapshot(
old_text.clone().into(),
old_text_rope,
buffer.language().cloned(),
Some(language_registry.clone()),
cx,
@@ -895,7 +1115,7 @@ async fn build_buffer_diff(
.update(|cx| {
BufferDiffSnapshot::new_with_base_buffer(
buffer.text.clone(),
Some(old_text.into()),
Some(old_text),
base_buffer,
cx,
)
@@ -920,8 +1140,8 @@ async fn build_buffer_diff(
mod tests {
use super::*;
use client::TelemetrySettings;
use fs::FakeFs;
use gpui::TestAppContext;
use fs::{FakeFs, Fs};
use gpui::{TestAppContext, UpdateGlobal};
use language_model::fake_provider::FakeLanguageModel;
use serde_json::json;
use settings::SettingsStore;
@@ -1131,4 +1351,340 @@ mod tests {
Project::init_settings(cx);
});
}
#[gpui::test]
async fn test_format_on_save(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({"src": {}})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
// Set up a Rust language with LSP formatting support
let rust_language = Arc::new(language::Language::new(
language::LanguageConfig {
name: "Rust".into(),
matcher: language::LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
None,
));
// Register the language and fake LSP
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
language_registry.add(rust_language);
let mut fake_language_servers = language_registry.register_fake_lsp(
"Rust",
language::FakeLspAdapter {
capabilities: lsp::ServerCapabilities {
document_formatting_provider: Some(lsp::OneOf::Left(true)),
..Default::default()
},
..Default::default()
},
);
// Create the file
fs.save(
path!("/root/src/main.rs").as_ref(),
&"initial content".into(),
language::LineEnding::Unix,
)
.await
.unwrap();
// Open the buffer to trigger LSP initialization
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(path!("/root/src/main.rs"), cx)
})
.await
.unwrap();
// Register the buffer with language servers
let _handle = project.update(cx, |project, cx| {
project.register_buffer_with_language_servers(&buffer, cx)
});
const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
const FORMATTED_CONTENT: &str =
"This file was formatted by the fake formatter in the test.\n";
// Get the fake language server and set up formatting handler
let fake_language_server = fake_language_servers.next().await.unwrap();
fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
|_, _| async move {
Ok(Some(vec![lsp::TextEdit {
range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
new_text: FORMATTED_CONTENT.to_string(),
}]))
}
});
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
// First, test with format_on_save enabled
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
cx,
|settings| {
settings.defaults.format_on_save = Some(FormatOnSave::On);
settings.defaults.formatter =
Some(language::language_settings::SelectedFormatter::Auto);
},
);
});
});
// Have the model stream unformatted content
let edit_result = {
let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Create main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
});
// Stream the unformatted content
cx.executor().run_until_parked();
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
model.end_last_completion_stream();
edit_task.await
};
assert!(edit_result.is_ok());
// Wait for any async operations (e.g. formatting) to complete
cx.executor().run_until_parked();
// Read the file to verify it was formatted automatically
let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
assert_eq!(
// Ignore carriage returns on Windows
new_content.replace("\r\n", "\n"),
FORMATTED_CONTENT,
"Code should be formatted when format_on_save is enabled"
);
let stale_buffer_count = action_log.read_with(cx, |log, cx| log.stale_buffers(cx).count());
assert_eq!(
stale_buffer_count, 0,
"BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
This causes the agent to think the file was modified externally when it was just formatted.",
stale_buffer_count
);
// Next, test with format_on_save disabled
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
cx,
|settings| {
settings.defaults.format_on_save = Some(FormatOnSave::Off);
},
);
});
});
// Stream unformatted edits again
let edit_result = {
let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Update main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
});
// Stream the unformatted content
cx.executor().run_until_parked();
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
model.end_last_completion_stream();
edit_task.await
};
assert!(edit_result.is_ok());
// Wait for any async operations (e.g. formatting) to complete
cx.executor().run_until_parked();
// Verify the file was not formatted
let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
assert_eq!(
// Ignore carriage returns on Windows
new_content.replace("\r\n", "\n"),
UNFORMATTED_CONTENT,
"Code should not be formatted when format_on_save is disabled"
);
}
#[gpui::test]
async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({"src": {}})).await;
// Create a simple file with trailing whitespace
fs.save(
path!("/root/src/main.rs").as_ref(),
&"initial content".into(),
language::LineEnding::Unix,
)
.await
.unwrap();
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
// First, test with remove_trailing_whitespace_on_save enabled
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
cx,
|settings| {
settings.defaults.remove_trailing_whitespace_on_save = Some(true);
},
);
});
});
const CONTENT_WITH_TRAILING_WHITESPACE: &str =
"fn main() { \n println!(\"Hello!\"); \n}\n";
// Have the model stream content that contains trailing whitespace
let edit_result = {
let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Create main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
});
// Stream the content with trailing whitespace
cx.executor().run_until_parked();
model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
model.end_last_completion_stream();
edit_task.await
};
assert!(edit_result.is_ok());
// Wait for any async operations (e.g. formatting) to complete
cx.executor().run_until_parked();
// Read the file to verify trailing whitespace was removed automatically
assert_eq!(
// Ignore carriage returns on Windows
fs.load(path!("/root/src/main.rs").as_ref())
.await
.unwrap()
.replace("\r\n", "\n"),
"fn main() {\n println!(\"Hello!\");\n}\n",
"Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
);
// Next, test with remove_trailing_whitespace_on_save disabled
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
cx,
|settings| {
settings.defaults.remove_trailing_whitespace_on_save = Some(false);
},
);
});
});
// Stream edits again with trailing whitespace
let edit_result = {
let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Update main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
});
// Stream the content with trailing whitespace
cx.executor().run_until_parked();
model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
model.end_last_completion_stream();
edit_task.await
};
assert!(edit_result.is_ok());
// Wait for any async operations (e.g. formatting) to complete
cx.executor().run_until_parked();
// Verify the file still has trailing whitespace
// Read the file again - it should still have trailing whitespace
let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
assert_eq!(
// Ignore carriage returns on Windows
final_content.replace("\r\n", "\n"),
CONTENT_WITH_TRAILING_WHITESPACE,
"Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
);
}
}

View File

@@ -1,15 +0,0 @@
Renames a symbol across your codebase using the language server's semantic knowledge.
This tool performs a rename refactoring operation on a specified symbol. It uses the project's language server to analyze the code and perform the rename correctly across all files where the symbol is referenced.
Unlike a simple find and replace, this tool understands the semantic meaning of the code, so it only renames the specific symbol you specify and not unrelated text that happens to have the same name.
Examples of symbols you can rename:
- Variables
- Functions
- Classes/structs
- Fields/properties
- Methods
- Interfaces/traits
The language server handles updating all references to the renamed symbol throughout the codebase.

View File

@@ -1,11 +0,0 @@
Gives detailed information about code symbols in your project such as variables, functions, classes, interface, traits, and other programming constructs, using the editor's integrated Language Server Protocol (LSP) servers.
This tool is the preferred way to do things like:
* Find out where a code symbol is first declared (or first defined - that is, assigned)
* Find all the places where a code symbol is referenced
* Find the type definition for a code symbol
* Find a code symbol's implementation
This tool gives more reliable answers than things like regex searches, because it can account for relevant semantics like aliases. It should be used over textual search tools (e.g. regex) when searching for information about code symbols that this tool supports directly.
This tool should not be used when you need to search for something that is not a code symbol.

View File

@@ -1,3 +1,6 @@
{{!----------------------------------------------------------------------------------
NOTE: Changes to this prompt require a symmetric update in monorepo/request_kind.rs
------------------------------------------------------------------------------------}}
You are an expert engineer and your task is to write a new file from scratch.
You MUST respond with the file's content wrapped in triple backticks (```).

View File

@@ -1,3 +1,6 @@
{{!----------------------------------------------------------------------------------
NOTE: Changes to this prompt require a symmetric update in monorepo/request_kind.rs
------------------------------------------------------------------------------------}}
You MUST respond with a series of edits to a file, using the following format:
```

View File

@@ -49,8 +49,12 @@ pub enum VersionCheckType {
pub enum AutoUpdateStatus {
Idle,
Checking,
Downloading,
Installing,
Downloading {
version: VersionCheckType,
},
Installing {
version: VersionCheckType,
},
Updated {
binary_path: PathBuf,
version: VersionCheckType,
@@ -511,12 +515,12 @@ impl AutoUpdater {
Self::get_latest_release(&this, "zed", OS, ARCH, release_channel, &mut cx).await?;
let fetched_version = fetched_release_data.clone().version;
let app_commit_sha = cx.update(|cx| AppCommitSha::try_global(cx).map(|sha| sha.full()));
let newer_version = Self::check_for_newer_version(
let newer_version = Self::check_if_fetched_version_is_newer(
*RELEASE_CHANNEL,
app_commit_sha,
installed_version,
previous_status.clone(),
fetched_version,
previous_status.clone(),
)?;
let Some(newer_version) = newer_version else {
@@ -531,7 +535,9 @@ impl AutoUpdater {
};
this.update(&mut cx, |this, cx| {
this.status = AutoUpdateStatus::Downloading;
this.status = AutoUpdateStatus::Downloading {
version: newer_version.clone(),
};
cx.notify();
})?;
@@ -540,7 +546,9 @@ impl AutoUpdater {
download_release(&target_path, fetched_release_data, client, &cx).await?;
this.update(&mut cx, |this, cx| {
this.status = AutoUpdateStatus::Installing;
this.status = AutoUpdateStatus::Installing {
version: newer_version.clone(),
};
cx.notify();
})?;
@@ -557,12 +565,12 @@ impl AutoUpdater {
})
}
fn check_for_newer_version(
fn check_if_fetched_version_is_newer(
release_channel: ReleaseChannel,
app_commit_sha: Result<Option<String>>,
installed_version: SemanticVersion,
status: AutoUpdateStatus,
fetched_version: String,
status: AutoUpdateStatus,
) -> Result<Option<VersionCheckType>> {
let parsed_fetched_version = fetched_version.parse::<SemanticVersion>();
@@ -575,7 +583,7 @@ impl AutoUpdater {
return Ok(newer_version);
}
VersionCheckType::Semantic(cached_version) => {
return Self::check_for_newer_version_non_nightly(
return Self::check_if_fetched_version_is_newer_non_nightly(
cached_version,
parsed_fetched_version?,
);
@@ -594,7 +602,7 @@ impl AutoUpdater {
.then(|| VersionCheckType::Sha(AppCommitSha::new(fetched_version)));
Ok(newer_version)
}
_ => Self::check_for_newer_version_non_nightly(
_ => Self::check_if_fetched_version_is_newer_non_nightly(
installed_version,
parsed_fetched_version?,
),
@@ -631,7 +639,7 @@ impl AutoUpdater {
}
}
fn check_for_newer_version_non_nightly(
fn check_if_fetched_version_is_newer_non_nightly(
installed_version: SemanticVersion,
fetched_version: SemanticVersion,
) -> Result<Option<VersionCheckType>> {
@@ -925,12 +933,12 @@ mod tests {
let status = AutoUpdateStatus::Idle;
let fetched_version = SemanticVersion::new(1, 0, 0);
let newer_version = AutoUpdater::check_for_newer_version(
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
release_channel,
app_commit_sha,
installed_version,
status,
fetched_version.to_string(),
status,
);
assert_eq!(newer_version.unwrap(), None);
@@ -944,12 +952,12 @@ mod tests {
let status = AutoUpdateStatus::Idle;
let fetched_version = SemanticVersion::new(1, 0, 1);
let newer_version = AutoUpdater::check_for_newer_version(
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
release_channel,
app_commit_sha,
installed_version,
status,
fetched_version.to_string(),
status,
);
assert_eq!(
@@ -969,12 +977,12 @@ mod tests {
};
let fetched_version = SemanticVersion::new(1, 0, 1);
let newer_version = AutoUpdater::check_for_newer_version(
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
release_channel,
app_commit_sha,
installed_version,
status,
fetched_version.to_string(),
status,
);
assert_eq!(newer_version.unwrap(), None);
@@ -991,12 +999,12 @@ mod tests {
};
let fetched_version = SemanticVersion::new(1, 0, 2);
let newer_version = AutoUpdater::check_for_newer_version(
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
release_channel,
app_commit_sha,
installed_version,
status,
fetched_version.to_string(),
status,
);
assert_eq!(
@@ -1013,12 +1021,12 @@ mod tests {
let status = AutoUpdateStatus::Idle;
let fetched_sha = "a".to_string();
let newer_version = AutoUpdater::check_for_newer_version(
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
release_channel,
app_commit_sha,
installed_version,
status,
fetched_sha,
status,
);
assert_eq!(newer_version.unwrap(), None);
@@ -1032,12 +1040,12 @@ mod tests {
let status = AutoUpdateStatus::Idle;
let fetched_sha = "b".to_string();
let newer_version = AutoUpdater::check_for_newer_version(
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
release_channel,
app_commit_sha,
installed_version,
status,
fetched_sha.clone(),
status,
);
assert_eq!(
@@ -1057,12 +1065,12 @@ mod tests {
};
let fetched_sha = "b".to_string();
let newer_version = AutoUpdater::check_for_newer_version(
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
release_channel,
app_commit_sha,
installed_version,
status,
fetched_sha,
status,
);
assert_eq!(newer_version.unwrap(), None);
@@ -1079,12 +1087,12 @@ mod tests {
};
let fetched_sha = "c".to_string();
let newer_version = AutoUpdater::check_for_newer_version(
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
release_channel,
app_commit_sha,
installed_version,
status,
fetched_sha.clone(),
status,
);
assert_eq!(
@@ -1101,12 +1109,12 @@ mod tests {
let status = AutoUpdateStatus::Idle;
let fetched_sha = "a".to_string();
let newer_version = AutoUpdater::check_for_newer_version(
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
release_channel,
app_commit_sha,
installed_version,
status,
fetched_sha.clone(),
status,
);
assert_eq!(
@@ -1127,12 +1135,12 @@ mod tests {
};
let fetched_sha = "b".to_string();
let newer_version = AutoUpdater::check_for_newer_version(
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
release_channel,
app_commit_sha,
installed_version,
status,
fetched_sha,
status,
);
assert_eq!(newer_version.unwrap(), None);
@@ -1150,12 +1158,12 @@ mod tests {
};
let fetched_sha = "c".to_string();
let newer_version = AutoUpdater::check_for_newer_version(
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
release_channel,
app_commit_sha,
installed_version,
status,
fetched_sha.clone(),
status,
);
assert_eq!(

View File

@@ -20,6 +20,7 @@ test-support = ["sqlite"]
[dependencies]
anyhow.workspace = true
async-stripe.workspace = true
async-trait.workspace = true
async-tungstenite.workspace = true
aws-config = { version = "1.1.5" }
aws-sdk-s3 = { version = "1.15.0" }

View File

@@ -2,5 +2,5 @@ ZED_ENVIRONMENT=production
RUST_LOG=info
INVITE_LINK_PREFIX=https://zed.dev/invites/
AUTO_JOIN_CHANNEL_ID=283
DATABASE_MAX_CONNECTIONS=85
DATABASE_MAX_CONNECTIONS=250
LLM_DATABASE_MAX_CONNECTIONS=25

View File

@@ -29,6 +29,7 @@ use crate::db::billing_subscription::{
use crate::llm::db::subscription_usage_meter::CompletionMode;
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
use crate::rpc::{ResultExt as _, Server};
use crate::stripe_client::{StripeCustomerId, StripeSubscriptionId};
use crate::{AppState, Error, Result};
use crate::{db::UserId, llm::db::LlmDatabase};
use crate::{
@@ -282,7 +283,6 @@ async fn list_billing_subscriptions(
enum ProductCode {
ZedPro,
ZedProTrial,
ZedFree,
}
#[derive(Debug, Deserialize)]
@@ -338,8 +338,7 @@ async fn create_billing_subscription(
}
let customer_id = if let Some(existing_customer) = &existing_billing_customer {
CustomerId::from_str(&existing_customer.stripe_customer_id)
.context("failed to parse customer ID")?
StripeCustomerId(existing_customer.stripe_customer_id.clone().into())
} else {
stripe_billing
.find_or_create_customer_by_email(user.email_address.as_deref())
@@ -354,7 +353,7 @@ async fn create_billing_subscription(
let checkout_session_url = match body.product {
ProductCode::ZedPro => {
stripe_billing
.checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
.checkout_with_zed_pro(&customer_id, &user.github_login, &success_url)
.await?
}
ProductCode::ZedProTrial => {
@@ -371,18 +370,13 @@ async fn create_billing_subscription(
stripe_billing
.checkout_with_zed_pro_trial(
customer_id,
&customer_id,
&user.github_login,
feature_flags,
&success_url,
)
.await?
}
ProductCode::ZedFree => {
stripe_billing
.checkout_with_zed_free(customer_id, &user.github_login, &success_url)
.await?
}
};
Ok(Json(CreateBillingSubscriptionResponse {
@@ -498,8 +492,10 @@ async fn manage_billing_subscription(
let flow = match body.intent {
ManageSubscriptionIntent::ManageSubscription => None,
ManageSubscriptionIntent::UpgradeToPro => {
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
let zed_pro_price_id: stripe::PriceId =
stripe_billing.zed_pro_price_id().await?.try_into()?;
let zed_free_price_id: stripe::PriceId =
stripe_billing.zed_free_price_id().await?.try_into()?;
let stripe_subscription =
Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
@@ -1186,10 +1182,8 @@ async fn sync_subscription(
.has_active_billing_subscription(billing_customer.user_id)
.await?;
if !already_has_active_billing_subscription {
let stripe_customer_id = billing_customer
.stripe_customer_id
.parse::<stripe::CustomerId>()
.context("failed to parse Stripe customer ID from database")?;
let stripe_customer_id =
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
stripe_billing
.subscribe_to_zed_free(stripe_customer_id)
@@ -1515,6 +1509,12 @@ async fn sync_model_request_usage_with_stripe(
let claude_sonnet_4_max = stripe_billing
.find_price_by_lookup_key("claude-sonnet-4-requests-max")
.await?;
let claude_opus_4 = stripe_billing
.find_price_by_lookup_key("claude-opus-4-requests")
.await?;
let claude_opus_4_max = stripe_billing
.find_price_by_lookup_key("claude-opus-4-requests-max")
.await?;
let claude_3_5_sonnet = stripe_billing
.find_price_by_lookup_key("claude-3-5-sonnet-requests")
.await?;
@@ -1536,18 +1536,18 @@ async fn sync_model_request_usage_with_stripe(
);
};
let stripe_customer_id = billing_customer
.stripe_customer_id
.parse::<stripe::CustomerId>()
.context("failed to parse Stripe customer ID from database")?;
let stripe_subscription_id = billing_subscription
.stripe_subscription_id
.parse::<stripe::SubscriptionId>()
.context("failed to parse Stripe subscription ID from database")?;
let stripe_customer_id =
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
let stripe_subscription_id =
StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
let model = llm_db.model_by_id(usage_meter.model_id)?;
let (price, meter_event_name) = match model.name.as_str() {
"claude-opus-4" => match usage_meter.mode {
CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
},
"claude-sonnet-4" => match usage_meter.mode {
CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
CompletionMode::Max => (&claude_sonnet_4_max, "claude_sonnet_4/requests/max"),

View File

@@ -20,7 +20,7 @@ impl Database {
&self,
params: &CreateBillingCustomerParams,
) -> Result<billing_customer::Model> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let customer = billing_customer::Entity::insert(billing_customer::ActiveModel {
user_id: ActiveValue::set(params.user_id),
stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()),
@@ -40,7 +40,7 @@ impl Database {
id: BillingCustomerId,
params: &UpdateBillingCustomerParams,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
billing_customer::Entity::update(billing_customer::ActiveModel {
id: ActiveValue::set(id),
user_id: params.user_id.clone(),
@@ -61,7 +61,7 @@ impl Database {
&self,
id: BillingCustomerId,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::Id.eq(id))
.one(&*tx)
@@ -75,7 +75,7 @@ impl Database {
&self,
user_id: UserId,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::UserId.eq(user_id))
.one(&*tx)
@@ -89,7 +89,7 @@ impl Database {
&self,
stripe_customer_id: &str,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::StripeCustomerId.eq(stripe_customer_id))
.one(&*tx)

View File

@@ -22,7 +22,7 @@ impl Database {
&self,
user_id: UserId,
) -> Result<Option<billing_preference::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_preference::Entity::find()
.filter(billing_preference::Column::UserId.eq(user_id))
.one(&*tx)
@@ -37,7 +37,7 @@ impl Database {
user_id: UserId,
params: &CreateBillingPreferencesParams,
) -> Result<billing_preference::Model> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let preferences = billing_preference::Entity::insert(billing_preference::ActiveModel {
user_id: ActiveValue::set(user_id),
max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
@@ -65,7 +65,7 @@ impl Database {
user_id: UserId,
params: &UpdateBillingPreferencesParams,
) -> Result<billing_preference::Model> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let preferences = billing_preference::Entity::update_many()
.set(billing_preference::ActiveModel {
max_monthly_llm_usage_spending_in_cents: params

View File

@@ -35,7 +35,7 @@ impl Database {
&self,
params: &CreateBillingSubscriptionParams,
) -> Result<billing_subscription::Model> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
billing_customer_id: ActiveValue::set(params.billing_customer_id),
kind: ActiveValue::set(params.kind),
@@ -64,7 +64,7 @@ impl Database {
id: BillingSubscriptionId,
params: &UpdateBillingSubscriptionParams,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
billing_subscription::Entity::update(billing_subscription::ActiveModel {
id: ActiveValue::set(id),
billing_customer_id: params.billing_customer_id.clone(),
@@ -90,7 +90,7 @@ impl Database {
&self,
id: BillingSubscriptionId,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_subscription::Entity::find_by_id(id)
.one(&*tx)
.await?)
@@ -103,7 +103,7 @@ impl Database {
&self,
stripe_subscription_id: &str,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_subscription::Entity::find()
.filter(
billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id),
@@ -118,7 +118,7 @@ impl Database {
&self,
user_id: UserId,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(billing_customer::Column::UserId.eq(user_id))
@@ -152,7 +152,7 @@ impl Database {
&self,
user_id: UserId,
) -> Result<Vec<billing_subscription::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let subscriptions = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(billing_customer::Column::UserId.eq(user_id))
@@ -169,7 +169,7 @@ impl Database {
&self,
user_ids: HashSet<UserId>,
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| {
self.weak_transaction(|tx| {
let user_ids = user_ids.clone();
async move {
let mut rows = billing_subscription::Entity::find()
@@ -201,7 +201,7 @@ impl Database {
&self,
user_ids: HashSet<UserId>,
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| {
self.weak_transaction(|tx| {
let user_ids = user_ids.clone();
async move {
let mut rows = billing_subscription::Entity::find()
@@ -236,7 +236,7 @@ impl Database {
/// Returns the count of the active billing subscriptions for the user with the specified ID.
pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let count = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(

View File

@@ -9,7 +9,7 @@ pub enum ContributorSelector {
impl Database {
/// Retrieves the GitHub logins of all users who have signed the CLA.
pub async fn get_contributors(&self) -> Result<Vec<String>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryGithubLogin {
GithubLogin,
@@ -32,7 +32,7 @@ impl Database {
&self,
selector: &ContributorSelector,
) -> Result<Option<DateTime>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let condition = match selector {
ContributorSelector::GitHubUserId { github_user_id } => {
user::Column::GithubUserId.eq(*github_user_id)
@@ -69,7 +69,7 @@ impl Database {
github_user_created_at: DateTimeUtc,
initial_channel_id: Option<ChannelId>,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let user = self
.get_or_create_user_by_github_account_tx(
github_login,

View File

@@ -15,7 +15,7 @@ impl Database {
max_schema_version: i32,
limit: usize,
) -> Result<Vec<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let mut condition = Condition::all()
.add(
extension::Column::LatestVersion
@@ -43,7 +43,7 @@ impl Database {
ids: &[&str],
constraints: Option<&ExtensionVersionConstraints>,
) -> Result<Vec<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let extensions = extension::Entity::find()
.filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
.all(&*tx)
@@ -123,7 +123,7 @@ impl Database {
&self,
extension_id: &str,
) -> Result<Vec<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let condition = extension::Column::ExternalId
.eq(extension_id)
.into_condition();
@@ -162,7 +162,7 @@ impl Database {
extension_id: &str,
constraints: Option<&ExtensionVersionConstraints>,
) -> Result<Option<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let extension = extension::Entity::find()
.filter(extension::Column::ExternalId.eq(extension_id))
.one(&*tx)
@@ -187,7 +187,7 @@ impl Database {
extension_id: &str,
version: &str,
) -> Result<Option<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let extension = extension::Entity::find()
.filter(extension::Column::ExternalId.eq(extension_id))
.filter(extension_version::Column::Version.eq(version))
@@ -204,7 +204,7 @@ impl Database {
}
pub async fn get_known_extension_versions(&self) -> Result<HashMap<String, Vec<String>>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let mut extension_external_ids_by_id = HashMap::default();
let mut rows = extension::Entity::find().stream(&*tx).await?;
@@ -242,7 +242,7 @@ impl Database {
&self,
versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
for (external_id, versions) in versions_by_extension_id {
if versions.is_empty() {
continue;
@@ -346,7 +346,7 @@ impl Database {
}
pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryId {
Id,

View File

@@ -13,7 +13,7 @@ impl Database {
&self,
params: &CreateProcessedStripeEventParams,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
processed_stripe_event::Entity::insert(processed_stripe_event::ActiveModel {
stripe_event_id: ActiveValue::set(params.stripe_event_id.clone()),
stripe_event_type: ActiveValue::set(params.stripe_event_type.clone()),
@@ -35,7 +35,7 @@ impl Database {
&self,
event_id: &str,
) -> Result<Option<processed_stripe_event::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(processed_stripe_event::Entity::find_by_id(event_id)
.one(&*tx)
.await?)
@@ -48,7 +48,7 @@ impl Database {
&self,
event_ids: &[&str],
) -> Result<Vec<processed_stripe_event::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(processed_stripe_event::Entity::find()
.filter(
processed_stripe_event::Column::StripeEventId.is_in(event_ids.iter().copied()),

View File

@@ -382,7 +382,7 @@ impl Database {
/// Returns the active flags for the user.
pub async fn get_user_flags(&self, user: UserId) -> Result<Vec<String>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryAs {
Flag,

View File

@@ -9,6 +9,7 @@ pub mod migrations;
pub mod rpc;
pub mod seed;
pub mod stripe_billing;
pub mod stripe_client;
pub mod user_backfiller;
#[cfg(test)]

View File

@@ -5,6 +5,7 @@ use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::db::LlmDatabase;
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
use crate::stripe_client::StripeCustomerId;
use crate::{
AppState, Error, Result, auth,
db::{
@@ -4039,7 +4040,8 @@ async fn get_llm_api_token(
} else {
let customer_id = stripe_billing
.find_or_create_customer_by_email(user.email_address.as_deref())
.await?;
.await?
.try_into()?;
find_or_create_billing_customer(
&session.app_state,
@@ -4054,10 +4056,8 @@ async fn get_llm_api_token(
if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
billing_subscription
} else {
let stripe_customer_id = billing_customer
.stripe_customer_id
.parse::<stripe::CustomerId>()
.context("failed to parse Stripe customer ID from database")?;
let stripe_customer_id =
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
let stripe_subscription = stripe_billing
.subscribe_to_zed_free(stripe_customer_id)

View File

@@ -1,30 +1,49 @@
use std::sync::Arc;
use anyhow::{Context as _, anyhow};
use chrono::Utc;
use collections::HashMap;
use stripe::SubscriptionStatus;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::Result;
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use anyhow::{Context as _, anyhow};
use chrono::Utc;
use collections::HashMap;
use serde::{Deserialize, Serialize};
use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus};
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::stripe_client::{
RealStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams,
StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription,
StripeSubscriptionId, StripeSubscriptionTrialSettings,
StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
UpdateSubscriptionParams,
};
pub struct StripeBilling {
state: RwLock<StripeBillingState>,
client: Arc<stripe::Client>,
client: Arc<dyn StripeClient>,
}
#[derive(Default)]
struct StripeBillingState {
meters_by_event_name: HashMap<String, StripeMeter>,
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
prices_by_lookup_key: HashMap<String, stripe::Price>,
price_ids_by_meter_id: HashMap<String, StripePriceId>,
prices_by_lookup_key: HashMap<String, StripePrice>,
}
impl StripeBilling {
pub fn new(client: Arc<stripe::Client>) -> Self {
Self {
client: Arc::new(RealStripeClient::new(client.clone())),
state: RwLock::default(),
}
}
#[cfg(test)]
pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
Self {
client,
state: RwLock::default(),
@@ -36,24 +55,16 @@ impl StripeBilling {
let mut state = self.state.write().await;
let (meters, prices) = futures::try_join!(
StripeMeter::list(&self.client),
stripe::Price::list(
&self.client,
&stripe::ListPrices {
limit: Some(100),
..Default::default()
}
)
)?;
let (meters, prices) =
futures::try_join!(self.client.list_meters(), self.client.list_prices())?;
for meter in meters.data {
for meter in meters {
state
.meters_by_event_name
.insert(meter.event_name.clone(), meter);
}
for price in prices.data {
for price in prices {
if let Some(lookup_key) = price.lookup_key.clone() {
state.prices_by_lookup_key.insert(lookup_key, price.clone());
}
@@ -70,15 +81,15 @@ impl StripeBilling {
Ok(())
}
pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
self.find_price_id_by_lookup_key("zed-pro").await
}
pub async fn zed_free_price_id(&self) -> Result<PriceId> {
pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
self.find_price_id_by_lookup_key("zed-free").await
}
pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
self.state
.read()
.await
@@ -88,7 +99,7 @@ impl StripeBilling {
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
}
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
self.state
.read()
.await
@@ -102,8 +113,10 @@ impl StripeBilling {
&self,
subscription: &stripe::Subscription,
) -> Option<SubscriptionKind> {
let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
let zed_free_price_id = self.zed_free_price_id().await.ok()?;
let zed_pro_price_id: stripe::PriceId =
self.zed_pro_price_id().await.ok()?.try_into().ok()?;
let zed_free_price_id: stripe::PriceId =
self.zed_free_price_id().await.ok()?.try_into().ok()?;
subscription.items.data.iter().find_map(|item| {
let price = item.price.as_ref()?;
@@ -129,18 +142,11 @@ impl StripeBilling {
pub async fn find_or_create_customer_by_email(
&self,
email_address: Option<&str>,
) -> Result<CustomerId> {
) -> Result<StripeCustomerId> {
let existing_customer = if let Some(email) = email_address {
let customers = Customer::list(
&self.client,
&stripe::ListCustomers {
email: Some(email),
..Default::default()
},
)
.await?;
let customers = self.client.list_customers_by_email(email).await?;
customers.data.first().cloned()
customers.first().cloned()
} else {
None
};
@@ -148,14 +154,12 @@ impl StripeBilling {
let customer_id = if let Some(existing_customer) = existing_customer {
existing_customer.id
} else {
let customer = Customer::create(
&self.client,
CreateCustomer {
let customer = self
.client
.create_customer(crate::stripe_client::CreateCustomerParams {
email: email_address,
..Default::default()
},
)
.await?;
})
.await?;
customer.id
};
@@ -165,11 +169,10 @@ impl StripeBilling {
pub async fn subscribe_to_price(
&self,
subscription_id: &stripe::SubscriptionId,
price: &stripe::Price,
subscription_id: &StripeSubscriptionId,
price: &StripePrice,
) -> Result<()> {
let subscription =
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
let subscription = self.client.get_subscription(subscription_id).await?;
if subscription_contains_price(&subscription, &price.id) {
return Ok(());
@@ -180,39 +183,36 @@ impl StripeBilling {
let price_per_unit = price.unit_amount.unwrap_or_default();
let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
stripe::Subscription::update(
&self.client,
subscription_id,
stripe::UpdateSubscription {
items: Some(vec![stripe::UpdateSubscriptionItems {
price: Some(price.id.to_string()),
..Default::default()
}]),
trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
},
}),
..Default::default()
},
)
.await?;
self.client
.update_subscription(
subscription_id,
UpdateSubscriptionParams {
items: Some(vec![UpdateSubscriptionItems {
price: Some(price.id.clone()),
}]),
trial_settings: Some(StripeSubscriptionTrialSettings {
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
},
}),
},
)
.await?;
Ok(())
}
pub async fn bill_model_request_usage(
&self,
customer_id: &stripe::CustomerId,
customer_id: &StripeCustomerId,
event_name: &str,
requests: i32,
) -> Result<()> {
let timestamp = Utc::now().timestamp();
let idempotency_key = Uuid::new_v4();
StripeMeterEvent::create(
&self.client,
StripeCreateMeterEventParams {
self.client
.create_meter_event(StripeCreateMeterEventParams {
identifier: &format!("model_requests/{}", idempotency_key),
event_name,
payload: StripeCreateMeterEventPayload {
@@ -220,39 +220,37 @@ impl StripeBilling {
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
},
)
.await?;
})
.await?;
Ok(())
}
pub async fn checkout_with_zed_pro(
&self,
customer_id: stripe::CustomerId,
customer_id: &StripeCustomerId,
github_login: &str,
success_url: &str,
) -> Result<String> {
let zed_pro_price_id = self.zed_pro_price_id().await?;
let mut params = stripe::CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
let mut params = StripeCreateCheckoutSessionParams::default();
params.mode = Some(StripeCheckoutSessionMode::Subscription);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
price: Some(zed_pro_price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
params.success_url = Some(success_url);
let session = stripe::CheckoutSession::create(&self.client, params).await?;
let session = self.client.create_checkout_session(params).await?;
Ok(session.url.context("no checkout session URL")?)
}
pub async fn checkout_with_zed_pro_trial(
&self,
customer_id: stripe::CustomerId,
customer_id: &StripeCustomerId,
github_login: &str,
feature_flags: Vec<String>,
success_url: &str,
@@ -273,172 +271,75 @@ impl StripeBilling {
);
}
let mut params = stripe::CreateCheckoutSession::new();
params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
let mut params = StripeCreateCheckoutSessionParams::default();
params.subscription_data = Some(StripeCreateCheckoutSessionSubscriptionData {
trial_period_days: Some(trial_period_days),
trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
}
trial_settings: Some(StripeSubscriptionTrialSettings {
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
missing_payment_method:
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
},
}),
metadata: if !subscription_metadata.is_empty() {
Some(subscription_metadata)
} else {
None
},
..Default::default()
});
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.mode = Some(StripeCheckoutSessionMode::Subscription);
params.payment_method_collection =
Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
price: Some(zed_pro_price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
params.success_url = Some(success_url);
let session = stripe::CheckoutSession::create(&self.client, params).await?;
let session = self.client.create_checkout_session(params).await?;
Ok(session.url.context("no checkout session URL")?)
}
pub async fn subscribe_to_zed_free(
&self,
customer_id: stripe::CustomerId,
) -> Result<stripe::Subscription> {
customer_id: StripeCustomerId,
) -> Result<StripeSubscription> {
let zed_free_price_id = self.zed_free_price_id().await?;
let existing_subscriptions = stripe::Subscription::list(
&self.client,
&stripe::ListSubscriptions {
customer: Some(customer_id.clone()),
status: None,
..Default::default()
},
)
.await?;
let existing_subscriptions = self
.client
.list_subscriptions_for_customer(&customer_id)
.await?;
let existing_active_subscription =
existing_subscriptions
.data
.into_iter()
.find(|subscription| {
subscription.status == SubscriptionStatus::Active
|| subscription.status == SubscriptionStatus::Trialing
});
existing_subscriptions.into_iter().find(|subscription| {
subscription.status == SubscriptionStatus::Active
|| subscription.status == SubscriptionStatus::Trialing
});
if let Some(subscription) = existing_active_subscription {
return Ok(subscription);
}
let mut params = stripe::CreateSubscription::new(customer_id);
params.items = Some(vec![stripe::CreateSubscriptionItems {
price: Some(zed_free_price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
let params = StripeCreateSubscriptionParams {
customer: customer_id,
items: vec![StripeCreateSubscriptionItems {
price: Some(zed_free_price_id),
quantity: Some(1),
}],
};
let subscription = stripe::Subscription::create(&self.client, params).await?;
let subscription = self.client.create_subscription(params).await?;
Ok(subscription)
}
pub async fn checkout_with_zed_free(
&self,
customer_id: stripe::CustomerId,
github_login: &str,
success_url: &str,
) -> Result<String> {
let zed_free_price_id = self.zed_free_price_id().await?;
let mut params = stripe::CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.payment_method_collection =
Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
price: Some(zed_free_price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
params.success_url = Some(success_url);
let session = stripe::CheckoutSession::create(&self.client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
}
#[derive(Clone, Deserialize)]
struct StripeMeter {
id: String,
event_name: String,
}
impl StripeMeter {
pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
#[derive(Serialize)]
struct Params {
#[serde(skip_serializing_if = "Option::is_none")]
limit: Option<u64>,
}
client.get_query("/billing/meters", Params { limit: Some(100) })
}
}
#[derive(Deserialize)]
struct StripeMeterEvent {
identifier: String,
}
impl StripeMeterEvent {
pub async fn create(
client: &stripe::Client,
params: StripeCreateMeterEventParams<'_>,
) -> Result<Self, stripe::StripeError> {
let identifier = params.identifier;
match client.post_form("/billing/meter_events", params).await {
Ok(event) => Ok(event),
Err(stripe::StripeError::Stripe(error)) => {
if error.http_status == 400
&& error
.message
.as_ref()
.map_or(false, |message| message.contains(identifier))
{
Ok(Self {
identifier: identifier.to_string(),
})
} else {
Err(stripe::StripeError::Stripe(error))
}
}
Err(error) => Err(error),
}
}
}
#[derive(Serialize)]
struct StripeCreateMeterEventParams<'a> {
identifier: &'a str,
event_name: &'a str,
payload: StripeCreateMeterEventPayload<'a>,
timestamp: Option<i64>,
}
#[derive(Serialize)]
struct StripeCreateMeterEventPayload<'a> {
value: u64,
stripe_customer_id: &'a stripe::CustomerId,
}
fn subscription_contains_price(
subscription: &stripe::Subscription,
price_id: &stripe::PriceId,
subscription: &StripeSubscription,
price_id: &StripePriceId,
) -> bool {
subscription.items.data.iter().any(|item| {
subscription.items.iter().any(|item| {
item.price
.as_ref()
.map_or(false, |price| price.id == *price_id)

View File

@@ -0,0 +1,211 @@
#[cfg(test)]
mod fake_stripe_client;
mod real_stripe_client;
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
#[cfg(test)]
pub use fake_stripe_client::*;
pub use real_stripe_client::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Serialize)]
pub struct StripeCustomerId(pub Arc<str>);
#[derive(Debug, Clone)]
pub struct StripeCustomer {
pub id: StripeCustomerId,
pub email: Option<String>,
}
#[derive(Debug)]
pub struct CreateCustomerParams<'a> {
pub email: Option<&'a str>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripeSubscriptionId(pub Arc<str>);
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscription {
pub id: StripeSubscriptionId,
pub customer: StripeCustomerId,
// TODO: Create our own version of this enum.
pub status: stripe::SubscriptionStatus,
pub current_period_end: i64,
pub current_period_start: i64,
pub items: Vec<StripeSubscriptionItem>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripeSubscriptionItemId(pub Arc<str>);
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscriptionItem {
pub id: StripeSubscriptionItemId,
pub price: Option<StripePrice>,
}
#[derive(Debug)]
pub struct StripeCreateSubscriptionParams {
pub customer: StripeCustomerId,
pub items: Vec<StripeCreateSubscriptionItems>,
}
#[derive(Debug)]
pub struct StripeCreateSubscriptionItems {
pub price: Option<StripePriceId>,
pub quantity: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct UpdateSubscriptionParams {
pub items: Option<Vec<UpdateSubscriptionItems>>,
pub trial_settings: Option<StripeSubscriptionTrialSettings>,
}
#[derive(Debug, PartialEq, Clone)]
pub struct UpdateSubscriptionItems {
pub price: Option<StripePriceId>,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscriptionTrialSettings {
pub end_behavior: StripeSubscriptionTrialSettingsEndBehavior,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscriptionTrialSettingsEndBehavior {
pub missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
Cancel,
CreateInvoice,
Pause,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripePriceId(pub Arc<str>);
#[derive(Debug, PartialEq, Clone)]
pub struct StripePrice {
pub id: StripePriceId,
pub unit_amount: Option<i64>,
pub lookup_key: Option<String>,
pub recurring: Option<StripePriceRecurring>,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripePriceRecurring {
pub meter: Option<String>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Deserialize)]
pub struct StripeMeterId(pub Arc<str>);
#[derive(Debug, Clone, Deserialize)]
pub struct StripeMeter {
pub id: StripeMeterId,
pub event_name: String,
}
#[derive(Debug, Serialize)]
pub struct StripeCreateMeterEventParams<'a> {
pub identifier: &'a str,
pub event_name: &'a str,
pub payload: StripeCreateMeterEventPayload<'a>,
pub timestamp: Option<i64>,
}
#[derive(Debug, Serialize)]
pub struct StripeCreateMeterEventPayload<'a> {
pub value: u64,
pub stripe_customer_id: &'a StripeCustomerId,
}
#[derive(Debug, Default)]
pub struct StripeCreateCheckoutSessionParams<'a> {
pub customer: Option<&'a StripeCustomerId>,
pub client_reference_id: Option<&'a str>,
pub mode: Option<StripeCheckoutSessionMode>,
pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
pub success_url: Option<&'a str>,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCheckoutSessionMode {
Payment,
Setup,
Subscription,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeCreateCheckoutSessionLineItems {
pub price: Option<String>,
pub quantity: Option<u64>,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCheckoutSessionPaymentMethodCollection {
Always,
IfRequired,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeCreateCheckoutSessionSubscriptionData {
pub metadata: Option<HashMap<String, String>>,
pub trial_period_days: Option<u32>,
pub trial_settings: Option<StripeSubscriptionTrialSettings>,
}
#[derive(Debug)]
pub struct StripeCheckoutSession {
pub url: Option<String>,
}
#[async_trait]
pub trait StripeClient: Send + Sync {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
async fn list_subscriptions_for_customer(
&self,
customer_id: &StripeCustomerId,
) -> Result<Vec<StripeSubscription>>;
async fn get_subscription(
&self,
subscription_id: &StripeSubscriptionId,
) -> Result<StripeSubscription>;
async fn create_subscription(
&self,
params: StripeCreateSubscriptionParams,
) -> Result<StripeSubscription>;
async fn update_subscription(
&self,
subscription_id: &StripeSubscriptionId,
params: UpdateSubscriptionParams,
) -> Result<()>;
async fn list_prices(&self) -> Result<Vec<StripePrice>>;
async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>;
async fn create_checkout_session(
&self,
params: StripeCreateCheckoutSessionParams<'_>,
) -> Result<StripeCheckoutSession>;
}

View File

@@ -0,0 +1,207 @@
use std::sync::Arc;
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use chrono::{Duration, Utc};
use collections::HashMap;
use parking_lot::Mutex;
use uuid::Uuid;
use crate::stripe_client::{
CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
StripeCheckoutSessionPaymentMethodCollection, StripeClient,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripeMeterId,
StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
StripeSubscriptionItemId, UpdateSubscriptionParams,
};
#[derive(Debug, Clone)]
pub struct StripeCreateMeterEventCall {
pub identifier: Arc<str>,
pub event_name: Arc<str>,
pub value: u64,
pub stripe_customer_id: StripeCustomerId,
pub timestamp: Option<i64>,
}
#[derive(Debug, Clone)]
pub struct StripeCreateCheckoutSessionCall {
pub customer: Option<StripeCustomerId>,
pub client_reference_id: Option<String>,
pub mode: Option<StripeCheckoutSessionMode>,
pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
pub success_url: Option<String>,
}
pub struct FakeStripeClient {
pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
pub update_subscription_calls:
Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
pub create_checkout_session_calls: Arc<Mutex<Vec<StripeCreateCheckoutSessionCall>>>,
}
impl FakeStripeClient {
pub fn new() -> Self {
Self {
customers: Arc::new(Mutex::new(HashMap::default())),
subscriptions: Arc::new(Mutex::new(HashMap::default())),
update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
prices: Arc::new(Mutex::new(HashMap::default())),
meters: Arc::new(Mutex::new(HashMap::default())),
create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())),
}
}
}
#[async_trait]
impl StripeClient for FakeStripeClient {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
Ok(self
.customers
.lock()
.values()
.filter(|customer| customer.email.as_deref() == Some(email))
.cloned()
.collect())
}
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
let customer = StripeCustomer {
id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
email: params.email.map(|email| email.to_string()),
};
self.customers
.lock()
.insert(customer.id.clone(), customer.clone());
Ok(customer)
}
async fn list_subscriptions_for_customer(
&self,
customer_id: &StripeCustomerId,
) -> Result<Vec<StripeSubscription>> {
let subscriptions = self
.subscriptions
.lock()
.values()
.filter(|subscription| subscription.customer == *customer_id)
.cloned()
.collect();
Ok(subscriptions)
}
async fn get_subscription(
&self,
subscription_id: &StripeSubscriptionId,
) -> Result<StripeSubscription> {
self.subscriptions
.lock()
.get(subscription_id)
.cloned()
.ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
}
async fn create_subscription(
&self,
params: StripeCreateSubscriptionParams,
) -> Result<StripeSubscription> {
let now = Utc::now();
let subscription = StripeSubscription {
id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()),
customer: params.customer,
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: params
.items
.into_iter()
.map(|item| StripeSubscriptionItem {
id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()),
price: item
.price
.and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
})
.collect(),
};
self.subscriptions
.lock()
.insert(subscription.id.clone(), subscription.clone());
Ok(subscription)
}
async fn update_subscription(
&self,
subscription_id: &StripeSubscriptionId,
params: UpdateSubscriptionParams,
) -> Result<()> {
let subscription = self.get_subscription(subscription_id).await?;
self.update_subscription_calls
.lock()
.push((subscription.id, params));
Ok(())
}
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
let prices = self.prices.lock().values().cloned().collect();
Ok(prices)
}
async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
let meters = self.meters.lock().values().cloned().collect();
Ok(meters)
}
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
self.create_meter_event_calls
.lock()
.push(StripeCreateMeterEventCall {
identifier: params.identifier.into(),
event_name: params.event_name.into(),
value: params.payload.value,
stripe_customer_id: params.payload.stripe_customer_id.clone(),
timestamp: params.timestamp,
});
Ok(())
}
async fn create_checkout_session(
&self,
params: StripeCreateCheckoutSessionParams<'_>,
) -> Result<StripeCheckoutSession> {
self.create_checkout_session_calls
.lock()
.push(StripeCreateCheckoutSessionCall {
customer: params.customer.cloned(),
client_reference_id: params.client_reference_id.map(|id| id.to_string()),
mode: params.mode,
line_items: params.line_items,
payment_method_collection: params.payment_method_collection,
subscription_data: params.subscription_data,
success_url: params.success_url.map(|url| url.to_string()),
});
Ok(StripeCheckoutSession {
url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()),
})
}
}

View File

@@ -0,0 +1,456 @@
use std::str::FromStr as _;
use std::sync::Arc;
use anyhow::{Context as _, Result, anyhow};
use async_trait::async_trait;
use serde::Serialize;
use stripe::{
CheckoutSession, CheckoutSessionMode, CheckoutSessionPaymentMethodCollection,
CreateCheckoutSession, CreateCheckoutSessionLineItems, CreateCheckoutSessionSubscriptionData,
CreateCheckoutSessionSubscriptionDataTrialSettings,
CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior,
CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod,
CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring, Subscription,
SubscriptionId, SubscriptionItem, SubscriptionItemId, UpdateSubscriptionItems,
UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
};
use crate::stripe_client::{
CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
StripeCheckoutSessionPaymentMethodCollection, StripeClient,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings,
StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionParams,
};
pub struct RealStripeClient {
client: Arc<stripe::Client>,
}
impl RealStripeClient {
pub fn new(client: Arc<stripe::Client>) -> Self {
Self { client }
}
}
#[async_trait]
impl StripeClient for RealStripeClient {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
let response = Customer::list(
&self.client,
&ListCustomers {
email: Some(email),
..Default::default()
},
)
.await?;
Ok(response
.data
.into_iter()
.map(StripeCustomer::from)
.collect())
}
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
let customer = Customer::create(
&self.client,
CreateCustomer {
email: params.email,
..Default::default()
},
)
.await?;
Ok(StripeCustomer::from(customer))
}
async fn list_subscriptions_for_customer(
&self,
customer_id: &StripeCustomerId,
) -> Result<Vec<StripeSubscription>> {
let customer_id = customer_id.try_into()?;
let subscriptions = stripe::Subscription::list(
&self.client,
&stripe::ListSubscriptions {
customer: Some(customer_id),
status: None,
..Default::default()
},
)
.await?;
Ok(subscriptions
.data
.into_iter()
.map(StripeSubscription::from)
.collect())
}
async fn get_subscription(
&self,
subscription_id: &StripeSubscriptionId,
) -> Result<StripeSubscription> {
let subscription_id = subscription_id.try_into()?;
let subscription = Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
Ok(StripeSubscription::from(subscription))
}
async fn create_subscription(
&self,
params: StripeCreateSubscriptionParams,
) -> Result<StripeSubscription> {
let customer_id = params.customer.try_into()?;
let mut create_subscription = stripe::CreateSubscription::new(customer_id);
create_subscription.items = Some(
params
.items
.into_iter()
.map(|item| stripe::CreateSubscriptionItems {
price: item.price.map(|price| price.to_string()),
quantity: item.quantity,
..Default::default()
})
.collect(),
);
let subscription = Subscription::create(&self.client, create_subscription).await?;
Ok(StripeSubscription::from(subscription))
}
async fn update_subscription(
&self,
subscription_id: &StripeSubscriptionId,
params: UpdateSubscriptionParams,
) -> Result<()> {
let subscription_id = subscription_id.try_into()?;
stripe::Subscription::update(
&self.client,
&subscription_id,
stripe::UpdateSubscription {
items: params.items.map(|items| {
items
.into_iter()
.map(|item| UpdateSubscriptionItems {
price: item.price.map(|price| price.to_string()),
..Default::default()
})
.collect()
}),
trial_settings: params.trial_settings.map(Into::into),
..Default::default()
},
)
.await?;
Ok(())
}
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
let response = stripe::Price::list(
&self.client,
&stripe::ListPrices {
limit: Some(100),
..Default::default()
},
)
.await?;
Ok(response.data.into_iter().map(StripePrice::from).collect())
}
async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
#[derive(Serialize)]
struct Params {
#[serde(skip_serializing_if = "Option::is_none")]
limit: Option<u64>,
}
let response = self
.client
.get_query::<stripe::List<StripeMeter>, _>(
"/billing/meters",
Params { limit: Some(100) },
)
.await?;
Ok(response.data)
}
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
let identifier = params.identifier;
match self.client.post_form("/billing/meter_events", params).await {
Ok(event) => Ok(event),
Err(stripe::StripeError::Stripe(error)) => {
if error.http_status == 400
&& error
.message
.as_ref()
.map_or(false, |message| message.contains(identifier))
{
Ok(())
} else {
Err(anyhow!(stripe::StripeError::Stripe(error)))
}
}
Err(error) => Err(anyhow!(error)),
}
}
async fn create_checkout_session(
&self,
params: StripeCreateCheckoutSessionParams<'_>,
) -> Result<StripeCheckoutSession> {
let params = params.try_into()?;
let session = CheckoutSession::create(&self.client, params).await?;
Ok(session.into())
}
}
impl From<CustomerId> for StripeCustomerId {
fn from(value: CustomerId) -> Self {
Self(value.as_str().into())
}
}
impl TryFrom<StripeCustomerId> for CustomerId {
type Error = anyhow::Error;
fn try_from(value: StripeCustomerId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
}
}
impl TryFrom<&StripeCustomerId> for CustomerId {
type Error = anyhow::Error;
fn try_from(value: &StripeCustomerId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
}
}
impl From<Customer> for StripeCustomer {
fn from(value: Customer) -> Self {
StripeCustomer {
id: value.id.into(),
email: value.email,
}
}
}
impl From<SubscriptionId> for StripeSubscriptionId {
fn from(value: SubscriptionId) -> Self {
Self(value.as_str().into())
}
}
impl TryFrom<&StripeSubscriptionId> for SubscriptionId {
type Error = anyhow::Error;
fn try_from(value: &StripeSubscriptionId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe subscription ID")
}
}
impl From<Subscription> for StripeSubscription {
fn from(value: Subscription) -> Self {
Self {
id: value.id.into(),
customer: value.customer.id().into(),
status: value.status,
current_period_start: value.current_period_start,
current_period_end: value.current_period_end,
items: value.items.data.into_iter().map(Into::into).collect(),
}
}
}
impl From<SubscriptionItemId> for StripeSubscriptionItemId {
fn from(value: SubscriptionItemId) -> Self {
Self(value.as_str().into())
}
}
impl From<SubscriptionItem> for StripeSubscriptionItem {
fn from(value: SubscriptionItem) -> Self {
Self {
id: value.id.into(),
price: value.price.map(Into::into),
}
}
}
impl From<StripeSubscriptionTrialSettings> for UpdateSubscriptionTrialSettings {
fn from(value: StripeSubscriptionTrialSettings) -> Self {
Self {
end_behavior: value.end_behavior.into(),
}
}
}
impl From<StripeSubscriptionTrialSettingsEndBehavior>
for UpdateSubscriptionTrialSettingsEndBehavior
{
fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
Self {
missing_payment_method: value.missing_payment_method.into(),
}
}
}
impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod
{
fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
match value {
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
Self::CreateInvoice
}
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
}
}
}
impl From<PriceId> for StripePriceId {
fn from(value: PriceId) -> Self {
Self(value.as_str().into())
}
}
impl TryFrom<StripePriceId> for PriceId {
type Error = anyhow::Error;
fn try_from(value: StripePriceId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe price ID")
}
}
impl From<Price> for StripePrice {
fn from(value: Price) -> Self {
Self {
id: value.id.into(),
unit_amount: value.unit_amount,
lookup_key: value.lookup_key,
recurring: value.recurring.map(StripePriceRecurring::from),
}
}
}
impl From<Recurring> for StripePriceRecurring {
fn from(value: Recurring) -> Self {
Self { meter: value.meter }
}
}
impl<'a> TryFrom<StripeCreateCheckoutSessionParams<'a>> for CreateCheckoutSession<'a> {
type Error = anyhow::Error;
fn try_from(value: StripeCreateCheckoutSessionParams<'a>) -> Result<Self, Self::Error> {
Ok(Self {
customer: value
.customer
.map(|customer_id| customer_id.try_into())
.transpose()?,
client_reference_id: value.client_reference_id,
mode: value.mode.map(Into::into),
line_items: value
.line_items
.map(|line_items| line_items.into_iter().map(Into::into).collect()),
payment_method_collection: value.payment_method_collection.map(Into::into),
subscription_data: value.subscription_data.map(Into::into),
success_url: value.success_url,
..Default::default()
})
}
}
impl From<StripeCheckoutSessionMode> for CheckoutSessionMode {
fn from(value: StripeCheckoutSessionMode) -> Self {
match value {
StripeCheckoutSessionMode::Payment => Self::Payment,
StripeCheckoutSessionMode::Setup => Self::Setup,
StripeCheckoutSessionMode::Subscription => Self::Subscription,
}
}
}
impl From<StripeCreateCheckoutSessionLineItems> for CreateCheckoutSessionLineItems {
fn from(value: StripeCreateCheckoutSessionLineItems) -> Self {
Self {
price: value.price,
quantity: value.quantity,
..Default::default()
}
}
}
impl From<StripeCheckoutSessionPaymentMethodCollection> for CheckoutSessionPaymentMethodCollection {
fn from(value: StripeCheckoutSessionPaymentMethodCollection) -> Self {
match value {
StripeCheckoutSessionPaymentMethodCollection::Always => Self::Always,
StripeCheckoutSessionPaymentMethodCollection::IfRequired => Self::IfRequired,
}
}
}
impl From<StripeCreateCheckoutSessionSubscriptionData> for CreateCheckoutSessionSubscriptionData {
fn from(value: StripeCreateCheckoutSessionSubscriptionData) -> Self {
Self {
trial_period_days: value.trial_period_days,
trial_settings: value.trial_settings.map(Into::into),
metadata: value.metadata,
..Default::default()
}
}
}
impl From<StripeSubscriptionTrialSettings> for CreateCheckoutSessionSubscriptionDataTrialSettings {
fn from(value: StripeSubscriptionTrialSettings) -> Self {
Self {
end_behavior: value.end_behavior.into(),
}
}
}
impl From<StripeSubscriptionTrialSettingsEndBehavior>
for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior
{
fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
Self {
missing_payment_method: value.missing_payment_method.into(),
}
}
}
impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod
{
fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
match value {
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
Self::CreateInvoice
}
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
}
}
}
impl From<CheckoutSession> for StripeCheckoutSession {
fn from(value: CheckoutSession) -> Self {
Self { url: value.url }
}
}

View File

@@ -18,6 +18,7 @@ mod random_channel_buffer_tests;
mod random_project_collaboration_tests;
mod randomized_test_helpers;
mod remote_editing_collaboration_tests;
mod stripe_billing_tests;
mod test_server;
use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};

View File

@@ -0,0 +1,557 @@
use std::sync::Arc;
use chrono::{Duration, Utc};
use pretty_assertions::assert_eq;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use crate::stripe_billing::StripeBilling;
use crate::stripe_client::{
FakeStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionSubscriptionData,
StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
};
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
let stripe_client = Arc::new(FakeStripeClient::new());
let stripe_billing = StripeBilling::test(stripe_client.clone());
(stripe_billing, stripe_client)
}
#[gpui::test]
async fn test_initialize() {
let (stripe_billing, stripe_client) = make_stripe_billing();
// Add test meters
let meter1 = StripeMeter {
id: StripeMeterId("meter_1".into()),
event_name: "event_1".to_string(),
};
let meter2 = StripeMeter {
id: StripeMeterId("meter_2".into()),
event_name: "event_2".to_string(),
};
stripe_client
.meters
.lock()
.insert(meter1.id.clone(), meter1);
stripe_client
.meters
.lock()
.insert(meter2.id.clone(), meter2);
// Add test prices
let price1 = StripePrice {
id: StripePriceId("price_1".into()),
unit_amount: Some(1_000),
lookup_key: Some("zed-pro".to_string()),
recurring: None,
};
let price2 = StripePrice {
id: StripePriceId("price_2".into()),
unit_amount: Some(0),
lookup_key: Some("zed-free".to_string()),
recurring: None,
};
let price3 = StripePrice {
id: StripePriceId("price_3".into()),
unit_amount: Some(500),
lookup_key: None,
recurring: Some(StripePriceRecurring {
meter: Some("meter_1".to_string()),
}),
};
stripe_client
.prices
.lock()
.insert(price1.id.clone(), price1);
stripe_client
.prices
.lock()
.insert(price2.id.clone(), price2);
stripe_client
.prices
.lock()
.insert(price3.id.clone(), price3);
// Initialize the billing system
stripe_billing.initialize().await.unwrap();
// Verify that prices can be found by lookup key
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
assert_eq!(zed_pro_price_id.to_string(), "price_1");
let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
assert_eq!(zed_free_price_id.to_string(), "price_2");
// Verify that a price can be found by lookup key
let zed_pro_price = stripe_billing
.find_price_by_lookup_key("zed-pro")
.await
.unwrap();
assert_eq!(zed_pro_price.id.to_string(), "price_1");
assert_eq!(zed_pro_price.unit_amount, Some(1_000));
// Verify that finding a non-existent lookup key returns an error
let result = stripe_billing
.find_price_by_lookup_key("non-existent")
.await;
assert!(result.is_err());
}
#[gpui::test]
async fn test_find_or_create_customer_by_email() {
let (stripe_billing, stripe_client) = make_stripe_billing();
// Create a customer with an email that doesn't yet correspond to a customer.
{
let email = "user@example.com";
let customer_id = stripe_billing
.find_or_create_customer_by_email(Some(email))
.await
.unwrap();
let customer = stripe_client
.customers
.lock()
.get(&customer_id)
.unwrap()
.clone();
assert_eq!(customer.email.as_deref(), Some(email));
}
// Create a customer with an email that corresponds to an existing customer.
{
let email = "user2@example.com";
let existing_customer_id = stripe_billing
.find_or_create_customer_by_email(Some(email))
.await
.unwrap();
let customer_id = stripe_billing
.find_or_create_customer_by_email(Some(email))
.await
.unwrap();
assert_eq!(customer_id, existing_customer_id);
let customer = stripe_client
.customers
.lock()
.get(&customer_id)
.unwrap()
.clone();
assert_eq!(customer.email.as_deref(), Some(email));
}
}
#[gpui::test]
async fn test_subscribe_to_price() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let price = StripePrice {
id: StripePriceId("price_test".into()),
unit_amount: Some(2000),
lookup_key: Some("test-price".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(price.id.clone(), price.clone());
let now = Utc::now();
let subscription = StripeSubscription {
id: StripeSubscriptionId("sub_test".into()),
customer: StripeCustomerId("cus_test".into()),
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![],
};
stripe_client
.subscriptions
.lock()
.insert(subscription.id.clone(), subscription.clone());
stripe_billing
.subscribe_to_price(&subscription.id, &price)
.await
.unwrap();
let update_subscription_calls = stripe_client
.update_subscription_calls
.lock()
.iter()
.map(|(id, params)| (id.clone(), params.clone()))
.collect::<Vec<_>>();
assert_eq!(update_subscription_calls.len(), 1);
assert_eq!(update_subscription_calls[0].0, subscription.id);
assert_eq!(
update_subscription_calls[0].1.items,
Some(vec![UpdateSubscriptionItems {
price: Some(price.id.clone())
}])
);
// Subscribing to a price that is already on the subscription is a no-op.
{
let now = Utc::now();
let subscription = StripeSubscription {
id: StripeSubscriptionId("sub_test".into()),
customer: StripeCustomerId("cus_test".into()),
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()),
price: Some(price.clone()),
}],
};
stripe_client
.subscriptions
.lock()
.insert(subscription.id.clone(), subscription.clone());
stripe_billing
.subscribe_to_price(&subscription.id, &price)
.await
.unwrap();
assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
}
}
#[gpui::test]
async fn test_subscribe_to_zed_free() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let zed_pro_price = StripePrice {
id: StripePriceId("price_1".into()),
unit_amount: Some(0),
lookup_key: Some("zed-pro".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(zed_pro_price.id.clone(), zed_pro_price.clone());
let zed_free_price = StripePrice {
id: StripePriceId("price_2".into()),
unit_amount: Some(0),
lookup_key: Some("zed-free".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(zed_free_price.id.clone(), zed_free_price.clone());
stripe_billing.initialize().await.unwrap();
// Customer is subscribed to Zed Free when not already subscribed to a plan.
{
let customer_id = StripeCustomerId("cus_no_plan".into());
let subscription = stripe_billing
.subscribe_to_zed_free(customer_id)
.await
.unwrap();
assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
}
// Customer is not subscribed to Zed Free when they already have an active subscription.
{
let customer_id = StripeCustomerId("cus_active_subscription".into());
let now = Utc::now();
let existing_subscription = StripeSubscription {
id: StripeSubscriptionId("sub_existing_active".into()),
customer: customer_id.clone(),
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()),
price: Some(zed_pro_price.clone()),
}],
};
stripe_client.subscriptions.lock().insert(
existing_subscription.id.clone(),
existing_subscription.clone(),
);
let subscription = stripe_billing
.subscribe_to_zed_free(customer_id)
.await
.unwrap();
assert_eq!(subscription, existing_subscription);
}
// Customer is not subscribed to Zed Free when they already have a trial subscription.
{
let customer_id = StripeCustomerId("cus_trial_subscription".into());
let now = Utc::now();
let existing_subscription = StripeSubscription {
id: StripeSubscriptionId("sub_existing_trial".into()),
customer: customer_id.clone(),
status: stripe::SubscriptionStatus::Trialing,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(14)).timestamp(),
items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()),
price: Some(zed_pro_price.clone()),
}],
};
stripe_client.subscriptions.lock().insert(
existing_subscription.id.clone(),
existing_subscription.clone(),
);
let subscription = stripe_billing
.subscribe_to_zed_free(customer_id)
.await
.unwrap();
assert_eq!(subscription, existing_subscription);
}
}
#[gpui::test]
async fn test_bill_model_request_usage() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let customer_id = StripeCustomerId("cus_test".into());
stripe_billing
.bill_model_request_usage(&customer_id, "some_model/requests", 73)
.await
.unwrap();
let create_meter_event_calls = stripe_client
.create_meter_event_calls
.lock()
.iter()
.cloned()
.collect::<Vec<_>>();
assert_eq!(create_meter_event_calls.len(), 1);
assert!(
create_meter_event_calls[0]
.identifier
.starts_with("model_requests/")
);
assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
assert_eq!(
create_meter_event_calls[0].event_name.as_ref(),
"some_model/requests"
);
assert_eq!(create_meter_event_calls[0].value, 73);
}
#[gpui::test]
async fn test_checkout_with_zed_pro() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let customer_id = StripeCustomerId("cus_test".into());
let github_login = "zeduser1";
let success_url = "https://example.com/success";
// It returns an error when the Zed Pro price doesn't exist.
{
let result = stripe_billing
.checkout_with_zed_pro(&customer_id, github_login, success_url)
.await;
assert!(result.is_err());
assert_eq!(
result.err().unwrap().to_string(),
r#"no price ID found for "zed-pro""#
);
}
// Successful checkout.
{
let price = StripePrice {
id: StripePriceId("price_1".into()),
unit_amount: Some(2000),
lookup_key: Some("zed-pro".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(price.id.clone(), price.clone());
stripe_billing.initialize().await.unwrap();
let checkout_url = stripe_billing
.checkout_with_zed_pro(&customer_id, github_login, success_url)
.await
.unwrap();
assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
let create_checkout_session_calls = stripe_client
.create_checkout_session_calls
.lock()
.drain(..)
.collect::<Vec<_>>();
assert_eq!(create_checkout_session_calls.len(), 1);
let call = create_checkout_session_calls.into_iter().next().unwrap();
assert_eq!(call.customer, Some(customer_id));
assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
assert_eq!(
call.line_items,
Some(vec![StripeCreateCheckoutSessionLineItems {
price: Some(price.id.to_string()),
quantity: Some(1)
}])
);
assert_eq!(call.payment_method_collection, None);
assert_eq!(call.subscription_data, None);
assert_eq!(call.success_url.as_deref(), Some(success_url));
}
}
#[gpui::test]
async fn test_checkout_with_zed_pro_trial() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let customer_id = StripeCustomerId("cus_test".into());
let github_login = "zeduser1";
let success_url = "https://example.com/success";
// It returns an error when the Zed Pro price doesn't exist.
{
let result = stripe_billing
.checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
.await;
assert!(result.is_err());
assert_eq!(
result.err().unwrap().to_string(),
r#"no price ID found for "zed-pro""#
);
}
let price = StripePrice {
id: StripePriceId("price_1".into()),
unit_amount: Some(2000),
lookup_key: Some("zed-pro".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(price.id.clone(), price.clone());
stripe_billing.initialize().await.unwrap();
// Successful checkout.
{
let checkout_url = stripe_billing
.checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
.await
.unwrap();
assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
let create_checkout_session_calls = stripe_client
.create_checkout_session_calls
.lock()
.drain(..)
.collect::<Vec<_>>();
assert_eq!(create_checkout_session_calls.len(), 1);
let call = create_checkout_session_calls.into_iter().next().unwrap();
assert_eq!(call.customer.as_ref(), Some(&customer_id));
assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
assert_eq!(
call.line_items,
Some(vec![StripeCreateCheckoutSessionLineItems {
price: Some(price.id.to_string()),
quantity: Some(1)
}])
);
assert_eq!(
call.payment_method_collection,
Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
);
assert_eq!(
call.subscription_data,
Some(StripeCreateCheckoutSessionSubscriptionData {
trial_period_days: Some(14),
trial_settings: Some(StripeSubscriptionTrialSettings {
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
missing_payment_method:
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
},
}),
metadata: None,
})
);
assert_eq!(call.success_url.as_deref(), Some(success_url));
}
// Successful checkout with extended trial.
{
let checkout_url = stripe_billing
.checkout_with_zed_pro_trial(
&customer_id,
github_login,
vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
success_url,
)
.await
.unwrap();
assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
let create_checkout_session_calls = stripe_client
.create_checkout_session_calls
.lock()
.drain(..)
.collect::<Vec<_>>();
assert_eq!(create_checkout_session_calls.len(), 1);
let call = create_checkout_session_calls.into_iter().next().unwrap();
assert_eq!(call.customer, Some(customer_id));
assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
assert_eq!(
call.line_items,
Some(vec![StripeCreateCheckoutSessionLineItems {
price: Some(price.id.to_string()),
quantity: Some(1)
}])
);
assert_eq!(
call.payment_method_collection,
Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
);
assert_eq!(
call.subscription_data,
Some(StripeCreateCheckoutSessionSubscriptionData {
trial_period_days: Some(60),
trial_settings: Some(StripeSubscriptionTrialSettings {
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
missing_payment_method:
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
},
}),
metadata: Some(std::collections::HashMap::from_iter([(
"promo_feature_flag".into(),
AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
)])),
})
);
assert_eq!(call.success_url.as_deref(), Some(success_url));
}
}

View File

@@ -56,5 +56,7 @@ async-pipe.workspace = true
gpui = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
task = { workspace = true, features = ["test-support"] }
tree-sitter.workspace = true
tree-sitter-go.workspace = true
util = { workspace = true, features = ["test-support"] }
zlog.workspace = true

View File

@@ -275,3 +275,386 @@ impl InlineValueProvider for PythonInlineValueProvider {
variables
}
}
pub struct GoInlineValueProvider;
impl InlineValueProvider for GoInlineValueProvider {
fn provide(
&self,
mut node: language::Node,
source: &str,
max_row: usize,
) -> Vec<InlineValueLocation> {
let mut variables = Vec::new();
let mut variable_names = HashSet::new();
let mut scope = VariableScope::Local;
loop {
let mut variable_names_in_scope = HashMap::new();
for child in node.named_children(&mut node.walk()) {
if child.start_position().row >= max_row {
break;
}
if scope == VariableScope::Local {
match child.kind() {
"var_declaration" => {
for var_spec in child.named_children(&mut child.walk()) {
if var_spec.kind() == "var_spec" {
if let Some(name_node) = var_spec.child_by_field_name("name") {
let variable_name =
source[name_node.byte_range()].to_string();
if variable_names.contains(&variable_name) {
continue;
}
if let Some(index) =
variable_names_in_scope.get(&variable_name)
{
variables.remove(*index);
}
variable_names_in_scope
.insert(variable_name.clone(), variables.len());
variables.push(InlineValueLocation {
variable_name,
scope: VariableScope::Local,
lookup: VariableLookupKind::Variable,
row: name_node.end_position().row,
column: name_node.end_position().column,
});
}
}
}
}
"short_var_declaration" => {
if let Some(left_side) = child.child_by_field_name("left") {
for identifier in left_side.named_children(&mut left_side.walk()) {
if identifier.kind() == "identifier" {
let variable_name =
source[identifier.byte_range()].to_string();
if variable_names.contains(&variable_name) {
continue;
}
if let Some(index) =
variable_names_in_scope.get(&variable_name)
{
variables.remove(*index);
}
variable_names_in_scope
.insert(variable_name.clone(), variables.len());
variables.push(InlineValueLocation {
variable_name,
scope: VariableScope::Local,
lookup: VariableLookupKind::Variable,
row: identifier.end_position().row,
column: identifier.end_position().column,
});
}
}
}
}
"assignment_statement" => {
if let Some(left_side) = child.child_by_field_name("left") {
for identifier in left_side.named_children(&mut left_side.walk()) {
if identifier.kind() == "identifier" {
let variable_name =
source[identifier.byte_range()].to_string();
if variable_names.contains(&variable_name) {
continue;
}
if let Some(index) =
variable_names_in_scope.get(&variable_name)
{
variables.remove(*index);
}
variable_names_in_scope
.insert(variable_name.clone(), variables.len());
variables.push(InlineValueLocation {
variable_name,
scope: VariableScope::Local,
lookup: VariableLookupKind::Variable,
row: identifier.end_position().row,
column: identifier.end_position().column,
});
}
}
}
}
"function_declaration" | "method_declaration" => {
if let Some(params) = child.child_by_field_name("parameters") {
for param in params.named_children(&mut params.walk()) {
if param.kind() == "parameter_declaration" {
if let Some(name_node) = param.child_by_field_name("name") {
let variable_name =
source[name_node.byte_range()].to_string();
if variable_names.contains(&variable_name) {
continue;
}
if let Some(index) =
variable_names_in_scope.get(&variable_name)
{
variables.remove(*index);
}
variable_names_in_scope
.insert(variable_name.clone(), variables.len());
variables.push(InlineValueLocation {
variable_name,
scope: VariableScope::Local,
lookup: VariableLookupKind::Variable,
row: name_node.end_position().row,
column: name_node.end_position().column,
});
}
}
}
}
}
"for_statement" => {
if let Some(clause) = child.named_child(0) {
if clause.kind() == "for_clause" {
if let Some(init) = clause.named_child(0) {
if init.kind() == "short_var_declaration" {
if let Some(left_side) =
init.child_by_field_name("left")
{
if left_side.kind() == "expression_list" {
for identifier in left_side
.named_children(&mut left_side.walk())
{
if identifier.kind() == "identifier" {
let variable_name = source
[identifier.byte_range()]
.to_string();
if variable_names
.contains(&variable_name)
{
continue;
}
if let Some(index) =
variable_names_in_scope
.get(&variable_name)
{
variables.remove(*index);
}
variable_names_in_scope.insert(
variable_name.clone(),
variables.len(),
);
variables.push(InlineValueLocation {
variable_name,
scope: VariableScope::Local,
lookup:
VariableLookupKind::Variable,
row: identifier.end_position().row,
column: identifier
.end_position()
.column,
});
}
}
}
}
}
}
} else if clause.kind() == "range_clause" {
if let Some(left) = clause.child_by_field_name("left") {
if left.kind() == "expression_list" {
for identifier in left.named_children(&mut left.walk())
{
if identifier.kind() == "identifier" {
let variable_name =
source[identifier.byte_range()].to_string();
if variable_name == "_" {
continue;
}
if variable_names.contains(&variable_name) {
continue;
}
if let Some(index) =
variable_names_in_scope.get(&variable_name)
{
variables.remove(*index);
}
variable_names_in_scope.insert(
variable_name.clone(),
variables.len(),
);
variables.push(InlineValueLocation {
variable_name,
scope: VariableScope::Local,
lookup: VariableLookupKind::Variable,
row: identifier.end_position().row,
column: identifier.end_position().column,
});
}
}
}
}
}
}
}
_ => {}
}
} else if child.kind() == "var_declaration" {
for var_spec in child.named_children(&mut child.walk()) {
if var_spec.kind() == "var_spec" {
if let Some(name_node) = var_spec.child_by_field_name("name") {
let variable_name = source[name_node.byte_range()].to_string();
variables.push(InlineValueLocation {
variable_name,
scope: VariableScope::Global,
lookup: VariableLookupKind::Expression,
row: name_node.end_position().row,
column: name_node.end_position().column,
});
}
}
}
}
}
variable_names.extend(variable_names_in_scope.keys().cloned());
if matches!(node.kind(), "function_declaration" | "method_declaration") {
scope = VariableScope::Global;
}
if let Some(parent) = node.parent() {
node = parent;
} else {
break;
}
}
variables
}
}
#[cfg(test)]
mod tests {
use super::*;
use tree_sitter::Parser;
#[test]
fn test_go_inline_value_provider() {
let provider = GoInlineValueProvider;
let source = r#"
package main
func main() {
items := []int{1, 2, 3, 4, 5}
for i, v := range items {
println(i, v)
}
for j := 0; j < 10; j++ {
println(j)
}
}
"#;
let mut parser = Parser::new();
if parser
.set_language(&tree_sitter_go::LANGUAGE.into())
.is_err()
{
return;
}
let Some(tree) = parser.parse(source, None) else {
return;
};
let root_node = tree.root_node();
let mut main_body = None;
for child in root_node.named_children(&mut root_node.walk()) {
if child.kind() == "function_declaration" {
if let Some(name) = child.child_by_field_name("name") {
if &source[name.byte_range()] == "main" {
if let Some(body) = child.child_by_field_name("body") {
main_body = Some(body);
break;
}
}
}
}
}
let Some(main_body) = main_body else {
return;
};
let variables = provider.provide(main_body, source, 100);
assert!(variables.len() >= 2);
let variable_names: Vec<&str> =
variables.iter().map(|v| v.variable_name.as_str()).collect();
assert!(variable_names.contains(&"items"));
assert!(variable_names.contains(&"j"));
}
#[test]
fn test_go_inline_value_provider_counter_pattern() {
let provider = GoInlineValueProvider;
let source = r#"
package main
func main() {
N := 10
for i := range N {
println(i)
}
}
"#;
let mut parser = Parser::new();
if parser
.set_language(&tree_sitter_go::LANGUAGE.into())
.is_err()
{
return;
}
let Some(tree) = parser.parse(source, None) else {
return;
};
let root_node = tree.root_node();
let mut main_body = None;
for child in root_node.named_children(&mut root_node.walk()) {
if child.kind() == "function_declaration" {
if let Some(name) = child.child_by_field_name("name") {
if &source[name.byte_range()] == "main" {
if let Some(body) = child.child_by_field_name("body") {
main_body = Some(body);
break;
}
}
}
}
}
let Some(main_body) = main_body else {
return;
};
let variables = provider.provide(main_body, source, 100);
let variable_names: Vec<&str> =
variables.iter().map(|v| v.variable_name.as_str()).collect();
assert!(variable_names.contains(&"N"));
assert!(variable_names.contains(&"i"));
}
}

View File

@@ -658,9 +658,13 @@ impl StdioTransport {
.stderr(Stdio::piped())
.kill_on_drop(true);
let mut process = command
.spawn()
.with_context(|| "failed to spawn command.")?;
let mut process = command.spawn().with_context(|| {
format!(
"failed to spawn command `{} {}`.",
binary.command,
binary.arguments.join(" ")
)
})?;
let stdin = process.stdin.take().context("Failed to open stdin")?;
let stdout = process.stdout.take().context("Failed to open stdout")?;

View File

@@ -18,7 +18,7 @@ use dap::{
GithubRepo,
},
configure_tcp_connection,
inline_value::{PythonInlineValueProvider, RustInlineValueProvider},
inline_value::{GoInlineValueProvider, PythonInlineValueProvider, RustInlineValueProvider},
};
use gdb::GdbDebugAdapter;
use go::GoDebugAdapter;
@@ -37,7 +37,7 @@ pub fn init(cx: &mut App) {
registry.add_adapter(Arc::from(PhpDebugAdapter::default()));
registry.add_adapter(Arc::from(JsDebugAdapter::default()));
registry.add_adapter(Arc::from(RubyDebugAdapter));
registry.add_adapter(Arc::from(GoDebugAdapter));
registry.add_adapter(Arc::from(GoDebugAdapter::default()));
registry.add_adapter(Arc::from(GdbDebugAdapter));
#[cfg(any(test, feature = "test-support"))]
@@ -48,5 +48,6 @@ pub fn init(cx: &mut App) {
registry.add_inline_value_provider("Rust".to_string(), Arc::from(RustInlineValueProvider));
registry
.add_inline_value_provider("Python".to_string(), Arc::from(PythonInlineValueProvider));
registry.add_inline_value_provider("Go".to_string(), Arc::from(GoInlineValueProvider));
})
}

View File

@@ -26,10 +26,12 @@ impl DebugAdapter for GdbDebugAdapter {
match &zed_scenario.request {
dap::DebugRequest::Attach(attach) => {
obj.insert("request".into(), "attach".into());
obj.insert("pid".into(), attach.process_id.into());
}
dap::DebugRequest::Launch(launch) => {
obj.insert("request".into(), "launch".into());
obj.insert("program".into(), launch.program.clone().into());
if !launch.args.is_empty() {

View File

@@ -1,22 +1,87 @@
use anyhow::{Context as _, anyhow, bail};
use dap::{
StartDebuggingRequestArguments, StartDebuggingRequestArgumentsRequest,
adapters::DebugTaskDefinition,
adapters::{
DebugTaskDefinition, DownloadedFileType, download_adapter_from_github,
latest_github_release,
},
};
use gpui::{AsyncApp, SharedString};
use language::LanguageName;
use std::{collections::HashMap, ffi::OsStr, path::PathBuf};
use std::{collections::HashMap, env::consts, ffi::OsStr, path::PathBuf, sync::OnceLock};
use util;
use crate::*;
#[derive(Default, Debug)]
pub(crate) struct GoDebugAdapter;
pub(crate) struct GoDebugAdapter {
shim_path: OnceLock<PathBuf>,
}
impl GoDebugAdapter {
const ADAPTER_NAME: &'static str = "Delve";
const DEFAULT_TIMEOUT_MS: u64 = 60000;
async fn fetch_latest_adapter_version(
delegate: &Arc<dyn DapDelegate>,
) -> Result<AdapterVersion> {
let release = latest_github_release(
&"zed-industries/delve-shim-dap",
true,
false,
delegate.http_client(),
)
.await?;
let os = match consts::OS {
"macos" => "apple-darwin",
"linux" => "unknown-linux-gnu",
"windows" => "pc-windows-msvc",
other => bail!("Running on unsupported os: {other}"),
};
let suffix = if consts::OS == "windows" {
".zip"
} else {
".tar.gz"
};
let asset_name = format!("delve-shim-dap-{}-{os}{suffix}", consts::ARCH);
let asset = release
.assets
.iter()
.find(|asset| asset.name == asset_name)
.with_context(|| format!("no asset found matching `{asset_name:?}`"))?;
Ok(AdapterVersion {
tag_name: release.tag_name,
url: asset.browser_download_url.clone(),
})
}
async fn install_shim(&self, delegate: &Arc<dyn DapDelegate>) -> anyhow::Result<PathBuf> {
if let Some(path) = self.shim_path.get().cloned() {
return Ok(path);
}
let asset = Self::fetch_latest_adapter_version(delegate).await?;
let ty = if consts::OS == "windows" {
DownloadedFileType::Zip
} else {
DownloadedFileType::GzipTar
};
download_adapter_from_github(
"delve-shim-dap".into(),
asset.clone(),
ty,
delegate.as_ref(),
)
.await?;
let path = paths::debug_adapters_dir()
.join("delve-shim-dap")
.join(format!("delve-shim-dap{}", asset.tag_name))
.join("delve-shim-dap");
self.shim_path.set(path.clone()).ok();
Ok(path)
}
}
#[async_trait(?Send)]
@@ -307,15 +372,27 @@ impl DebugAdapter for GoDebugAdapter {
let mut args = match &zed_scenario.request {
dap::DebugRequest::Attach(attach_config) => {
json!({
"request": "attach",
"mode": "debug",
"processId": attach_config.process_id,
})
}
dap::DebugRequest::Launch(launch_config) => json!({
"program": launch_config.program,
"cwd": launch_config.cwd,
"args": launch_config.args,
"env": launch_config.env_json()
}),
dap::DebugRequest::Launch(launch_config) => {
let mode = if launch_config.program != "." {
"exec"
} else {
"debug"
};
json!({
"request": "launch",
"mode": mode,
"program": launch_config.program,
"cwd": launch_config.cwd,
"args": launch_config.args,
"env": launch_config.env_json()
})
}
};
let map = args.as_object_mut().unwrap();
@@ -372,16 +449,10 @@ impl DebugAdapter for GoDebugAdapter {
adapter_path.join("dlv").to_string_lossy().to_string()
};
let minidelve_path = self.install_shim(delegate).await?;
let tcp_connection = task_definition.tcp_connection.clone().unwrap_or_default();
let mut tcp_connection = task_definition.tcp_connection.clone().unwrap_or_default();
if tcp_connection.timeout.is_none()
|| tcp_connection.timeout.unwrap_or(0) < Self::DEFAULT_TIMEOUT_MS
{
tcp_connection.timeout = Some(Self::DEFAULT_TIMEOUT_MS);
}
let (host, port, timeout) = crate::configure_tcp_connection(tcp_connection).await?;
let (host, port, _) = crate::configure_tcp_connection(tcp_connection).await?;
let cwd = task_definition
.config
@@ -392,6 +463,7 @@ impl DebugAdapter for GoDebugAdapter {
let arguments = if cfg!(windows) {
vec![
delve_path,
"dap".into(),
"--listen".into(),
format!("{}:{}", host, port),
@@ -399,6 +471,7 @@ impl DebugAdapter for GoDebugAdapter {
]
} else {
vec![
delve_path,
"dap".into(),
"--listen".into(),
format!("{}:{}", host, port),
@@ -406,15 +479,11 @@ impl DebugAdapter for GoDebugAdapter {
};
Ok(DebugAdapterBinary {
command: delve_path,
command: minidelve_path.to_string_lossy().into_owned(),
arguments,
cwd: Some(cwd),
envs: HashMap::default(),
connection: Some(adapters::TcpArguments {
host,
port,
timeout,
}),
connection: None,
request_args: StartDebuggingRequestArguments {
configuration: task_definition.config.clone(),
request: self.validate_config(&task_definition.config)?,

View File

@@ -47,13 +47,6 @@ impl PhpDebugAdapter {
})
}
fn validate_config(
&self,
_: &serde_json::Value,
) -> Result<StartDebuggingRequestArgumentsRequest> {
Ok(StartDebuggingRequestArgumentsRequest::Launch)
}
async fn get_installed_binary(
&self,
delegate: &Arc<dyn DapDelegate>,
@@ -101,7 +94,7 @@ impl PhpDebugAdapter {
envs: HashMap::default(),
request_args: StartDebuggingRequestArguments {
configuration: task_definition.config.clone(),
request: self.validate_config(&task_definition.config)?,
request: <Self as DebugAdapter>::validate_config(self, &task_definition.config)?,
},
})
}
@@ -156,22 +149,8 @@ impl DebugAdapter for PhpDebugAdapter {
"default": false
},
"pathMappings": {
"type": "array",
"description": "A list of server paths mapping to the local source paths on your machine for remote host debugging",
"items": {
"type": "object",
"properties": {
"serverPath": {
"type": "string",
"description": "Path on the server"
},
"localPath": {
"type": "string",
"description": "Corresponding path on the local machine"
}
},
"required": ["serverPath", "localPath"]
}
"type": "object",
"description": "A mapping of server paths to local paths.",
},
"log": {
"type": "boolean",
@@ -303,6 +282,13 @@ impl DebugAdapter for PhpDebugAdapter {
Some(SharedString::new_static("PHP").into())
}
fn validate_config(
&self,
_: &serde_json::Value,
) -> Result<StartDebuggingRequestArgumentsRequest> {
Ok(StartDebuggingRequestArgumentsRequest::Launch)
}
fn config_from_zed_format(&self, zed_scenario: ZedDebugConfig) -> Result<DebugScenario> {
let obj = match &zed_scenario.request {
dap::DebugRequest::Attach(_) => {

View File

@@ -660,7 +660,7 @@ impl DebugAdapter for PythonDebugAdapter {
}
}
self.get_installed_binary(delegate, &config, None, None, false)
self.get_installed_binary(delegate, &config, None, toolchain, false)
.await
}
}

View File

@@ -5,7 +5,7 @@ use crate::{
ClearAllBreakpoints, Continue, Detach, FocusBreakpointList, FocusConsole, FocusFrames,
FocusLoadedSources, FocusModules, FocusTerminal, FocusVariables, Pause, Restart,
ShowStackTrace, StepBack, StepInto, StepOut, StepOver, Stop, ToggleIgnoreBreakpoints,
ToggleSessionPicker, ToggleThreadPicker, persistence,
ToggleSessionPicker, ToggleThreadPicker, persistence, spawn_task_or_modal,
};
use anyhow::{Context as _, Result, anyhow};
use command_palette_hooks::CommandPaletteFilter;
@@ -65,6 +65,7 @@ pub struct DebugPanel {
workspace: WeakEntity<Workspace>,
focus_handle: FocusHandle,
context_menu: Option<(Entity<ContextMenu>, Point<Pixels>, Subscription)>,
debug_scenario_scheduled_last: bool,
pub(crate) thread_picker_menu_handle: PopoverMenuHandle<ContextMenu>,
pub(crate) session_picker_menu_handle: PopoverMenuHandle<ContextMenu>,
fs: Arc<dyn Fs>,
@@ -103,6 +104,7 @@ impl DebugPanel {
thread_picker_menu_handle,
session_picker_menu_handle,
_subscriptions: [focus_subscription],
debug_scenario_scheduled_last: true,
}
})
}
@@ -264,6 +266,7 @@ impl DebugPanel {
cx,
)
});
self.debug_scenario_scheduled_last = true;
if let Some(inventory) = self
.project
.read(cx)
@@ -295,7 +298,6 @@ impl DebugPanel {
})
})?
.await?;
dap_store
.update(cx, |dap_store, cx| {
dap_store.boot_session(session.clone(), definition, cx)
@@ -433,7 +435,10 @@ impl DebugPanel {
};
let dap_store_handle = self.project.read(cx).dap_store().clone();
let label = parent_session.read(cx).label().clone();
let mut label = parent_session.read(cx).label().clone();
if !label.ends_with("(child)") {
label = format!("{label} (child)").into();
}
let adapter = parent_session.read(cx).adapter().clone();
let mut binary = parent_session.read(cx).binary().clone();
binary.request_args = request.clone();
@@ -1379,4 +1384,30 @@ impl workspace::DebuggerProvider for DebuggerProvider {
})
})
}
fn spawn_task_or_modal(
&self,
workspace: &mut Workspace,
action: &tasks_ui::Spawn,
window: &mut Window,
cx: &mut Context<Workspace>,
) {
spawn_task_or_modal(workspace, action, window, cx);
}
fn debug_scenario_scheduled(&self, cx: &mut App) {
self.0.update(cx, |this, _| {
this.debug_scenario_scheduled_last = true;
});
}
fn task_scheduled(&self, cx: &mut App) {
self.0.update(cx, |this, _| {
this.debug_scenario_scheduled_last = false;
})
}
fn debug_scenario_scheduled_last(&self, cx: &App) -> bool {
self.0.read(cx).debug_scenario_scheduled_last
}
}

View File

@@ -3,11 +3,12 @@ use debugger_panel::{DebugPanel, ToggleFocus};
use editor::Editor;
use feature_flags::{DebuggerFeatureFlag, FeatureFlagViewExt};
use gpui::{App, EntityInputHandler, actions};
use new_session_modal::NewSessionModal;
use new_session_modal::{NewSessionModal, NewSessionMode};
use project::debugger::{self, breakpoint_store::SourceBreakpoint};
use session::DebugSession;
use settings::Settings;
use stack_trace_view::StackTraceView;
use tasks_ui::{Spawn, TaskOverrides};
use util::maybe;
use workspace::{ItemHandle, ShutdownDebugAdapters, Workspace};
@@ -62,6 +63,7 @@ pub fn init(cx: &mut App) {
cx.when_flag_enabled::<DebuggerFeatureFlag>(window, |workspace, _, _| {
workspace
.register_action(spawn_task_or_modal)
.register_action(|workspace, _: &ToggleFocus, window, cx| {
workspace.toggle_panel_focus::<DebugPanel>(window, cx);
})
@@ -208,7 +210,7 @@ pub fn init(cx: &mut App) {
},
)
.register_action(|workspace: &mut Workspace, _: &Start, window, cx| {
NewSessionModal::show(workspace, window, cx);
NewSessionModal::show(workspace, window, NewSessionMode::Launch, None, cx);
})
.register_action(
|workspace: &mut Workspace, _: &RerunLastSession, window, cx| {
@@ -309,3 +311,48 @@ pub fn init(cx: &mut App) {
})
.detach();
}
fn spawn_task_or_modal(
workspace: &mut Workspace,
action: &Spawn,
window: &mut ui::Window,
cx: &mut ui::Context<Workspace>,
) {
match action {
Spawn::ByName {
task_name,
reveal_target,
} => {
let overrides = reveal_target.map(|reveal_target| TaskOverrides {
reveal_target: Some(reveal_target),
});
let name = task_name.clone();
tasks_ui::spawn_tasks_filtered(
move |(_, task)| task.label.eq(&name),
overrides,
window,
cx,
)
.detach_and_log_err(cx)
}
Spawn::ByTag {
task_tag,
reveal_target,
} => {
let overrides = reveal_target.map(|reveal_target| TaskOverrides {
reveal_target: Some(reveal_target),
});
let tag = task_tag.clone();
tasks_ui::spawn_tasks_filtered(
move |(_, task)| task.tags.contains(&tag),
overrides,
window,
cx,
)
.detach_and_log_err(cx)
}
Spawn::ViaModal { reveal_target } => {
NewSessionModal::show(workspace, window, NewSessionMode::Task, *reveal_target, cx);
}
}
}

View File

@@ -1,4 +1,6 @@
use gpui::Entity;
use std::time::Duration;
use gpui::{Animation, AnimationExt as _, Entity, Transformation, percentage};
use project::debugger::session::{ThreadId, ThreadStatus};
use ui::{ContextMenu, DropdownMenu, DropdownStyle, Indicator, prelude::*};
@@ -23,31 +25,40 @@ impl DebugPanel {
let sessions = self.sessions().clone();
let weak = cx.weak_entity();
let running_state = running_state.read(cx);
let label = if let Some(active_session) = active_session {
let label = if let Some(active_session) = active_session.clone() {
active_session.read(cx).session(cx).read(cx).label()
} else {
SharedString::new_static("Unknown Session")
};
let is_terminated = running_state.session().read(cx).is_terminated();
let session_state_indicator = {
if is_terminated {
Some(Indicator::dot().color(Color::Error))
} else {
match running_state.thread_status(cx).unwrap_or_default() {
project::debugger::session::ThreadStatus::Stopped => {
Some(Indicator::dot().color(Color::Conflict))
}
_ => Some(Indicator::dot().color(Color::Success)),
let is_started = active_session
.is_some_and(|session| session.read(cx).session(cx).read(cx).is_started());
let session_state_indicator = if is_terminated {
Indicator::dot().color(Color::Error).into_any_element()
} else if !is_started {
Icon::new(IconName::ArrowCircle)
.size(IconSize::Small)
.color(Color::Muted)
.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
)
.into_any_element()
} else {
match running_state.thread_status(cx).unwrap_or_default() {
ThreadStatus::Stopped => {
Indicator::dot().color(Color::Conflict).into_any_element()
}
_ => Indicator::dot().color(Color::Success).into_any_element(),
}
};
let trigger = h_flex()
.gap_2()
.when_some(session_state_indicator, |this, indicator| {
this.child(indicator)
})
.child(session_state_indicator)
.justify_between()
.child(
DebugPanel::dropdown_label(label)

View File

@@ -1,5 +1,5 @@
use collections::FxHashMap;
use language::LanguageRegistry;
use language::{LanguageRegistry, Point, Selection};
use std::{
borrow::Cow,
ops::Not,
@@ -8,20 +8,21 @@ use std::{
time::Duration,
usize,
};
use tasks_ui::{TaskOverrides, TasksModal};
use dap::{
DapRegistry, DebugRequest, TelemetrySpawnLocation, adapters::DebugAdapterName, send_telemetry,
};
use editor::{Editor, EditorElement, EditorStyle};
use editor::{Anchor, Editor, EditorElement, EditorStyle, scroll::Autoscroll};
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
Animation, AnimationExt as _, App, AppContext, DismissEvent, Entity, EventEmitter, FocusHandle,
Focusable, Render, Subscription, TextStyle, Transformation, WeakEntity, percentage,
Focusable, KeyContext, Render, Subscription, TextStyle, Transformation, WeakEntity, percentage,
};
use picker::{Picker, PickerDelegate, highlighted_match_with_paths::HighlightedMatch};
use project::{ProjectPath, TaskContexts, TaskSourceKind, task_store::TaskStore};
use settings::Settings;
use task::{DebugScenario, LaunchRequest, ZedDebugConfig};
use task::{DebugScenario, LaunchRequest, RevealTarget, ZedDebugConfig};
use theme::ThemeSettings;
use ui::{
ActiveTheme, Button, ButtonCommon, ButtonSize, CheckboxWithLabel, Clickable, Color, Context,
@@ -37,7 +38,7 @@ use crate::{attach_modal::AttachModal, debugger_panel::DebugPanel};
enum SaveScenarioState {
Saving,
Saved(ProjectPath),
Saved((ProjectPath, SharedString)),
Failed(SharedString),
}
@@ -47,10 +48,11 @@ pub(super) struct NewSessionModal {
mode: NewSessionMode,
launch_picker: Entity<Picker<DebugScenarioDelegate>>,
attach_mode: Entity<AttachMode>,
custom_mode: Entity<CustomMode>,
configure_mode: Entity<ConfigureMode>,
task_mode: TaskMode,
debugger: Option<DebugAdapterName>,
save_scenario_state: Option<SaveScenarioState>,
_subscriptions: [Subscription; 2],
_subscriptions: [Subscription; 3],
}
fn suggested_label(request: &DebugRequest, debugger: &str) -> SharedString {
@@ -75,6 +77,8 @@ impl NewSessionModal {
pub(super) fn show(
workspace: &mut Workspace,
window: &mut Window,
mode: NewSessionMode,
reveal_target: Option<RevealTarget>,
cx: &mut Context<Workspace>,
) {
let Some(debug_panel) = workspace.panel::<DebugPanel>(cx) else {
@@ -84,20 +88,50 @@ impl NewSessionModal {
let languages = workspace.app_state().languages.clone();
cx.spawn_in(window, async move |workspace, cx| {
let task_contexts = workspace
.update_in(cx, |workspace, window, cx| {
tasks_ui::task_contexts(workspace, window, cx)
})?
.await;
let task_contexts = Arc::new(task_contexts);
workspace.update_in(cx, |workspace, window, cx| {
let workspace_handle = workspace.weak_handle();
workspace.toggle_modal(window, cx, |window, cx| {
let attach_mode = AttachMode::new(None, workspace_handle.clone(), window, cx);
let launch_picker = cx.new(|cx| {
Picker::uniform_list(
DebugScenarioDelegate::new(debug_panel.downgrade(), task_store),
window,
cx,
)
.modal(false)
let mut delegate =
DebugScenarioDelegate::new(debug_panel.downgrade(), task_store.clone());
delegate.task_contexts_loaded(task_contexts.clone(), languages, window, cx);
Picker::uniform_list(delegate, window, cx).modal(false)
});
let configure_mode = ConfigureMode::new(None, window, cx);
if let Some(active_cwd) = task_contexts
.active_context()
.and_then(|context| context.cwd.clone())
{
configure_mode.update(cx, |configure_mode, cx| {
configure_mode.load(active_cwd, window, cx);
});
}
let task_overrides = Some(TaskOverrides { reveal_target });
let task_mode = TaskMode {
task_modal: cx.new(|cx| {
TasksModal::new(
task_store.clone(),
task_contexts,
task_overrides,
false,
workspace_handle.clone(),
window,
cx,
)
}),
};
let _subscriptions = [
cx.subscribe(&launch_picker, |_, _, _, cx| {
cx.emit(DismissEvent);
@@ -108,52 +142,18 @@ impl NewSessionModal {
cx.emit(DismissEvent);
},
),
cx.subscribe(&task_mode.task_modal, |_, _, _: &DismissEvent, cx| {
cx.emit(DismissEvent)
}),
];
let custom_mode = CustomMode::new(None, window, cx);
cx.spawn_in(window, {
let workspace_handle = workspace_handle.clone();
async move |this, cx| {
let task_contexts = workspace_handle
.update_in(cx, |workspace, window, cx| {
tasks_ui::task_contexts(workspace, window, cx)
})?
.await;
this.update_in(cx, |this, window, cx| {
if let Some(active_cwd) = task_contexts
.active_context()
.and_then(|context| context.cwd.clone())
{
this.custom_mode.update(cx, |custom, cx| {
custom.load(active_cwd, window, cx);
});
this.debugger = None;
}
this.launch_picker.update(cx, |picker, cx| {
picker.delegate.task_contexts_loaded(
task_contexts,
languages,
window,
cx,
);
picker.refresh(window, cx);
cx.notify();
});
})
}
})
.detach();
Self {
launch_picker,
attach_mode,
custom_mode,
configure_mode,
task_mode,
debugger: None,
mode: NewSessionMode::Launch,
mode,
debug_panel: debug_panel.downgrade(),
workspace: workspace_handle,
save_scenario_state: None,
@@ -170,10 +170,17 @@ impl NewSessionModal {
fn render_mode(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl ui::IntoElement {
let dap_menu = self.adapter_drop_down_menu(window, cx);
match self.mode {
NewSessionMode::Task => self
.task_mode
.task_modal
.read(cx)
.picker
.clone()
.into_any_element(),
NewSessionMode::Attach => self.attach_mode.update(cx, |this, cx| {
this.clone().render(window, cx).into_any_element()
}),
NewSessionMode::Custom => self.custom_mode.update(cx, |this, cx| {
NewSessionMode::Configure => self.configure_mode.update(cx, |this, cx| {
this.clone().render(dap_menu, window, cx).into_any_element()
}),
NewSessionMode::Launch => v_flex()
@@ -185,16 +192,17 @@ impl NewSessionModal {
fn mode_focus_handle(&self, cx: &App) -> FocusHandle {
match self.mode {
NewSessionMode::Task => self.task_mode.task_modal.focus_handle(cx),
NewSessionMode::Attach => self.attach_mode.read(cx).attach_picker.focus_handle(cx),
NewSessionMode::Custom => self.custom_mode.read(cx).program.focus_handle(cx),
NewSessionMode::Configure => self.configure_mode.read(cx).program.focus_handle(cx),
NewSessionMode::Launch => self.launch_picker.focus_handle(cx),
}
}
fn debug_scenario(&self, debugger: &str, cx: &App) -> Option<DebugScenario> {
let request = match self.mode {
NewSessionMode::Custom => Some(DebugRequest::Launch(
self.custom_mode.read(cx).debug_request(cx),
NewSessionMode::Configure => Some(DebugRequest::Launch(
self.configure_mode.read(cx).debug_request(cx),
)),
NewSessionMode::Attach => Some(DebugRequest::Attach(
self.attach_mode.read(cx).debug_request(),
@@ -203,8 +211,8 @@ impl NewSessionModal {
}?;
let label = suggested_label(&request, debugger);
let stop_on_entry = if let NewSessionMode::Custom = &self.mode {
Some(self.custom_mode.read(cx).stop_on_entry.selected())
let stop_on_entry = if let NewSessionMode::Configure = &self.mode {
Some(self.configure_mode.read(cx).stop_on_entry.selected())
} else {
None
};
@@ -284,6 +292,177 @@ impl NewSessionModal {
self.launch_picker.read(cx).delegate.task_contexts.clone()
}
fn save_debug_scenario(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let Some((save_scenario, scenario_label)) = self
.debugger
.as_ref()
.and_then(|debugger| self.debug_scenario(&debugger, cx))
.zip(self.task_contexts(cx).and_then(|tcx| tcx.worktree()))
.and_then(|(scenario, worktree_id)| {
self.debug_panel
.update(cx, |panel, cx| {
panel.save_scenario(&scenario, worktree_id, window, cx)
})
.ok()
.zip(Some(scenario.label.clone()))
})
else {
return;
};
self.save_scenario_state = Some(SaveScenarioState::Saving);
cx.spawn(async move |this, cx| {
let res = save_scenario.await;
this.update(cx, |this, _| match res {
Ok(saved_file) => {
this.save_scenario_state =
Some(SaveScenarioState::Saved((saved_file, scenario_label)))
}
Err(error) => {
this.save_scenario_state =
Some(SaveScenarioState::Failed(error.to_string().into()))
}
})
.ok();
cx.background_executor().timer(Duration::from_secs(3)).await;
this.update(cx, |this, _| this.save_scenario_state.take())
.ok();
})
.detach();
}
fn render_save_state(&self, cx: &mut Context<Self>) -> impl IntoElement {
let this_entity = cx.weak_entity().clone();
div().when_some(self.save_scenario_state.as_ref(), {
let this_entity = this_entity.clone();
move |this, save_state| match save_state {
SaveScenarioState::Saved((saved_path, scenario_label)) => this.child(
IconButton::new("new-session-modal-go-to-file", IconName::ArrowUpRight)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.on_click({
let this_entity = this_entity.clone();
let saved_path = saved_path.clone();
let scenario_label = scenario_label.clone();
move |_, window, cx| {
window
.spawn(cx, {
let this_entity = this_entity.clone();
let saved_path = saved_path.clone();
let scenario_label = scenario_label.clone();
async move |cx| {
let editor = this_entity
.update_in(cx, |this, window, cx| {
this.workspace.update(cx, |workspace, cx| {
workspace.open_path(
saved_path.clone(),
None,
true,
window,
cx,
)
})
})??
.await?;
cx.update(|window, cx| {
if let Some(editor) = editor.act_as::<Editor>(cx) {
editor.update(cx, |editor, cx| {
let row = editor
.text(cx)
.lines()
.enumerate()
.find_map(|(row, text)| {
if text.contains(
scenario_label.as_ref(),
) {
Some(row)
} else {
None
}
})?;
let buffer = editor.buffer().read(cx);
let excerpt_id =
*buffer.excerpt_ids().first()?;
let snapshot = buffer
.as_singleton()?
.read(cx)
.snapshot();
let anchor = snapshot.anchor_before(
Point::new(row as u32, 0),
);
let anchor = Anchor {
buffer_id: anchor.buffer_id,
excerpt_id,
text_anchor: anchor,
diff_base_anchor: None,
};
editor.change_selections(
Some(Autoscroll::center()),
window,
cx,
|selections| {
let id =
selections.new_selection_id();
selections.select_anchors(
vec![Selection {
id,
start: anchor,
end: anchor,
reversed: false,
goal: language::SelectionGoal::None
}],
);
},
);
Some(())
});
}
})?;
this_entity
.update(cx, |_, cx| cx.emit(DismissEvent))
.ok();
anyhow::Ok(())
}
})
.detach();
}
}),
),
SaveScenarioState::Saving => this.child(
Icon::new(IconName::Spinner)
.size(IconSize::Small)
.color(Color::Muted)
.with_animation(
"Spinner",
Animation::new(Duration::from_secs(3)).repeat(),
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
),
),
SaveScenarioState::Failed(error_msg) => this.child(
IconButton::new("Failed Scenario Saved", IconName::X)
.icon_size(IconSize::Small)
.icon_color(Color::Error)
.tooltip(ui::Tooltip::text(error_msg.clone())),
),
}
})
}
fn adapter_drop_down_menu(
&mut self,
window: &mut Window,
@@ -355,8 +534,9 @@ impl NewSessionModal {
static SELECT_DEBUGGER_LABEL: SharedString = SharedString::new_static("Select Debugger");
#[derive(Clone)]
enum NewSessionMode {
Custom,
pub(crate) enum NewSessionMode {
Task,
Configure,
Attach,
Launch,
}
@@ -364,9 +544,10 @@ enum NewSessionMode {
impl std::fmt::Display for NewSessionMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mode = match self {
NewSessionMode::Launch => "Launch".to_owned(),
NewSessionMode::Attach => "Attach".to_owned(),
NewSessionMode::Custom => "Custom".to_owned(),
NewSessionMode::Task => "Run",
NewSessionMode::Launch => "Debug",
NewSessionMode::Attach => "Attach",
NewSessionMode::Configure => "Configure Debugger",
};
write!(f, "{}", mode)
@@ -423,41 +604,42 @@ impl Render for NewSessionModal {
window: &mut ui::Window,
cx: &mut ui::Context<Self>,
) -> impl ui::IntoElement {
let this = cx.weak_entity().clone();
v_flex()
.size_full()
.w(rems(34.))
.key_context("Pane")
.key_context({
let mut key_context = KeyContext::new_with_defaults();
key_context.add("Pane");
key_context.add("RunModal");
key_context
})
.elevation_3(cx)
.bg(cx.theme().colors().elevated_surface_background)
.on_action(cx.listener(|_, _: &menu::Cancel, _, cx| {
cx.emit(DismissEvent);
}))
.on_action(cx.listener(|this, _: &pane::ActivateNextItem, window, cx| {
this.mode = match this.mode {
NewSessionMode::Task => NewSessionMode::Launch,
NewSessionMode::Launch => NewSessionMode::Attach,
NewSessionMode::Attach => NewSessionMode::Configure,
NewSessionMode::Configure => NewSessionMode::Task,
};
this.mode_focus_handle(cx).focus(window);
}))
.on_action(
cx.listener(|this, _: &pane::ActivatePreviousItem, window, cx| {
this.mode = match this.mode {
NewSessionMode::Task => NewSessionMode::Configure,
NewSessionMode::Launch => NewSessionMode::Task,
NewSessionMode::Attach => NewSessionMode::Launch,
NewSessionMode::Launch => NewSessionMode::Attach,
_ => {
return;
}
NewSessionMode::Configure => NewSessionMode::Attach,
};
this.mode_focus_handle(cx).focus(window);
}),
)
.on_action(cx.listener(|this, _: &pane::ActivateNextItem, window, cx| {
this.mode = match this.mode {
NewSessionMode::Attach => NewSessionMode::Launch,
NewSessionMode::Launch => NewSessionMode::Attach,
_ => {
return;
}
};
this.mode_focus_handle(cx).focus(window);
}))
.child(
h_flex()
.w_full()
@@ -468,37 +650,73 @@ impl Render for NewSessionModal {
.justify_start()
.w_full()
.child(
ToggleButton::new("debugger-session-ui-picker-button", "Launch")
.size(ButtonSize::Default)
.style(ui::ButtonStyle::Subtle)
.toggle_state(matches!(self.mode, NewSessionMode::Launch))
.on_click(cx.listener(|this, _, window, cx| {
this.mode = NewSessionMode::Launch;
this.mode_focus_handle(cx).focus(window);
cx.notify();
}))
.first(),
ToggleButton::new(
"debugger-session-ui-tasks-button",
NewSessionMode::Task.to_string(),
)
.size(ButtonSize::Default)
.toggle_state(matches!(self.mode, NewSessionMode::Task))
.style(ui::ButtonStyle::Subtle)
.on_click(cx.listener(|this, _, window, cx| {
this.mode = NewSessionMode::Task;
this.mode_focus_handle(cx).focus(window);
cx.notify();
}))
.first(),
)
.child(
ToggleButton::new("debugger-session-ui-attach-button", "Attach")
.size(ButtonSize::Default)
.toggle_state(matches!(self.mode, NewSessionMode::Attach))
.style(ui::ButtonStyle::Subtle)
.on_click(cx.listener(|this, _, window, cx| {
this.mode = NewSessionMode::Attach;
ToggleButton::new(
"debugger-session-ui-launch-button",
NewSessionMode::Launch.to_string(),
)
.size(ButtonSize::Default)
.style(ui::ButtonStyle::Subtle)
.toggle_state(matches!(self.mode, NewSessionMode::Launch))
.on_click(cx.listener(|this, _, window, cx| {
this.mode = NewSessionMode::Launch;
this.mode_focus_handle(cx).focus(window);
cx.notify();
}))
.middle(),
)
.child(
ToggleButton::new(
"debugger-session-ui-attach-button",
NewSessionMode::Attach.to_string(),
)
.size(ButtonSize::Default)
.toggle_state(matches!(self.mode, NewSessionMode::Attach))
.style(ui::ButtonStyle::Subtle)
.on_click(cx.listener(|this, _, window, cx| {
this.mode = NewSessionMode::Attach;
if let Some(debugger) = this.debugger.as_ref() {
Self::update_attach_picker(
&this.attach_mode,
&debugger,
window,
cx,
);
}
this.mode_focus_handle(cx).focus(window);
cx.notify();
}))
.last(),
if let Some(debugger) = this.debugger.as_ref() {
Self::update_attach_picker(
&this.attach_mode,
&debugger,
window,
cx,
);
}
this.mode_focus_handle(cx).focus(window);
cx.notify();
}))
.middle(),
)
.child(
ToggleButton::new(
"debugger-session-ui-custom-button",
NewSessionMode::Configure.to_string(),
)
.size(ButtonSize::Default)
.toggle_state(matches!(self.mode, NewSessionMode::Configure))
.style(ui::ButtonStyle::Subtle)
.on_click(cx.listener(|this, _, window, cx| {
this.mode = NewSessionMode::Configure;
this.mode_focus_handle(cx).focus(window);
cx.notify();
}))
.last(),
),
)
.justify_between()
@@ -506,210 +724,83 @@ impl Render for NewSessionModal {
.border_b_1(),
)
.child(v_flex().child(self.render_mode(window, cx)))
.child(
h_flex()
.map(|el| {
let container = h_flex()
.justify_between()
.gap_2()
.p_2()
.border_color(cx.theme().colors().border_variant)
.border_t_1()
.w_full()
.child(match self.mode {
NewSessionMode::Attach => {
div().child(self.adapter_drop_down_menu(window, cx))
}
NewSessionMode::Launch => div().child(
Button::new("new-session-modal-custom", "Custom").on_click({
let this = cx.weak_entity();
move |_, window, cx| {
this.update(cx, |this, cx| {
this.mode = NewSessionMode::Custom;
this.mode_focus_handle(cx).focus(window);
})
.ok();
}
}),
),
NewSessionMode::Custom => h_flex()
.w_full();
match self.mode {
NewSessionMode::Configure => el.child(
container
.child(
Button::new("new-session-modal-back", "Save to .zed/debug.json...")
h_flex()
.child(
Button::new(
"new-session-modal-back",
"Save to .zed/debug.json...",
)
.on_click(cx.listener(|this, _, window, cx| {
this.save_debug_scenario(window, cx);
}))
.disabled(
self.debugger.is_none()
|| self
.configure_mode
.read(cx)
.program
.read(cx)
.is_empty(cx)
|| self.save_scenario_state.is_some(),
),
)
.child(self.render_save_state(cx)),
)
.child(
Button::new("debugger-spawn", "Start")
.on_click(cx.listener(|this, _, window, cx| {
let Some(save_scenario) = this
.debugger
.as_ref()
.and_then(|debugger| this.debug_scenario(&debugger, cx))
.zip(
this.task_contexts(cx)
.and_then(|tcx| tcx.worktree()),
)
.and_then(|(scenario, worktree_id)| {
this.debug_panel
.update(cx, |panel, cx| {
panel.save_scenario(
&scenario,
worktree_id,
window,
cx,
)
})
.ok()
})
else {
return;
};
this.save_scenario_state = Some(SaveScenarioState::Saving);
cx.spawn(async move |this, cx| {
let res = save_scenario.await;
this.update(cx, |this, _| match res {
Ok(saved_file) => {
this.save_scenario_state =
Some(SaveScenarioState::Saved(saved_file))
}
Err(error) => {
this.save_scenario_state =
Some(SaveScenarioState::Failed(
error.to_string().into(),
))
}
})
.ok();
cx.background_executor()
.timer(Duration::from_secs(2))
.await;
this.update(cx, |this, _| {
this.save_scenario_state.take()
})
.ok();
})
.detach();
this.start_new_session(window, cx)
}))
.disabled(
self.debugger.is_none()
|| self
.custom_mode
.configure_mode
.read(cx)
.program
.read(cx)
.is_empty(cx)
|| self.save_scenario_state.is_some(),
.is_empty(cx),
),
)
.when_some(self.save_scenario_state.as_ref(), {
let this_entity = this.clone();
move |this, save_state| match save_state {
SaveScenarioState::Saved(saved_path) => this.child(
IconButton::new(
"new-session-modal-go-to-file",
IconName::ArrowUpRight,
)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.on_click({
let this_entity = this_entity.clone();
let saved_path = saved_path.clone();
move |_, window, cx| {
window
.spawn(cx, {
let this_entity = this_entity.clone();
let saved_path = saved_path.clone();
async move |cx| {
this_entity
.update_in(
cx,
|this, window, cx| {
this.workspace.update(
cx,
|workspace, cx| {
workspace.open_path(
saved_path
.clone(),
None,
true,
window,
cx,
)
},
)
},
)??
.await?;
this_entity
.update(cx, |_, cx| {
cx.emit(DismissEvent)
})
.ok();
anyhow::Ok(())
}
})
.detach();
}
}),
),
SaveScenarioState::Saving => this.child(
Icon::new(IconName::Spinner)
.size(IconSize::Small)
.color(Color::Muted)
.with_animation(
"Spinner",
Animation::new(Duration::from_secs(3)).repeat(),
|icon, delta| {
icon.transform(Transformation::rotate(
percentage(delta),
))
},
),
),
SaveScenarioState::Failed(error_msg) => this.child(
IconButton::new("Failed Scenario Saved", IconName::X)
.icon_size(IconSize::Small)
.icon_color(Color::Error)
.tooltip(ui::Tooltip::text(error_msg.clone())),
),
}
}),
})
.child(
Button::new("debugger-spawn", "Start")
.on_click(cx.listener(|this, _, window, cx| match &this.mode {
NewSessionMode::Launch => {
this.launch_picker.update(cx, |picker, cx| {
picker.delegate.confirm(true, window, cx)
})
}
_ => this.start_new_session(window, cx),
}))
.disabled(match self.mode {
NewSessionMode::Launch => {
!self.launch_picker.read(cx).delegate.matches.is_empty()
}
NewSessionMode::Attach => {
self.debugger.is_none()
|| self
.attach_mode
.read(cx)
.attach_picker
.read(cx)
.picker
.read(cx)
.delegate
.match_count()
== 0
}
NewSessionMode::Custom => {
self.debugger.is_none()
|| self.custom_mode.read(cx).program.read(cx).is_empty(cx)
}
}),
),
),
)
NewSessionMode::Attach => el.child(
container
.child(div().child(self.adapter_drop_down_menu(window, cx)))
.child(
Button::new("debugger-spawn", "Start")
.on_click(cx.listener(|this, _, window, cx| {
this.start_new_session(window, cx)
}))
.disabled(
self.debugger.is_none()
|| self
.attach_mode
.read(cx)
.attach_picker
.read(cx)
.picker
.read(cx)
.delegate
.match_count()
== 0,
),
),
),
NewSessionMode::Launch => el,
NewSessionMode::Task => el,
}
})
}
}
@@ -732,13 +823,13 @@ impl RenderOnce for AttachMode {
}
#[derive(Clone)]
pub(super) struct CustomMode {
pub(super) struct ConfigureMode {
program: Entity<Editor>,
cwd: Entity<Editor>,
stop_on_entry: ToggleState,
}
impl CustomMode {
impl ConfigureMode {
pub(super) fn new(
past_launch_config: Option<LaunchRequest>,
window: &mut Window,
@@ -898,6 +989,11 @@ impl AttachMode {
}
}
#[derive(Clone)]
pub(super) struct TaskMode {
pub(super) task_modal: Entity<TasksModal>,
}
pub(super) struct DebugScenarioDelegate {
task_store: Entity<TaskStore>,
candidates: Vec<(Option<TaskSourceKind>, DebugScenario)>,
@@ -953,12 +1049,12 @@ impl DebugScenarioDelegate {
pub fn task_contexts_loaded(
&mut self,
task_contexts: TaskContexts,
task_contexts: Arc<TaskContexts>,
languages: Arc<LanguageRegistry>,
_window: &mut Window,
cx: &mut Context<Picker<Self>>,
) {
self.task_contexts = Some(Arc::new(task_contexts));
self.task_contexts = Some(task_contexts);
let (recent, scenarios) = self
.task_store
@@ -1163,26 +1259,37 @@ pub(crate) fn resolve_path(path: &mut String) {
}
#[cfg(test)]
mod tests {
use paths::home_dir;
impl NewSessionModal {
pub(crate) fn set_configure(
&mut self,
program: impl AsRef<str>,
cwd: impl AsRef<str>,
stop_on_entry: bool,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.mode = NewSessionMode::Configure;
self.debugger = Some(dap::adapters::DebugAdapterName("fake-adapter".into()));
#[test]
fn test_normalize_paths() {
let sep = std::path::MAIN_SEPARATOR;
let home = home_dir().to_string_lossy().to_string();
let resolve_path = |path: &str| -> String {
let mut path = path.to_string();
super::resolve_path(&mut path);
path
};
self.configure_mode.update(cx, |configure, cx| {
configure.program.update(cx, |editor, cx| {
editor.clear(window, cx);
editor.set_text(program.as_ref(), window, cx);
});
assert_eq!(resolve_path("bin"), format!("bin"));
assert_eq!(resolve_path(&format!("{sep}foo")), format!("{sep}foo"));
assert_eq!(resolve_path(""), format!(""));
assert_eq!(
resolve_path(&format!("~{sep}blah")),
format!("{home}{sep}blah")
);
assert_eq!(resolve_path("~"), home);
configure.cwd.update(cx, |editor, cx| {
editor.clear(window, cx);
editor.set_text(cwd.as_ref(), window, cx);
});
configure.stop_on_entry = match stop_on_entry {
true => ToggleState::Selected,
_ => ToggleState::Unselected,
}
})
}
pub(crate) fn save_scenario(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.save_debug_scenario(window, cx);
}
}

View File

@@ -547,6 +547,10 @@ impl RunningState {
.for_each(|value| Self::substitute_variables_in_config(value, context));
}
serde_json::Value::String(s) => {
// Some built-in zed tasks wrap their arguments in quotes as they might contain spaces.
if s.starts_with("\"$ZED_") && s.ends_with('"') {
*s = s[1..s.len() - 1].to_string();
}
if let Some(substituted) = substitute_variables_in_str(&s, context) {
*s = substituted;
}
@@ -571,6 +575,10 @@ impl RunningState {
.for_each(|value| Self::relativlize_paths(None, value, context));
}
serde_json::Value::String(s) if key == Some("program") || key == Some("cwd") => {
// Some built-in zed tasks wrap their arguments in quotes as they might contain spaces.
if s.starts_with("\"$ZED_") && s.ends_with('"') {
*s = s[1..s.len() - 1].to_string();
}
resolve_path(s);
if let Some(substituted) = substitute_variables_in_str(&s, context) {
@@ -866,6 +874,7 @@ impl RunningState {
args,
..task.resolved.clone()
};
let terminal = project
.update_in(cx, |project, window, cx| {
project.create_terminal(
@@ -910,12 +919,6 @@ impl RunningState {
};
if config_is_valid {
// Ok(DebugTaskDefinition {
// label,
// adapter: DebugAdapterName(adapter),
// config,
// tcp_connection,
// })
} else if let Some((task, locator_name)) = build_output {
let locator_name =
locator_name.context("Could not find a valid locator for a build task")?;
@@ -934,7 +937,7 @@ impl RunningState {
let scenario = dap_registry
.adapter(&adapter)
.ok_or_else(|| anyhow!("{}: is not a valid adapter name", &adapter))
.context(format!("{}: is not a valid adapter name", &adapter))
.map(|adapter| adapter.config_from_zed_format(zed_config))??;
config = scenario.config;
Self::substitute_variables_in_config(&mut config, &task_context);

View File

@@ -110,7 +110,7 @@ impl Console {
}
fn is_running(&self, cx: &Context<Self>) -> bool {
self.session.read(cx).is_local()
self.session.read(cx).is_running()
}
fn handle_stack_frame_list_events(

View File

@@ -250,9 +250,6 @@ impl StackFrameList {
let Some(abs_path) = Self::abs_path_from_stack_frame(&stack_frame) else {
return Task::ready(Err(anyhow!("Project path not found")));
};
if abs_path.starts_with("<node_internals>") {
return Task::ready(Ok(()));
}
let row = stack_frame.line.saturating_sub(1) as u32;
cx.emit(StackFrameListEvent::SelectedStackFrameChanged(
stack_frame_id,
@@ -345,6 +342,7 @@ impl StackFrameList {
s.path
.as_deref()
.map(|path| Arc::<Path>::from(Path::new(path)))
.filter(|path| path.is_absolute())
})
}

View File

@@ -25,7 +25,6 @@ mod inline_values;
#[cfg(test)]
mod module_list;
#[cfg(test)]
#[cfg(not(windows))]
mod new_session_modal;
#[cfg(test)]
mod persistence;

View File

@@ -1,14 +1,15 @@
use dap::DapRegistry;
use gpui::{BackgroundExecutor, TestAppContext, VisualTestContext};
use project::{FakeFs, Project};
use project::{FakeFs, Fs, Project};
use serde_json::json;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use task::{DebugScenario, TaskContext, VariableName};
use task::{DebugRequest, DebugScenario, LaunchRequest, TaskContext, VariableName, ZedDebugConfig};
use util::path;
use crate::new_session_modal::NewSessionMode;
use crate::tests::{init_test, init_test_workspace};
// todo(tasks) figure out why task replacement is broken on windows
#[gpui::test]
async fn test_debug_session_substitutes_variables_and_relativizes_paths(
executor: BackgroundExecutor,
@@ -29,10 +30,9 @@ async fn test_debug_session_substitutes_variables_and_relativizes_paths(
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
// Set up task variables to simulate a real environment
let test_variables = vec![(
VariableName::WorktreeRoot,
"/test/worktree/path".to_string(),
path!("/test/worktree/path").to_string(),
)]
.into_iter()
.collect();
@@ -45,33 +45,35 @@ async fn test_debug_session_substitutes_variables_and_relativizes_paths(
let home_dir = paths::home_dir();
let sep = std::path::MAIN_SEPARATOR;
// Test cases for different path formats
let test_cases: Vec<(Arc<String>, Arc<String>)> = vec![
let test_cases: Vec<(&'static str, &'static str)> = vec![
// Absolute path - should not be relativized
(
Arc::from(format!("{0}absolute{0}path{0}to{0}program", sep)),
Arc::from(format!("{0}absolute{0}path{0}to{0}program", sep)),
path!("/absolute/path/to/program"),
path!("/absolute/path/to/program"),
),
// Relative path - should be prefixed with worktree root
(
Arc::from(format!(".{0}src{0}program", sep)),
Arc::from(format!("{0}test{0}worktree{0}path{0}src{0}program", sep)),
format!(".{0}src{0}program", std::path::MAIN_SEPARATOR).leak(),
path!("/test/worktree/path/src/program"),
),
// Home directory path - should be prefixed with worktree root
// Home directory path - should be expanded to full home directory path
(
Arc::from(format!("~{0}src{0}program", sep)),
Arc::from(format!(
"{1}{0}src{0}program",
sep,
home_dir.to_string_lossy()
)),
format!("~{0}src{0}program", std::path::MAIN_SEPARATOR).leak(),
home_dir
.join("src")
.join("program")
.to_string_lossy()
.to_string()
.leak(),
),
// Path with $ZED_WORKTREE_ROOT - should be substituted without double appending
(
Arc::from(format!("$ZED_WORKTREE_ROOT{0}src{0}program", sep)),
Arc::from(format!("{0}test{0}worktree{0}path{0}src{0}program", sep)),
format!(
"$ZED_WORKTREE_ROOT{0}src{0}program",
std::path::MAIN_SEPARATOR
)
.leak(),
path!("/test/worktree/path/src/program"),
),
];
@@ -80,44 +82,38 @@ async fn test_debug_session_substitutes_variables_and_relativizes_paths(
for (input_path, expected_path) in test_cases {
let _subscription = project::debugger::test::intercept_debug_sessions(cx, {
let called_launch = called_launch.clone();
let input_path = input_path.clone();
let expected_path = expected_path.clone();
move |client| {
client.on_request::<dap::requests::Launch, _>({
let called_launch = called_launch.clone();
let input_path = input_path.clone();
let expected_path = expected_path.clone();
move |_, args| {
let config = args.raw.as_object().unwrap();
// Verify the program path was substituted correctly
assert_eq!(
config["program"].as_str().unwrap(),
expected_path.as_str(),
expected_path,
"Program path was not correctly substituted for input: {}",
input_path.as_str()
input_path
);
// Verify the cwd path was substituted correctly
assert_eq!(
config["cwd"].as_str().unwrap(),
expected_path.as_str(),
expected_path,
"CWD path was not correctly substituted for input: {}",
input_path.as_str()
input_path
);
// Verify that otherField was substituted but not relativized
// It should still have $ZED_WORKTREE_ROOT substituted if present
let expected_other_field = if input_path.contains("$ZED_WORKTREE_ROOT") {
input_path.replace("$ZED_WORKTREE_ROOT", "/test/worktree/path")
input_path
.replace("$ZED_WORKTREE_ROOT", &path!("/test/worktree/path"))
.to_owned()
} else {
input_path.to_string()
};
assert_eq!(
config["otherField"].as_str().unwrap(),
expected_other_field,
&expected_other_field,
"Other field was incorrectly modified for input: {}",
input_path
);
@@ -155,3 +151,199 @@ async fn test_debug_session_substitutes_variables_and_relativizes_paths(
called_launch.store(false, Ordering::SeqCst);
}
}
#[gpui::test]
async fn test_save_debug_scenario_to_file(executor: BackgroundExecutor, cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(executor.clone());
fs.insert_tree(
path!("/project"),
json!({
"main.rs": "fn main() {}"
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
workspace
.update(cx, |workspace, window, cx| {
crate::new_session_modal::NewSessionModal::show(
workspace,
window,
NewSessionMode::Launch,
None,
cx,
);
})
.unwrap();
cx.run_until_parked();
let modal = workspace
.update(cx, |workspace, _, cx| {
workspace.active_modal::<crate::new_session_modal::NewSessionModal>(cx)
})
.unwrap()
.expect("Modal should be active");
modal.update_in(cx, |modal, window, cx| {
modal.set_configure("/project/main", "/project", false, window, cx);
modal.save_scenario(window, cx);
});
cx.executor().run_until_parked();
let debug_json_content = fs
.load(path!("/project/.zed/debug.json").as_ref())
.await
.expect("debug.json should exist");
let expected_content = vec![
"[",
" {",
r#" "adapter": "fake-adapter","#,
r#" "label": "main (fake-adapter)","#,
r#" "request": "launch","#,
r#" "program": "/project/main","#,
r#" "cwd": "/project","#,
r#" "args": [],"#,
r#" "env": {}"#,
" }",
"]",
];
let actual_lines: Vec<&str> = debug_json_content.lines().collect();
pretty_assertions::assert_eq!(expected_content, actual_lines);
modal.update_in(cx, |modal, window, cx| {
modal.set_configure("/project/other", "/project", true, window, cx);
modal.save_scenario(window, cx);
});
cx.executor().run_until_parked();
let debug_json_content = fs
.load(path!("/project/.zed/debug.json").as_ref())
.await
.expect("debug.json should exist after second save");
let expected_content = vec![
"[",
" {",
r#" "adapter": "fake-adapter","#,
r#" "label": "main (fake-adapter)","#,
r#" "request": "launch","#,
r#" "program": "/project/main","#,
r#" "cwd": "/project","#,
r#" "args": [],"#,
r#" "env": {}"#,
" },",
" {",
r#" "adapter": "fake-adapter","#,
r#" "label": "other (fake-adapter)","#,
r#" "request": "launch","#,
r#" "program": "/project/other","#,
r#" "cwd": "/project","#,
r#" "args": [],"#,
r#" "env": {}"#,
" }",
"]",
];
let actual_lines: Vec<&str> = debug_json_content.lines().collect();
pretty_assertions::assert_eq!(expected_content, actual_lines);
}
#[gpui::test]
async fn test_dap_adapter_config_conversion_and_validation(cx: &mut TestAppContext) {
init_test(cx);
let mut expected_adapters = vec![
"CodeLLDB",
"Debugpy",
"PHP",
"JavaScript",
"Ruby",
"Delve",
"GDB",
"fake-adapter",
];
let adapter_names = cx.update(|cx| {
let registry = DapRegistry::global(cx);
registry.enumerate_adapters()
});
let zed_config = ZedDebugConfig {
label: "test_debug_session".into(),
adapter: "test_adapter".into(),
request: DebugRequest::Launch(LaunchRequest {
program: "test_program".into(),
cwd: None,
args: vec![],
env: Default::default(),
}),
stop_on_entry: Some(true),
};
for adapter_name in adapter_names {
let adapter_str = adapter_name.to_string();
if let Some(pos) = expected_adapters.iter().position(|&x| x == adapter_str) {
expected_adapters.remove(pos);
}
let adapter = cx
.update(|cx| {
let registry = DapRegistry::global(cx);
registry.adapter(adapter_name.as_ref())
})
.unwrap_or_else(|| panic!("Adapter {} should exist", adapter_name));
let mut adapter_specific_config = zed_config.clone();
adapter_specific_config.adapter = adapter_name.to_string().into();
let debug_scenario = adapter
.config_from_zed_format(adapter_specific_config)
.unwrap_or_else(|_| {
panic!(
"Adapter {} should successfully convert from Zed format",
adapter_name
)
});
assert!(
debug_scenario.config.is_object(),
"Adapter {} should produce a JSON object for config",
adapter_name
);
let request_type = adapter
.validate_config(&debug_scenario.config)
.unwrap_or_else(|_| {
panic!(
"Adapter {} should validate the config successfully",
adapter_name
)
});
match request_type {
dap::StartDebuggingRequestArgumentsRequest::Launch => {}
dap::StartDebuggingRequestArgumentsRequest::Attach => {
panic!(
"Expected Launch request but got Attach for adapter {}",
adapter_name
);
}
}
}
assert!(
expected_adapters.is_empty(),
"The following expected adapters were not found in the registry: {:?}",
expected_adapters
);
}

View File

@@ -125,7 +125,7 @@ fn find_binding(os: &str, action: &str) -> Option<String> {
// Find the binding in reverse order, as the last binding takes precedence.
keymap.sections().rev().find_map(|section| {
section.bindings().rev().find_map(|(keystroke, a)| {
if a.to_string() == action {
if name_for_action(a.to_string()) == action {
Some(keystroke.to_string())
} else {
None
@@ -134,6 +134,36 @@ fn find_binding(os: &str, action: &str) -> Option<String> {
})
}
/// Removes any configurable options from the stringified action if existing,
/// ensuring that only the actual action name is returned. If the action consists
/// only of a string and nothing else, the string is returned as-is.
///
/// Example:
///
/// This will return the action name unmodified.
///
/// ```
/// let action_as_str = "assistant::Assist";
/// let action_name = name_for_action(action_as_str);
/// assert_eq!(action_name, "assistant::Assist");
/// ```
///
/// This will return the action name with any trailing options removed.
///
///
/// ```
/// let action_as_str = "\"editor::ToggleComments\", {\"advance_downwards\":false}";
/// let action_name = name_for_action(action_as_str);
/// assert_eq!(action_name, "editor::ToggleComments");
/// ```
fn name_for_action(action_as_str: String) -> String {
action_as_str
.split(",")
.next()
.map(|name| name.trim_matches('"').to_string())
.unwrap_or(action_as_str)
}
fn load_keymap(asset_path: &str) -> Result<KeymapFile> {
let content = util::asset_str::<settings::SettingsAssets>(asset_path);
KeymapFile::parse(content.as_ref())

View File

@@ -82,6 +82,7 @@ tree-sitter-rust = { workspace = true, optional = true }
tree-sitter-typescript = { workspace = true, optional = true }
tree-sitter-python = { workspace = true, optional = true }
unicode-segmentation.workspace = true
unicode-script.workspace = true
unindent = { workspace = true, optional = true }
ui.workspace = true
url.workspace = true
@@ -97,6 +98,7 @@ gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
languages = {workspace = true, features = ["test-support"] }
lsp = { workspace = true, features = ["test-support"] }
markdown = { workspace = true, features = ["test-support"] }
multi_buffer = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
release_channel.workspace = true

View File

@@ -4,8 +4,9 @@ use gpui::{
Size, StrikethroughStyle, StyledText, UniformListScrollHandle, div, px, uniform_list,
};
use gpui::{AsyncWindowContext, WeakEntity};
use language::Buffer;
use itertools::Itertools;
use language::CodeLabel;
use language::{Buffer, LanguageName, LanguageRegistry};
use markdown::{Markdown, MarkdownElement};
use multi_buffer::{Anchor, ExcerptId};
use ordered_float::OrderedFloat;
@@ -15,6 +16,8 @@ use project::{CodeAction, Completion, TaskSourceKind};
use task::DebugScenario;
use task::TaskContext;
use std::collections::VecDeque;
use std::sync::Arc;
use std::{
cell::RefCell,
cmp::{Reverse, min},
@@ -41,6 +44,25 @@ pub const MENU_ASIDE_X_PADDING: Pixels = px(16.);
pub const MENU_ASIDE_MIN_WIDTH: Pixels = px(260.);
pub const MENU_ASIDE_MAX_WIDTH: Pixels = px(500.);
// Constants for the markdown cache. The purpose of this cache is to reduce flickering due to
// documentation not yet being parsed.
//
// The size of the cache is set to the number of items fetched around the current selection plus one
// for the current selection and another to avoid cases where and adjacent selection exits the
// cache. The only current benefit of a larger cache would be doing less markdown parsing when the
// selection revisits items.
//
// One future benefit of a larger cache would be reducing flicker on backspace. This would require
// not recreating the menu on every change, by not re-querying the language server when
// `is_incomplete = false`.
const MARKDOWN_CACHE_MAX_SIZE: usize = MARKDOWN_CACHE_BEFORE_ITEMS + MARKDOWN_CACHE_AFTER_ITEMS + 2;
const MARKDOWN_CACHE_BEFORE_ITEMS: usize = 2;
const MARKDOWN_CACHE_AFTER_ITEMS: usize = 2;
// Number of items beyond the visible items to resolve documentation.
const RESOLVE_BEFORE_ITEMS: usize = 4;
const RESOLVE_AFTER_ITEMS: usize = 4;
pub enum CodeContextMenu {
Completions(CompletionsMenu),
CodeActions(CodeActionsMenu),
@@ -148,13 +170,12 @@ impl CodeContextMenu {
pub fn render_aside(
&mut self,
editor: &Editor,
max_size: Size<Pixels>,
window: &mut Window,
cx: &mut Context<Editor>,
) -> Option<AnyElement> {
match self {
CodeContextMenu::Completions(menu) => menu.render_aside(editor, max_size, window, cx),
CodeContextMenu::Completions(menu) => menu.render_aside(max_size, window, cx),
CodeContextMenu::CodeActions(_) => None,
}
}
@@ -162,7 +183,7 @@ impl CodeContextMenu {
pub fn focused(&self, window: &mut Window, cx: &mut Context<Editor>) -> bool {
match self {
CodeContextMenu::Completions(completions_menu) => completions_menu
.markdown_element
.get_or_create_entry_markdown(completions_menu.selected_item, cx)
.as_ref()
.is_some_and(|markdown| markdown.focus_handle(cx).contains_focused(window, cx)),
CodeContextMenu::CodeActions(_) => false,
@@ -176,7 +197,7 @@ pub enum ContextMenuOrigin {
QuickActionBar,
}
#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct CompletionsMenu {
pub id: CompletionId,
sort_completions: bool,
@@ -191,7 +212,9 @@ pub struct CompletionsMenu {
show_completion_documentation: bool,
pub(super) ignore_completion_provider: bool,
last_rendered_range: Rc<RefCell<Option<Range<usize>>>>,
markdown_element: Option<Entity<Markdown>>,
markdown_cache: Rc<RefCell<VecDeque<(usize, Entity<Markdown>)>>>,
language_registry: Option<Arc<LanguageRegistry>>,
language: Option<LanguageName>,
snippet_sort_order: SnippetSortOrder,
}
@@ -205,6 +228,9 @@ impl CompletionsMenu {
buffer: Entity<Buffer>,
completions: Box<[Completion]>,
snippet_sort_order: SnippetSortOrder,
language_registry: Option<Arc<LanguageRegistry>>,
language: Option<LanguageName>,
cx: &mut Context<Editor>,
) -> Self {
let match_candidates = completions
.iter()
@@ -212,7 +238,7 @@ impl CompletionsMenu {
.map(|(id, completion)| StringMatchCandidate::new(id, &completion.label.filter_text()))
.collect();
Self {
let completions_menu = Self {
id,
sort_completions,
initial_position,
@@ -226,9 +252,15 @@ impl CompletionsMenu {
scroll_handle: UniformListScrollHandle::new(),
resolve_completions: true,
last_rendered_range: RefCell::new(None).into(),
markdown_element: None,
markdown_cache: RefCell::new(VecDeque::with_capacity(MARKDOWN_CACHE_MAX_SIZE)).into(),
language_registry,
language,
snippet_sort_order,
}
};
completions_menu.start_markdown_parse_for_nearby_entries(cx);
completions_menu
}
pub fn new_snippet_choices(
@@ -286,7 +318,9 @@ impl CompletionsMenu {
show_completion_documentation: false,
ignore_completion_provider: false,
last_rendered_range: RefCell::new(None).into(),
markdown_element: None,
markdown_cache: RefCell::new(VecDeque::new()).into(),
language_registry: None,
language: None,
snippet_sort_order,
}
}
@@ -359,6 +393,7 @@ impl CompletionsMenu {
self.scroll_handle
.scroll_to_item(self.selected_item, ScrollStrategy::Top);
self.resolve_visible_completions(provider, cx);
self.start_markdown_parse_for_nearby_entries(cx);
if let Some(provider) = provider {
self.handle_selection_changed(provider, window, cx);
}
@@ -433,11 +468,10 @@ impl CompletionsMenu {
// Expand the range to resolve more completions than are predicted to be visible, to reduce
// jank on navigation.
const EXTRA_TO_RESOLVE: usize = 4;
let entry_indices = util::iterate_expanded_and_wrapped_usize_range(
let entry_indices = util::expanded_and_wrapped_usize_range(
entry_range.clone(),
EXTRA_TO_RESOLVE,
EXTRA_TO_RESOLVE,
RESOLVE_BEFORE_ITEMS,
RESOLVE_AFTER_ITEMS,
entries.len(),
);
@@ -467,14 +501,120 @@ impl CompletionsMenu {
cx,
);
let completion_id = self.id;
cx.spawn(async move |editor, cx| {
if let Some(true) = resolve_task.await.log_err() {
editor.update(cx, |_, cx| cx.notify()).ok();
editor
.update(cx, |editor, cx| {
// `resolve_completions` modified state affecting display.
cx.notify();
editor.with_completions_menu_matching_id(
completion_id,
|| (),
|this| this.start_markdown_parse_for_nearby_entries(cx),
);
})
.ok();
}
})
.detach();
}
fn start_markdown_parse_for_nearby_entries(&self, cx: &mut Context<Editor>) {
// Enqueue parse tasks of nearer items first.
//
// TODO: This means that the nearer items will actually be further back in the cache, which
// is not ideal. In practice this is fine because `get_or_create_markdown` moves the current
// selection to the front (when `is_render = true`).
let entry_indices = util::wrapped_usize_outward_from(
self.selected_item,
MARKDOWN_CACHE_BEFORE_ITEMS,
MARKDOWN_CACHE_AFTER_ITEMS,
self.entries.borrow().len(),
);
for index in entry_indices {
self.get_or_create_entry_markdown(index, cx);
}
}
fn get_or_create_entry_markdown(
&self,
index: usize,
cx: &mut Context<Editor>,
) -> Option<Entity<Markdown>> {
let entries = self.entries.borrow();
if index >= entries.len() {
return None;
}
let candidate_id = entries[index].candidate_id;
match &self.completions.borrow()[candidate_id].documentation {
Some(CompletionDocumentation::MultiLineMarkdown(source)) if !source.is_empty() => Some(
self.get_or_create_markdown(candidate_id, source.clone(), false, cx)
.1,
),
Some(_) => None,
_ => None,
}
}
fn get_or_create_markdown(
&self,
candidate_id: usize,
source: SharedString,
is_render: bool,
cx: &mut Context<Editor>,
) -> (bool, Entity<Markdown>) {
let mut markdown_cache = self.markdown_cache.borrow_mut();
if let Some((cache_index, (_, markdown))) = markdown_cache
.iter()
.find_position(|(id, _)| *id == candidate_id)
{
let markdown = if is_render && cache_index != 0 {
// Move the current selection's cache entry to the front.
markdown_cache.rotate_right(1);
let cache_len = markdown_cache.len();
markdown_cache.swap(0, (cache_index + 1) % cache_len);
&markdown_cache[0].1
} else {
markdown
};
let is_parsing = markdown.update(cx, |markdown, cx| {
// `reset` is called as it's possible for documentation to change due to resolve
// requests. It does nothing if `source` is unchanged.
markdown.reset(source, cx);
markdown.is_parsing()
});
return (is_parsing, markdown.clone());
}
if markdown_cache.len() < MARKDOWN_CACHE_MAX_SIZE {
let markdown = cx.new(|cx| {
Markdown::new(
source,
self.language_registry.clone(),
self.language.clone(),
cx,
)
});
// Handles redraw when the markdown is done parsing. The current render is for a
// deferred draw, and so without this did not redraw when `markdown` notified.
cx.observe(&markdown, |_, _, cx| cx.notify()).detach();
markdown_cache.push_front((candidate_id, markdown.clone()));
(true, markdown)
} else {
debug_assert_eq!(markdown_cache.capacity(), MARKDOWN_CACHE_MAX_SIZE);
// Moves the last cache entry to the start. The ring buffer is full, so this does no
// copying and just shifts indexes.
markdown_cache.rotate_right(1);
markdown_cache[0].0 = candidate_id;
let markdown = &markdown_cache[0].1;
markdown.update(cx, |markdown, cx| markdown.reset(source, cx));
(true, markdown.clone())
}
}
pub fn visible(&self) -> bool {
!self.entries.borrow().is_empty()
}
@@ -625,7 +765,6 @@ impl CompletionsMenu {
fn render_aside(
&mut self,
editor: &Editor,
max_size: Size<Pixels>,
window: &mut Window,
cx: &mut Context<Editor>,
@@ -644,33 +783,14 @@ impl CompletionsMenu {
plain_text: Some(text),
..
} => div().child(text.clone()),
CompletionDocumentation::MultiLineMarkdown(parsed) if !parsed.is_empty() => {
let markdown = self.markdown_element.get_or_insert_with(|| {
let markdown = cx.new(|cx| {
let languages = editor
.workspace
.as_ref()
.and_then(|(workspace, _)| workspace.upgrade())
.map(|workspace| workspace.read(cx).app_state().languages.clone());
let language = editor
.language_at(self.initial_position, cx)
.map(|l| l.name().to_proto());
Markdown::new(SharedString::default(), languages, language, cx)
});
// Handles redraw when the markdown is done parsing. The current render is for a
// deferred draw and so was not getting redrawn when `markdown` notified.
cx.observe(&markdown, |_, _, cx| cx.notify()).detach();
markdown
});
let is_parsing = markdown.update(cx, |markdown, cx| {
markdown.reset(parsed.clone(), cx);
markdown.is_parsing()
});
CompletionDocumentation::MultiLineMarkdown(source) if !source.is_empty() => {
let (is_parsing, markdown) =
self.get_or_create_markdown(mat.candidate_id, source.clone(), true, cx);
if is_parsing {
return None;
}
div().child(
MarkdownElement::new(markdown.clone(), hover_markdown_style(window, cx))
MarkdownElement::new(markdown, hover_markdown_style(window, cx))
.code_block_renderer(markdown::CodeBlockRenderer::Default {
copy_button: false,
copy_button_on_hover: false,
@@ -882,13 +1002,7 @@ impl CompletionsMenu {
// another opened. `provider.selection_changed` should not be called in this case.
let this_menu_still_active = editor
.read_with(cx, |editor, _cx| {
if let Some(CodeContextMenu::Completions(completions_menu)) =
editor.context_menu.borrow().as_ref()
{
completions_menu.id == self.id
} else {
false
}
editor.with_completions_menu_matching_id(self.id, || false, |_| true)
})
.unwrap_or(false);
if this_menu_still_active {

View File

@@ -201,7 +201,7 @@ use ui::{
ButtonSize, ButtonStyle, ContextMenu, Disclosure, IconButton, IconButtonShape, IconName,
IconSize, Indicator, Key, Tooltip, h_flex, prelude::*,
};
use util::{RangeExt, ResultExt, TryFutureExt, maybe, post_inc, wrap_with_prefix};
use util::{RangeExt, ResultExt, TryFutureExt, maybe, post_inc};
use workspace::{
CollaboratorId, Item as WorkspaceItem, ItemId, ItemNavHistory, OpenInTerminal, OpenTerminal,
RestoreOnStartupBehavior, SERIALIZATION_THROTTLE_TIME, SplitDirection, TabBarSettings, Toast,
@@ -936,6 +936,8 @@ pub struct Editor {
select_next_state: Option<SelectNextState>,
select_prev_state: Option<SelectNextState>,
selection_history: SelectionHistory,
defer_selection_effects: bool,
deferred_selection_effects_state: Option<DeferredSelectionEffectsState>,
autoclose_regions: Vec<AutocloseRegion>,
snippet_stack: InvalidationStack<SnippetState>,
select_syntax_node_history: SelectSyntaxNodeHistory,
@@ -1195,6 +1197,14 @@ impl Default for SelectionHistoryMode {
}
}
struct DeferredSelectionEffectsState {
changed: bool,
show_completions: bool,
autoscroll: Option<Autoscroll>,
old_cursor_position: Anchor,
history_entry: SelectionHistoryEntry,
}
#[derive(Default)]
struct SelectionHistory {
#[allow(clippy::type_complexity)]
@@ -1670,6 +1680,13 @@ impl Editor {
editor
.refresh_inlay_hints(InlayHintRefreshReason::RefreshRequested, cx);
}
project::Event::LanguageServerAdded(..)
| project::Event::LanguageServerRemoved(..) => {
if editor.tasks_update_task.is_none() {
editor.tasks_update_task =
Some(editor.refresh_runnables(window, cx));
}
}
project::Event::SnippetEdit(id, snippet_edits) => {
if let Some(buffer) = editor.buffer.read(cx).buffer(*id) {
let focus_handle = editor.focus_handle(cx);
@@ -1784,6 +1801,8 @@ impl Editor {
select_next_state: None,
select_prev_state: None,
selection_history: SelectionHistory::default(),
defer_selection_effects: false,
deferred_selection_effects_state: None,
autoclose_regions: Vec::new(),
snippet_stack: InvalidationStack::default(),
select_syntax_node_history: SelectSyntaxNodeHistory::default(),
@@ -2947,6 +2966,9 @@ impl Editor {
Subscription::join(other_subscription, this_subscription)
}
/// Changes selections using the provided mutation function. Changes to `self.selections` occur
/// immediately, but when run within `transact` or `with_selection_effects_deferred` other
/// effects of selection change occur at the end of the transaction.
pub fn change_selections<R>(
&mut self,
autoscroll: Option<Autoscroll>,
@@ -2954,39 +2976,105 @@ impl Editor {
cx: &mut Context<Self>,
change: impl FnOnce(&mut MutableSelectionsCollection<'_>) -> R,
) -> R {
self.change_selections_inner(autoscroll, true, window, cx, change)
self.change_selections_inner(true, autoscroll, window, cx, change)
}
fn change_selections_inner<R>(
pub(crate) fn change_selections_without_showing_completions<R>(
&mut self,
autoscroll: Option<Autoscroll>,
request_completions: bool,
window: &mut Window,
cx: &mut Context<Self>,
change: impl FnOnce(&mut MutableSelectionsCollection<'_>) -> R,
) -> R {
let old_cursor_position = self.selections.newest_anchor().head();
self.push_to_selection_history();
self.change_selections_inner(false, autoscroll, window, cx, change)
}
fn change_selections_inner<R>(
&mut self,
show_completions: bool,
autoscroll: Option<Autoscroll>,
window: &mut Window,
cx: &mut Context<Self>,
change: impl FnOnce(&mut MutableSelectionsCollection<'_>) -> R,
) -> R {
if let Some(state) = &mut self.deferred_selection_effects_state {
state.autoscroll = autoscroll.or(state.autoscroll);
state.show_completions = show_completions;
let (changed, result) = self.selections.change_with(cx, change);
state.changed |= changed;
return result;
}
let mut state = DeferredSelectionEffectsState {
changed: false,
show_completions,
autoscroll,
old_cursor_position: self.selections.newest_anchor().head(),
history_entry: SelectionHistoryEntry {
selections: self.selections.disjoint_anchors(),
select_next_state: self.select_next_state.clone(),
select_prev_state: self.select_prev_state.clone(),
add_selections_state: self.add_selections_state.clone(),
},
};
let (changed, result) = self.selections.change_with(cx, change);
state.changed = state.changed || changed;
if self.defer_selection_effects {
self.deferred_selection_effects_state = Some(state);
} else {
self.apply_selection_effects(state, window, cx);
}
result
}
if changed {
if let Some(autoscroll) = autoscroll {
/// Defers the effects of selection change, so that the effects of multiple calls to
/// `change_selections` are applied at the end. This way these intermediate states aren't added
/// to selection history and the state of popovers based on selection position aren't
/// erroneously updated.
pub fn with_selection_effects_deferred<R>(
&mut self,
window: &mut Window,
cx: &mut Context<Self>,
update: impl FnOnce(&mut Self, &mut Window, &mut Context<Self>) -> R,
) -> R {
let already_deferred = self.defer_selection_effects;
self.defer_selection_effects = true;
let result = update(self, window, cx);
if !already_deferred {
self.defer_selection_effects = false;
if let Some(state) = self.deferred_selection_effects_state.take() {
self.apply_selection_effects(state, window, cx);
}
}
result
}
fn apply_selection_effects(
&mut self,
state: DeferredSelectionEffectsState,
window: &mut Window,
cx: &mut Context<Self>,
) {
if state.changed {
self.selection_history.push(state.history_entry);
if let Some(autoscroll) = state.autoscroll {
self.request_autoscroll(autoscroll, cx);
}
self.selections_did_change(true, &old_cursor_position, request_completions, window, cx);
if self.should_open_signature_help_automatically(
let old_cursor_position = &state.old_cursor_position;
self.selections_did_change(
true,
&old_cursor_position,
self.signature_help_state.backspace_pressed(),
state.show_completions,
window,
cx,
) {
);
if self.should_open_signature_help_automatically(&old_cursor_position, cx) {
self.show_signature_help(&ShowSignatureHelp, window, cx);
}
self.signature_help_state.set_backspace_pressed(false);
}
result
}
pub fn edit<I, S, T>(&mut self, edits: I, cx: &mut Context<Self>)
@@ -3870,9 +3958,12 @@ impl Editor {
}
let had_active_inline_completion = this.has_active_inline_completion();
this.change_selections_inner(Some(Autoscroll::fit()), false, window, cx, |s| {
s.select(new_selections)
});
this.change_selections_without_showing_completions(
Some(Autoscroll::fit()),
window,
cx,
|s| s.select(new_selections),
);
if !bracket_inserted {
if let Some(on_type_format_task) =
@@ -4987,14 +5078,12 @@ impl Editor {
(buffer_position..buffer_position, None)
};
let completion_settings = language_settings(
buffer_snapshot
.language_at(buffer_position)
.map(|language| language.name()),
buffer_snapshot.file(),
cx,
)
.completions;
let language = buffer_snapshot
.language_at(buffer_position)
.map(|language| language.name());
let completion_settings =
language_settings(language.clone(), buffer_snapshot.file(), cx).completions;
// The document can be large, so stay in reasonable bounds when searching for words,
// otherwise completion pop-up might be slow to appear.
@@ -5106,16 +5195,26 @@ impl Editor {
let menu = if completions.is_empty() {
None
} else {
let mut menu = CompletionsMenu::new(
id,
sort_completions,
show_completion_documentation,
ignore_completion_provider,
position,
buffer.clone(),
completions.into(),
snippet_sort_order,
);
let mut menu = editor.update(cx, |editor, cx| {
let languages = editor
.workspace
.as_ref()
.and_then(|(workspace, _)| workspace.upgrade())
.map(|workspace| workspace.read(cx).app_state().languages.clone());
CompletionsMenu::new(
id,
sort_completions,
show_completion_documentation,
ignore_completion_provider,
position,
buffer.clone(),
completions.into(),
snippet_sort_order,
languages,
language,
cx,
)
})?;
menu.filter(
if filter_completions {
@@ -5190,6 +5289,22 @@ impl Editor {
}
}
pub fn with_completions_menu_matching_id<R>(
&self,
id: CompletionId,
on_absent: impl FnOnce() -> R,
on_match: impl FnOnce(&mut CompletionsMenu) -> R,
) -> R {
let mut context_menu = self.context_menu.borrow_mut();
let Some(CodeContextMenu::Completions(completions_menu)) = &mut *context_menu else {
return on_absent();
};
if completions_menu.id != id {
return on_absent();
}
on_match(completions_menu)
}
pub fn confirm_completion(
&mut self,
action: &ConfirmCompletion,
@@ -8686,7 +8801,7 @@ impl Editor {
) -> Option<AnyElement> {
self.context_menu.borrow_mut().as_mut().and_then(|menu| {
if menu.visible() {
menu.render_aside(self, max_size, window, cx)
menu.render_aside(max_size, window, cx)
} else {
None
}
@@ -9002,7 +9117,6 @@ impl Editor {
}
}
this.signature_help_state.set_backspace_pressed(true);
this.change_selections(Some(Autoscroll::fit()), window, cx, |s| {
s.select(selections)
});
@@ -12724,7 +12838,6 @@ impl Editor {
) -> Result<()> {
self.hide_mouse_cursor(&HideMouseCursorOrigin::MovementAction);
self.push_to_selection_history();
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
self.select_next_match_internal(&display_map, false, None, window, cx)?;
@@ -12777,7 +12890,6 @@ impl Editor {
cx: &mut Context<Self>,
) -> Result<()> {
self.hide_mouse_cursor(&HideMouseCursorOrigin::MovementAction);
self.push_to_selection_history();
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
self.select_next_match_internal(
&display_map,
@@ -12796,7 +12908,6 @@ impl Editor {
cx: &mut Context<Self>,
) -> Result<()> {
self.hide_mouse_cursor(&HideMouseCursorOrigin::MovementAction);
self.push_to_selection_history();
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
let buffer = &display_map.buffer_snapshot;
let mut selections = self.selections.all::<usize>(cx);
@@ -13519,6 +13630,7 @@ impl Editor {
}
let project = self.project.as_ref().map(Entity::downgrade);
let task_sources = self.lsp_task_sources(cx);
let multi_buffer = self.buffer.downgrade();
cx.spawn_in(window, async move |editor, cx| {
cx.background_executor().timer(UPDATE_DEBOUNCE).await;
let Some(project) = project.and_then(|p| p.upgrade()) else {
@@ -13602,7 +13714,19 @@ impl Editor {
return;
};
let rows = Self::runnable_rows(project, display_snapshot, new_rows, cx.clone());
let Ok(prefer_lsp) = multi_buffer.update(cx, |buffer, cx| {
buffer.language_settings(cx).tasks.prefer_lsp
}) else {
return;
};
let rows = Self::runnable_rows(
project,
display_snapshot,
prefer_lsp && !lsp_tasks_by_rows.is_empty(),
new_rows,
cx.clone(),
);
editor
.update(cx, |editor, _| {
editor.clear_tasks();
@@ -13630,15 +13754,21 @@ impl Editor {
fn runnable_rows(
project: Entity<Project>,
snapshot: DisplaySnapshot,
prefer_lsp: bool,
runnable_ranges: Vec<RunnableRange>,
mut cx: AsyncWindowContext,
) -> Vec<((BufferId, BufferRow), RunnableTasks)> {
runnable_ranges
.into_iter()
.filter_map(|mut runnable| {
let tasks = cx
let mut tasks = cx
.update(|_, cx| Self::templates_with_tags(&project, &mut runnable.runnable, cx))
.ok()?;
if prefer_lsp {
tasks.retain(|(task_kind, _)| {
!matches!(task_kind, TaskSourceKind::Language { .. })
});
}
if tasks.is_empty() {
return None;
}
@@ -14971,7 +15101,7 @@ impl Editor {
text_style = text_style.highlight(highlight_style);
}
div()
.block_mouse_down()
.block_mouse_except_scroll()
.pl(cx.anchor_x)
.child(EditorElement::new(
&rename_editor,
@@ -15647,24 +15777,17 @@ impl Editor {
self.selections_did_change(false, &old_cursor_position, true, window, cx);
}
fn push_to_selection_history(&mut self) {
self.selection_history.push(SelectionHistoryEntry {
selections: self.selections.disjoint_anchors(),
select_next_state: self.select_next_state.clone(),
select_prev_state: self.select_prev_state.clone(),
add_selections_state: self.add_selections_state.clone(),
});
}
pub fn transact(
&mut self,
window: &mut Window,
cx: &mut Context<Self>,
update: impl FnOnce(&mut Self, &mut Window, &mut Context<Self>),
) -> Option<TransactionId> {
self.start_transaction_at(Instant::now(), window, cx);
update(self, window, cx);
self.end_transaction_at(Instant::now(), cx)
self.with_selection_effects_deferred(window, cx, |this, window, cx| {
this.start_transaction_at(Instant::now(), window, cx);
update(this, window, cx);
this.end_transaction_at(Instant::now(), cx)
})
}
pub fn start_transaction_at(
@@ -18585,16 +18708,20 @@ impl Editor {
}
let minimap_settings = EditorSettings::get_global(cx).minimap;
if self.minimap_visibility.settings_visibility() != minimap_settings.minimap_enabled() {
self.set_minimap_visibility(
MinimapVisibility::for_mode(self.mode(), cx),
window,
cx,
);
} else if let Some(minimap_entity) = self.minimap.as_ref() {
minimap_entity.update(cx, |minimap_editor, cx| {
minimap_editor.update_minimap_configuration(minimap_settings, cx)
})
if self.minimap_visibility != MinimapVisibility::Disabled {
if self.minimap_visibility.settings_visibility()
!= minimap_settings.minimap_enabled()
{
self.set_minimap_visibility(
MinimapVisibility::for_mode(self.mode(), cx),
window,
cx,
);
} else if let Some(minimap_entity) = self.minimap.as_ref() {
minimap_entity.update(cx, |minimap_editor, cx| {
minimap_editor.update_minimap_configuration(minimap_settings, cx)
})
}
}
}
@@ -19563,6 +19690,347 @@ fn update_uncommitted_diff_for_buffer(
})
}
fn char_len_with_expanded_tabs(offset: usize, text: &str, tab_size: NonZeroU32) -> usize {
let tab_size = tab_size.get() as usize;
let mut width = offset;
for ch in text.chars() {
width += if ch == '\t' {
tab_size - (width % tab_size)
} else {
1
};
}
width - offset
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_string_size_with_expanded_tabs() {
let nz = |val| NonZeroU32::new(val).unwrap();
assert_eq!(char_len_with_expanded_tabs(0, "", nz(4)), 0);
assert_eq!(char_len_with_expanded_tabs(0, "hello", nz(4)), 5);
assert_eq!(char_len_with_expanded_tabs(0, "\thello", nz(4)), 9);
assert_eq!(char_len_with_expanded_tabs(0, "abc\tab", nz(4)), 6);
assert_eq!(char_len_with_expanded_tabs(0, "hello\t", nz(4)), 8);
assert_eq!(char_len_with_expanded_tabs(0, "\t\t", nz(8)), 16);
assert_eq!(char_len_with_expanded_tabs(0, "x\t", nz(8)), 8);
assert_eq!(char_len_with_expanded_tabs(7, "x\t", nz(8)), 9);
}
}
/// Tokenizes a string into runs of text that should stick together, or that is whitespace.
struct WordBreakingTokenizer<'a> {
input: &'a str,
}
impl<'a> WordBreakingTokenizer<'a> {
fn new(input: &'a str) -> Self {
Self { input }
}
}
fn is_char_ideographic(ch: char) -> bool {
use unicode_script::Script::*;
use unicode_script::UnicodeScript;
matches!(ch.script(), Han | Tangut | Yi)
}
fn is_grapheme_ideographic(text: &str) -> bool {
text.chars().any(is_char_ideographic)
}
fn is_grapheme_whitespace(text: &str) -> bool {
text.chars().any(|x| x.is_whitespace())
}
fn should_stay_with_preceding_ideograph(text: &str) -> bool {
text.chars().next().map_or(false, |ch| {
matches!(ch, '。' | '、' | '' | '' | '' | '' | '' | '…')
})
}
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
enum WordBreakToken<'a> {
Word { token: &'a str, grapheme_len: usize },
InlineWhitespace { token: &'a str, grapheme_len: usize },
Newline,
}
impl<'a> Iterator for WordBreakingTokenizer<'a> {
/// Yields a span, the count of graphemes in the token, and whether it was
/// whitespace. Note that it also breaks at word boundaries.
type Item = WordBreakToken<'a>;
fn next(&mut self) -> Option<Self::Item> {
use unicode_segmentation::UnicodeSegmentation;
if self.input.is_empty() {
return None;
}
let mut iter = self.input.graphemes(true).peekable();
let mut offset = 0;
let mut grapheme_len = 0;
if let Some(first_grapheme) = iter.next() {
let is_newline = first_grapheme == "\n";
let is_whitespace = is_grapheme_whitespace(first_grapheme);
offset += first_grapheme.len();
grapheme_len += 1;
if is_grapheme_ideographic(first_grapheme) && !is_whitespace {
if let Some(grapheme) = iter.peek().copied() {
if should_stay_with_preceding_ideograph(grapheme) {
offset += grapheme.len();
grapheme_len += 1;
}
}
} else {
let mut words = self.input[offset..].split_word_bound_indices().peekable();
let mut next_word_bound = words.peek().copied();
if next_word_bound.map_or(false, |(i, _)| i == 0) {
next_word_bound = words.next();
}
while let Some(grapheme) = iter.peek().copied() {
if next_word_bound.map_or(false, |(i, _)| i == offset) {
break;
};
if is_grapheme_whitespace(grapheme) != is_whitespace
|| (grapheme == "\n") != is_newline
{
break;
};
offset += grapheme.len();
grapheme_len += 1;
iter.next();
}
}
let token = &self.input[..offset];
self.input = &self.input[offset..];
if token == "\n" {
Some(WordBreakToken::Newline)
} else if is_whitespace {
Some(WordBreakToken::InlineWhitespace {
token,
grapheme_len,
})
} else {
Some(WordBreakToken::Word {
token,
grapheme_len,
})
}
} else {
None
}
}
}
#[test]
fn test_word_breaking_tokenizer() {
let tests: &[(&str, &[WordBreakToken<'static>])] = &[
("", &[]),
(" ", &[whitespace(" ", 2)]),
("Ʒ", &[word("Ʒ", 1)]),
("Ǽ", &[word("Ǽ", 1)]),
("", &[word("", 1)]),
("⋑⋑", &[word("⋑⋑", 2)]),
(
"原理,进而",
&[word("", 1), word("理,", 2), word("", 1), word("", 1)],
),
(
"hello world",
&[word("hello", 5), whitespace(" ", 1), word("world", 5)],
),
(
"hello, world",
&[word("hello,", 6), whitespace(" ", 1), word("world", 5)],
),
(
" hello world",
&[
whitespace(" ", 2),
word("hello", 5),
whitespace(" ", 1),
word("world", 5),
],
),
(
"这是什么 \n 钢笔",
&[
word("", 1),
word("", 1),
word("", 1),
word("", 1),
whitespace(" ", 1),
newline(),
whitespace(" ", 1),
word("", 1),
word("", 1),
],
),
("mutton", &[whitespace("", 1), word("mutton", 6)]),
];
fn word(token: &'static str, grapheme_len: usize) -> WordBreakToken<'static> {
WordBreakToken::Word {
token,
grapheme_len,
}
}
fn whitespace(token: &'static str, grapheme_len: usize) -> WordBreakToken<'static> {
WordBreakToken::InlineWhitespace {
token,
grapheme_len,
}
}
fn newline() -> WordBreakToken<'static> {
WordBreakToken::Newline
}
for (input, result) in tests {
assert_eq!(
WordBreakingTokenizer::new(input)
.collect::<Vec<_>>()
.as_slice(),
*result,
);
}
}
fn wrap_with_prefix(
line_prefix: String,
unwrapped_text: String,
wrap_column: usize,
tab_size: NonZeroU32,
preserve_existing_whitespace: bool,
) -> String {
let line_prefix_len = char_len_with_expanded_tabs(0, &line_prefix, tab_size);
let mut wrapped_text = String::new();
let mut current_line = line_prefix.clone();
let tokenizer = WordBreakingTokenizer::new(&unwrapped_text);
let mut current_line_len = line_prefix_len;
let mut in_whitespace = false;
for token in tokenizer {
let have_preceding_whitespace = in_whitespace;
match token {
WordBreakToken::Word {
token,
grapheme_len,
} => {
in_whitespace = false;
if current_line_len + grapheme_len > wrap_column
&& current_line_len != line_prefix_len
{
wrapped_text.push_str(current_line.trim_end());
wrapped_text.push('\n');
current_line.truncate(line_prefix.len());
current_line_len = line_prefix_len;
}
current_line.push_str(token);
current_line_len += grapheme_len;
}
WordBreakToken::InlineWhitespace {
mut token,
mut grapheme_len,
} => {
in_whitespace = true;
if have_preceding_whitespace && !preserve_existing_whitespace {
continue;
}
if !preserve_existing_whitespace {
token = " ";
grapheme_len = 1;
}
if current_line_len + grapheme_len > wrap_column {
wrapped_text.push_str(current_line.trim_end());
wrapped_text.push('\n');
current_line.truncate(line_prefix.len());
current_line_len = line_prefix_len;
} else if current_line_len != line_prefix_len || preserve_existing_whitespace {
current_line.push_str(token);
current_line_len += grapheme_len;
}
}
WordBreakToken::Newline => {
in_whitespace = true;
if preserve_existing_whitespace {
wrapped_text.push_str(current_line.trim_end());
wrapped_text.push('\n');
current_line.truncate(line_prefix.len());
current_line_len = line_prefix_len;
} else if have_preceding_whitespace {
continue;
} else if current_line_len + 1 > wrap_column && current_line_len != line_prefix_len
{
wrapped_text.push_str(current_line.trim_end());
wrapped_text.push('\n');
current_line.truncate(line_prefix.len());
current_line_len = line_prefix_len;
} else if current_line_len != line_prefix_len {
current_line.push(' ');
current_line_len += 1;
}
}
}
}
if !current_line.is_empty() {
wrapped_text.push_str(&current_line);
}
wrapped_text
}
#[test]
fn test_wrap_with_prefix() {
assert_eq!(
wrap_with_prefix(
"# ".to_string(),
"abcdefg".to_string(),
4,
NonZeroU32::new(4).unwrap(),
false,
),
"# abcdefg"
);
assert_eq!(
wrap_with_prefix(
"".to_string(),
"\thello world".to_string(),
8,
NonZeroU32::new(4).unwrap(),
false,
),
"hello\nworld"
);
assert_eq!(
wrap_with_prefix(
"// ".to_string(),
"xx \nyy zz aa bb cc".to_string(),
12,
NonZeroU32::new(4).unwrap(),
false,
),
"// xx yy zz\n// aa bb cc"
);
assert_eq!(
wrap_with_prefix(
String::new(),
"这是什么 \n 钢笔".to_string(),
3,
NonZeroU32::new(4).unwrap(),
false,
),
"这是什\n么 钢\n"
);
}
pub trait CollaborationHub {
fn collaborators<'a>(&self, cx: &'a App) -> &'a HashMap<PeerId, Collaborator>;
fn user_participant_indices<'a>(&self, cx: &'a App) -> &'a HashMap<u64, ParticipantIndex>;
@@ -21512,7 +21980,7 @@ fn render_diff_hunk_controls(
.rounded_b_lg()
.bg(cx.theme().colors().editor_background)
.gap_1()
.stop_mouse_events_except_scroll()
.block_mouse_except_scroll()
.shadow_md()
.child(if status.has_secondary_hunk() {
Button::new(("stage", row as u64), "Stage")

View File

@@ -9111,11 +9111,10 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) {
lsp::Url::from_file_path(path!("/file.rs")).unwrap()
);
assert_eq!(params.options.tab_size, 8);
Ok(Some(vec![]))
Ok(Some(Vec::new()))
})
.next()
.await;
cx.executor().start_waiting();
save.await;
}
@@ -16769,9 +16768,9 @@ fn indent_guide(buffer_id: BufferId, start_row: u32, end_row: u32, depth: u32) -
async fn test_indent_guide_single_line(cx: &mut TestAppContext) {
let (buffer_id, mut cx) = setup_indent_guides_editor(
&"
fn main() {
let a = 1;
}"
fn main() {
let a = 1;
}"
.unindent(),
cx,
)
@@ -16784,10 +16783,10 @@ async fn test_indent_guide_single_line(cx: &mut TestAppContext) {
async fn test_indent_guide_simple_block(cx: &mut TestAppContext) {
let (buffer_id, mut cx) = setup_indent_guides_editor(
&"
fn main() {
let a = 1;
let b = 2;
}"
fn main() {
let a = 1;
let b = 2;
}"
.unindent(),
cx,
)
@@ -16800,14 +16799,14 @@ async fn test_indent_guide_simple_block(cx: &mut TestAppContext) {
async fn test_indent_guide_nested(cx: &mut TestAppContext) {
let (buffer_id, mut cx) = setup_indent_guides_editor(
&"
fn main() {
let a = 1;
if a == 3 {
let b = 2;
} else {
let c = 3;
}
}"
fn main() {
let a = 1;
if a == 3 {
let b = 2;
} else {
let c = 3;
}
}"
.unindent(),
cx,
)
@@ -16829,11 +16828,11 @@ async fn test_indent_guide_nested(cx: &mut TestAppContext) {
async fn test_indent_guide_tab(cx: &mut TestAppContext) {
let (buffer_id, mut cx) = setup_indent_guides_editor(
&"
fn main() {
let a = 1;
let b = 2;
let c = 3;
}"
fn main() {
let a = 1;
let b = 2;
let c = 3;
}"
.unindent(),
cx,
)
@@ -16963,6 +16962,72 @@ async fn test_indent_guide_ends_off_screen(cx: &mut TestAppContext) {
);
}
#[gpui::test]
async fn test_indent_guide_with_folds(cx: &mut TestAppContext) {
let (buffer_id, mut cx) = setup_indent_guides_editor(
&"
fn main() {
if a {
b(
c,
d,
)
} else {
e(
f
)
}
}"
.unindent(),
cx,
)
.await;
assert_indent_guides(
0..11,
vec![
indent_guide(buffer_id, 1, 10, 0),
indent_guide(buffer_id, 2, 5, 1),
indent_guide(buffer_id, 7, 9, 1),
indent_guide(buffer_id, 3, 4, 2),
indent_guide(buffer_id, 8, 8, 2),
],
None,
&mut cx,
);
cx.update_editor(|editor, window, cx| {
editor.fold_at(MultiBufferRow(2), window, cx);
assert_eq!(
editor.display_text(cx),
"
fn main() {
if a {
b(⋯
)
} else {
e(
f
)
}
}"
.unindent()
);
});
assert_indent_guides(
0..11,
vec![
indent_guide(buffer_id, 1, 10, 0),
indent_guide(buffer_id, 2, 5, 1),
indent_guide(buffer_id, 7, 9, 1),
indent_guide(buffer_id, 8, 8, 2),
],
None,
&mut cx,
);
}
#[gpui::test]
async fn test_indent_guide_without_brackets(cx: &mut TestAppContext) {
let (buffer_id, mut cx) = setup_indent_guides_editor(

View File

@@ -42,13 +42,13 @@ use git::{
use gpui::{
Action, Along, AnyElement, App, AppContext, AvailableSpace, Axis as ScrollbarAxis, BorderStyle,
Bounds, ClickEvent, ContentMask, Context, Corner, Corners, CursorStyle, DispatchPhase, Edges,
Element, ElementInputHandler, Entity, Focusable as _, FontId, GlobalElementId, Hitbox, Hsla,
InteractiveElement, IntoElement, IsZero, Keystroke, Length, ModifiersChangedEvent, MouseButton,
MouseDownEvent, MouseMoveEvent, MouseUpEvent, PaintQuad, ParentElement, Pixels, ScrollDelta,
ScrollHandle, ScrollWheelEvent, ShapedLine, SharedString, Size, StatefulInteractiveElement,
Style, Styled, TextRun, TextStyleRefinement, WeakEntity, Window, anchored, deferred, div, fill,
linear_color_stop, linear_gradient, outline, point, px, quad, relative, size, solid_background,
transparent_black,
Element, ElementInputHandler, Entity, Focusable as _, FontId, GlobalElementId, Hitbox,
HitboxBehavior, Hsla, InteractiveElement, IntoElement, IsZero, Keystroke, Length,
ModifiersChangedEvent, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, PaintQuad,
ParentElement, Pixels, ScrollDelta, ScrollHandle, ScrollWheelEvent, ShapedLine, SharedString,
Size, StatefulInteractiveElement, Style, Styled, TextRun, TextStyleRefinement, WeakEntity,
Window, anchored, deferred, div, fill, linear_color_stop, linear_gradient, outline, point, px,
quad, relative, size, solid_background, transparent_black,
};
use itertools::Itertools;
use language::language_settings::{
@@ -1512,6 +1512,17 @@ impl EditorElement {
ShowScrollbar::Never => return None,
};
// The horizontal scrollbar is usually slightly offset to align nicely with
// indent guides. However, this offset is not needed if indent guides are
// disabled for the current editor.
let content_offset = self
.editor
.read(cx)
.show_indent_guides
.is_none_or(|should_show| should_show)
.then_some(content_offset)
.unwrap_or_default();
Some(EditorScrollbars::from_scrollbar_axes(
ScrollbarAxes {
horizontal: scrollbar_settings.axes.horizontal
@@ -1609,7 +1620,7 @@ impl EditorElement {
);
let layout = ScrollbarLayout::for_minimap(
window.insert_hitbox(minimap_bounds, false),
window.insert_hitbox(minimap_bounds, HitboxBehavior::Normal),
visible_editor_lines,
total_editor_lines,
minimap_line_height,
@@ -1780,7 +1791,7 @@ impl EditorElement {
if matches!(hunk, DisplayDiffHunk::Unfolded { .. }) {
let hunk_bounds =
Self::diff_hunk_bounds(snapshot, line_height, gutter_hitbox.bounds, hunk);
*hitbox = Some(window.insert_hitbox(hunk_bounds, true));
*hitbox = Some(window.insert_hitbox(hunk_bounds, HitboxBehavior::BlockMouse));
}
}
}
@@ -2872,7 +2883,7 @@ impl EditorElement {
let hitbox = line_origin.map(|line_origin| {
window.insert_hitbox(
Bounds::new(line_origin, size(shaped_line.width, line_height)),
false,
HitboxBehavior::Normal,
)
});
#[cfg(test)]
@@ -6360,7 +6371,7 @@ impl EditorElement {
}
};
if phase == DispatchPhase::Bubble && hitbox.is_hovered(window) {
if phase == DispatchPhase::Bubble && hitbox.should_handle_scroll(window) {
delta = delta.coalesce(event.delta);
editor.update(cx, |editor, cx| {
let position_map: &PositionMap = &position_map;
@@ -7607,7 +7618,10 @@ impl Element for EditorElement {
editor.gutter_dimensions = gutter_dimensions;
editor.set_visible_line_count(bounds.size.height / line_height, window, cx);
if matches!(editor.mode, EditorMode::Minimap { .. }) {
if matches!(
editor.mode,
EditorMode::AutoHeight { .. } | EditorMode::Minimap { .. }
) {
snapshot
} else {
let wrap_width_for = |column: u32| (column as f32 * em_advance).ceil();
@@ -7637,15 +7651,17 @@ impl Element for EditorElement {
.map(|(guide, active)| (self.column_pixels(*guide, window, cx), *active))
.collect::<SmallVec<[_; 2]>>();
let hitbox = window.insert_hitbox(bounds, false);
let gutter_hitbox =
window.insert_hitbox(gutter_bounds(bounds, gutter_dimensions), false);
let hitbox = window.insert_hitbox(bounds, HitboxBehavior::Normal);
let gutter_hitbox = window.insert_hitbox(
gutter_bounds(bounds, gutter_dimensions),
HitboxBehavior::Normal,
);
let text_hitbox = window.insert_hitbox(
Bounds {
origin: gutter_hitbox.top_right(),
size: size(text_width, bounds.size.height),
},
false,
HitboxBehavior::Normal,
);
let content_origin = text_hitbox.origin + content_offset;
@@ -8866,7 +8882,7 @@ impl EditorScrollbars {
})
.map(|(viewport_size, scroll_range)| {
ScrollbarLayout::new(
window.insert_hitbox(scrollbar_bounds_for(axis), false),
window.insert_hitbox(scrollbar_bounds_for(axis), HitboxBehavior::Normal),
viewport_size,
scroll_range,
glyph_grid_cell.along(axis),
@@ -9626,7 +9642,6 @@ fn compute_auto_height_layout(
let font_size = style.text.font_size.to_pixels(window.rem_size());
let line_height = style.text.line_height_in_pixels(window.rem_size());
let em_width = window.text_system().em_width(font_id, font_size).unwrap();
let em_advance = window.text_system().em_advance(font_id, font_size).unwrap();
let mut snapshot = editor.snapshot(window, cx);
let gutter_dimensions = snapshot
@@ -9643,18 +9658,10 @@ fn compute_auto_height_layout(
let overscroll = size(em_width, px(0.));
let editor_width = text_width - gutter_dimensions.margin - overscroll.width - em_width;
let content_offset = point(gutter_dimensions.margin, Pixels::ZERO);
let editor_content_width = editor_width - content_offset.x;
let wrap_width_for = |column: u32| (column as f32 * em_advance).ceil();
let wrap_width = match editor.soft_wrap_mode(cx) {
SoftWrap::GitDiff => None,
SoftWrap::None => Some(wrap_width_for(MAX_LINE_LEN as u32 / 2)),
SoftWrap::EditorWidth => Some(editor_content_width),
SoftWrap::Column(column) => Some(wrap_width_for(column)),
SoftWrap::Bounded(column) => Some(editor_content_width.min(wrap_width_for(column))),
};
if editor.set_wrap_width(wrap_width, cx) {
snapshot = editor.snapshot(window, cx);
if !matches!(editor.soft_wrap_mode(cx), SoftWrap::None) {
if editor.set_wrap_width(Some(editor_width), cx) {
snapshot = editor.snapshot(window, cx);
}
}
let scroll_height = (snapshot.max_point().row().next_row().0 as f32) * line_height;

View File

@@ -583,13 +583,6 @@ async fn parse_blocks(
language: Option<Arc<Language>>,
cx: &mut AsyncWindowContext,
) -> Option<Entity<Markdown>> {
let fallback_language_name = if let Some(ref l) = language {
let l = Arc::clone(l);
Some(l.lsp_id().clone())
} else {
None
};
let combined_text = blocks
.iter()
.map(|block| match &block.kind {
@@ -607,7 +600,7 @@ async fn parse_blocks(
Markdown::new(
combined_text.into(),
Some(language_registry.clone()),
fallback_language_name,
language.map(|language| language.name()),
cx,
)
})
@@ -1057,7 +1050,9 @@ mod tests {
for (range, event) in slice.iter() {
match event {
MarkdownEvent::SubstitutedText(parsed) => rendered_text.push_str(parsed),
MarkdownEvent::SubstitutedText(parsed) => {
rendered_text.push_str(parsed.as_str())
}
MarkdownEvent::Text | MarkdownEvent::Code => {
rendered_text.push_str(&text[range.clone()])
}

View File

@@ -1,9 +1,9 @@
use std::{ops::Range, time::Duration};
use std::{cmp::Ordering, ops::Range, time::Duration};
use collections::HashSet;
use gpui::{App, AppContext as _, Context, Task, Window};
use language::language_settings::language_settings;
use multi_buffer::{IndentGuide, MultiBufferRow};
use multi_buffer::{IndentGuide, MultiBufferRow, ToPoint};
use text::{LineIndent, Point};
use util::ResultExt;
@@ -154,12 +154,28 @@ pub fn indent_guides_in_range(
snapshot: &DisplaySnapshot,
cx: &App,
) -> Vec<IndentGuide> {
let start_anchor = snapshot
let start_offset = snapshot
.buffer_snapshot
.anchor_before(Point::new(visible_buffer_range.start.0, 0));
let end_anchor = snapshot
.point_to_offset(Point::new(visible_buffer_range.start.0, 0));
let end_offset = snapshot
.buffer_snapshot
.anchor_after(Point::new(visible_buffer_range.end.0, 0));
.point_to_offset(Point::new(visible_buffer_range.end.0, 0));
let start_anchor = snapshot.buffer_snapshot.anchor_before(start_offset);
let end_anchor = snapshot.buffer_snapshot.anchor_after(end_offset);
let mut fold_ranges = Vec::<Range<Point>>::new();
let mut folds = snapshot.folds_in_range(start_offset..end_offset).peekable();
while let Some(fold) = folds.next() {
let start = fold.range.start.to_point(&snapshot.buffer_snapshot);
let end = fold.range.end.to_point(&snapshot.buffer_snapshot);
if let Some(last_range) = fold_ranges.last_mut() {
if last_range.end >= start {
last_range.end = last_range.end.max(end);
continue;
}
}
fold_ranges.push(start..end);
}
snapshot
.buffer_snapshot
@@ -169,15 +185,19 @@ pub fn indent_guides_in_range(
return false;
}
let start = MultiBufferRow(indent_guide.start_row.0.saturating_sub(1));
// Filter out indent guides that are inside a fold
// All indent guides that are starting "offscreen" have a start value of the first visible row minus one
// Therefore checking if a line is folded at first visible row minus one causes the other indent guides that are not related to the fold to disappear as well
let is_folded = snapshot.is_line_folded(start);
let line_indent = snapshot.line_indent_for_buffer_row(start);
let contained_in_fold =
line_indent.len(indent_guide.tab_size) <= indent_guide.indent_level();
!(is_folded && contained_in_fold)
let has_containing_fold = fold_ranges
.binary_search_by(|fold_range| {
if fold_range.start >= Point::new(indent_guide.start_row.0, 0) {
Ordering::Greater
} else if fold_range.end < Point::new(indent_guide.end_row.0, 0) {
Ordering::Less
} else {
Ordering::Equal
}
})
.is_ok();
!has_containing_fold
})
.collect()
}

View File

@@ -600,7 +600,7 @@ pub(crate) fn handle_from(
})
.collect::<Vec<_>>();
this.update_in(cx, |this, window, cx| {
this.change_selections_inner(None, false, window, cx, |s| {
this.change_selections_without_showing_completions(None, window, cx, |s| {
s.select(base_selections);
});
})

View File

@@ -1,4 +1,5 @@
use std::sync::Arc;
use std::time::Duration;
use crate::Editor;
use collections::HashMap;
@@ -16,6 +17,7 @@ use project::LocationLink;
use project::Project;
use project::TaskSourceKind;
use project::lsp_store::lsp_ext_command::GetLspRunnables;
use smol::future::FutureExt as _;
use smol::stream::StreamExt;
use task::ResolvedTask;
use task::TaskContext;
@@ -130,44 +132,58 @@ pub fn lsp_tasks(
.collect::<FuturesUnordered<_>>();
cx.spawn(async move |cx| {
let mut lsp_tasks = Vec::new();
while let Some(server_to_query) = lsp_task_sources.next().await {
if let Some((server_id, buffers)) = server_to_query {
let source_kind = TaskSourceKind::Lsp(server_id);
let id_base = source_kind.to_id_base();
let mut new_lsp_tasks = Vec::new();
for buffer in buffers {
let lsp_buffer_context = lsp_task_context(&project, &buffer, cx)
.await
.unwrap_or_default();
cx.spawn(async move |cx| {
let mut lsp_tasks = Vec::new();
while let Some(server_to_query) = lsp_task_sources.next().await {
if let Some((server_id, buffers)) = server_to_query {
let source_kind = TaskSourceKind::Lsp(server_id);
let id_base = source_kind.to_id_base();
let mut new_lsp_tasks = Vec::new();
for buffer in buffers {
let lsp_buffer_context = lsp_task_context(&project, &buffer, cx)
.await
.unwrap_or_default();
if let Ok(runnables_task) = project.update(cx, |project, cx| {
let buffer_id = buffer.read(cx).remote_id();
project.request_lsp(
buffer,
LanguageServerToQuery::Other(server_id),
GetLspRunnables {
buffer_id,
position: for_position,
},
cx,
)
}) {
if let Some(new_runnables) = runnables_task.await.log_err() {
new_lsp_tasks.extend(new_runnables.runnables.into_iter().filter_map(
|(location, runnable)| {
let resolved_task =
runnable.resolve_task(&id_base, &lsp_buffer_context)?;
Some((location, resolved_task))
if let Ok(runnables_task) = project.update(cx, |project, cx| {
let buffer_id = buffer.read(cx).remote_id();
project.request_lsp(
buffer,
LanguageServerToQuery::Other(server_id),
GetLspRunnables {
buffer_id,
position: for_position,
},
));
cx,
)
}) {
if let Some(new_runnables) = runnables_task.await.log_err() {
new_lsp_tasks.extend(
new_runnables.runnables.into_iter().filter_map(
|(location, runnable)| {
let resolved_task = runnable
.resolve_task(&id_base, &lsp_buffer_context)?;
Some((location, resolved_task))
},
),
);
}
}
}
lsp_tasks.push((source_kind, new_lsp_tasks));
}
lsp_tasks.push((source_kind, new_lsp_tasks));
}
}
lsp_tasks
lsp_tasks
})
.race({
// `lsp::LSP_REQUEST_TIMEOUT` is larger than we want for the modal to open fast
let timer = cx.background_executor().timer(Duration::from_millis(200));
async move {
timer.await;
log::info!("Timed out waiting for LSP tasks");
Vec::new()
}
})
.await
})
}

Some files were not shown because too many files have changed in this diff Show More