Compare commits

..

71 Commits

Author SHA1 Message Date
Cole Miller
fbd27148d5 fix 2025-08-08 01:29:07 -04:00
Samuel
d6022dc87c emmet: Enable in Vue.js files (#35599)
Resolves part of #34337

Actually I need also to add:

```
"languages": {
    "Vue.js": {
      "language_servers": [
        "vue-language-server",
        "emmet-language-server",
        "..."
      ]
    }
  },
```

not sure how to resolve fully, happy to continue only little guidance
needed.

Release Notes:

- allow emmet in Vue.js files

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
2025-08-08 02:50:54 +00:00
Dan Wood
0dd480d475 Add spread operator to the @operator list for ECMAScript languages (#35360)
Previously, this was the one thing that could not be styled properly in
ecmascript languages in the zed config, because it was not able to be
targeted.

Now, it is added alongside other operators. This has been tested and
works as expected.

Release Notes:

- N/A
2025-08-08 01:58:26 +00:00
Phoenix Himself
34fc2fd9d0 Treat Arduino files as C++ (#35467)
Closes https://github.com/zed-industries/zed/discussions/35466

Release Notes:

- N/A
2025-08-08 01:54:42 +00:00
Neo Nie
00701b5e99 git_hosting_providers: Extract Bitbucket pull request number (#34584)
git: Extract Bitbucket pull request number

Release Notes:

- git: Extract Bitbucket pull request number

---------

Co-authored-by: Peter Tripp <peter@zed.dev>
2025-08-08 01:39:32 +00:00
Abdelhakim Qbaich
cdfb3348ea git: Make inline blame padding configurable (#33631)
Just like with diagnostics, adding a configurable padding to inline
blame

Release Notes:

- Added configurable padding to inline blame

---------

Co-authored-by: Cole Miller <cole@zed.dev>
Co-authored-by: Peter Tripp <petertripp@gmail.com>
2025-08-08 01:35:07 +00:00
Mikayla Maki
35cd1b9ae1 filter out comments in deploy helper env vars (#35847)
Turns out a `.sh` file isn't actually a shell script :(

Release Notes:

- N/A
2025-08-07 18:01:46 -07:00
smit
bd402fdc7d editor: Fix Follow Agent unexpectedly stopping during edits (#35845)
Closes #34881

For horizontal scroll, we weren't keeping track of the `local` bool, so
whenever the agent tries to autoscroll horizontally, it would be seen as
a user scroll event resulting in unfollow.

Release Notes:

- Fixed an issue where the Follow Agent could unexpectedly stop
following during edits.
2025-08-08 06:17:37 +05:30
Mikayla Maki
c7d641ecb8 Revert "chore: Bump Rust to 1.89 (#35788)" (#35843)
This reverts commit efba2cbfd3.

Unfortunately, the Docker image for 1.89 has not shown up yet. Once it
has, we should re-land this.

Release Notes:

- N/A
2025-08-07 23:55:15 +00:00
Agus Zubiaga
3d662ee282 agent2: Port read_file tool (#35840)
Ports the read_file tool from `assistant_tools` to `agent2`. 

Note: Image support not implemented.

Release Notes:

- N/A
2025-08-07 20:46:47 -03:00
Richard Feldman
7d4d8b8398 Add GPT-5 support through OpenAI API (#35822)
(This PR does not add GPT-5 to Zed Pro, but rather adds access if you're
using your own OpenAI API key.)

<img width="772" height="333" alt="Screenshot 2025-08-07 at 2 23 18 PM"
src="https://github.com/user-attachments/assets/42e75082-118a-4737-89b6-a740ae33b169"
/>

---

**NOTE:** If your API key is not through a verified organization, you
may see this error:

<img width="549" height="253" alt="Screenshot 2025-08-07 at 2 04 54 PM"
src="https://github.com/user-attachments/assets/d0b6d739-9c39-4af3-88d7-0c9609b0e6ba"
/>

Even if your org is verified, you still may not have access to GPT-5, in
which case you could see this error:

<img width="543" height="98" alt="Screenshot 2025-08-07 at 2 09 18 PM"
src="https://github.com/user-attachments/assets/e3ed31e3-2a11-4f07-8f3c-5b410fbe4540"
/>

One way to test if you're in this situation is to visit
https://platform.openai.com/chat/edit?models=gpt-5 and see if you get
the same "you don't have access to GPT-5" error on OpenAI's official
playground. It looks like this:

<img width="581" height="196" alt="Screenshot 2025-08-07 at 2 15 25 PM"
src="https://github.com/user-attachments/assets/ea1454ca-3c10-4703-8126-c02cb92a34f2"
/>

Release Notes:

- Added GPT-5, as well as its mini and nano variants. To use this, you
need to have an OpenAI API key configured via the `OPENAI_API_KEY`
environment variable.
2025-08-07 23:35:41 +00:00
Agus Zubiaga
6912dc8399 Fix CC tool state on cancel (#35763)
When we stop the generation, CC tells us the tool completed, but it was
actually cancelled.

Release Notes:

- N/A
2025-08-07 20:26:19 -03:00
Peter Tripp
952e3713d7 ci: Switch to Namespace (#35835)
Follow-up to:
- https://github.com/zed-industries/zed/pull/35826

Release Notes:

- N/A
2025-08-07 23:16:25 +00:00
Mikayla Maki
913e9adf90 Move timing fields into span (#35833)
Release Notes:

- N/A
2025-08-07 23:07:33 +00:00
Marshall Bowers
50482a6bc2 language_model: Refresh the LLM token upon receiving a UserUpdated message from Cloud (#35839)
This PR makes it so we refresh the LLM token upon receiving a
`UserUpdated` message from Cloud over the WebSocket connection.

Release Notes:

- N/A
2025-08-07 23:00:45 +00:00
Marshall Bowers
d110459ef8 collab_ui: Show signed-out state when not connected to Collab (#35832)
This PR updates signed-out state of the Collab panel to show when not
connected to Collab, as opposed to just when the user is signed-out.

Release Notes:

- N/A
2025-08-07 22:29:59 +00:00
Cole Miller
d693f02c63 Settings: fix release channel settings not being respected (#35838)
Typo in #35756 

Release Notes:

- N/A
2025-08-07 22:27:50 +00:00
Ben Brandt
90fa921756 Wire up find_path tool in agent2 (#35799)
Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-08-07 22:21:26 +00:00
Marshall Bowers
11efa32fa7 client: Only connect to Collab automatically for Zed staff (#35827)
This PR makes it so that only Zed staff connect to Collab automatically.

Anyone else can connect to Collab manually when they want to collaborate
(but this is not required for using Zed's LLM features).

Release Notes:

- N/A

---------

Co-authored-by: Richard <richard@zed.dev>
2025-08-07 22:14:25 +00:00
Cole Miller
e6dc6faccf Don't insert resource links for @mentions that have been removed from the message editor (#35831)
Release Notes:

- N/A
2025-08-07 22:10:29 +00:00
Danilo Leal
070f7dbe1a onboarding: Add fast-follow adjustments (#35814)
Release Notes:

- N/A
2025-08-07 19:01:52 -03:00
Marshall Bowers
106d4cfce9 client: Re-fetch the authenticated user when receiving a UserUpdated message from Cloud (#35807)
This PR wires up handling for the new `UserUpdated` message coming from
Cloud over the WebSocket connection.

When we receive this message we will refresh the authenticated user.

Release Notes:

- N/A

Co-authored-by: Richard <richard@zed.dev>
2025-08-07 21:44:53 +00:00
Cole Miller
a1080a0411 Update diff editor font size when agent_font_size setting changes (#35834)
Release Notes:

- N/A
2025-08-07 21:31:30 +00:00
Peter Tripp
7679db99ac ci: Switch from BuildJet to GitHub runners (#35826)
In response to an ongoing BuildJet outage, consider migrating CI to
GitHub hosted runners.

Also includes revert of (causing flaky tests):
- https://github.com/zed-industries/zed/pull/35741

Downsides:
- Cost (2x)
- Force migration to Ubuntu 22.04 from 20.04 will bump our glibc minimum
from 2.31 to 2.35. Which would break RHEL 9.x (glibc 2.34), Ubuntu 20.04
(EOL) and derivatives.

Release Notes:

- N/A
2025-08-07 16:59:11 -04:00
Fabian Bergström
9ade399756 workspace: Don't update platform window title if title has not changed (#34753)
Closes #34749 #34715

Release Notes:

- Fixed window title X event spam
2025-08-07 22:13:51 +03:00
mcwindy
e8db429d24 project_panel: Add file comparison function, supports selecting files for comparison (#35255)
Closes https://github.com/zed-industries/zed/discussions/35010
Closes https://github.com/zed-industries/zed/issues/17100
Closes https://github.com/zed-industries/zed/issues/4523

Release Notes:

- Added file comparison function in project panel

---------

Co-authored-by: Kirill Bulatov <kirill@zed.dev>
2025-08-07 21:34:12 +03:00
Kirill Bulatov
53b69d29c5 Actually update remote collab capabilities (#35809)
Follow-up of https://github.com/zed-industries/zed/pull/35682

Release Notes:

- N/A
2025-08-07 13:58:33 -04:00
Julia Ryan
e2e147ab0e Add OS specific settings (#35756)
Release Notes:

- Settings can now be configured per operating system with the new
top-level fields: `"macos"`/`"windows"`/`"linux"`. These will override
user level settings, but are lower precedence than _release channel_
settings.
2025-08-07 10:52:54 -07:00
Marshall Bowers
fa2ff3ce1c collab: Increase DATABASE_MAX_CONNECTIONS for Collab server (#35818)
This PR increases the `DATABASE_MAX_CONNECTIONS` limit for the Collab
server to 850 (up from 250).

Release Notes:

- N/A

Co-authored-by: Nathan <nathan@zed.dev>
Co-authored-by: Mikayla <mikayla@zed.dev>
2025-08-07 10:26:08 -07:00
Piotr Osiewicz
c1d1d1cff6 chore: Bump to taffy 0.9 (#35802)
Co-authored-by: Anthony Eid <hello@anthonyeid.me>
Co-authored-by: Lukas Wirth <lukas@zed.dev>
Co-authored-by: Ben Kunkle <ben@zed.dev>

Release Notes:

- N/A

Co-authored-by: Anthony Eid <hello@anthonyeid.me>
Co-authored-by: Lukas Wirth <lukas@zed.dev>
Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-08-07 15:43:37 +00:00
Piotr Osiewicz
efba2cbfd3 chore: Bump Rust to 1.89 (#35788)
Release Notes:

- N/A

---------

Co-authored-by: Julia Ryan <juliaryan3.14@gmail.com>
2025-08-07 15:32:06 +00:00
Raphael Lüthy
2234220618 completions: Add subtle/eager behavior to Supermaven and Copilot (#35548)
This pull request introduces changes to improve the behavior and
consistency of multiple completion providers
(`CopilotCompletionProvider`, `SupermavenCompletionProvider`) and their
integration with UI elements like menus and inline completion buttons.
It now allows to see the prediction with the completion menu open whilst
pressing `opt` and also enables the subtle/eager setting that was
introduced with zeta.

Edit: I managed to get the preview working with correct icons!
<img width="909" height="232" alt="image"
src="https://github.com/user-attachments/assets/65800e67-4bc4-40f8-be78-806fcfe74ad9"
/>
<img width="1460" height="318" alt="CleanShot 2025-08-04 at 01 36 31@2x"
src="https://github.com/user-attachments/assets/15651405-720f-465f-a13c-c7470817810a"
/>

Correct icons are also displayed:
<img width="244" height="96" alt="image"
src="https://github.com/user-attachments/assets/0b8a687f-73e3-452d-aefb-784c52831b73"
/>


Edit2: I added some comments, would be very happy to receive feedback
(still learning rust)

Release Notes:

- Added Subtle and Eager edit prediction modes to Copilot and Supermaven
2025-08-07 18:27:29 +03:00
Piotr Osiewicz
dd840e4b27 editor: Fix multi-buffer headers spilling over at narrow widths (#35800)
Release Notes:

- N/A
2025-08-07 15:01:22 +00:00
Danilo Leal
262365ca24 keymap editor: Refine how we display matching keystrokes (#35796)
| Before | After |
|--------|--------|
| <img width="1092" height="528" alt="CleanShot 2025-08-07 at 10  54
42@2x"
src="https://github.com/user-attachments/assets/8b0a3b50-e1d1-4763-824c-2b419df430fc"
/> | <img width="1096" height="580" alt="CleanShot 2025-08-07 at 11  29
47@2x"
src="https://github.com/user-attachments/assets/bd484655-90a6-46fe-91ef-c9c8d2ab93bc"
/> |

Release Notes:

- N/A
2025-08-07 11:50:11 -03:00
localcc
90fa06dd61 Fix file unlocking after closing the workspace (#35741)
Release Notes:

- Fixed folders being locked after closing them in zed
2025-08-07 16:47:19 +02:00
Kirill Bulatov
740686b883 Batch diagnostics updates (#35794)
Diagnostics updates were programmed in Zed based off the r-a LSP push
diagnostics, with all related updates happening per file.

https://github.com/zed-industries/zed/pull/19230 and especially
https://github.com/zed-industries/zed/pull/32269 brought in pull
diagnostics that could produce results for thousands files
simultaneously.

It was noted and addressed on the local side in
https://github.com/zed-industries/zed/pull/34022 but the remote side was
still not adjusted properly.

This PR 

* removes redundant diagnostics pull updates on remote clients, as
buffer diagnostics are updated via buffer sync operations separately
* batches all diagnostics-related updates and proto messages, so
multiple diagnostic summaries (per file) could be sent at once,
specifically, 1 (potentially large) diagnostics summary update instead
of N*10^3 small ones.

Buffer updates are still sent per buffer and not updated, as happening
separately and not offending the collab traffic that much.

Release Notes:

- Improved diagnostics performance in the collaborative mode
2025-08-07 14:45:41 +00:00
Danilo Leal
a5c25e0366 agent: Improve end of trial card display (#35789)
Now rendering the backdrop behind the card to clean up the UI, bring
focus to the card's content, and direct the user to act on it, either by
ignoring it or upgrading.

<img width="500" height="1242" alt="CleanShot 2025-08-07 at 10  30
58@2x"
src="https://github.com/user-attachments/assets/8c6b9c34-eb22-4f01-b3fa-158ac78b7439"
/>

Release Notes:

- N/A
2025-08-07 10:53:15 -03:00
Tongue_chaude
305c653c62 Add icons for Puppet files (#35778)
Release Notes:

- Added icon for Puppet (.pp) files

Actually puppet icons are available in the extension here :
<https://github.com/AlexandarY/zed-puppet/tree/main/icon_themes>

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
2025-08-07 13:51:29 +00:00
Danilo Leal
e227b5ac30 onboarding: Add young account treatment to AI upsell card (#35785)
Release Notes:

- N/A
2025-08-07 10:42:46 -03:00
Antonio Scandurra
03876d076e Add system prompt and tool permission to agent2 (#35781)
Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-08-07 13:40:12 +00:00
Lukas Wirth
4dbd24d75f Reduce amount of allocations in RustLsp label handling (#35786)
There can be a lot of completions after all


Release Notes:

- N/A
2025-08-07 13:24:29 +00:00
Kirill Bulatov
c397027ec2 Add release_channel into the span fields list (#35783)
Follow-up of https://github.com/zed-industries/zed/pull/35729

Release Notes:

- N/A

Co-authored-by: Marshall Bowers <marshall@zed.dev>
2025-08-07 12:56:10 +00:00
Lukas Wirth
f5f837d39a languages: Fix rust completions not having proper detail labels (#35772)
rust-analyzer changed the format here a bit some months ago which
partially broke our nice detailed highlighted completion labels. The
brings that back while also cleaning up the code a bit.

Also fixes a bug where disabling rust-analyzers snippet callable
completions would fully break them.

Release Notes:

- N/A
2025-08-07 10:38:58 +00:00
smit
5b1b3c51d4 language_models: Fix high memory consumption while using Agent Panel (#35764)
Closes #31108

The `num_tokens_from_messages` method we use from `tiktoken-rs` creates
new BPE every time that method is called. This creation of BPE is
expensive as well as has some underlying issue that keeps memory from
releasing once the method is finished, specifically noticeable on Linux.
This leads to a gradual increase in memory every time that method is
called in my case around +50MB on each call. We call this method with
debounce every time user types in Agent Panel to calculate tokens. This
can add up really fast.

This PR lands quick fix, while I/maintainers figure out underlying
issue. See upstream discussion:
https://github.com/zurawiki/tiktoken-rs/issues/39.

Here on fork https://github.com/zed-industries/tiktoken-rs/pull/1,
instead of creating BPE instances every time that method is called, we
use singleton BPE instances instead. So, whatever memory it is holding
on to, at least that is only once per model.

Before: Increase of 700MB+ on extensive use

On init:
<img width="500" alt="prev-init"
src="https://github.com/user-attachments/assets/70da7c44-60cb-477b-84aa-7dd579baa3da"
/>
First message:
<img width="500" alt="prev-first-call"
src="https://github.com/user-attachments/assets/599ffc48-3ad3-4729-b94c-6d88493afdbf"
/>
Extensive use:
<img width="500" alt="prev-extensive-use"
src="https://github.com/user-attachments/assets/e0e6b688-6412-486d-8b2e-7216c6b62470"
/>

After: Increase of 50MB+ on extensive use
On init:
<img width="500" alt="now-init"
src="https://github.com/user-attachments/assets/11a2cd9c-20b0-47ae-be02-07ff876e68ad"
/>
First message:
<img width="500" alt="now-first-call"
src="https://github.com/user-attachments/assets/ef505f8d-cd31-49cd-b6bb-7da3f0838fa7"
/>
Extensive use: 
<img width="500" alt="now-extensive-use"
src="https://github.com/user-attachments/assets/513cb85a-a00b-4f11-8666-69103a9eb2b8"
/>

Release Notes:

- Fixed issue where Agent Panel would cause high memory consumption over
prolonged use.
2025-08-07 11:39:27 +05:30
Gregor
b4a441f12f Add UnwrapSyntaxNode action (#31421)
Remake of #8967

> Hey there,
> 
> I have started relying on this action, that I've also put into VSCode
as [an extension](https://github.com/Gregoor/soy). On some level I don't
know how people code (cope?) without it:
> 
> Release Notes:
> 
> * Added UnwrapSyntaxNode action
> 
>
https://github.com/zed-industries/zed/assets/4051932/d74c98c0-96d8-4075-9b63-cea55bea42f6
> 
> Since I had to put it into Zed anyway to make it my daily driver, I
thought I'd also check here if there's an interest in shipping it by
default (that would ofc also personally make my life better, not having
to maintain my personal fork and all).
> 
> If there is interest, I'd be happy to make any changes to make this
more mergeable. Two TODOs on my mind are:
> 
> * unwrap multiple into single (e.g. `fn(≤a≥, b)` to `fn(≤a≥)`)
> * multi-cursor
> * syntax awareness, i.e. only unwrap if it does not break syntax (I
added [a coarse version of that for my VSC
extension](https://github.com/Gregoor/soy/blob/main/src/actions/unwrap.ts#L29))
> 
> Somewhat off-topic: I was happy to see that you're
[also](https://github.com/Gregoor/soy/blob/main/src/actions/unwrap.test.ts)
using rare special chars in test code to denote cursor positions.


Release Notes:

- Added UnwrapSyntaxNode action

---------

Co-authored-by: Peter Tripp <peter@zed.dev>
2025-08-07 07:52:22 +03:00
Anthony Eid
f1e69f6311 gpui: Impl Default for ClickEvent (#35751)
While default for ClickEvent shouldn't be used much this is helpful for
other projects using gpui besides Zed. Mainly because the orphan rule
prevents those projects from implementing their own default trait

cc: @huacnlee 

Release Notes:

- N/A
2025-08-06 23:24:37 -04:00
Agus Zubiaga
bd1c26cb5b Fix interrupting ACP threads and CC cancellation (#35752)
Fixes a bug where generation wouldn't continue after interrupting the
agent, and improves CC cancellation so we don't display "[Request
interrupted by user]"

Release Notes:

- N/A

---------

Co-authored-by: Cole Miller <cole@zed.dev>
2025-08-06 22:55:17 -03:00
Richard Feldman
1907b16fe6 Establish WebSocket connection to Cloud (#35734)
This PR adds a new WebSocket connection to Cloud.

This connection will be used to push down notifications from the server
to the client.

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
2025-08-07 01:28:41 +00:00
Max Brunsfeld
c595a7576d Fix git hunk staging on windows (#35755)
We were failing to flush the Git process's `stdin` before dropping it.

Release Notes:

- N/A
2025-08-06 17:30:36 -07:00
Danilo Leal
8e290b446e thread view: Add UI refinements (#35754)
More notably around how we render tool calls. Nothing too drastic,
though.

Release Notes:

- N/A
2025-08-06 20:31:11 -03:00
Marshall Bowers
58392b9c13 cloud_api_types: Add types for WebSocket protocol (#35753)
This PR adds types for the Cloud WebSocket protocol to the
`cloud_api_types` crate.

Release Notes:

- N/A
2025-08-06 23:20:04 +00:00
Cole Miller
9358690337 Fix flicker when agent plan updates (#35739)
Currently, when the agent updates its plan, there are a few frames where
the text after `Current:` in the plan summary is blank, causing a
flicker. This is because we treat that field as markdown, and the
`MarkdownElement` renders as blank until the raw text has finished
parsing in the background.

This PR fixes the flicker by changing `Markdown::new_text` to
optimistically render the source as a single `MarkdownEvent::Text` span
until background parsing has finished.

Release Notes:

- N/A

Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-08-06 22:38:00 +00:00
Anthony Eid
3ea90e397b debugger: Filter out debug scenarios with invalid Adapters from debug picker (#35744)
I also removed a debug assertion that wasn't true when a debug session
was restarting through a request, because there wasn't a booting task
Zed needed to run before the session.

I renamed SessionState::Building to SessionState::Booting as well,
because building implies that we're building code while booting the
session covers more cases and is more accurate.

Release Notes:

- debugger: Filter out more invalid debug configurations from the debug
picker

Co-authored-by: Remco Smits <djsmits12@gmail.com>
2025-08-06 18:10:17 -04:00
Peter Tripp
a5dd8d0052 Recognize pixi.lock as YAML (#35747)
Release Notes:

- N/A
2025-08-06 15:56:11 -04:00
Agus Zubiaga
250c51bb20 Fix syntax highlighting in ACP diffs (#35748)
Release Notes:

- N/A
2025-08-06 19:53:45 +00:00
Anthony Eid
010441e23b debugger: Show run to cursor in editor's context menu (#35745)
This also fixed a bug where evaluate selected text was an available
option when the selected debug session was terminated.


Release Notes:

- debugger: add Run to Cursor back to Editor's context menu

Co-authored-by: Remco Smits <djsmits12@gmail.com>
2025-08-06 15:45:22 -04:00
Peter Tripp
f9038f6189 Add key contexts for Pickers (#35665)
Closes: https://github.com/zed-industries/zed/issues/35430

Added:
- Workspace > CommandPalette
- Workspace > GitBranchSelector
- Workspace > GitRepositorySelector
- Workspace > RecentProjects
- Workspace > LanguageSelector
- Workspace > IconThemeSelector
- Workspace > ThemeSelector

Release Notes:

- Added new keymap contexts for various Pickers - CommandPalette,
GitBranchSelector, GitRepositorySelector, RecentProjects,
LanguageSelector, IconThemeSelector, ThemeSelector
2025-08-06 15:28:18 -04:00
xdBronch
a80da784b7 lsp: Advertise support for markdown in completion documentation (#35727)
Release Notes:

- N/A
2025-08-06 21:42:29 +03:00
Piotr Osiewicz
fb1f9d1212 lsp: Correctly serialize errors for LSP requests + improve handling of unrecognized methods (#35738)
We used to not respond at all to requests that we didn't have a handler
for, which is yuck. It may have left the language server waiting for the
response for no good reason. The other (worse) finding is that we did
not have a full definition of an Error type for LSP, which made it so
that a spec-compliant language server would fail to deserialize our
response (with an error). This then could lead to all sorts of
funkiness, including hangs and crashes on the language server's part.

Co-authored-by: Lukas <lukas@zed.dev>
Co-authored-by: Remco Smits <djsmits12@gmail.com>

Co-authored-by: Anthony Eid <hello@anthonyeid.me>

Closes #ISSUE

Release Notes:

- Improved reporting of errors to language servers, which should improve
the stability of LSPs ran by Zed.

---------

Co-authored-by: Lukas <lukas@zed.dev>
Co-authored-by: Remco Smits <djsmits12@gmail.com>
Co-authored-by: Anthony Eid <hello@anthonyeid.me>
2025-08-06 18:27:48 +00:00
Mikayla Maki
794098e5c9 Update instructions for local collaboration (#35689)
Release Notes:

- N/A
2025-08-06 11:10:28 -07:00
Marshall Bowers
b08e26df60 collab: Remove unused StripeBilling methods (#35740)
This PR removes some unused methods from the `StripeBilling` object.

Release Notes:

- N/A
2025-08-06 17:42:12 +00:00
Marshall Bowers
740597492b collab: Remove Stripe events polling (#35736)
This PR removes the Stripe event polling from Collab, as it has been
moved to Cloud.

Release Notes:

- N/A
2025-08-06 16:53:43 +00:00
Ben Kunkle
ebda6b8a94 keymap_ui: Show matching bindings (#35732)
Closes #ISSUE

Adds a bit of text in the keybind editing modal when there are existing
keystrokes with the same key, with the ability for the user to click the
text and have the keymap editor search be updated to show only bindings
with those keystrokes

Release Notes:

- Keymap Editor: Added a warning to the keybind editing modal when
existing bindings have the same keystrokes. Clicking the warning will
close the modal and show bindings with the entered keystrokes in the
keymap editor. This behavior was previously possible with the
`keymap_editor::ShowMatchingKeybinds` action in the Keymap Editor, and
is now present in the keybind editing modal as well.
2025-08-06 12:16:05 -04:00
Kirill Bulatov
55b4df4d9f Add a way to distinguish metrics by Zed's release channel (#35729)
Release Notes:

- N/A

---------

Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
2025-08-06 18:47:44 +03:00
Umesh Yadav
b8e8fbd8e6 ollama: Add support for gpt-oss (#35648)
There is a know bug when calling tool discussion:
https://discord.com/channels/1128867683291627614/1402385744038858853
I have raised the issue with ollama team and they are currently fixing
it.

Release Notes:

- ollama: Add support for gpt-oss
2025-08-06 10:44:15 -04:00
Agus Zubiaga
33f198fef1 Thread view scrollbar (#35655)
This also adds a convenient `Scrollbar:auto_hide` function so that we
don't have to handle that at the callsite.

Release Notes:

- N/A

---------

Co-authored-by: David Kleingeld <davidsk@zed.dev>
2025-08-06 14:01:34 +00:00
Peter Tripp
3c602fecbf docs: Cleanup tool use documentation (#35725)
Remove redundant documentation about tool use.

Release Notes:

- N/A
2025-08-06 09:59:13 -04:00
Agus Zubiaga
334bdd0efc Fix acp thread entry width (#35723)
Release Notes:

- N/A
2025-08-06 13:39:55 +00:00
Agus Zubiaga
69dc870828 Fix CC todo tool parsing (#35721)
It looks like the TODO tool call no longer requires a priority.

Release Notes:

- N/A
2025-08-06 13:27:11 +00:00
Agus Zubiaga
22fa41e9c0 Handle CC thinking (#35722)
Release Notes:

- N/A
2025-08-06 13:20:53 +00:00
Joseph T. Lyons
7e790f52c8 Bump Zed to v0.200 (#35719)
🎉

Release Notes:

-N/A
2025-08-06 13:11:46 +00:00
166 changed files with 6129 additions and 3113 deletions

View File

@@ -24,9 +24,6 @@ self-hosted-runner:
- namespace-profile-8x16-ubuntu-2204
- namespace-profile-16x32-ubuntu-2204
- namespace-profile-32x64-ubuntu-2204
# Namespace Limited Preview
- namespace-profile-8x16-ubuntu-2004-arm-m4
- namespace-profile-8x32-ubuntu-2004-arm-m4
# Self Hosted Runners
- self-mini-macos
- self-32vcpu-windows-2022

View File

@@ -511,8 +511,8 @@ jobs:
runs-on:
- self-mini-macos
if: |
( startsWith(github.ref, 'refs/tags/v')
|| contains(github.event.pull_request.labels.*.name, 'run-bundling') )
startsWith(github.ref, 'refs/tags/v')
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
needs: [macos_tests]
env:
MACOS_CERTIFICATE: ${{ secrets.MACOS_CERTIFICATE }}
@@ -599,8 +599,8 @@ jobs:
runs-on:
- namespace-profile-16x32-ubuntu-2004 # ubuntu 20.04 for minimal glibc
if: |
( startsWith(github.ref, 'refs/tags/v')
|| contains(github.event.pull_request.labels.*.name, 'run-bundling') )
startsWith(github.ref, 'refs/tags/v')
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
needs: [linux_tests]
steps:
- name: Checkout repo
@@ -650,7 +650,7 @@ jobs:
timeout-minutes: 60
name: Linux arm64 release bundle
runs-on:
- namespace-profile-8x32-ubuntu-2004-arm-m4 # ubuntu 20.04 for minimal glibc
- namespace-profile-32x64-ubuntu-2004-arm # ubuntu 20.04 for minimal glibc
if: |
startsWith(github.ref, 'refs/tags/v')
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
@@ -703,8 +703,10 @@ jobs:
timeout-minutes: 60
runs-on: github-8vcpu-ubuntu-2404
if: |
false && ( startsWith(github.ref, 'refs/tags/v')
|| contains(github.event.pull_request.labels.*.name, 'run-bundling') )
false && (
startsWith(github.ref, 'refs/tags/v')
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
)
needs: [linux_tests]
name: Build Zed on FreeBSD
steps:

View File

@@ -137,12 +137,14 @@ jobs:
export ZED_SERVICE_NAME=collab
export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT
export DATABASE_MAX_CONNECTIONS=850
envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"
export ZED_SERVICE_NAME=api
export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_API_LOAD_BALANCER_SIZE_UNIT
export DATABASE_MAX_CONNECTIONS=60
envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"

View File

@@ -168,7 +168,7 @@ jobs:
name: Create a Linux *.tar.gz bundle for ARM
if: github.repository_owner == 'zed-industries'
runs-on:
- namespace-profile-8x32-ubuntu-2004-arm-m4 # ubuntu 20.04 for minimal glibc
- namespace-profile-32x64-ubuntu-2004-arm # ubuntu 20.04 for minimal glibc
needs: tests
steps:
- name: Checkout repo

23
Cargo.lock generated
View File

@@ -138,9 +138,9 @@ dependencies = [
[[package]]
name = "agent-client-protocol"
version = "0.0.20"
version = "0.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12dbfec3d27680337ed9d3064eecafe97acf0b0f190148bb4e29d96707c9e403"
checksum = "3fad72b7b8ee4331b3a4c8d43c107e982a4725564b4ee658ae5c4e79d2b486e8"
dependencies = [
"anyhow",
"futures 0.3.31",
@@ -159,6 +159,7 @@ dependencies = [
"agent-client-protocol",
"agent_servers",
"anyhow",
"assistant_tool",
"client",
"clock",
"cloud_llm_client",
@@ -171,10 +172,14 @@ dependencies = [
"gpui_tokio",
"handlebars 4.5.0",
"indoc",
"itertools 0.14.0",
"language",
"language_model",
"language_models",
"log",
"pretty_assertions",
"project",
"prompt_store",
"reqwest_client",
"rust-embed",
"schemars",
@@ -185,6 +190,7 @@ dependencies = [
"ui",
"util",
"uuid",
"watch",
"workspace-hack",
"worktree",
]
@@ -7502,9 +7508,9 @@ dependencies = [
[[package]]
name = "grid"
version = "0.17.0"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71b01d27060ad58be4663b9e4ac9e2d4806918e8876af8912afbddd1a91d5eaa"
checksum = "12101ecc8225ea6d675bc70263074eab6169079621c2186fe0c66590b2df9681"
[[package]]
name = "group"
@@ -11184,7 +11190,6 @@ dependencies = [
"anyhow",
"futures 0.3.31",
"http_client",
"log",
"schemars",
"serde",
"serde_json",
@@ -12634,6 +12639,7 @@ dependencies = [
"editor",
"file_icons",
"git",
"git_ui",
"gpui",
"indexmap",
"language",
@@ -16217,9 +16223,9 @@ dependencies = [
[[package]]
name = "taffy"
version = "0.8.3"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7aaef0ac998e6527d6d0d5582f7e43953bb17221ac75bb8eb2fcc2db3396db1c"
checksum = "a13e5d13f79d558b5d353a98072ca8ca0e99da429467804de959aa8c83c9a004"
dependencies = [
"arrayvec",
"grid",
@@ -20461,7 +20467,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.199.8"
version = "0.200.0"
dependencies = [
"activity_indicator",
"agent",
@@ -20864,7 +20870,6 @@ dependencies = [
"menu",
"postage",
"project",
"rand 0.8.5",
"regex",
"release_channel",
"reqwest_client",

View File

@@ -425,7 +425,7 @@ zlog_settings = { path = "crates/zlog_settings" }
#
agentic-coding-protocol = "0.0.10"
agent-client-protocol = "0.0.20"
agent-client-protocol = { version = "0.0.23" }
aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
any_vec = "0.14"

View File

@@ -1,3 +1,4 @@
collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve all
cloud: cd ../cloud; cargo make dev
livekit: livekit-server --dev
blob_store: ./script/run-local-minio

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="17" height="16" fill="none"><path fill="#000" d="M12.5 9.639V6.354H9.683L8.18 5.007V2.5H4.5v3.285h2.817l1.51 1.347v1.736l-1.51 1.347H4.5V13.5h3.681v-2.507l1.51-1.347H12.5v-.007ZM5.727 3.595h1.227V4.69H5.727V3.595Zm1.227 8.803H5.727v-1.095h1.227v1.095Z"/></svg>

After

Width:  |  Height:  |  Size: 308 B

View File

Before

Width:  |  Height:  |  Size: 776 B

After

Width:  |  Height:  |  Size: 776 B

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 6.4 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 2.5 KiB

After

Width:  |  Height:  |  Size: 5.3 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 5.5 KiB

View File

@@ -332,7 +332,9 @@
"enter": "agent::Chat",
"up": "agent::PreviousHistoryMessage",
"down": "agent::NextHistoryMessage",
"shift-ctrl-r": "agent::OpenAgentDiff"
"shift-ctrl-r": "agent::OpenAgentDiff",
"ctrl-shift-y": "agent::KeepAll",
"ctrl-shift-n": "agent::RejectAll"
}
},
{
@@ -846,6 +848,7 @@
"ctrl-delete": ["project_panel::Delete", { "skip_prompt": false }],
"alt-ctrl-r": "project_panel::RevealInFileManager",
"ctrl-shift-enter": "project_panel::OpenWithSystem",
"alt-d": "project_panel::CompareMarkedFiles",
"shift-find": "project_panel::NewSearchInDirectory",
"ctrl-alt-shift-f": "project_panel::NewSearchInDirectory",
"shift-down": "menu::SelectNext",
@@ -1100,6 +1103,13 @@
"ctrl-enter": "menu::Confirm"
}
},
{
"context": "OnboardingAiConfigurationModal",
"use_key_equivalents": true,
"bindings": {
"escape": "menu::Cancel"
}
},
{
"context": "Diagnostics",
"use_key_equivalents": true,
@@ -1176,7 +1186,8 @@
"ctrl-1": "onboarding::ActivateBasicsPage",
"ctrl-2": "onboarding::ActivateEditingPage",
"ctrl-3": "onboarding::ActivateAISetupPage",
"ctrl-escape": "onboarding::Finish"
"ctrl-escape": "onboarding::Finish",
"alt-tab": "onboarding::SignIn"
}
}
]

View File

@@ -384,7 +384,9 @@
"enter": "agent::Chat",
"up": "agent::PreviousHistoryMessage",
"down": "agent::NextHistoryMessage",
"shift-ctrl-r": "agent::OpenAgentDiff"
"shift-ctrl-r": "agent::OpenAgentDiff",
"cmd-shift-y": "agent::KeepAll",
"cmd-shift-n": "agent::RejectAll"
}
},
{
@@ -905,6 +907,7 @@
"cmd-delete": ["project_panel::Delete", { "skip_prompt": false }],
"alt-cmd-r": "project_panel::RevealInFileManager",
"ctrl-shift-enter": "project_panel::OpenWithSystem",
"alt-d": "project_panel::CompareMarkedFiles",
"cmd-alt-backspace": ["project_panel::Delete", { "skip_prompt": false }],
"cmd-alt-shift-f": "project_panel::NewSearchInDirectory",
"shift-down": "menu::SelectNext",
@@ -1202,6 +1205,13 @@
"cmd-enter": "menu::Confirm"
}
},
{
"context": "OnboardingAiConfigurationModal",
"use_key_equivalents": true,
"bindings": {
"escape": "menu::Cancel"
}
},
{
"context": "Diagnostics",
"use_key_equivalents": true,
@@ -1278,7 +1288,8 @@
"cmd-1": "onboarding::ActivateBasicsPage",
"cmd-2": "onboarding::ActivateEditingPage",
"cmd-3": "onboarding::ActivateAISetupPage",
"cmd-escape": "onboarding::Finish"
"cmd-escape": "onboarding::Finish",
"alt-tab": "onboarding::SignIn"
}
}
]

View File

@@ -813,6 +813,7 @@
"p": "project_panel::Open",
"x": "project_panel::RevealInFileManager",
"s": "project_panel::OpenWithSystem",
"z d": "project_panel::CompareMarkedFiles",
"] c": "project_panel::SelectNextGitEntry",
"[ c": "project_panel::SelectPrevGitEntry",
"] d": "project_panel::SelectNextDiagnostic",

View File

@@ -1171,6 +1171,9 @@
// Sets a delay after which the inline blame information is shown.
// Delay is restarted with every cursor movement.
"delay_ms": 0,
// The amount of padding between the end of the source line and the start
// of the inline blame in units of em widths.
"padding": 7,
// Whether or not to display the git commit summary on the same line.
"show_commit_summary": false,
// The minimum column number to show the inline blame information at

View File

@@ -166,6 +166,7 @@ pub struct ToolCall {
pub status: ToolCallStatus,
pub locations: Vec<acp::ToolCallLocation>,
pub raw_input: Option<serde_json::Value>,
pub raw_output: Option<serde_json::Value>,
}
impl ToolCall {
@@ -194,6 +195,7 @@ impl ToolCall {
locations: tool_call.locations,
status,
raw_input: tool_call.raw_input,
raw_output: tool_call.raw_output,
}
}
@@ -210,6 +212,7 @@ impl ToolCall {
content,
locations,
raw_input,
raw_output,
} = fields;
if let Some(kind) = kind {
@@ -221,7 +224,9 @@ impl ToolCall {
}
if let Some(title) = title {
self.label = cx.new(|cx| Markdown::new_text(title.into(), cx));
self.label.update(cx, |label, cx| {
label.replace(title, cx);
});
}
if let Some(content) = content {
@@ -238,6 +243,10 @@ impl ToolCall {
if let Some(raw_input) = raw_input {
self.raw_input = Some(raw_input);
}
if let Some(raw_output) = raw_output {
self.raw_output = Some(raw_output);
}
}
pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
@@ -411,8 +420,6 @@ impl ToolCallContent {
pub struct Diff {
pub multibuffer: Entity<MultiBuffer>,
pub path: PathBuf,
pub new_buffer: Entity<Buffer>,
pub old_buffer: Entity<Buffer>,
_task: Task<Result<()>>,
}
@@ -433,23 +440,34 @@ impl Diff {
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
let old_buffer_snapshot = old_buffer.read(cx).snapshot();
let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
let diff_task = buffer_diff.update(cx, |diff, cx| {
diff.set_base_text(
old_buffer_snapshot,
Some(language_registry.clone()),
new_buffer_snapshot,
cx,
)
});
let task = cx.spawn({
let multibuffer = multibuffer.clone();
let path = path.clone();
let new_buffer = new_buffer.clone();
async move |cx| {
diff_task.await?;
let language = language_registry
.language_for_file_path(&path)
.await
.log_err();
new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
buffer.set_language(language, cx);
buffer.snapshot()
})?;
buffer_diff
.update(cx, |diff, cx| {
diff.set_base_text(
old_buffer_snapshot,
Some(language_registry),
new_buffer_snapshot,
cx,
)
})?
.await?;
multibuffer
.update(cx, |multibuffer, cx| {
@@ -468,18 +486,10 @@ impl Diff {
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
multibuffer.add_diff(buffer_diff.clone(), cx);
multibuffer.add_diff(buffer_diff, cx);
})
.log_err();
if let Some(language) = language_registry
.language_for_file_path(&path)
.await
.log_err()
{
new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
}
anyhow::Ok(())
}
});
@@ -487,8 +497,6 @@ impl Diff {
Self {
multibuffer,
path,
new_buffer,
old_buffer,
_task: task,
}
}
@@ -557,7 +565,7 @@ pub struct PlanEntry {
impl PlanEntry {
pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
Self {
content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)),
content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
priority: entry.priority,
status: entry.status,
}
@@ -661,7 +669,7 @@ impl AcpThread {
}
pub fn status(&self) -> ThreadStatus {
if self.send_task.is_some() {
if self.send_task.as_ref().map_or(false, |t| !t.is_finished()) {
if self.waiting_for_tool_confirmation() {
ThreadStatus::WaitingForToolConfirmation
} else {
@@ -896,7 +904,7 @@ impl AcpThread {
});
}
pub fn request_tool_call_permission(
pub fn request_tool_call_authorization(
&mut self,
tool_call: acp::ToolCall,
options: Vec<acp::PermissionOption>,
@@ -971,13 +979,26 @@ impl AcpThread {
}
pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
self.plan = Plan {
entries: request
.entries
.into_iter()
.map(|entry| PlanEntry::from_acp(entry, cx))
.collect(),
};
let new_entries_len = request.entries.len();
let mut new_entries = request.entries.into_iter();
// Reuse existing markdown to prevent flickering
for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
let PlanEntry {
content,
priority,
status,
} = old;
content.update(cx, |old, cx| {
old.replace(new.content, cx);
});
*priority = new.priority;
*status = new.status;
}
for new in new_entries {
self.plan.entries.push(PlanEntry::from_acp(new, cx))
}
self.plan.entries.truncate(new_entries_len);
cx.notify();
}
@@ -1038,8 +1059,8 @@ impl AcpThread {
)
})?
.await;
tx.send(result).log_err();
this.update(cx, |this, _cx| this.send_task.take())?;
anyhow::Ok(())
}
.await
@@ -1525,6 +1546,7 @@ mod tests {
content: vec![],
locations: vec![],
raw_input: None,
raw_output: None,
}),
cx,
)
@@ -1637,6 +1659,7 @@ mod tests {
}],
locations: vec![],
raw_input: None,
raw_output: None,
}),
cx,
)

View File

@@ -16,6 +16,7 @@ acp_thread.workspace = true
agent-client-protocol.workspace = true
agent_servers.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
fs.workspace = true
@@ -23,10 +24,13 @@ futures.workspace = true
gpui.workspace = true
handlebars = { workspace = true, features = ["rust-embed"] }
indoc.workspace = true
itertools.workspace = true
language.workspace = true
language_model.workspace = true
language_models.workspace = true
log.workspace = true
project.workspace = true
prompt_store.workspace = true
rust-embed.workspace = true
schemars.workspace = true
serde.workspace = true
@@ -36,7 +40,7 @@ smol.workspace = true
ui.workspace = true
util.workspace = true
uuid.workspace = true
worktree.workspace = true
watch.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
@@ -47,8 +51,10 @@ env_logger.workspace = true
fs = { workspace = true, "features" = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] }
gpui_tokio.workspace = true
language = { workspace = true, "features" = ["test-support"] }
language_model = { workspace = true, "features" = ["test-support"] }
project = { workspace = true, "features" = ["test-support"] }
reqwest_client.workspace = true
settings = { workspace = true, "features" = ["test-support"] }
worktree = { workspace = true, "features" = ["test-support"] }
pretty_assertions.workspace = true

View File

@@ -1,16 +1,39 @@
use crate::{templates::Templates, AgentResponseEvent, Thread};
use crate::{FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization};
use acp_thread::ModelSelector;
use agent_client_protocol as acp;
use anyhow::{anyhow, Result};
use futures::StreamExt;
use gpui::{App, AppContext, AsyncApp, Entity, Subscription, Task, WeakEntity};
use anyhow::{anyhow, Context as _, Result};
use futures::{future, StreamExt};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
};
use language_model::{LanguageModel, LanguageModelRegistry};
use project::Project;
use project::{Project, ProjectItem, ProjectPath, Worktree};
use prompt_store::{
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
};
use std::cell::RefCell;
use std::collections::HashMap;
use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
use util::ResultExt;
use crate::{templates::Templates, AgentResponseEvent, Thread};
const RULES_FILE_NAMES: [&'static str; 9] = [
".rules",
".cursorrules",
".windsurfrules",
".clinerules",
".github/copilot-instructions.md",
"CLAUDE.md",
"AGENT.md",
"AGENTS.md",
"GEMINI.md",
];
pub struct RulesLoadingError {
pub message: SharedString,
}
/// Holds both the internal Thread and the AcpThread for a session
struct Session {
@@ -24,17 +47,247 @@ struct Session {
pub struct NativeAgent {
/// Session ID -> Session mapping
sessions: HashMap<acp::SessionId, Session>,
/// Shared project context for all threads
project_context: Rc<RefCell<ProjectContext>>,
project_context_needs_refresh: watch::Sender<()>,
_maintain_project_context: Task<Result<()>>,
/// Shared templates for all threads
templates: Arc<Templates>,
project: Entity<Project>,
prompt_store: Option<Entity<PromptStore>>,
_subscriptions: Vec<Subscription>,
}
impl NativeAgent {
pub fn new(templates: Arc<Templates>) -> Self {
pub async fn new(
project: Entity<Project>,
templates: Arc<Templates>,
prompt_store: Option<Entity<PromptStore>>,
cx: &mut AsyncApp,
) -> Result<Entity<NativeAgent>> {
log::info!("Creating new NativeAgent");
Self {
sessions: HashMap::new(),
templates,
let project_context = cx
.update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
.await;
cx.new(|cx| {
let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
if let Some(prompt_store) = prompt_store.as_ref() {
subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
}
let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
watch::channel(());
Self {
sessions: HashMap::new(),
project_context: Rc::new(RefCell::new(project_context)),
project_context_needs_refresh: project_context_needs_refresh_tx,
_maintain_project_context: cx.spawn(async move |this, cx| {
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
}),
templates,
project,
prompt_store,
_subscriptions: subscriptions,
}
})
}
async fn maintain_project_context(
this: WeakEntity<Self>,
mut needs_refresh: watch::Receiver<()>,
cx: &mut AsyncApp,
) -> Result<()> {
while needs_refresh.changed().await.is_ok() {
let project_context = this
.update(cx, |this, cx| {
Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
})?
.await;
this.update(cx, |this, _| this.project_context.replace(project_context))?;
}
Ok(())
}
fn build_project_context(
project: &Entity<Project>,
prompt_store: Option<&Entity<PromptStore>>,
cx: &mut App,
) -> Task<ProjectContext> {
let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
let worktree_tasks = worktrees
.into_iter()
.map(|worktree| {
Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
})
.collect::<Vec<_>>();
let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
prompt_store.read_with(cx, |prompt_store, cx| {
let prompts = prompt_store.default_prompt_metadata();
let load_tasks = prompts.into_iter().map(|prompt_metadata| {
let contents = prompt_store.load(prompt_metadata.id, cx);
async move { (contents.await, prompt_metadata) }
});
cx.background_spawn(future::join_all(load_tasks))
})
} else {
Task::ready(vec![])
};
cx.spawn(async move |_cx| {
let (worktrees, default_user_rules) =
future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
let worktrees = worktrees
.into_iter()
.map(|(worktree, _rules_error)| {
// TODO: show error message
// if let Some(rules_error) = rules_error {
// this.update(cx, |_, cx| cx.emit(rules_error)).ok();
// }
worktree
})
.collect::<Vec<_>>();
let default_user_rules = default_user_rules
.into_iter()
.flat_map(|(contents, prompt_metadata)| match contents {
Ok(contents) => Some(UserRulesContext {
uuid: match prompt_metadata.id {
PromptId::User { uuid } => uuid,
PromptId::EditWorkflow => return None,
},
title: prompt_metadata.title.map(|title| title.to_string()),
contents,
}),
Err(_err) => {
// TODO: show error message
// this.update(cx, |_, cx| {
// cx.emit(RulesLoadingError {
// message: format!("{err:?}").into(),
// });
// })
// .ok();
None
}
})
.collect::<Vec<_>>();
ProjectContext::new(worktrees, default_user_rules)
})
}
fn load_worktree_info_for_system_prompt(
worktree: Entity<Worktree>,
project: Entity<Project>,
cx: &mut App,
) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
let tree = worktree.read(cx);
let root_name = tree.root_name().into();
let abs_path = tree.abs_path();
let mut context = WorktreeContext {
root_name,
abs_path,
rules_file: None,
};
let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
let Some(rules_task) = rules_task else {
return Task::ready((context, None));
};
cx.spawn(async move |_| {
let (rules_file, rules_file_error) = match rules_task.await {
Ok(rules_file) => (Some(rules_file), None),
Err(err) => (
None,
Some(RulesLoadingError {
message: format!("{err}").into(),
}),
),
};
context.rules_file = rules_file;
(context, rules_file_error)
})
}
fn load_worktree_rules_file(
worktree: Entity<Worktree>,
project: Entity<Project>,
cx: &mut App,
) -> Option<Task<Result<RulesFileContext>>> {
let worktree = worktree.read(cx);
let worktree_id = worktree.id();
let selected_rules_file = RULES_FILE_NAMES
.into_iter()
.filter_map(|name| {
worktree
.entry_for_path(name)
.filter(|entry| entry.is_file())
.map(|entry| entry.path.clone())
})
.next();
// Note that Cline supports `.clinerules` being a directory, but that is not currently
// supported. This doesn't seem to occur often in GitHub repositories.
selected_rules_file.map(|path_in_worktree| {
let project_path = ProjectPath {
worktree_id,
path: path_in_worktree.clone(),
};
let buffer_task =
project.update(cx, |project, cx| project.open_buffer(project_path, cx));
let rope_task = cx.spawn(async move |cx| {
buffer_task.await?.read_with(cx, |buffer, cx| {
let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
})?
});
// Build a string from the rope on a background thread.
cx.background_spawn(async move {
let (project_entry_id, rope) = rope_task.await?;
anyhow::Ok(RulesFileContext {
path_in_worktree,
text: rope.to_string().trim().to_string(),
project_entry_id: project_entry_id.to_usize(),
})
})
})
}
fn handle_project_event(
&mut self,
_project: Entity<Project>,
event: &project::Event,
_cx: &mut Context<Self>,
) {
match event {
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
self.project_context_needs_refresh.send(()).ok();
}
project::Event::WorktreeUpdatedEntries(_, items) => {
if items.iter().any(|(path, _, _)| {
RULES_FILE_NAMES
.iter()
.any(|name| path.as_ref() == Path::new(name))
}) {
self.project_context_needs_refresh.send(()).ok();
}
}
_ => {}
}
}
fn handle_prompts_updated_event(
&mut self,
_prompt_store: Entity<PromptStore>,
_event: &prompt_store::PromptsUpdatedEvent,
_cx: &mut Context<Self>,
) {
self.project_context_needs_refresh.send(()).ok();
}
}
@@ -120,8 +373,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
cx.spawn(async move |cx| {
log::debug!("Starting thread creation in async context");
// Generate session ID
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
log::info!("Created session with ID: {}", session_id);
// Create AcpThread
let acp_thread = cx.update(|cx| {
cx.new(|cx| {
acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
})
})?;
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
// Create Thread
let (session_id, thread) = agent.update(
let thread = agent.update(
cx,
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
// Fetch default model from registry settings
@@ -146,22 +412,18 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
anyhow!("No default model configured. Please configure a default model in settings.")
})?;
let thread = cx.new(|_| Thread::new(project.clone(), agent.templates.clone(), default_model));
let thread = cx.new(|_| {
let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
thread.add_tool(ThinkingTool);
thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
thread
});
// Generate session ID
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
log::info!("Created session with ID: {}", session_id);
Ok((session_id, thread))
Ok(thread)
},
)??;
// Create AcpThread
let acp_thread = cx.update(|cx| {
cx.new(|cx| {
acp_thread::AcpThread::new("agent2", self.clone(), project, session_id.clone(), cx)
})
})?;
// Store the session
agent.update(cx, |agent, cx| {
agent.sessions.insert(
@@ -264,6 +526,28 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
)
})??;
}
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
tool_call,
options,
response,
}) => {
let recv = acp_thread.update(cx, |thread, cx| {
thread.request_tool_call_authorization(tool_call, options, cx)
})?;
cx.background_spawn(async move {
if let Some(option) = recv
.await
.context("authorization sender was dropped")
.log_err()
{
response
.send(option)
.map(|_| anyhow!("authorization receiver was dropped"))
.log_err();
}
})
.detach();
}
AgentResponseEvent::ToolCall(tool_call) => {
acp_thread.update(cx, |thread, cx| {
thread.handle_session_update(
@@ -343,3 +627,77 @@ fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
message
}
#[cfg(test)]
mod tests {
use super::*;
use fs::FakeFs;
use gpui::TestAppContext;
use serde_json::json;
use settings::SettingsStore;
#[gpui::test]
async fn test_maintaining_project_context(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/",
json!({
"a": {}
}),
)
.await;
let project = Project::test(fs.clone(), [], cx).await;
let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
.await
.unwrap();
agent.read_with(cx, |agent, _| {
assert_eq!(agent.project_context.borrow().worktrees, vec![])
});
let worktree = project
.update(cx, |project, cx| project.create_worktree("/a", true, cx))
.await
.unwrap();
cx.run_until_parked();
agent.read_with(cx, |agent, _| {
assert_eq!(
agent.project_context.borrow().worktrees,
vec![WorktreeContext {
root_name: "a".into(),
abs_path: Path::new("/a").into(),
rules_file: None
}]
)
});
// Creating `/a/.rules` updates the project context.
fs.insert_file("/a/.rules", Vec::new()).await;
cx.run_until_parked();
agent.read_with(cx, |agent, cx| {
let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
assert_eq!(
agent.project_context.borrow().worktrees,
vec![WorktreeContext {
root_name: "a".into(),
abs_path: Path::new("/a").into(),
rules_file: Some(RulesFileContext {
path_in_worktree: Path::new(".rules").into(),
text: "".into(),
project_entry_id: rules_entry.id.to_usize()
})
}]
)
});
}
fn init_test(cx: &mut TestAppContext) {
env_logger::try_init().ok();
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
language::init(cx);
});
}
}

View File

@@ -1,6 +1,5 @@
mod agent;
mod native_agent_server;
mod prompts;
mod templates;
mod thread;
mod tools;
@@ -11,3 +10,4 @@ mod tests;
pub use agent::*;
pub use native_agent_server::NativeAgentServer;
pub use thread::*;
pub use tools::*;

View File

@@ -3,8 +3,9 @@ use std::rc::Rc;
use agent_servers::AgentServer;
use anyhow::Result;
use gpui::{App, AppContext, Entity, Task};
use gpui::{App, Entity, Task};
use project::Project;
use prompt_store::PromptStore;
use crate::{templates::Templates, NativeAgent, NativeAgentConnection};
@@ -32,21 +33,22 @@ impl AgentServer for NativeAgentServer {
fn connect(
&self,
_root_dir: &Path,
_project: &Entity<Project>,
project: &Entity<Project>,
cx: &mut App,
) -> Task<Result<Rc<dyn acp_thread::AgentConnection>>> {
log::info!(
"NativeAgentServer::connect called for path: {:?}",
_root_dir
);
let project = project.clone();
let prompt_store = PromptStore::global(cx);
cx.spawn(async move |cx| {
log::debug!("Creating templates for native agent");
// Create templates (you might want to load these from files or resources)
let templates = Templates::new();
let prompt_store = prompt_store.await?;
// Create the native agent
log::debug!("Creating native agent entity");
let agent = cx.update(|cx| cx.new(|_| NativeAgent::new(templates)))?;
let agent = NativeAgent::new(project, templates, Some(prompt_store), cx).await?;
// Create the connection wrapper
let connection = NativeAgentConnection(agent);

View File

@@ -1,35 +0,0 @@
use crate::{
templates::{BaseTemplate, Template, Templates, WorktreeData},
thread::Prompt,
};
use anyhow::Result;
use gpui::{App, Entity};
use project::Project;
pub struct BasePrompt {
project: Entity<Project>,
}
impl BasePrompt {
pub fn new(project: Entity<Project>) -> Self {
Self { project }
}
}
impl Prompt for BasePrompt {
fn render(&self, templates: &Templates, cx: &App) -> Result<String> {
BaseTemplate {
os: std::env::consts::OS.to_string(),
shell: util::get_system_shell(),
worktrees: self
.project
.read(cx)
.worktrees(cx)
.map(|worktree| WorktreeData {
root_name: worktree.read(cx).root_name().to_string(),
})
.collect(),
}
.render(templates)
}
}

View File

@@ -1,9 +1,9 @@
use std::sync::Arc;
use anyhow::Result;
use gpui::SharedString;
use handlebars::Handlebars;
use rust_embed::RustEmbed;
use serde::Serialize;
use std::sync::Arc;
#[derive(RustEmbed)]
#[folder = "src/templates"]
@@ -15,6 +15,8 @@ pub struct Templates(Handlebars<'static>);
impl Templates {
pub fn new() -> Arc<Self> {
let mut handlebars = Handlebars::new();
handlebars.set_strict_mode(true);
handlebars.register_helper("contains", Box::new(contains));
handlebars.register_embed_templates::<Assets>().unwrap();
Arc::new(Self(handlebars))
}
@@ -32,26 +34,54 @@ pub trait Template: Sized {
}
#[derive(Serialize)]
pub struct BaseTemplate {
pub os: String,
pub shell: String,
pub worktrees: Vec<WorktreeData>,
pub struct SystemPromptTemplate<'a> {
#[serde(flatten)]
pub project: &'a prompt_store::ProjectContext,
pub available_tools: Vec<SharedString>,
}
impl Template for BaseTemplate {
const TEMPLATE_NAME: &'static str = "base.hbs";
impl Template for SystemPromptTemplate<'_> {
const TEMPLATE_NAME: &'static str = "system_prompt.hbs";
}
#[derive(Serialize)]
pub struct WorktreeData {
pub root_name: String,
/// Handlebars helper for checking if an item is in a list
fn contains(
h: &handlebars::Helper,
_: &handlebars::Handlebars,
_: &handlebars::Context,
_: &mut handlebars::RenderContext,
out: &mut dyn handlebars::Output,
) -> handlebars::HelperResult {
let list = h
.param(0)
.and_then(|v| v.value().as_array())
.ok_or_else(|| {
handlebars::RenderError::new("contains: missing or invalid list parameter")
})?;
let query = h.param(1).map(|v| v.value()).ok_or_else(|| {
handlebars::RenderError::new("contains: missing or invalid query parameter")
})?;
if list.contains(&query) {
out.write("true")?;
}
Ok(())
}
#[derive(Serialize)]
pub struct GlobTemplate {
pub project_roots: String,
}
#[cfg(test)]
mod tests {
use super::*;
impl Template for GlobTemplate {
const TEMPLATE_NAME: &'static str = "glob.hbs";
#[test]
fn test_system_prompt_template() {
let project = prompt_store::ProjectContext::default();
let template = SystemPromptTemplate {
project: &project,
available_tools: vec!["echo".into()],
};
let templates = Templates::new();
let rendered = template.render(&templates).unwrap();
assert!(rendered.contains("## Fixing Diagnostics"));
}
}

View File

@@ -1,56 +0,0 @@
You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
## Communication
1. Be conversational but professional.
2. Refer to the USER in the second person and yourself in the first person.
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names.
4. NEVER lie or make things up.
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
## Tool Use
1. Make sure to adhere to the tools schema.
2. Provide every required argument.
3. DO NOT use tools to access items that are already available in the context section.
4. Use only the tools that are currently available.
5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off.
## Searching and Reading
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
If appropriate, use tool calls to explore the current project, which contains the following root directories:
{{#each worktrees}}
- `{{root_name}}`
{{/each}}
- When providing paths to tools, the path should always begin with a path that starts with a project root directory listed above.
- When looking for symbols in the project, prefer the `grep` tool.
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
- Bias towards not asking the user for help if you can find the answer yourself.
## Fixing Diagnostics
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.
2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem.
## Debugging
When debugging, only make code changes if you are certain that you can solve the problem.
Otherwise, follow debugging best practices:
1. Address the root cause instead of the symptoms.
2. Add descriptive logging statements and error messages to track variable and code state.
3. Add test functions and statements to isolate the problem.
## Calling External APIs
1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission.
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data.
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
## System Information
Operating System: {{os}}
Default Shell: {{shell}}

View File

@@ -1,8 +0,0 @@
Find paths on disk with glob patterns.
Assume that all glob patterns are matched in a project directory with the following entries.
{{project_roots}}
When searching with patterns that begin with literal path components, e.g. `foo/bar/**/*.rs`, be
sure to anchor them with one of the directories listed above.

View File

@@ -0,0 +1,178 @@
You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
## Communication
1. Be conversational but professional.
2. Refer to the user in the second person and yourself in the first person.
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names.
4. NEVER lie or make things up.
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
{{#if (gt (len available_tools) 0)}}
## Tool Use
1. Make sure to adhere to the tools schema.
2. Provide every required argument.
3. DO NOT use tools to access items that are already available in the context section.
4. Use only the tools that are currently available.
5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off.
6. NEVER run commands that don't terminate on their own such as web servers (like `npm run start`, `npm run dev`, `python -m http.server`, etc) or file watchers.
7. Avoid HTML entity escaping - use plain characters instead.
## Searching and Reading
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
If appropriate, use tool calls to explore the current project, which contains the following root directories:
{{#each worktrees}}
- `{{abs_path}}`
{{/each}}
- Bias towards not asking the user for help if you can find the answer yourself.
- When providing paths to tools, the path should always start with the name of a project root directory listed above.
- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path!
{{# if (contains available_tools 'grep') }}
- When looking for symbols in the project, prefer the `grep` tool.
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file.
{{/if}}
{{else}}
You are being tasked with providing a response, but you have no ability to use tools or to read or write any aspect of the user's system (other than any context the user might have provided to you).
As such, if you need the user to perform any actions for you, you must request them explicitly. Bias towards giving a response to the best of your ability, and then making requests for the user to take action (e.g. to give you more context) only optionally.
The one exception to this is if the user references something you don't know about - for example, the name of a source code file, function, type, or other piece of code that you have no awareness of. In this case, you MUST NOT MAKE SOMETHING UP, or assume you know what that thing is or how it works. Instead, you must ask the user for clarification rather than giving a response.
{{/if}}
## Code Block Formatting
Whenever you mention a code block, you MUST use ONLY use the following format:
```path/to/Something.blah#L123-456
(code goes here)
```
The `#L123-456` means the line number range 123 through 456, and the path/to/Something.blah
is a path in the project. (If there is no valid path in the project, then you can use
/dev/null/path.extension for its path.) This is the ONLY valid way to format code blocks, because the Markdown parser
does not understand the more common ```language syntax, or bare ``` blocks. It only
understands this path-based syntax, and if the path is missing, then it will error and you will have to do it over again.
Just to be really clear about this, if you ever find yourself writing three backticks followed by a language name, STOP!
You have made a mistake. You can only ever put paths after triple backticks!
<example>
Based on all the information I've gathered, here's a summary of how this system works:
1. The README file is loaded into the system.
2. The system finds the first two headers, including everything in between. In this case, that would be:
```path/to/README.md#L8-12
# First Header
This is the info under the first header.
## Sub-header
```
3. Then the system finds the last header in the README:
```path/to/README.md#L27-29
## Last Header
This is the last header in the README.
```
4. Finally, it passes this information on to the next process.
</example>
<example>
In Markdown, hash marks signify headings. For example:
```/dev/null/example.md#L1-3
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</example>
Here are examples of ways you must never render code blocks:
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
```
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</bad_example_do_not_do_this>
This example is unacceptable because it does not include the path.
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
```markdown
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</bad_example_do_not_do_this>
This example is unacceptable because it has the language instead of the path.
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
# Level 1 heading
## Level 2 heading
### Level 3 heading
</bad_example_do_not_do_this>
This example is unacceptable because it uses indentation to mark the code block
instead of backticks with a path.
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
```markdown
/dev/null/example.md#L1-3
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</bad_example_do_not_do_this>
This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks.
{{#if (gt (len available_tools) 0)}}
## Fixing Diagnostics
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.
2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem.
## Debugging
When debugging, only make code changes if you are certain that you can solve the problem.
Otherwise, follow debugging best practices:
1. Address the root cause instead of the symptoms.
2. Add descriptive logging statements and error messages to track variable and code state.
3. Add test functions and statements to isolate the problem.
{{/if}}
## Calling External APIs
1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission.
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file(s). If no such file exists or if the package is not present, use the latest version that is in your training data.
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
## System Information
Operating System: {{os}}
Default Shell: {{shell}}
{{#if (or has_rules has_user_rules)}}
## User's Custom Instructions
The following additional instructions are provided by the user, and should be followed to the best of your ability{{#if (gt (len available_tools) 0)}} without interfering with the tool use guidelines{{/if}}.
{{#if has_rules}}
There are project rules that apply to these root directories:
{{#each worktrees}}
{{#if rules_file}}
`{{root_name}}/{{rules_file.path_in_worktree}}`:
``````
{{{rules_file.text}}}
``````
{{/if}}
{{/each}}
{{/if}}
{{#if has_user_rules}}
The user has specified the following rules that should be applied:
{{#each user_rules}}
{{#if title}}
Rules title: {{title}}
{{/if}}
``````
{{contents}}}
``````
{{/each}}
{{/if}}
{{/if}}

View File

@@ -1,30 +1,34 @@
use super::*;
use crate::templates::Templates;
use acp_thread::AgentConnection;
use agent_client_protocol as acp;
use agent_client_protocol::{self as acp};
use anyhow::Result;
use assistant_tool::ActionLog;
use client::{Client, UserStore};
use fs::FakeFs;
use futures::channel::mpsc::UnboundedReceiver;
use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
use indoc::indoc;
use language_model::{
fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, MessageContent,
StopReason,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult,
LanguageModelToolUse, MessageContent, Role, StopReason,
};
use project::Project;
use prompt_store::ProjectContext;
use reqwest_client::ReqwestClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
use smol::stream::StreamExt;
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
use util::path;
mod test_tools;
use test_tools::*;
#[gpui::test]
#[ignore = "temporarily disabled until it can be run on CI"]
#[ignore = "can't run on CI yet"]
async fn test_echo(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
@@ -44,7 +48,7 @@ async fn test_echo(cx: &mut TestAppContext) {
}
#[gpui::test]
#[ignore = "temporarily disabled until it can be run on CI"]
#[ignore = "can't run on CI yet"]
async fn test_thinking(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
@@ -77,7 +81,46 @@ async fn test_thinking(cx: &mut TestAppContext) {
}
#[gpui::test]
#[ignore = "temporarily disabled until it can be run on CI"]
async fn test_system_prompt(cx: &mut TestAppContext) {
let ThreadTest {
model,
thread,
project_context,
..
} = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
project_context.borrow_mut().shell = "test-shell".into();
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
cx.run_until_parked();
let mut pending_completions = fake_model.pending_completions();
assert_eq!(
pending_completions.len(),
1,
"unexpected pending completions: {:?}",
pending_completions
);
let pending_completion = pending_completions.pop().unwrap();
assert_eq!(pending_completion.messages[0].role, Role::System);
let system_message = &pending_completion.messages[0];
let system_prompt = system_message.content[0].to_str().unwrap();
assert!(
system_prompt.contains("test-shell"),
"unexpected system message: {:?}",
system_message
);
assert!(
system_prompt.contains("## Fixing Diagnostics"),
"unexpected system message: {:?}",
system_message
);
}
#[gpui::test]
#[ignore = "can't run on CI yet"]
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
@@ -127,7 +170,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
}
#[gpui::test]
#[ignore = "temporarily disabled until it can be run on CI"]
#[ignore = "can't run on CI yet"]
async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
@@ -175,7 +218,161 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
}
#[gpui::test]
#[ignore = "temporarily disabled until it can be run on CI"]
async fn test_tool_authorization(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let mut events = thread.update(cx, |thread, cx| {
thread.add_tool(ToolRequiringPermission);
thread.send(model.clone(), "abc", cx)
});
cx.run_until_parked();
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: "tool_id_1".into(),
name: ToolRequiringPermission.name().into(),
raw_input: "{}".into(),
input: json!({}),
is_input_complete: true,
},
));
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: "tool_id_2".into(),
name: ToolRequiringPermission.name().into(),
raw_input: "{}".into(),
input: json!({}),
is_input_complete: true,
},
));
fake_model.end_last_completion_stream();
let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
// Approve the first
tool_call_auth_1
.response
.send(tool_call_auth_1.options[1].id.clone())
.unwrap();
cx.run_until_parked();
// Reject the second
tool_call_auth_2
.response
.send(tool_call_auth_1.options[2].id.clone())
.unwrap();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
let message = completion.messages.last().unwrap();
assert_eq!(
message.content,
vec![
MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
tool_name: ToolRequiringPermission.name().into(),
is_error: false,
content: "Allowed".into(),
output: None
}),
MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
tool_name: ToolRequiringPermission.name().into(),
is_error: true,
content: "Permission to run tool denied by user".into(),
output: None
})
]
);
}
#[gpui::test]
async fn test_tool_hallucination(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
cx.run_until_parked();
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: "tool_id_1".into(),
name: "nonexistent_tool".into(),
raw_input: "{}".into(),
input: json!({}),
is_input_complete: true,
},
));
fake_model.end_last_completion_stream();
let tool_call = expect_tool_call(&mut events).await;
assert_eq!(tool_call.title, "nonexistent_tool");
assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
let update = expect_tool_call_update(&mut events).await;
assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
}
async fn expect_tool_call(
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
) -> acp::ToolCall {
let event = events
.next()
.await
.expect("no tool call authorization event received")
.unwrap();
match event {
AgentResponseEvent::ToolCall(tool_call) => return tool_call,
event => {
panic!("Unexpected event {event:?}");
}
}
}
async fn expect_tool_call_update(
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
) -> acp::ToolCallUpdate {
let event = events
.next()
.await
.expect("no tool call authorization event received")
.unwrap();
match event {
AgentResponseEvent::ToolCallUpdate(tool_call_update) => return tool_call_update,
event => {
panic!("Unexpected event {event:?}");
}
}
}
async fn next_tool_call_authorization(
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
) -> ToolCallAuthorization {
loop {
let event = events
.next()
.await
.expect("no tool call authorization event received")
.unwrap();
if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
let permission_kinds = tool_call_authorization
.options
.iter()
.map(|o| o.kind)
.collect::<Vec<_>>();
assert_eq!(
permission_kinds,
vec![
acp::PermissionOptionKind::AllowAlways,
acp::PermissionOptionKind::AllowOnce,
acp::PermissionOptionKind::RejectOnce,
]
);
return tool_call_authorization;
}
}
}
#[gpui::test]
#[ignore = "can't run on CI yet"]
async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
@@ -214,7 +411,7 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
}
#[gpui::test]
#[ignore = "temporarily disabled until it can be run on CI"]
#[ignore = "can't run on CI yet"]
async fn test_cancellation(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
@@ -281,12 +478,10 @@ async fn test_cancellation(cx: &mut TestAppContext) {
#[gpui::test]
async fn test_refusal(cx: &mut TestAppContext) {
let fake_model = Arc::new(FakeLanguageModel::default());
let ThreadTest { thread, .. } = setup(cx, TestModel::Fake(fake_model.clone())).await;
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let events = thread.update(cx, |thread, cx| {
thread.send(fake_model.clone(), "Hello", cx)
});
let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx));
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
@@ -343,8 +538,16 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
});
cx.executor().forbid_parking();
// Create a project for new_thread
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
fake_fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
let cwd = Path::new("/test");
// Create agent and connection
let agent = cx.new(|_| NativeAgent::new(templates.clone()));
let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
.await
.unwrap();
let connection = NativeAgentConnection(agent.clone());
// Test model_selector returns Some
@@ -366,12 +569,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
assert!(!listed_models.is_empty(), "should have at least one model");
assert_eq!(listed_models[0].id().0, "fake");
// Create a project for new_thread
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
// Create a thread using new_thread
let cwd = Path::new("/test");
let connection_rc = Rc::new(connection.clone());
let acp_thread = cx
.update(|cx| {
@@ -441,6 +639,77 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
);
}
#[gpui::test]
async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
let fake_model = model.as_fake();
let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Think", cx));
cx.run_until_parked();
let input = json!({ "content": "Thinking hard!" });
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: "1".into(),
name: ThinkingTool.name().into(),
raw_input: input.to_string(),
input,
is_input_complete: true,
},
));
fake_model.end_last_completion_stream();
cx.run_until_parked();
let tool_call = expect_tool_call(&mut events).await;
assert_eq!(
tool_call,
acp::ToolCall {
id: acp::ToolCallId("1".into()),
title: "Thinking".into(),
kind: acp::ToolKind::Think,
status: acp::ToolCallStatus::Pending,
content: vec![],
locations: vec![],
raw_input: Some(json!({ "content": "Thinking hard!" })),
raw_output: None,
}
);
let update = expect_tool_call_update(&mut events).await;
assert_eq!(
update,
acp::ToolCallUpdate {
id: acp::ToolCallId("1".into()),
fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress,),
..Default::default()
},
}
);
let update = expect_tool_call_update(&mut events).await;
assert_eq!(
update,
acp::ToolCallUpdate {
id: acp::ToolCallId("1".into()),
fields: acp::ToolCallUpdateFields {
content: Some(vec!["Thinking hard!".into()]),
..Default::default()
},
}
);
let update = expect_tool_call_update(&mut events).await;
assert_eq!(
update,
acp::ToolCallUpdate {
id: acp::ToolCallId("1".into()),
fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::Completed),
..Default::default()
},
}
);
}
/// Filters out the stop events for asserting against in tests
fn stop_events(
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
@@ -457,12 +726,13 @@ fn stop_events(
struct ThreadTest {
model: Arc<dyn LanguageModel>,
thread: Entity<Thread>,
project_context: Rc<RefCell<ProjectContext>>,
}
enum TestModel {
Sonnet4,
Sonnet4Thinking,
Fake(Arc<FakeLanguageModel>),
Fake,
}
impl TestModel {
@@ -470,7 +740,7 @@ impl TestModel {
match self {
TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
TestModel::Fake(fake_model) => fake_model.id(),
TestModel::Fake => unreachable!(),
}
}
}
@@ -499,8 +769,8 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
if let TestModel::Fake(model) = model {
Task::ready(model as Arc<_>)
if let TestModel::Fake = model {
Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
} else {
let model_id = model.id();
let models = LanguageModelRegistry::read_global(cx);
@@ -520,9 +790,22 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
})
.await;
let thread = cx.new(|_| Thread::new(project, templates, model.clone()));
ThreadTest { model, thread }
let project_context = Rc::new(RefCell::new(ProjectContext::default()));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let thread = cx.new(|_| {
Thread::new(
project,
project_context.clone(),
action_log,
templates,
model.clone(),
)
});
ThreadTest {
model,
thread,
project_context,
}
}
#[cfg(test)]

View File

@@ -19,7 +19,20 @@ impl AgentTool for EchoTool {
"echo".into()
}
fn run(self: Arc<Self>, input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Other
}
fn initial_title(&self, _: Self::Input) -> SharedString {
"Echo".into()
}
fn run(
self: Arc<Self>,
input: Self::Input,
_event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Task<Result<String>> {
Task::ready(Ok(input.text))
}
}
@@ -40,7 +53,20 @@ impl AgentTool for DelayTool {
"delay".into()
}
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>
fn initial_title(&self, input: Self::Input) -> SharedString {
format!("Delay {}ms", input.ms).into()
}
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Other
}
fn run(
self: Arc<Self>,
input: Self::Input,
_event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>>
where
Self: Sized,
{
@@ -51,6 +77,43 @@ impl AgentTool for DelayTool {
}
}
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct ToolRequiringPermissionInput {}
pub struct ToolRequiringPermission;
impl AgentTool for ToolRequiringPermission {
type Input = ToolRequiringPermissionInput;
fn name(&self) -> SharedString {
"tool_requiring_permission".into()
}
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Other
}
fn initial_title(&self, _input: Self::Input) -> SharedString {
"This tool requires permission".into()
}
fn run(
self: Arc<Self>,
input: Self::Input,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>>
where
Self: Sized,
{
let auth_check = self.authorize(input, event_stream);
cx.foreground_executor().spawn(async move {
auth_check.await?;
Ok("Allowed".to_string())
})
}
}
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct InfiniteToolInput {}
@@ -63,7 +126,20 @@ impl AgentTool for InfiniteTool {
"infinite".into()
}
fn run(self: Arc<Self>, _input: Self::Input, cx: &mut App) -> Task<Result<String>> {
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Other
}
fn initial_title(&self, _input: Self::Input) -> SharedString {
"This is the tool that never ends... it just goes on and on my friends!".into()
}
fn run(
self: Arc<Self>,
_input: Self::Input,
_event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
cx.foreground_executor().spawn(async move {
future::pending::<()>().await;
unreachable!()
@@ -100,7 +176,20 @@ impl AgentTool for WordListTool {
"word_list".into()
}
fn run(self: Arc<Self>, _input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
fn initial_title(&self, _input: Self::Input) -> SharedString {
"List of random words".into()
}
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Other
}
fn run(
self: Arc<Self>,
_input: Self::Input,
_event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Task<Result<String>> {
Task::ready(Ok("ok".to_string()))
}
}

View File

@@ -1,22 +1,27 @@
use crate::{prompts::BasePrompt, templates::Templates};
use crate::templates::{SystemPromptTemplate, Template, Templates};
use agent_client_protocol as acp;
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context as _, Result};
use assistant_tool::{adapt_schema_to_format, ActionLog};
use cloud_llm_client::{CompletionIntent, CompletionMode};
use collections::HashMap;
use futures::{channel::mpsc, stream::FuturesUnordered};
use gpui::{App, Context, Entity, ImageFormat, SharedString, Task};
use futures::{
channel::{mpsc, oneshot},
stream::FuturesUnordered,
};
use gpui::{App, Context, Entity, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
};
use log;
use project::Project;
use prompt_store::ProjectContext;
use schemars::{JsonSchema, Schema};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use smol::stream::StreamExt;
use std::{collections::BTreeMap, fmt::Write, sync::Arc};
use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
use util::{markdown::MarkdownCodeBlock, ResultExt};
#[derive(Debug, Clone)]
@@ -97,11 +102,15 @@ pub enum AgentResponseEvent {
Thinking(String),
ToolCall(acp::ToolCall),
ToolCallUpdate(acp::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization),
Stop(acp::StopReason),
}
pub trait Prompt {
fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
#[derive(Debug)]
pub struct ToolCallAuthorization {
pub tool_call: acp::ToolCall,
pub options: Vec<acp::PermissionOption>,
pub response: oneshot::Sender<acp::PermissionOptionId>,
}
pub struct Thread {
@@ -112,28 +121,31 @@ pub struct Thread {
/// we run tools, report their results.
running_turn: Option<Task<()>>,
pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
system_prompts: Vec<Arc<dyn Prompt>>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
project_context: Rc<RefCell<ProjectContext>>,
templates: Arc<Templates>,
pub selected_model: Arc<dyn LanguageModel>,
// action_log: Entity<ActionLog>,
action_log: Entity<ActionLog>,
}
impl Thread {
pub fn new(
project: Entity<Project>,
_project: Entity<Project>,
project_context: Rc<RefCell<ProjectContext>>,
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
default_model: Arc<dyn LanguageModel>,
) -> Self {
Self {
messages: Vec::new(),
completion_mode: CompletionMode::Normal,
system_prompts: vec![Arc::new(BasePrompt::new(project))],
running_turn: None,
pending_tool_uses: HashMap::default(),
tools: BTreeMap::default(),
project_context,
templates,
selected_model: default_model,
action_log,
}
}
@@ -188,6 +200,7 @@ impl Thread {
cx.notify();
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let event_stream = AgentResponseEventStream(events_tx);
let user_message_ix = self.messages.len();
self.messages.push(AgentMessage {
@@ -222,12 +235,7 @@ impl Thread {
while let Some(event) = events.next().await {
match event {
Ok(LanguageModelCompletionEvent::Stop(reason)) => {
if let Some(reason) = to_acp_stop_reason(reason) {
events_tx
.unbounded_send(Ok(AgentResponseEvent::Stop(reason)))
.ok();
}
event_stream.send_stop(reason);
if reason == StopReason::Refusal {
thread.update(cx, |thread, _cx| {
thread.messages.truncate(user_message_ix);
@@ -240,14 +248,16 @@ impl Thread {
thread
.update(cx, |thread, cx| {
tool_uses.extend(thread.handle_streamed_completion_event(
event, &events_tx, cx,
event,
&event_stream,
cx,
));
})
.ok();
}
Err(error) => {
log::error!("Error in completion stream: {:?}", error);
events_tx.unbounded_send(Err(error)).ok();
event_stream.send_error(error);
break;
}
}
@@ -266,11 +276,17 @@ impl Thread {
while let Some(tool_result) = tool_uses.next().await {
log::info!("Tool finished {:?}", tool_result);
events_tx
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
to_acp_tool_call_update(&tool_result),
)))
.ok();
event_stream.send_tool_call_update(
&tool_result.tool_use_id,
acp::ToolCallUpdateFields {
status: Some(if tool_result.is_error {
acp::ToolCallStatus::Failed
} else {
acp::ToolCallStatus::Completed
}),
..Default::default()
},
);
thread
.update(cx, |thread, _cx| {
thread.pending_tool_uses.remove(&tool_result.tool_use_id);
@@ -291,7 +307,7 @@ impl Thread {
if let Err(error) = turn_result {
log::error!("Turn execution failed: {:?}", error);
events_tx.unbounded_send(Err(error)).ok();
event_stream.send_error(error);
} else {
log::info!("Turn execution completed successfully");
}
@@ -299,24 +315,24 @@ impl Thread {
events_rx
}
pub fn build_system_message(&self, cx: &App) -> Option<AgentMessage> {
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn build_system_message(&self) -> AgentMessage {
log::debug!("Building system message");
let mut system_message = AgentMessage {
role: Role::System,
content: Vec::new(),
};
for prompt in &self.system_prompts {
if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
system_message
.content
.push(MessageContent::Text(rendered_prompt));
}
let prompt = SystemPromptTemplate {
project: &self.project_context.borrow(),
available_tools: self.tools.keys().cloned().collect(),
}
.render(&self.templates)
.context("failed to build system prompt")
.expect("Invalid template");
log::debug!("System message built");
AgentMessage {
role: Role::System,
content: vec![prompt.into()],
}
let result = (!system_message.content.is_empty()).then_some(system_message);
log::debug!("System message built: {}", result.is_some());
result
}
/// A helper method that's called on every streamed completion event.
@@ -325,7 +341,7 @@ impl Thread {
fn handle_streamed_completion_event(
&mut self,
event: LanguageModelCompletionEvent,
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
event_stream: &AgentResponseEventStream,
cx: &mut Context<Self>,
) -> Option<Task<LanguageModelToolResult>> {
log::trace!("Handling streamed completion event: {:?}", event);
@@ -338,13 +354,13 @@ impl Thread {
content: Vec::new(),
});
}
Text(new_text) => self.handle_text_event(new_text, events_tx, cx),
Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
Thinking { text, signature } => {
self.handle_thinking_event(text, signature, events_tx, cx)
self.handle_thinking_event(text, signature, event_stream, cx)
}
RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
ToolUse(tool_use) => {
return self.handle_tool_use_event(tool_use, events_tx, cx);
return self.handle_tool_use_event(tool_use, event_stream, cx);
}
ToolUseJsonParseError {
id,
@@ -369,12 +385,10 @@ impl Thread {
fn handle_text_event(
&mut self,
new_text: String,
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
events_stream: &AgentResponseEventStream,
cx: &mut Context<Self>,
) {
events_tx
.unbounded_send(Ok(AgentResponseEvent::Text(new_text.clone())))
.ok();
events_stream.send_text(&new_text);
let last_message = self.last_assistant_message();
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
@@ -390,12 +404,10 @@ impl Thread {
&mut self,
new_text: String,
new_signature: Option<String>,
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
event_stream: &AgentResponseEventStream,
cx: &mut Context<Self>,
) {
events_tx
.unbounded_send(Ok(AgentResponseEvent::Thinking(new_text.clone())))
.ok();
event_stream.send_thinking(&new_text);
let last_message = self.last_assistant_message();
if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
@@ -423,11 +435,13 @@ impl Thread {
fn handle_tool_use_event(
&mut self,
tool_use: LanguageModelToolUse,
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
event_stream: &AgentResponseEventStream,
cx: &mut Context<Self>,
) -> Option<Task<LanguageModelToolResult>> {
cx.notify();
let tool = self.tools.get(tool_use.name.as_ref()).cloned();
self.pending_tool_uses
.insert(tool_use.id.clone(), tool_use.clone());
let last_message = self.last_assistant_message();
@@ -445,82 +459,74 @@ impl Thread {
true
}
});
if push_new_tool_use {
events_tx
.unbounded_send(Ok(AgentResponseEvent::ToolCall(acp::ToolCall {
id: acp::ToolCallId(tool_use.id.to_string().into()),
title: tool_use.name.to_string(),
kind: acp::ToolKind::Other,
status: acp::ToolCallStatus::Pending,
content: vec![],
locations: vec![],
raw_input: Some(tool_use.input.clone()),
})))
.ok();
event_stream.send_tool_call(tool.as_ref(), &tool_use);
last_message
.content
.push(MessageContent::ToolUse(tool_use.clone()));
} else {
events_tx
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use.id.to_string().into()),
fields: acp::ToolCallUpdateFields {
raw_input: Some(tool_use.input.clone()),
..Default::default()
},
},
)))
.ok();
event_stream.send_tool_call_update(
&tool_use.id,
acp::ToolCallUpdateFields {
raw_input: Some(tool_use.input.clone()),
..Default::default()
},
);
}
if !tool_use.is_input_complete {
return None;
}
if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
events_tx
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use.id.to_string().into()),
fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
},
},
)))
.ok();
let pending_tool_result = tool.clone().run(tool_use.input, cx);
Some(cx.foreground_executor().spawn(async move {
match pending_tool_result.await {
Ok(tool_output) => LanguageModelToolResult {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: false,
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
output: None,
},
Err(error) => LanguageModelToolResult {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: true,
content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
output: None,
},
}
}))
} else {
let Some(tool) = tool else {
let content = format!("No tool named {} exists", tool_use.name);
Some(Task::ready(LanguageModelToolResult {
return Some(Task::ready(LanguageModelToolResult {
content: LanguageModelToolResultContent::Text(Arc::from(content)),
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: true,
output: None,
}))
}
}));
};
let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx);
Some(cx.foreground_executor().spawn(async move {
match tool_result.await {
Ok(tool_output) => LanguageModelToolResult {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: false,
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
output: None,
},
Err(error) => LanguageModelToolResult {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: true,
content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
output: None,
},
}
}))
}
fn run_tool(
&self,
tool: Arc<dyn AnyAgentTool>,
tool_use: LanguageModelToolUse,
event_stream: AgentResponseEventStream,
cx: &mut Context<Self>,
) -> Task<Result<String>> {
cx.spawn(async move |_this, cx| {
let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream);
tool_event_stream.send_update(acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
});
cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))?
.await
})
}
fn handle_tool_use_json_parse_error_event(
@@ -575,7 +581,7 @@ impl Thread {
log::debug!("Completion intent: {:?}", completion_intent);
log::debug!("Completion mode: {:?}", self.completion_mode);
let messages = self.build_request_messages(cx);
let messages = self.build_request_messages();
log::info!("Request will include {} messages", messages.len());
let tools: Vec<LanguageModelRequestTool> = self
@@ -588,7 +594,7 @@ impl Thread {
name: tool_name,
description: tool.description(cx).to_string(),
input_schema: tool
.input_schema(LanguageModelToolSchemaFormat::JsonSchema)
.input_schema(self.selected_model.tool_input_format())
.log_err()?,
})
})
@@ -613,14 +619,13 @@ impl Thread {
request
}
fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
log::trace!(
"Building request messages from {} thread messages",
self.messages.len()
);
let messages = self
.build_system_message(cx)
let messages = Some(self.build_system_message())
.iter()
.chain(self.messages.iter())
.map(|message| {
@@ -656,9 +661,10 @@ pub trait AgentTool
where
Self: 'static + Sized,
{
type Input: for<'de> Deserialize<'de> + JsonSchema;
type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
fn name(&self) -> SharedString;
fn description(&self, _cx: &mut App) -> SharedString {
let schema = schemars::schema_for!(Self::Input);
SharedString::new(
@@ -669,13 +675,33 @@ where
)
}
fn kind(&self) -> acp::ToolKind;
/// The initial tool title to display. Can be updated during the tool run.
fn initial_title(&self, input: Self::Input) -> SharedString;
/// Returns the JSON schema that describes the tool's input.
fn input_schema(&self, _format: LanguageModelToolSchemaFormat) -> Schema {
fn input_schema(&self) -> Schema {
schemars::schema_for!(Self::Input)
}
/// Allows the tool to authorize a given tool call with the user if necessary
fn authorize(
&self,
input: Self::Input,
event_stream: ToolCallEventStream,
) -> impl use<Self> + Future<Output = Result<()>> {
let json_input = serde_json::json!(&input);
event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input)
}
/// Runs the tool with the provided input.
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
fn run(
self: Arc<Self>,
input: Self::Input,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>>;
fn erase(self) -> Arc<dyn AnyAgentTool> {
Arc::new(Erased(Arc::new(self)))
@@ -687,8 +713,15 @@ pub struct Erased<T>(T);
pub trait AnyAgentTool {
fn name(&self) -> SharedString;
fn description(&self, cx: &mut App) -> SharedString;
fn kind(&self) -> acp::ToolKind;
fn initial_title(&self, input: serde_json::Value) -> Result<SharedString>;
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
fn run(
self: Arc<Self>,
input: serde_json::Value,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>>;
}
impl<T> AnyAgentTool for Erased<Arc<T>>
@@ -703,52 +736,220 @@ where
self.0.description(cx)
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
Ok(serde_json::to_value(self.0.input_schema(format))?)
fn kind(&self) -> agent_client_protocol::ToolKind {
self.0.kind()
}
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
fn initial_title(&self, input: serde_json::Value) -> Result<SharedString> {
let parsed_input = serde_json::from_value(input)?;
Ok(self.0.initial_title(parsed_input))
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
let mut json = serde_json::to_value(self.0.input_schema())?;
adapt_schema_to_format(&mut json, format)?;
Ok(json)
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
match parsed_input {
Ok(input) => self.0.clone().run(input, cx),
Ok(input) => self.0.clone().run(input, event_stream, cx),
Err(error) => Task::ready(Err(anyhow!(error))),
}
}
}
fn to_acp_stop_reason(reason: StopReason) -> Option<acp::StopReason> {
match reason {
StopReason::EndTurn => Some(acp::StopReason::EndTurn),
StopReason::MaxTokens => Some(acp::StopReason::MaxTokens),
StopReason::Refusal => Some(acp::StopReason::Refusal),
StopReason::ToolUse => None,
#[derive(Clone)]
struct AgentResponseEventStream(
mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
);
impl AgentResponseEventStream {
fn send_text(&self, text: &str) {
self.0
.unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
.ok();
}
fn send_thinking(&self, text: &str) {
self.0
.unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
.ok();
}
fn authorize_tool_call(
&self,
id: &LanguageModelToolUseId,
title: String,
kind: acp::ToolKind,
input: serde_json::Value,
) -> impl use<> + Future<Output = Result<()>> {
let (response_tx, response_rx) = oneshot::channel();
self.0
.unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
ToolCallAuthorization {
tool_call: Self::initial_tool_call(id, title, kind, input),
options: vec![
acp::PermissionOption {
id: acp::PermissionOptionId("always_allow".into()),
name: "Always Allow".into(),
kind: acp::PermissionOptionKind::AllowAlways,
},
acp::PermissionOption {
id: acp::PermissionOptionId("allow".into()),
name: "Allow".into(),
kind: acp::PermissionOptionKind::AllowOnce,
},
acp::PermissionOption {
id: acp::PermissionOptionId("deny".into()),
name: "Deny".into(),
kind: acp::PermissionOptionKind::RejectOnce,
},
],
response: response_tx,
},
)))
.ok();
async move {
match response_rx.await?.0.as_ref() {
"allow" | "always_allow" => Ok(()),
_ => Err(anyhow!("Permission to run tool denied by user")),
}
}
}
fn send_tool_call(
&self,
tool: Option<&Arc<dyn AnyAgentTool>>,
tool_use: &LanguageModelToolUse,
) {
self.0
.unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
&tool_use.id,
tool.and_then(|t| t.initial_title(tool_use.input.clone()).ok())
.map(|i| i.into())
.unwrap_or_else(|| tool_use.name.to_string()),
tool.map(|t| t.kind()).unwrap_or(acp::ToolKind::Other),
tool_use.input.clone(),
))))
.ok();
}
fn initial_tool_call(
id: &LanguageModelToolUseId,
title: String,
kind: acp::ToolKind,
input: serde_json::Value,
) -> acp::ToolCall {
acp::ToolCall {
id: acp::ToolCallId(id.to_string().into()),
title,
kind,
status: acp::ToolCallStatus::Pending,
content: vec![],
locations: vec![],
raw_input: Some(input),
raw_output: None,
}
}
fn send_tool_call_update(
&self,
tool_use_id: &LanguageModelToolUseId,
fields: acp::ToolCallUpdateFields,
) {
self.0
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.to_string().into()),
fields,
},
)))
.ok();
}
fn send_stop(&self, reason: StopReason) {
match reason {
StopReason::EndTurn => {
self.0
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
.ok();
}
StopReason::MaxTokens => {
self.0
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
.ok();
}
StopReason::Refusal => {
self.0
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
.ok();
}
StopReason::ToolUse => {}
}
}
fn send_error(&self, error: LanguageModelCompletionError) {
self.0.unbounded_send(Err(error)).ok();
}
}
fn to_acp_tool_call_update(tool_result: &LanguageModelToolResult) -> acp::ToolCallUpdate {
let status = if tool_result.is_error {
acp::ToolCallStatus::Failed
} else {
acp::ToolCallStatus::Completed
};
let content = match &tool_result.content {
LanguageModelToolResultContent::Text(text) => text.to_string().into(),
LanguageModelToolResultContent::Image(LanguageModelImage { source, .. }) => {
acp::ToolCallContent::Content {
content: acp::ContentBlock::Image(acp::ImageContent {
annotations: None,
data: source.to_string(),
mime_type: ImageFormat::Png.mime_type().to_string(),
}),
}
#[derive(Clone)]
pub struct ToolCallEventStream {
tool_use_id: LanguageModelToolUseId,
stream: AgentResponseEventStream,
}
impl ToolCallEventStream {
fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self {
Self {
tool_use_id,
stream,
}
};
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_result.tool_use_id.to_string().into()),
fields: acp::ToolCallUpdateFields {
status: Some(status),
content: Some(vec![content]),
..Default::default()
},
}
pub fn send_update(&self, fields: acp::ToolCallUpdateFields) {
self.stream.send_tool_call_update(&self.tool_use_id, fields);
}
pub fn authorize(
&self,
title: String,
kind: acp::ToolKind,
input: serde_json::Value,
) -> impl use<> + Future<Output = Result<()>> {
self.stream
.authorize_tool_call(&self.tool_use_id, title, kind, input)
}
}
#[cfg(test)]
pub struct TestToolCallEventStream {
stream: ToolCallEventStream,
_events_rx: mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
}
#[cfg(test)]
impl TestToolCallEventStream {
pub fn new() -> Self {
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let stream = ToolCallEventStream::new("test".into(), AgentResponseEventStream(events_tx));
Self {
stream,
_events_rx: events_rx,
}
}
pub fn stream(&self) -> ToolCallEventStream {
self.stream.clone()
}
}

View File

@@ -1 +1,7 @@
mod glob;
mod find_path_tool;
mod read_file_tool;
mod thinking_tool;
pub use find_path_tool::*;
pub use read_file_tool::*;
pub use thinking_tool::*;

View File

@@ -0,0 +1,231 @@
use agent_client_protocol as acp;
use anyhow::{anyhow, Result};
use gpui::{App, AppContext, Entity, SharedString, Task};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use std::{cmp, path::PathBuf, sync::Arc};
use util::paths::PathMatcher;
use crate::{AgentTool, ToolCallEventStream};
/// Fast file path pattern matching tool that works with any codebase size
///
/// - Supports glob patterns like "**/*.js" or "src/**/*.ts"
/// - Returns matching file paths sorted alphabetically
/// - Prefer the `grep` tool to this tool when searching for symbols unless you have specific information about paths.
/// - Use this tool when you need to find files by name patterns
/// - Results are paginated with 50 matches per page. Use the optional 'offset' parameter to request subsequent pages.
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct FindPathToolInput {
/// The glob to match against every path in the project.
///
/// <example>
/// If the project has the following root directories:
///
/// - directory1/a/something.txt
/// - directory2/a/things.txt
/// - directory3/a/other.txt
///
/// You can get back the first two paths by providing a glob of "*thing*.txt"
/// </example>
pub glob: String,
/// Optional starting position for paginated results (0-based).
/// When not provided, starts from the beginning.
#[serde(default)]
pub offset: usize,
}
#[derive(Debug, Serialize, Deserialize)]
struct FindPathToolOutput {
paths: Vec<PathBuf>,
}
const RESULTS_PER_PAGE: usize = 50;
pub struct FindPathTool {
project: Entity<Project>,
}
impl FindPathTool {
pub fn new(project: Entity<Project>) -> Self {
Self { project }
}
}
impl AgentTool for FindPathTool {
type Input = FindPathToolInput;
fn name(&self) -> SharedString {
"find_path".into()
}
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Search
}
fn initial_title(&self, input: Self::Input) -> SharedString {
format!("Find paths matching “`{}`”", input.glob).into()
}
fn run(
self: Arc<Self>,
input: Self::Input,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
let search_paths_task = search_paths(&input.glob, self.project.clone(), cx);
cx.background_spawn(async move {
let matches = search_paths_task.await?;
let paginated_matches: &[PathBuf] = &matches[cmp::min(input.offset, matches.len())
..cmp::min(input.offset + RESULTS_PER_PAGE, matches.len())];
event_stream.send_update(acp::ToolCallUpdateFields {
title: Some(if paginated_matches.len() == 0 {
"No matches".into()
} else if paginated_matches.len() == 1 {
"1 match".into()
} else {
format!("{} matches", paginated_matches.len())
}),
content: Some(
paginated_matches
.iter()
.map(|path| acp::ToolCallContent::Content {
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
uri: format!("file://{}", path.display()),
name: path.to_string_lossy().into(),
annotations: None,
description: None,
mime_type: None,
size: None,
title: None,
}),
})
.collect(),
),
raw_output: Some(serde_json::json!({
"paths": &matches,
})),
..Default::default()
});
if matches.is_empty() {
Ok("No matches found".into())
} else {
let mut message = format!("Found {} total matches.", matches.len());
if matches.len() > RESULTS_PER_PAGE {
write!(
&mut message,
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
input.offset + 1,
input.offset + paginated_matches.len()
)
.unwrap();
}
for mat in matches.iter().skip(input.offset).take(RESULTS_PER_PAGE) {
write!(&mut message, "\n{}", mat.display()).unwrap();
}
Ok(message)
}
})
}
}
fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Result<Vec<PathBuf>>> {
let path_matcher = match PathMatcher::new([
// Sometimes models try to search for "". In this case, return all paths in the project.
if glob.is_empty() { "*" } else { glob },
]) {
Ok(matcher) => matcher,
Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
};
let snapshots: Vec<_> = project
.read(cx)
.worktrees(cx)
.map(|worktree| worktree.read(cx).snapshot())
.collect();
cx.background_spawn(async move {
Ok(snapshots
.iter()
.flat_map(|snapshot| {
let root_name = PathBuf::from(snapshot.root_name());
snapshot
.entries(false, 0)
.map(move |entry| root_name.join(&entry.path))
.filter(|path| path_matcher.is_match(&path))
})
.collect())
})
}
#[cfg(test)]
mod test {
use super::*;
use gpui::TestAppContext;
use project::{FakeFs, Project};
use settings::SettingsStore;
use util::path;
#[gpui::test]
async fn test_find_path_tool(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
serde_json::json!({
"apple": {
"banana": {
"carrot": "1",
},
"bandana": {
"carbonara": "2",
},
"endive": "3"
}
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let matches = cx
.update(|cx| search_paths("root/**/car*", project.clone(), cx))
.await
.unwrap();
assert_eq!(
matches,
&[
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
]
);
let matches = cx
.update(|cx| search_paths("**/car*", project.clone(), cx))
.await
.unwrap();
assert_eq!(
matches,
&[
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
]
);
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
}
}

View File

@@ -1,76 +0,0 @@
use anyhow::{anyhow, Result};
use gpui::{App, AppContext, Entity, SharedString, Task};
use project::Project;
use schemars::JsonSchema;
use serde::Deserialize;
use std::{path::PathBuf, sync::Arc};
use util::paths::PathMatcher;
use worktree::Snapshot as WorktreeSnapshot;
use crate::{
templates::{GlobTemplate, Template, Templates},
thread::AgentTool,
};
// Description is dynamic, see `fn description` below
#[derive(Deserialize, JsonSchema)]
struct GlobInput {
/// A POSIX glob pattern
glob: SharedString,
}
struct GlobTool {
project: Entity<Project>,
templates: Arc<Templates>,
}
impl AgentTool for GlobTool {
type Input = GlobInput;
fn name(&self) -> SharedString {
"glob".into()
}
fn description(&self, cx: &mut App) -> SharedString {
let project_roots = self
.project
.read(cx)
.worktrees(cx)
.map(|worktree| worktree.read(cx).root_name().into())
.collect::<Vec<String>>()
.join("\n");
GlobTemplate { project_roots }
.render(&self.templates)
.expect("template failed to render")
.into()
}
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>> {
let path_matcher = match PathMatcher::new([&input.glob]) {
Ok(matcher) => matcher,
Err(error) => return Task::ready(Err(anyhow!(error))),
};
let snapshots: Vec<WorktreeSnapshot> = self
.project
.read(cx)
.worktrees(cx)
.map(|worktree| worktree.read(cx).snapshot())
.collect();
cx.background_spawn(async move {
let paths = snapshots.iter().flat_map(|snapshot| {
let root_name = PathBuf::from(snapshot.root_name());
snapshot
.entries(false, 0)
.map(move |entry| root_name.join(&entry.path))
.filter(|path| path_matcher.is_match(&path))
});
let output = paths
.map(|path| format!("{}\n", path.display()))
.collect::<String>();
Ok(output)
})
}
}

View File

@@ -0,0 +1,970 @@
use agent_client_protocol::{self as acp};
use anyhow::{anyhow, Result};
use assistant_tool::{outline, ActionLog};
use gpui::{Entity, Task};
use indoc::formatdoc;
use language::{Anchor, Point};
use project::{AgentLocation, Project, WorktreeSettings};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::sync::Arc;
use ui::{App, SharedString};
use crate::{AgentTool, ToolCallEventStream};
/// Reads the content of the given file in the project.
///
/// - Never attempt to read a path that hasn't been previously mentioned.
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ReadFileToolInput {
/// The relative path of the file to read.
///
/// This path should never be absolute, and the first component
/// of the path should always be a root directory in a project.
///
/// <example>
/// If the project has the following root directories:
///
/// - /a/b/directory1
/// - /c/d/directory2
///
/// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`.
/// If you want to access `file.txt` in `directory2`, you should use the path `directory2/file.txt`.
/// </example>
pub path: String,
/// Optional line number to start reading on (1-based index)
#[serde(default)]
pub start_line: Option<u32>,
/// Optional line number to end reading on (1-based index, inclusive)
#[serde(default)]
pub end_line: Option<u32>,
}
pub struct ReadFileTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
}
impl ReadFileTool {
pub fn new(project: Entity<Project>, action_log: Entity<ActionLog>) -> Self {
Self {
project,
action_log,
}
}
}
impl AgentTool for ReadFileTool {
type Input = ReadFileToolInput;
fn name(&self) -> SharedString {
"read_file".into()
}
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Read
}
fn initial_title(&self, input: Self::Input) -> SharedString {
let path = &input.path;
match (input.start_line, input.end_line) {
(Some(start), Some(end)) => {
format!(
"[Read file `{}` (lines {}-{})](@selection:{}:({}-{}))",
path, start, end, path, start, end
)
}
(Some(start), None) => {
format!(
"[Read file `{}` (from line {})](@selection:{}:({}-{}))",
path, start, path, start, start
)
}
_ => format!("[Read file `{}`](@file:{})", path, path),
}
.into()
}
fn run(
self: Arc<Self>,
input: Self::Input,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path)));
};
// Error out if this path is either excluded or private in global settings
let global_settings = WorktreeSettings::get_global(cx);
if global_settings.is_path_excluded(&project_path.path) {
return Task::ready(Err(anyhow!(
"Cannot read file because its path matches the global `file_scan_exclusions` setting: {}",
&input.path
)));
}
if global_settings.is_path_private(&project_path.path) {
return Task::ready(Err(anyhow!(
"Cannot read file because its path matches the global `private_files` setting: {}",
&input.path
)));
}
// Error out if this path is either excluded or private in worktree settings
let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx);
if worktree_settings.is_path_excluded(&project_path.path) {
return Task::ready(Err(anyhow!(
"Cannot read file because its path matches the worktree `file_scan_exclusions` setting: {}",
&input.path
)));
}
if worktree_settings.is_path_private(&project_path.path) {
return Task::ready(Err(anyhow!(
"Cannot read file because its path matches the worktree `private_files` setting: {}",
&input.path
)));
}
let file_path = input.path.clone();
event_stream.send_update(acp::ToolCallUpdateFields {
locations: Some(vec![acp::ToolCallLocation {
path: project_path.path.to_path_buf(),
line: input.start_line,
// TODO (tracked): use full range
}]),
..Default::default()
});
// TODO (tracked): images
// if image_store::is_image_file(&self.project, &project_path, cx) {
// let model = &self.thread.read(cx).selected_model;
// if !model.supports_images() {
// return Task::ready(Err(anyhow!(
// "Attempted to read an image, but Zed doesn't currently support sending images to {}.",
// model.name().0
// )))
// .into();
// }
// return cx.spawn(async move |cx| -> Result<ToolResultOutput> {
// let image_entity: Entity<ImageItem> = cx
// .update(|cx| {
// self.project.update(cx, |project, cx| {
// project.open_image(project_path.clone(), cx)
// })
// })?
// .await?;
// let image =
// image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
// let language_model_image = cx
// .update(|cx| LanguageModelImage::from_image(image, cx))?
// .await
// .context("processing image")?;
// Ok(ToolResultOutput {
// content: ToolResultContent::Image(language_model_image),
// output: None,
// })
// });
// }
//
let project = self.project.clone();
let action_log = self.action_log.clone();
cx.spawn(async move |cx| {
let buffer = cx
.update(|cx| {
project.update(cx, |project, cx| project.open_buffer(project_path, cx))
})?
.await?;
if buffer.read_with(cx, |buffer, _| {
buffer
.file()
.as_ref()
.map_or(true, |file| !file.disk_state().exists())
})? {
anyhow::bail!("{file_path} not found");
}
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: Anchor::MIN,
}),
cx,
);
})?;
// Check if specific line ranges are provided
if input.start_line.is_some() || input.end_line.is_some() {
let mut anchor = None;
let result = buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
// .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0.
let start = input.start_line.unwrap_or(1).max(1);
let start_row = start - 1;
if start_row <= buffer.max_point().row {
let column = buffer.line_indent_for_row(start_row).raw_len();
anchor = Some(buffer.anchor_before(Point::new(start_row, column)));
}
let lines = text.split('\n').skip(start_row as usize);
if let Some(end) = input.end_line {
let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
itertools::intersperse(lines.take(count as usize), "\n").collect::<String>()
} else {
itertools::intersperse(lines, "\n").collect::<String>()
}
})?;
action_log.update(cx, |log, cx| {
log.buffer_read(buffer.clone(), cx);
})?;
if let Some(anchor) = anchor {
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: anchor,
}),
cx,
);
})?;
}
Ok(result)
} else {
// No line ranges specified, so check file size to see if it's too big.
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
if file_size <= outline::AUTO_OUTLINE_SIZE {
// File is small enough, so return its contents.
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
action_log.update(cx, |log, cx| {
log.buffer_read(buffer, cx);
})?;
Ok(result)
} else {
// File is too big, so return the outline
// and a suggestion to read again with line numbers.
let outline =
outline::file_outline(project, file_path, action_log, None, cx).await?;
Ok(formatdoc! {"
This file was too big to read all at once.
Here is an outline of its symbols:
{outline}
Using the line numbers in this outline, you can call this tool again
while specifying the start_line and end_line fields to see the
implementations of symbols in the outline.
Alternatively, you can fall back to the `grep` tool (if available)
to search the file for specific content."
})
}
}
})
}
}
#[cfg(test)]
mod test {
use crate::TestToolCallEventStream;
use super::*;
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use util::path;
#[gpui::test]
async fn test_read_nonexistent_file(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/root"), json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/nonexistent_file.txt".to_string(),
start_line: None,
end_line: None,
};
tool.run(input, event_stream.stream(), cx)
})
.await;
assert_eq!(
result.unwrap_err().to_string(),
"root/nonexistent_file.txt not found"
);
}
#[gpui::test]
async fn test_read_small_file(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root"),
json!({
"small_file.txt": "This is a small file content"
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/small_file.txt".into(),
start_line: None,
end_line: None,
};
tool.run(input, event_stream.stream(), cx)
})
.await;
assert_eq!(result.unwrap(), "This is a small file content");
}
#[gpui::test]
async fn test_read_large_file(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root"),
json!({
"large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n a: u32,\n b: usize,\n}}", i)).collect::<Vec<_>>().join("\n")
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
language_registry.add(Arc::new(rust_lang()));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let content = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/large_file.rs".into(),
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await
.unwrap();
assert_eq!(
content.lines().skip(4).take(6).collect::<Vec<_>>(),
vec![
"struct Test0 [L1-4]",
" a [L2]",
" b [L3]",
"struct Test1 [L5-8]",
" a [L6]",
" b [L7]",
]
);
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/large_file.rs".into(),
start_line: None,
end_line: None,
};
tool.run(input, event_stream.stream(), cx)
})
.await;
let content = result.unwrap();
let expected_content = (0..1000)
.flat_map(|i| {
vec![
format!("struct Test{} [L{}-{}]", i, i * 4 + 1, i * 4 + 4),
format!(" a [L{}]", i * 4 + 2),
format!(" b [L{}]", i * 4 + 3),
]
})
.collect::<Vec<_>>();
pretty_assertions::assert_eq!(
content
.lines()
.skip(4)
.take(expected_content.len())
.collect::<Vec<_>>(),
expected_content
);
}
#[gpui::test]
async fn test_read_file_with_line_range(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root"),
json!({
"multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(2),
end_line: Some(4),
};
tool.run(input, event_stream.stream(), cx)
})
.await;
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
}
#[gpui::test]
async fn test_read_file_line_range_edge_cases(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root"),
json!({
"multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
// start_line of 0 should be treated as 1
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(0),
end_line: Some(2),
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await;
assert_eq!(result.unwrap(), "Line 1\nLine 2");
// end_line of 0 should result in at least 1 line
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(1),
end_line: Some(0),
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await;
assert_eq!(result.unwrap(), "Line 1");
// when start_line > end_line, should still return at least 1 line
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(3),
end_line: Some(2),
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await;
assert_eq!(result.unwrap(), "Line 3");
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_rust::LANGUAGE.into()),
)
.with_outline_query(
r#"
(line_comment) @annotation
(struct_item
"struct" @context
name: (_) @name) @item
(enum_item
"enum" @context
name: (_) @name) @item
(enum_variant
name: (_) @name) @item
(field_declaration
name: (_) @name) @item
(impl_item
"impl" @context
trait: (_)? @name
"for"? @context
type: (_) @name
body: (_ "{" (_)* "}")) @item
(function_item
"fn" @context
name: (_) @name) @item
(mod_item
"mod" @context
name: (_) @name) @item
"#,
)
.unwrap()
}
#[gpui::test]
async fn test_read_file_security(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/"),
json!({
"project_root": {
"allowed_file.txt": "This file is in the project",
".mysecrets": "SECRET_KEY=abc123",
".secretdir": {
"config": "special configuration"
},
".mymetadata": "custom metadata",
"subdir": {
"normal_file.txt": "Normal file content",
"special.privatekey": "private key content",
"data.mysensitive": "sensitive data"
}
},
"outside_project": {
"sensitive_file.txt": "This file is outside the project"
}
}),
)
.await;
cx.update(|cx| {
use gpui::UpdateGlobal;
use project::WorktreeSettings;
use settings::SettingsStore;
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<WorktreeSettings>(cx, |settings| {
settings.file_scan_exclusions = Some(vec![
"**/.secretdir".to_string(),
"**/.mymetadata".to_string(),
]);
settings.private_files = Some(vec![
"**/.mysecrets".to_string(),
"**/*.privatekey".to_string(),
"**/*.mysensitive".to_string(),
]);
});
});
});
let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
// Reading a file outside the project worktree should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "/outside_project/sensitive_file.txt".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read an absolute path outside a worktree"
);
// Reading a file within the project should succeed
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/allowed_file.txt".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await;
assert!(
result.is_ok(),
"read_file_tool should be able to read files inside worktrees"
);
// Reading files that match file_scan_exclusions should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/.secretdir/config".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read files in .secretdir (file_scan_exclusions)"
);
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/.mymetadata".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read .mymetadata files (file_scan_exclusions)"
);
// Reading private files should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/.mysecrets".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read .mysecrets (private_files)"
);
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/subdir/special.privatekey".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read .privatekey files (private_files)"
);
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/subdir/data.mysensitive".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read .mysensitive files (private_files)"
);
// Reading a normal file should still work, even with private_files configured
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/subdir/normal_file.txt".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await;
assert!(result.is_ok(), "Should be able to read normal files");
assert_eq!(result.unwrap(), "Normal file content");
// Path traversal attempts with .. should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/../outside_project/sensitive_file.txt".to_string(),
start_line: None,
end_line: None,
};
tool.run(input, event_stream.stream(), cx)
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read a relative path that resolves to outside a worktree"
);
}
#[gpui::test]
async fn test_read_file_with_multiple_worktree_settings(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
// Create first worktree with its own private_files setting
fs.insert_tree(
path!("/worktree1"),
json!({
"src": {
"main.rs": "fn main() { println!(\"Hello from worktree1\"); }",
"secret.rs": "const API_KEY: &str = \"secret_key_1\";",
"config.toml": "[database]\nurl = \"postgres://localhost/db1\""
},
"tests": {
"test.rs": "mod tests { fn test_it() {} }",
"fixture.sql": "CREATE TABLE users (id INT, name VARCHAR(255));"
},
".zed": {
"settings.json": r#"{
"file_scan_exclusions": ["**/fixture.*"],
"private_files": ["**/secret.rs", "**/config.toml"]
}"#
}
}),
)
.await;
// Create second worktree with different private_files setting
fs.insert_tree(
path!("/worktree2"),
json!({
"lib": {
"public.js": "export function greet() { return 'Hello from worktree2'; }",
"private.js": "const SECRET_TOKEN = \"private_token_2\";",
"data.json": "{\"api_key\": \"json_secret_key\"}"
},
"docs": {
"README.md": "# Public Documentation",
"internal.md": "# Internal Secrets and Configuration"
},
".zed": {
"settings.json": r#"{
"file_scan_exclusions": ["**/internal.*"],
"private_files": ["**/private.js", "**/data.json"]
}"#
}
}),
)
.await;
// Set global settings
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<WorktreeSettings>(cx, |settings| {
settings.file_scan_exclusions =
Some(vec!["**/.git".to_string(), "**/node_modules".to_string()]);
settings.private_files = Some(vec!["**/.env".to_string()]);
});
});
});
let project = Project::test(
fs.clone(),
[path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
cx,
)
.await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone()));
let event_stream = TestToolCallEventStream::new();
// Test reading allowed files in worktree1
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "worktree1/src/main.rs".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await
.unwrap();
assert_eq!(result, "fn main() { println!(\"Hello from worktree1\"); }");
// Test reading private file in worktree1 should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "worktree1/src/secret.rs".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("worktree `private_files` setting"),
"Error should mention worktree private_files setting"
);
// Test reading excluded file in worktree1 should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "worktree1/tests/fixture.sql".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("worktree `file_scan_exclusions` setting"),
"Error should mention worktree file_scan_exclusions setting"
);
// Test reading allowed files in worktree2
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "worktree2/lib/public.js".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await
.unwrap();
assert_eq!(
result,
"export function greet() { return 'Hello from worktree2'; }"
);
// Test reading private file in worktree2 should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "worktree2/lib/private.js".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("worktree `private_files` setting"),
"Error should mention worktree private_files setting"
);
// Test reading excluded file in worktree2 should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "worktree2/docs/internal.md".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("worktree `file_scan_exclusions` setting"),
"Error should mention worktree file_scan_exclusions setting"
);
// Test that files allowed in one worktree but not in another are handled correctly
// (e.g., config.toml is private in worktree1 but doesn't exist in worktree2)
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "worktree1/src/config.toml".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
})
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("worktree `private_files` setting"),
"Config.toml should be blocked by worktree1's private_files setting"
);
}
}

View File

@@ -0,0 +1,48 @@
use agent_client_protocol as acp;
use anyhow::Result;
use gpui::{App, SharedString, Task};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::{AgentTool, ToolCallEventStream};
/// A tool for thinking through problems, brainstorming ideas, or planning without executing any actions.
/// Use this tool when you need to work through complex problems, develop strategies, or outline approaches before taking action.
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ThinkingToolInput {
/// Content to think about. This should be a description of what to think about or
/// a problem to solve.
content: String,
}
pub struct ThinkingTool;
impl AgentTool for ThinkingTool {
type Input = ThinkingToolInput;
fn name(&self) -> SharedString {
"thinking".into()
}
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Think
}
fn initial_title(&self, _input: Self::Input) -> SharedString {
"Thinking".into()
}
fn run(
self: Arc<Self>,
input: Self::Input,
event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Task<Result<String>> {
event_stream.send_update(acp::ToolCallUpdateFields {
content: Some(vec![input.content.into()]),
..Default::default()
});
Task::ready(Ok("Finished thinking.".to_string()))
}
}

View File

@@ -135,7 +135,7 @@ impl acp_old::Client for OldAcpClientDelegate {
let response = cx
.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.request_tool_call_permission(tool_call, acp_options, cx)
thread.request_tool_call_authorization(tool_call, acp_options, cx)
})
})?
.context("Failed to update thread")?
@@ -280,6 +280,7 @@ fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams)
.map(into_new_tool_call_location)
.collect(),
raw_input: None,
raw_output: None,
}
}

View File

@@ -210,7 +210,7 @@ impl acp::Client for ClientDelegate {
.context("Failed to get session")?
.thread
.update(cx, |thread, cx| {
thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx)
thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
})?;
let result = rx.await;

View File

@@ -24,7 +24,7 @@ use futures::{
};
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use serde::{Deserialize, Serialize};
use util::ResultExt;
use util::{ResultExt, debug_panic};
use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
use crate::claude::tools::ClaudeTool;
@@ -153,16 +153,17 @@ impl AgentConnection for ClaudeAgentConnection {
})
.detach();
let end_turn_tx = Rc::new(RefCell::new(None));
let turn_state = Rc::new(RefCell::new(TurnState::None));
let handler_task = cx.spawn({
let end_turn_tx = end_turn_tx.clone();
let turn_state = turn_state.clone();
let mut thread_rx = thread_rx.clone();
async move |cx| {
while let Some(message) = incoming_message_rx.next().await {
ClaudeAgentSession::handle_message(
thread_rx.clone(),
message,
end_turn_tx.clone(),
turn_state.clone(),
cx,
)
.await
@@ -188,7 +189,7 @@ impl AgentConnection for ClaudeAgentConnection {
let session = ClaudeAgentSession {
outgoing_tx,
end_turn_tx,
turn_state,
_handler_task: handler_task,
_mcp_server: Some(permission_mcp_server),
};
@@ -220,8 +221,8 @@ impl AgentConnection for ClaudeAgentConnection {
)));
};
let (tx, rx) = oneshot::channel();
session.end_turn_tx.borrow_mut().replace(tx);
let (end_tx, end_rx) = oneshot::channel();
session.turn_state.replace(TurnState::InProgress { end_tx });
let mut content = String::new();
for chunk in params.prompt {
@@ -255,7 +256,7 @@ impl AgentConnection for ClaudeAgentConnection {
return Task::ready(Err(anyhow!(err)));
}
cx.foreground_executor().spawn(async move { rx.await? })
cx.foreground_executor().spawn(async move { end_rx.await? })
}
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
@@ -265,18 +266,27 @@ impl AgentConnection for ClaudeAgentConnection {
return;
};
let request_id = new_request_id();
let turn_state = session.turn_state.take();
let TurnState::InProgress { end_tx } = turn_state else {
// Already cancelled or idle, put it back
session.turn_state.replace(turn_state);
return;
};
session.turn_state.replace(TurnState::CancelRequested {
end_tx,
request_id: request_id.clone(),
});
session
.outgoing_tx
.unbounded_send(SdkMessage::new_interrupt_message())
.unbounded_send(SdkMessage::ControlRequest {
request_id,
request: ControlRequest::Interrupt,
})
.log_err();
if let Some(end_turn_tx) = session.end_turn_tx.borrow_mut().take() {
end_turn_tx
.send(Ok(acp::PromptResponse {
stop_reason: acp::StopReason::Cancelled,
}))
.ok();
}
}
}
@@ -338,26 +348,139 @@ fn spawn_claude(
struct ClaudeAgentSession {
outgoing_tx: UnboundedSender<SdkMessage>,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
turn_state: Rc<RefCell<TurnState>>,
_mcp_server: Option<ClaudeZedMcpServer>,
_handler_task: Task<()>,
}
#[derive(Debug, Default)]
enum TurnState {
#[default]
None,
InProgress {
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
},
CancelRequested {
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
request_id: String,
},
CancelConfirmed {
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
},
}
impl TurnState {
fn is_cancelled(&self) -> bool {
matches!(self, TurnState::CancelConfirmed { .. })
}
fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
match self {
TurnState::None => None,
TurnState::InProgress { end_tx, .. } => Some(end_tx),
TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
TurnState::CancelConfirmed { end_tx } => Some(end_tx),
}
}
fn confirm_cancellation(self, id: &str) -> Self {
match self {
TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
TurnState::CancelConfirmed { end_tx }
}
_ => self,
}
}
}
impl ClaudeAgentSession {
async fn handle_message(
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
message: SdkMessage,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
turn_state: Rc<RefCell<TurnState>>,
cx: &mut AsyncApp,
) {
match message {
// we should only be sending these out, they don't need to be in the thread
SdkMessage::ControlRequest { .. } => {}
SdkMessage::Assistant {
SdkMessage::User {
message,
session_id: _,
} => {
let Some(thread) = thread_rx
.recv()
.await
.log_err()
.and_then(|entity| entity.upgrade())
else {
log::error!("Received an SDK message but thread is gone");
return;
};
for chunk in message.content.chunks() {
match chunk {
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
if !turn_state.borrow().is_cancelled() {
thread
.update(cx, |thread, cx| {
thread.push_user_content_block(text.into(), cx)
})
.log_err();
}
}
ContentChunk::ToolResult {
content,
tool_use_id,
} => {
let content = content.to_string();
thread
.update(cx, |thread, cx| {
thread.update_tool_call(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.into()),
fields: acp::ToolCallUpdateFields {
status: if turn_state.borrow().is_cancelled() {
// Do not set to completed if turn was cancelled
None
} else {
Some(acp::ToolCallStatus::Completed)
},
content: (!content.is_empty())
.then(|| vec![content.into()]),
..Default::default()
},
},
cx,
)
})
.log_err();
}
ContentChunk::Thinking { .. }
| ContentChunk::RedactedThinking
| ContentChunk::ToolUse { .. } => {
debug_panic!(
"Should not get {:?} with role: assistant. should we handle this?",
chunk
);
}
ContentChunk::Image
| ContentChunk::Document
| ContentChunk::WebSearchToolResult => {
thread
.update(cx, |thread, cx| {
thread.push_assistant_content_block(
format!("Unsupported content: {:?}", chunk).into(),
false,
cx,
)
})
.log_err();
}
}
}
}
| SdkMessage::User {
SdkMessage::Assistant {
message,
session_id: _,
} => {
@@ -380,6 +503,24 @@ impl ClaudeAgentSession {
})
.log_err();
}
ContentChunk::Thinking { thinking } => {
thread
.update(cx, |thread, cx| {
thread.push_assistant_content_block(thinking.into(), true, cx)
})
.log_err();
}
ContentChunk::RedactedThinking => {
thread
.update(cx, |thread, cx| {
thread.push_assistant_content_block(
"[REDACTED]".into(),
true,
cx,
)
})
.log_err();
}
ContentChunk::ToolUse { id, name, input } => {
let claude_tool = ClaudeTool::infer(&name, input);
@@ -405,33 +546,12 @@ impl ClaudeAgentSession {
})
.log_err();
}
ContentChunk::ToolResult {
content,
tool_use_id,
} => {
let content = content.to_string();
thread
.update(cx, |thread, cx| {
thread.update_tool_call(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.into()),
fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::Completed),
content: (!content.is_empty())
.then(|| vec![content.into()]),
..Default::default()
},
},
cx,
)
})
.log_err();
ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
debug_panic!(
"Should not get tool results with role: assistant. should we handle this?"
);
}
ContentChunk::Image
| ContentChunk::Document
| ContentChunk::Thinking
| ContentChunk::RedactedThinking
| ContentChunk::WebSearchToolResult => {
ContentChunk::Image | ContentChunk::Document => {
thread
.update(cx, |thread, cx| {
thread.push_assistant_content_block(
@@ -451,27 +571,41 @@ impl ClaudeAgentSession {
result,
..
} => {
if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
if is_error || subtype == ResultErrorType::ErrorDuringExecution {
end_turn_tx
.send(Err(anyhow!(
"Error: {}",
result.unwrap_or_else(|| subtype.to_string())
)))
.ok();
} else {
let stop_reason = match subtype {
ResultErrorType::Success => acp::StopReason::EndTurn,
ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
ResultErrorType::ErrorDuringExecution => unreachable!(),
};
end_turn_tx
.send(Ok(acp::PromptResponse { stop_reason }))
.ok();
}
let turn_state = turn_state.take();
let was_cancelled = turn_state.is_cancelled();
let Some(end_turn_tx) = turn_state.end_tx() else {
debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
return;
};
if is_error || (!was_cancelled && subtype == ResultErrorType::ErrorDuringExecution)
{
end_turn_tx
.send(Err(anyhow!(
"Error: {}",
result.unwrap_or_else(|| subtype.to_string())
)))
.ok();
} else {
let stop_reason = match subtype {
ResultErrorType::Success => acp::StopReason::EndTurn,
ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
};
end_turn_tx
.send(Ok(acp::PromptResponse { stop_reason }))
.ok();
}
}
SdkMessage::System { .. } | SdkMessage::ControlResponse { .. } => {}
SdkMessage::ControlResponse { response } => {
if matches!(response.subtype, ResultErrorType::Success) {
let new_state = turn_state.take().confirm_cancellation(&response.request_id);
turn_state.replace(new_state);
} else {
log::error!("Control response error: {:?}", response);
}
}
SdkMessage::System { .. } => {}
}
}
@@ -580,11 +714,13 @@ enum ContentChunk {
content: Content,
tool_use_id: String,
},
Thinking {
thinking: String,
},
RedactedThinking,
// TODO
Image,
Document,
Thinking,
RedactedThinking,
WebSearchToolResult,
#[serde(untagged)]
UntaggedText(String),
@@ -594,12 +730,12 @@ impl Display for ContentChunk {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ContentChunk::Text { text } => write!(f, "{}", text),
ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
ContentChunk::UntaggedText(text) => write!(f, "{}", text),
ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
ContentChunk::Image
| ContentChunk::Document
| ContentChunk::Thinking
| ContentChunk::RedactedThinking
| ContentChunk::ToolUse { .. }
| ContentChunk::WebSearchToolResult => {
write!(f, "\n{:?}\n", &self)
@@ -710,22 +846,15 @@ impl Display for ResultErrorType {
}
}
impl SdkMessage {
fn new_interrupt_message() -> Self {
use rand::Rng;
// In the Claude Code TS SDK they just generate a random 12 character string,
// `Math.random().toString(36).substring(2, 15)`
let request_id = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(12)
.map(char::from)
.collect();
Self::ControlRequest {
request_id,
request: ControlRequest::Interrupt,
}
}
fn new_request_id() -> String {
use rand::Rng;
// In the Claude Code TS SDK they just generate a random 12 character string,
// `Math.random().toString(36).substring(2, 15)`
rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(12)
.map(char::from)
.collect()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -746,6 +875,8 @@ enum PermissionMode {
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::e2e_tests;
use gpui::TestAppContext;
use serde_json::json;
crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
@@ -758,6 +889,68 @@ pub(crate) mod tests {
}
}
#[gpui::test]
#[cfg_attr(not(feature = "e2e"), ignore)]
async fn test_todo_plan(cx: &mut TestAppContext) {
let fs = e2e_tests::init_test(cx).await;
let project = Project::test(fs, [], cx).await;
let thread =
e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
thread
.update(cx, |thread, cx| {
thread.send_raw(
"Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
cx,
)
})
.await
.unwrap();
let mut entries_len = 0;
thread.read_with(cx, |thread, _| {
entries_len = thread.plan().entries.len();
assert!(thread.plan().entries.len() > 0, "Empty plan");
});
thread
.update(cx, |thread, cx| {
thread.send_raw(
"Mark the first entry status as in progress without acting on it.",
cx,
)
})
.await
.unwrap();
thread.read_with(cx, |thread, _| {
assert!(matches!(
thread.plan().entries[0].status,
acp::PlanEntryStatus::InProgress
));
assert_eq!(thread.plan().entries.len(), entries_len);
});
thread
.update(cx, |thread, cx| {
thread.send_raw(
"Now mark the first entry as completed without acting on it.",
cx,
)
})
.await
.unwrap();
thread.read_with(cx, |thread, _| {
assert!(matches!(
thread.plan().entries[0].status,
acp::PlanEntryStatus::Completed
));
assert_eq!(thread.plan().entries.len(), entries_len);
});
}
#[test]
fn test_deserialize_content_untagged_text() {
let json = json!("Hello, world!");

View File

@@ -153,7 +153,7 @@ impl McpServerTool for PermissionTool {
let chosen_option = thread
.update(cx, |thread, cx| {
thread.request_tool_call_permission(
thread.request_tool_call_authorization(
claude_tool.as_acp(tool_call_id),
vec![
acp::PermissionOption {

View File

@@ -143,25 +143,6 @@ impl ClaudeTool {
Self::Grep(Some(params)) => vec![format!("`{params}`").into()],
Self::WebFetch(Some(params)) => vec![params.prompt.clone().into()],
Self::WebSearch(Some(params)) => vec![params.to_string().into()],
Self::TodoWrite(Some(params)) => vec![
params
.todos
.iter()
.map(|todo| {
format!(
"- {} {}: {}",
match todo.status {
TodoStatus::Completed => "",
TodoStatus::InProgress => "🚧",
TodoStatus::Pending => "",
},
todo.priority,
todo.content
)
})
.join("\n")
.into(),
],
Self::ExitPlanMode(Some(params)) => vec![params.plan.clone().into()],
Self::Edit(Some(params)) => vec![acp::ToolCallContent::Diff {
diff: acp::Diff {
@@ -193,6 +174,10 @@ impl ClaudeTool {
})
.unwrap_or_default()
}
Self::TodoWrite(Some(_)) => {
// These are mapped to plan updates later
vec![]
}
Self::Task(None)
| Self::NotebookRead(None)
| Self::NotebookEdit(None)
@@ -312,6 +297,7 @@ impl ClaudeTool {
content: self.content(),
locations: self.locations(),
raw_input: None,
raw_output: None,
}
}
}
@@ -488,10 +474,11 @@ impl std::fmt::Display for GrepToolParams {
}
}
#[derive(Deserialize, Serialize, JsonSchema, strum::Display, Debug)]
#[derive(Default, Deserialize, Serialize, JsonSchema, strum::Display, Debug)]
#[serde(rename_all = "snake_case")]
pub enum TodoPriority {
High,
#[default]
Medium,
Low,
}
@@ -526,14 +513,13 @@ impl Into<acp::PlanEntryStatus> for TodoStatus {
#[derive(Deserialize, Serialize, JsonSchema, Debug)]
pub struct Todo {
/// Unique identifier
pub id: String,
/// Task description
pub content: String,
/// Priority level of the todo
pub priority: TodoPriority,
/// Current status of the todo
pub status: TodoStatus,
/// Priority level of the todo
#[serde(default)]
pub priority: TodoPriority,
}
impl Into<acp::PlanEntry> for Todo {

View File

@@ -246,7 +246,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
let full_turn = thread.update(cx, |thread, cx| {
let _ = thread.update(cx, |thread, cx| {
thread.send_raw(
r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
cx,
@@ -285,9 +285,8 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
id.clone()
});
let _ = thread.update(cx, |thread, cx| thread.cancel(cx));
full_turn.await.unwrap();
thread.read_with(cx, |thread, _| {
thread.update(cx, |thread, cx| thread.cancel(cx)).await;
thread.read_with(cx, |thread, _cx| {
let AgentThreadEntry::ToolCall(ToolCall {
status: ToolCallStatus::Canceled,
..

View File

@@ -45,6 +45,11 @@ impl<T> MessageHistory<T> {
None
})
}
#[cfg(test)]
pub fn items(&self) -> &[T] {
&self.items
}
}
#[cfg(test)]
mod tests {

View File

@@ -21,20 +21,22 @@ use editor::{
use file_icons::FileIcons;
use gpui::{
Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId,
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, PlatformDisplay, SharedString,
StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation,
UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, linear_gradient,
list, percentage, point, prelude::*, pulsating_between,
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, PlatformDisplay,
SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement,
Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop,
linear_gradient, list, percentage, point, prelude::*, pulsating_between,
};
use language::language_settings::SoftWrap;
use language::{Buffer, Language};
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
use parking_lot::Mutex;
use project::Project;
use settings::Settings as _;
use settings::{Settings as _, SettingsStore};
use text::{Anchor, BufferSnapshot};
use theme::ThemeSettings;
use ui::{Disclosure, Divider, DividerColor, KeyBinding, Tooltip, prelude::*};
use ui::{
Disclosure, Divider, DividerColor, KeyBinding, Scrollbar, ScrollbarState, Tooltip, prelude::*,
};
use util::ResultExt;
use workspace::{CollaboratorId, Workspace};
use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage};
@@ -69,6 +71,7 @@ pub struct AcpThreadView {
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
last_error: Option<Entity<Markdown>>,
list_state: ListState,
scrollbar_state: ScrollbarState,
auth_task: Option<Task<()>>,
expanded_tool_calls: HashSet<acp::ToolCallId>,
expanded_thinking_blocks: HashSet<(usize, usize)>,
@@ -77,6 +80,7 @@ pub struct AcpThreadView {
editor_expanded: bool,
message_history: Rc<RefCell<MessageHistory<Vec<acp::ContentBlock>>>>,
_cancel_task: Option<Task<()>>,
_subscriptions: [Subscription; 1],
}
enum ThreadState {
@@ -175,6 +179,8 @@ impl AcpThreadView {
let list_state = ListState::new(0, gpui::ListAlignment::Bottom, px(2048.0));
let subscription = cx.observe_global_in::<SettingsStore>(window, Self::settings_changed);
Self {
agent: agent.clone(),
workspace: workspace.clone(),
@@ -187,7 +193,8 @@ impl AcpThreadView {
notifications: Vec::new(),
notification_subscriptions: HashMap::default(),
diff_editors: Default::default(),
list_state: list_state,
list_state: list_state.clone(),
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
last_error: None,
auth_task: None,
expanded_tool_calls: HashSet::default(),
@@ -196,6 +203,7 @@ impl AcpThreadView {
plan_expanded: false,
editor_expanded: false,
message_history,
_subscriptions: [subscription],
_cancel_task: None,
}
}
@@ -373,6 +381,11 @@ impl AcpThreadView {
editor.display_map.update(cx, |map, cx| {
let snapshot = map.snapshot(cx);
for (crease_id, crease) in snapshot.crease_snapshot.creases() {
// Skip creases that have been edited out of the message buffer.
if !crease.range().start.is_valid(&snapshot.buffer_snapshot) {
continue;
}
if let Some(project_path) =
self.mention_set.lock().path_for_crease_id(crease_id)
{
@@ -700,15 +713,7 @@ impl AcpThreadView {
editor.set_show_code_actions(false, cx);
editor.set_show_git_diff_gutter(false, cx);
editor.set_expand_all_diff_hunks(cx);
editor.set_text_style_refinement(TextStyleRefinement {
font_size: Some(
TextSize::Small
.rems(cx)
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
.into(),
),
..Default::default()
});
editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
editor
});
let entity_id = multibuffer.entity_id();
@@ -854,6 +859,7 @@ impl AcpThreadView {
.into_any()
}
AgentThreadEntry::ToolCall(tool_call) => div()
.w_full()
.py_1p5()
.px_5()
.child(self.render_tool_call(index, tool_call, window, cx))
@@ -866,6 +872,7 @@ impl AcpThreadView {
let is_generating = matches!(thread.read(cx).status(), ThreadStatus::Generating);
if index == total_entries - 1 && !is_generating {
v_flex()
.w_full()
.child(primary)
.child(self.render_thread_controls(cx))
.into_any_element()
@@ -898,6 +905,7 @@ impl AcpThreadView {
cx: &Context<Self>,
) -> AnyElement {
let header_id = SharedString::from(format!("thinking-block-header-{}", entry_ix));
let card_header_id = SharedString::from("inner-card-header");
let key = (entry_ix, chunk_ix);
let is_open = self.expanded_thinking_blocks.contains(&key);
@@ -905,41 +913,53 @@ impl AcpThreadView {
.child(
h_flex()
.id(header_id)
.group("disclosure-header")
.group(&card_header_id)
.relative()
.w_full()
.justify_between()
.gap_1p5()
.opacity(0.8)
.hover(|style| style.opacity(1.))
.child(
h_flex()
.gap_1p5()
.child(
Icon::new(IconName::ToolBulb)
.size(IconSize::Small)
.color(Color::Muted),
)
.size_4()
.justify_center()
.child(
div()
.text_size(self.tool_name_font_size())
.child("Thinking"),
.group_hover(&card_header_id, |s| s.invisible().w_0())
.child(
Icon::new(IconName::ToolThink)
.size(IconSize::Small)
.color(Color::Muted),
),
)
.child(
h_flex()
.absolute()
.inset_0()
.invisible()
.justify_center()
.group_hover(&card_header_id, |s| s.visible())
.child(
Disclosure::new(("expand", entry_ix), is_open)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronRight)
.on_click(cx.listener({
move |this, _event, _window, cx| {
if is_open {
this.expanded_thinking_blocks.remove(&key);
} else {
this.expanded_thinking_blocks.insert(key);
}
cx.notify();
}
})),
),
),
)
.child(
div().visible_on_hover("disclosure-header").child(
Disclosure::new("thinking-disclosure", is_open)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown)
.on_click(cx.listener({
move |this, _event, _window, cx| {
if is_open {
this.expanded_thinking_blocks.remove(&key);
} else {
this.expanded_thinking_blocks.insert(key);
}
cx.notify();
}
})),
),
div()
.text_size(self.tool_name_font_size())
.child("Thinking"),
)
.on_click(cx.listener({
move |this, _event, _window, cx| {
@@ -970,6 +990,67 @@ impl AcpThreadView {
.into_any_element()
}
fn render_tool_call_icon(
&self,
group_name: SharedString,
entry_ix: usize,
is_collapsible: bool,
is_open: bool,
tool_call: &ToolCall,
cx: &Context<Self>,
) -> Div {
let tool_icon = Icon::new(match tool_call.kind {
acp::ToolKind::Read => IconName::ToolRead,
acp::ToolKind::Edit => IconName::ToolPencil,
acp::ToolKind::Delete => IconName::ToolDeleteFile,
acp::ToolKind::Move => IconName::ArrowRightLeft,
acp::ToolKind::Search => IconName::ToolSearch,
acp::ToolKind::Execute => IconName::ToolTerminal,
acp::ToolKind::Think => IconName::ToolThink,
acp::ToolKind::Fetch => IconName::ToolWeb,
acp::ToolKind::Other => IconName::ToolHammer,
})
.size(IconSize::Small)
.color(Color::Muted);
if is_collapsible {
h_flex()
.size_4()
.justify_center()
.child(
div()
.group_hover(&group_name, |s| s.invisible().w_0())
.child(tool_icon),
)
.child(
h_flex()
.absolute()
.inset_0()
.invisible()
.justify_center()
.group_hover(&group_name, |s| s.visible())
.child(
Disclosure::new(("expand", entry_ix), is_open)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronRight)
.on_click(cx.listener({
let id = tool_call.id.clone();
move |this: &mut Self, _, _, cx: &mut Context<Self>| {
if is_open {
this.expanded_tool_calls.remove(&id);
} else {
this.expanded_tool_calls.insert(id.clone());
}
cx.notify();
}
})),
),
)
} else {
div().child(tool_icon)
}
}
fn render_tool_call(
&self,
entry_ix: usize,
@@ -977,7 +1058,8 @@ impl AcpThreadView {
window: &Window,
cx: &Context<Self>,
) -> Div {
let header_id = SharedString::from(format!("tool-call-header-{}", entry_ix));
let header_id = SharedString::from(format!("outer-tool-call-header-{}", entry_ix));
let card_header_id = SharedString::from("inner-tool-call-header");
let status_icon = match &tool_call.status {
ToolCallStatus::Allowed {
@@ -1026,6 +1108,21 @@ impl AcpThreadView {
let is_collapsible = !tool_call.content.is_empty() && !needs_confirmation;
let is_open = !is_collapsible || self.expanded_tool_calls.contains(&tool_call.id);
let gradient_color = cx.theme().colors().panel_background;
let gradient_overlay = {
div()
.absolute()
.top_0()
.right_0()
.w_12()
.h_full()
.bg(linear_gradient(
90.,
linear_color_stop(gradient_color, 1.),
linear_color_stop(gradient_color.opacity(0.2), 0.),
))
};
v_flex()
.when(needs_confirmation, |this| {
this.rounded_lg()
@@ -1042,43 +1139,38 @@ impl AcpThreadView {
.justify_between()
.map(|this| {
if needs_confirmation {
this.px_2()
this.pl_2()
.pr_1()
.py_1()
.rounded_t_md()
.bg(self.tool_card_header_bg(cx))
.border_b_1()
.border_color(self.tool_card_border_color(cx))
.bg(self.tool_card_header_bg(cx))
} else {
this.opacity(0.8).hover(|style| style.opacity(1.))
}
})
.child(
h_flex()
.id("tool-call-header")
.overflow_x_scroll()
.group(&card_header_id)
.relative()
.w_full()
.map(|this| {
if needs_confirmation {
this.text_xs()
if tool_call.locations.len() == 1 {
this.gap_0()
} else {
this.text_size(self.tool_name_font_size())
this.gap_1p5()
}
})
.gap_1p5()
.child(
Icon::new(match tool_call.kind {
acp::ToolKind::Read => IconName::ToolRead,
acp::ToolKind::Edit => IconName::ToolPencil,
acp::ToolKind::Delete => IconName::ToolDeleteFile,
acp::ToolKind::Move => IconName::ArrowRightLeft,
acp::ToolKind::Search => IconName::ToolSearch,
acp::ToolKind::Execute => IconName::ToolTerminal,
acp::ToolKind::Think => IconName::ToolBulb,
acp::ToolKind::Fetch => IconName::ToolWeb,
acp::ToolKind::Other => IconName::ToolHammer,
})
.size(IconSize::Small)
.color(Color::Muted),
)
.text_size(self.tool_name_font_size())
.child(self.render_tool_call_icon(
card_header_id,
entry_ix,
is_collapsible,
is_open,
tool_call,
cx,
))
.child(if tool_call.locations.len() == 1 {
let name = tool_call.locations[0]
.path
@@ -1089,13 +1181,11 @@ impl AcpThreadView {
h_flex()
.id(("open-tool-call-location", entry_ix))
.child(name)
.w_full()
.max_w_full()
.pr_1()
.gap_0p5()
.cursor_pointer()
.px_1p5()
.rounded_sm()
.overflow_x_scroll()
.opacity(0.8)
.hover(|label| {
label.opacity(1.).bg(cx
@@ -1104,53 +1194,49 @@ impl AcpThreadView {
.element_hover
.opacity(0.5))
})
.child(name)
.tooltip(Tooltip::text("Jump to File"))
.on_click(cx.listener(move |this, _, window, cx| {
this.open_tool_call_location(entry_ix, 0, window, cx);
}))
.into_any_element()
} else {
self.render_markdown(
tool_call.label.clone(),
default_markdown_style(needs_confirmation, window, cx),
)
.into_any()
h_flex()
.id("non-card-label-container")
.w_full()
.relative()
.overflow_hidden()
.child(
h_flex()
.id("non-card-label")
.pr_8()
.w_full()
.overflow_x_scroll()
.child(self.render_markdown(
tool_call.label.clone(),
default_markdown_style(
needs_confirmation,
window,
cx,
),
)),
)
.child(gradient_overlay)
.on_click(cx.listener({
let id = tool_call.id.clone();
move |this: &mut Self, _, _, cx: &mut Context<Self>| {
if is_open {
this.expanded_tool_calls.remove(&id);
} else {
this.expanded_tool_calls.insert(id.clone());
}
cx.notify();
}
}))
.into_any()
}),
)
.child(
h_flex()
.gap_0p5()
.when(is_collapsible, |this| {
this.child(
Disclosure::new(("expand", entry_ix), is_open)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown)
.on_click(cx.listener({
let id = tool_call.id.clone();
move |this: &mut Self, _, _, cx: &mut Context<Self>| {
if is_open {
this.expanded_tool_calls.remove(&id);
} else {
this.expanded_tool_calls.insert(id.clone());
}
cx.notify();
}
})),
)
})
.children(status_icon),
)
.on_click(cx.listener({
let id = tool_call.id.clone();
move |this: &mut Self, _, _, cx: &mut Context<Self>| {
if is_open {
this.expanded_tool_calls.remove(&id);
} else {
this.expanded_tool_calls.insert(id.clone());
}
cx.notify();
}
})),
.children(status_icon),
)
.when(is_open, |this| {
this.child(
@@ -1244,8 +1330,7 @@ impl AcpThreadView {
cx: &Context<Self>,
) -> Div {
h_flex()
.py_1p5()
.px_1p5()
.p_1p5()
.gap_1()
.justify_end()
.when(!empty_content, |this| {
@@ -1271,6 +1356,7 @@ impl AcpThreadView {
})
.icon_position(IconPosition::Start)
.icon_size(IconSize::XSmall)
.label_size(LabelSize::Small)
.on_click(cx.listener({
let tool_call_id = tool_call_id.clone();
let option_id = option.id.clone();
@@ -1520,7 +1606,7 @@ impl AcpThreadView {
})
})
.when(!changed_buffers.is_empty(), |this| {
this.child(Divider::horizontal())
this.child(Divider::horizontal().color(DividerColor::Border))
.child(self.render_edits_summary(
action_log,
&changed_buffers,
@@ -1550,6 +1636,7 @@ impl AcpThreadView {
{
h_flex()
.w_full()
.cursor_default()
.gap_1()
.text_xs()
.text_color(cx.theme().colors().text_muted)
@@ -1579,7 +1666,7 @@ impl AcpThreadView {
let status_label = if stats.pending == 0 {
"All Done".to_string()
} else if stats.completed == 0 {
format!("{}", plan.entries.len())
format!("{} Tasks", plan.entries.len())
} else {
format!("{}/{}", stats.completed, plan.entries.len())
};
@@ -1693,7 +1780,6 @@ impl AcpThreadView {
.child(
h_flex()
.id("edits-container")
.cursor_pointer()
.w_full()
.gap_1()
.child(Disclosure::new("edits-disclosure", expanded))
@@ -2468,6 +2554,7 @@ impl AcpThreadView {
}));
h_flex()
.w_full()
.mr_1()
.pb_2()
.px(RESPONSE_PADDING_X)
@@ -2478,6 +2565,48 @@ impl AcpThreadView {
.child(open_as_markdown)
.child(scroll_to_top)
}
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Stateful<Div> {
div()
.id("acp-thread-scrollbar")
.occlude()
.on_mouse_move(cx.listener(|_, _, _, cx| {
cx.notify();
cx.stop_propagation()
}))
.on_hover(|_, _, cx| {
cx.stop_propagation();
})
.on_any_mouse_down(|_, _, cx| {
cx.stop_propagation();
})
.on_mouse_up(
MouseButton::Left,
cx.listener(|_, _, _, cx| {
cx.stop_propagation();
}),
)
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
cx.notify();
}))
.h_full()
.absolute()
.right_1()
.top_1()
.bottom_0()
.w(px(12.))
.cursor_default()
.children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx)))
}
fn settings_changed(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
for diff_editor in self.diff_editors.values() {
diff_editor.update(cx, |diff_editor, cx| {
diff_editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
cx.notify();
})
}
}
}
impl Focusable for AcpThreadView {
@@ -2552,6 +2681,7 @@ impl Render for AcpThreadView {
.flex_grow()
.into_any(),
)
.child(self.render_vertical_scrollbar(cx))
.children(match thread_clone.read(cx).status() {
ThreadStatus::Idle | ThreadStatus::WaitingForToolConfirmation => {
None
@@ -2754,6 +2884,18 @@ fn plan_label_markdown_style(
}
}
fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
TextStyleRefinement {
font_size: Some(
TextSize::Small
.rems(cx)
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
.into(),
),
..Default::default()
}
}
#[cfg(test)]
mod tests {
use agent_client_protocol::SessionId;
@@ -2761,8 +2903,12 @@ mod tests {
use fs::FakeFs;
use futures::future::try_join_all;
use gpui::{SemanticVersion, TestAppContext, VisualTestContext};
use lsp::{CompletionContext, CompletionTriggerKind};
use project::CompletionIntent;
use rand::Rng;
use serde_json::json;
use settings::SettingsStore;
use util::path;
use super::*;
@@ -2842,6 +2988,7 @@ mod tests {
content: vec!["hi".into()],
locations: vec![],
raw_input: None,
raw_output: None,
};
let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)])
.with_permission_requests(HashMap::from_iter([(
@@ -2874,6 +3021,109 @@ mod tests {
);
}
#[gpui::test]
async fn test_crease_removal(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/project", json!({"file": ""})).await;
let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
let agent = StubAgentServer::default();
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let thread_view = cx.update(|window, cx| {
cx.new(|cx| {
AcpThreadView::new(
Rc::new(agent),
workspace.downgrade(),
project,
Rc::new(RefCell::new(MessageHistory::default())),
1,
None,
window,
cx,
)
})
});
cx.run_until_parked();
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
let excerpt_id = message_editor.update(cx, |editor, cx| {
editor
.buffer()
.read(cx)
.excerpt_ids()
.into_iter()
.next()
.unwrap()
});
let completions = message_editor.update_in(cx, |editor, window, cx| {
editor.set_text("Hello @", window, cx);
let buffer = editor.buffer().read(cx).as_singleton().unwrap();
let completion_provider = editor.completion_provider().unwrap();
completion_provider.completions(
excerpt_id,
&buffer,
Anchor::MAX,
CompletionContext {
trigger_kind: CompletionTriggerKind::TRIGGER_CHARACTER,
trigger_character: Some("@".into()),
},
window,
cx,
)
});
let [_, completion]: [_; 2] = completions
.await
.unwrap()
.into_iter()
.flat_map(|response| response.completions)
.collect::<Vec<_>>()
.try_into()
.unwrap();
message_editor.update_in(cx, |editor, window, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx);
let start = snapshot
.anchor_in_excerpt(excerpt_id, completion.replace_range.start)
.unwrap();
let end = snapshot
.anchor_in_excerpt(excerpt_id, completion.replace_range.end)
.unwrap();
editor.edit([(start..end, completion.new_text)], cx);
(completion.confirm.unwrap())(CompletionIntent::Complete, window, cx);
});
cx.run_until_parked();
// Backspace over the inserted crease (and the following space).
message_editor.update_in(cx, |editor, window, cx| {
editor.backspace(&Default::default(), window, cx);
editor.backspace(&Default::default(), window, cx);
});
thread_view.update_in(cx, |thread_view, window, cx| {
thread_view.chat(&Chat, window, cx);
});
cx.run_until_parked();
let content = thread_view.update_in(cx, |thread_view, _window, _cx| {
thread_view
.message_history
.borrow()
.items()
.iter()
.flatten()
.cloned()
.collect::<Vec<_>>()
});
// We don't send a resource link for the deleted crease.
pretty_assertions::assert_matches!(content.as_slice(), [acp::ContentBlock::Text { .. }]);
}
async fn setup_thread_view(
agent: impl AgentServer + 'static,
cx: &mut TestAppContext,
@@ -3027,7 +3277,7 @@ mod tests {
let task = cx.spawn(async move |cx| {
if let Some((tool_call, options)) = permission_request {
let permission = thread.update(cx, |thread, cx| {
thread.request_tool_call_permission(
thread.request_tool_call_authorization(
tool_call.clone(),
options.clone(),
cx,

View File

@@ -69,8 +69,6 @@ pub struct ActiveThread {
messages: Vec<MessageId>,
list_state: ListState,
scrollbar_state: ScrollbarState,
show_scrollbar: bool,
hide_scrollbar_task: Option<Task<()>>,
rendered_messages_by_id: HashMap<MessageId, RenderedMessage>,
rendered_tool_uses: HashMap<LanguageModelToolUseId, RenderedToolUse>,
editing_message: Option<(MessageId, EditingMessageState)>,
@@ -805,9 +803,7 @@ impl ActiveThread {
expanded_thinking_segments: HashMap::default(),
expanded_code_blocks: HashMap::default(),
list_state: list_state.clone(),
scrollbar_state: ScrollbarState::new(list_state),
show_scrollbar: false,
hide_scrollbar_task: None,
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
editing_message: None,
last_error: None,
copied_code_block_ids: HashSet::default(),
@@ -2628,7 +2624,7 @@ impl ActiveThread {
h_flex()
.gap_1p5()
.child(
Icon::new(IconName::ToolBulb)
Icon::new(IconName::ToolThink)
.size(IconSize::Small)
.color(Color::Muted),
)
@@ -3502,60 +3498,37 @@ impl ActiveThread {
}
}
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Option<Stateful<Div>> {
if !self.show_scrollbar && !self.scrollbar_state.is_dragging() {
return None;
}
Some(
div()
.occlude()
.id("active-thread-scrollbar")
.on_mouse_move(cx.listener(|_, _, _, cx| {
cx.notify();
cx.stop_propagation()
}))
.on_hover(|_, _, cx| {
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Stateful<Div> {
div()
.occlude()
.id("active-thread-scrollbar")
.on_mouse_move(cx.listener(|_, _, _, cx| {
cx.notify();
cx.stop_propagation()
}))
.on_hover(|_, _, cx| {
cx.stop_propagation();
})
.on_any_mouse_down(|_, _, cx| {
cx.stop_propagation();
})
.on_mouse_up(
MouseButton::Left,
cx.listener(|_, _, _, cx| {
cx.stop_propagation();
})
.on_any_mouse_down(|_, _, cx| {
cx.stop_propagation();
})
.on_mouse_up(
MouseButton::Left,
cx.listener(|_, _, _, cx| {
cx.stop_propagation();
}),
)
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
cx.notify();
}))
.h_full()
.absolute()
.right_1()
.top_1()
.bottom_0()
.w(px(12.))
.cursor_default()
.children(Scrollbar::vertical(self.scrollbar_state.clone())),
)
}
fn hide_scrollbar_later(&mut self, cx: &mut Context<Self>) {
const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1);
self.hide_scrollbar_task = Some(cx.spawn(async move |thread, cx| {
cx.background_executor()
.timer(SCROLLBAR_SHOW_INTERVAL)
.await;
thread
.update(cx, |thread, cx| {
if !thread.scrollbar_state.is_dragging() {
thread.show_scrollbar = false;
cx.notify();
}
})
.log_err();
}))
}),
)
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
cx.notify();
}))
.h_full()
.absolute()
.right_1()
.top_1()
.bottom_0()
.w(px(12.))
.cursor_default()
.children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx)))
}
pub fn is_codeblock_expanded(&self, message_id: MessageId, ix: usize) -> bool {
@@ -3596,26 +3569,8 @@ impl Render for ActiveThread {
.size_full()
.relative()
.bg(cx.theme().colors().panel_background)
.on_mouse_move(cx.listener(|this, _, _, cx| {
this.show_scrollbar = true;
this.hide_scrollbar_later(cx);
cx.notify();
}))
.on_scroll_wheel(cx.listener(|this, _, _, cx| {
this.show_scrollbar = true;
this.hide_scrollbar_later(cx);
cx.notify();
}))
.on_mouse_up(
MouseButton::Left,
cx.listener(|this, _, _, cx| {
this.hide_scrollbar_later(cx);
}),
)
.child(list(self.list_state.clone(), cx.processor(Self::render_message)).flex_grow())
.when_some(self.render_vertical_scrollbar(cx), |this, scrollbar| {
this.child(scrollbar)
})
.child(self.render_vertical_scrollbar(cx))
}
}

View File

@@ -2343,6 +2343,16 @@ impl AgentPanel {
)
}
fn render_backdrop(&self, cx: &mut Context<Self>) -> impl IntoElement {
div()
.size_full()
.absolute()
.inset_0()
.bg(cx.theme().colors().panel_background)
.opacity(0.8)
.block_mouse_except_scroll()
}
fn render_trial_end_upsell(
&self,
_window: &mut Window,
@@ -2352,15 +2362,24 @@ impl AgentPanel {
return None;
}
Some(EndTrialUpsell::new(Arc::new({
let this = cx.entity();
move |_, cx| {
this.update(cx, |_this, cx| {
TrialEndUpsell::set_dismissed(true, cx);
cx.notify();
});
}
})))
Some(
v_flex()
.absolute()
.inset_0()
.size_full()
.bg(cx.theme().colors().panel_background)
.opacity(0.85)
.block_mouse_except_scroll()
.child(EndTrialUpsell::new(Arc::new({
let this = cx.entity();
move |_, cx| {
this.update(cx, |_this, cx| {
TrialEndUpsell::set_dismissed(true, cx);
cx.notify();
});
}
}))),
)
}
fn render_empty_state_section_header(
@@ -3210,9 +3229,10 @@ impl Render for AgentPanel {
// - Scrolling in all views works as expected
// - Files can be dropped into the panel
let content = v_flex()
.key_context(self.key_context())
.justify_between()
.relative()
.size_full()
.justify_between()
.key_context(self.key_context())
.on_action(cx.listener(Self::cancel))
.on_action(cx.listener(|this, action: &NewThread, window, cx| {
this.new_thread(action, window, cx);
@@ -3255,14 +3275,12 @@ impl Render for AgentPanel {
.on_action(cx.listener(Self::toggle_burn_mode))
.child(self.render_toolbar(window, cx))
.children(self.render_onboarding(window, cx))
.children(self.render_trial_end_upsell(window, cx))
.map(|parent| match &self.active_view {
ActiveView::Thread {
thread,
message_editor,
..
} => parent
.relative()
.child(
if thread.read(cx).is_empty() && !self.should_render_onboarding(cx) {
self.render_thread_empty_state(window, cx)
@@ -3299,21 +3317,10 @@ impl Render for AgentPanel {
})
.child(h_flex().relative().child(message_editor.clone()).when(
!LanguageModelRegistry::read_global(cx).has_authenticated_provider(cx),
|this| {
this.child(
div()
.size_full()
.absolute()
.inset_0()
.bg(cx.theme().colors().panel_background)
.opacity(0.8)
.block_mouse_except_scroll(),
)
},
|this| this.child(self.render_backdrop(cx)),
))
.child(self.render_drag_target(cx)),
ActiveView::ExternalAgentThread { thread_view, .. } => parent
.relative()
.child(thread_view.clone())
.child(self.render_drag_target(cx)),
ActiveView::History => parent.child(self.history.clone()),
@@ -3352,7 +3359,8 @@ impl Render for AgentPanel {
))
}
ActiveView::Configuration => parent.children(self.configuration.clone()),
});
})
.children(self.render_trial_end_upsell(window, cx));
match self.active_view.which_font_size_used() {
WhichFontSize::AgentFont => {

View File

@@ -1,26 +1,33 @@
use std::{sync::Arc, time::Duration};
use client::{Client, zed_urls};
use client::{Client, UserStore, zed_urls};
use cloud_llm_client::Plan;
use gpui::{
Animation, AnimationExt, AnyElement, App, IntoElement, RenderOnce, Transformation, Window,
percentage,
Animation, AnimationExt, AnyElement, App, Entity, IntoElement, RenderOnce, Transformation,
Window, percentage,
};
use ui::{Divider, Vector, VectorName, prelude::*};
use crate::{SignInStatus, plan_definitions::PlanDefinitions};
use crate::{SignInStatus, YoungAccountBanner, plan_definitions::PlanDefinitions};
#[derive(IntoElement, RegisterComponent)]
pub struct AiUpsellCard {
pub sign_in_status: SignInStatus,
pub sign_in: Arc<dyn Fn(&mut Window, &mut App)>,
pub account_too_young: bool,
pub user_plan: Option<Plan>,
pub tab_index: Option<isize>,
}
impl AiUpsellCard {
pub fn new(client: Arc<Client>, user_plan: Option<Plan>) -> Self {
pub fn new(
client: Arc<Client>,
user_store: &Entity<UserStore>,
user_plan: Option<Plan>,
cx: &mut App,
) -> Self {
let status = *client.status().borrow();
let store = user_store.read(cx);
Self {
user_plan,
@@ -32,6 +39,7 @@ impl AiUpsellCard {
})
.detach_and_log_err(cx);
}),
account_too_young: store.account_too_young(),
tab_index: None,
}
}
@@ -40,6 +48,7 @@ impl AiUpsellCard {
impl RenderOnce for AiUpsellCard {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
let plan_definitions = PlanDefinitions;
let young_account_banner = YoungAccountBanner;
let pro_section = v_flex()
.flex_grow()
@@ -128,7 +137,7 @@ impl RenderOnce for AiUpsellCard {
.size(rems_from_px(72.))
.child(
Vector::new(
VectorName::CertifiedUserStamp,
VectorName::ProUserStamp,
rems_from_px(72.),
rems_from_px(72.),
)
@@ -158,36 +167,70 @@ impl RenderOnce for AiUpsellCard {
SignInStatus::SignedIn => match self.user_plan {
None | Some(Plan::ZedFree) => card
.child(Label::new("Try Zed AI").size(LabelSize::Large))
.child(
div()
.max_w_3_4()
.mb_2()
.child(Label::new(description).color(Color::Muted)),
)
.child(plans_section)
.child(
footer_container
.child(
Button::new("start_trial", "Start 14-day Free Pro Trial")
.full_width()
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.when_some(self.tab_index, |this, tab_index| {
this.tab_index(tab_index)
})
.on_click(move |_, _window, cx| {
telemetry::event!(
"Start Trial Clicked",
state = "post-sign-in"
);
cx.open_url(&zed_urls::start_trial_url(cx))
}),
.map(|this| {
if self.account_too_young {
this.child(young_account_banner).child(
v_flex()
.mt_2()
.gap_1()
.child(
h_flex()
.gap_2()
.child(
Label::new("Pro")
.size(LabelSize::Small)
.color(Color::Accent)
.buffer_font(cx),
)
.child(Divider::horizontal()),
)
.child(plan_definitions.pro_plan(true))
.child(
Button::new("pro", "Get Started")
.full_width()
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.on_click(move |_, _window, cx| {
telemetry::event!(
"Upgrade To Pro Clicked",
state = "young-account"
);
cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))
}),
),
)
} else {
this.child(
div()
.max_w_3_4()
.mb_2()
.child(Label::new(description).color(Color::Muted)),
)
.child(plans_section)
.child(
Label::new("No credit card required")
.size(LabelSize::Small)
.color(Color::Muted),
),
),
footer_container
.child(
Button::new("start_trial", "Start 14-day Free Pro Trial")
.full_width()
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.when_some(self.tab_index, |this, tab_index| {
this.tab_index(tab_index)
})
.on_click(move |_, _window, cx| {
telemetry::event!(
"Start Trial Clicked",
state = "post-sign-in"
);
cx.open_url(&zed_urls::start_trial_url(cx))
}),
)
.child(
Label::new("No credit card required")
.size(LabelSize::Small)
.color(Color::Muted),
),
)
}
}),
Some(Plan::ZedProTrial) => card
.child(pro_trial_stamp)
.child(Label::new("You're in the Zed Pro Trial").size(LabelSize::Large))
@@ -255,48 +298,68 @@ impl Component for AiUpsellCard {
Some(
v_flex()
.gap_4()
.children(vec![example_group(vec![
single_example(
"Signed Out State",
AiUpsellCard {
sign_in_status: SignInStatus::SignedOut,
sign_in: Arc::new(|_, _| {}),
user_plan: None,
tab_index: Some(0),
}
.into_any_element(),
),
single_example(
"Free Plan",
AiUpsellCard {
sign_in_status: SignInStatus::SignedIn,
sign_in: Arc::new(|_, _| {}),
user_plan: Some(Plan::ZedFree),
tab_index: Some(1),
}
.into_any_element(),
),
single_example(
"Pro Trial",
AiUpsellCard {
sign_in_status: SignInStatus::SignedIn,
sign_in: Arc::new(|_, _| {}),
user_plan: Some(Plan::ZedProTrial),
tab_index: Some(1),
}
.into_any_element(),
),
single_example(
"Pro Plan",
AiUpsellCard {
sign_in_status: SignInStatus::SignedIn,
sign_in: Arc::new(|_, _| {}),
user_plan: Some(Plan::ZedPro),
tab_index: Some(1),
}
.into_any_element(),
),
])])
.items_center()
.max_w_4_5()
.child(single_example(
"Signed Out State",
AiUpsellCard {
sign_in_status: SignInStatus::SignedOut,
sign_in: Arc::new(|_, _| {}),
account_too_young: false,
user_plan: None,
tab_index: Some(0),
}
.into_any_element(),
))
.child(example_group_with_title(
"Signed In States",
vec![
single_example(
"Free Plan",
AiUpsellCard {
sign_in_status: SignInStatus::SignedIn,
sign_in: Arc::new(|_, _| {}),
account_too_young: false,
user_plan: Some(Plan::ZedFree),
tab_index: Some(1),
}
.into_any_element(),
),
single_example(
"Free Plan but Young Account",
AiUpsellCard {
sign_in_status: SignInStatus::SignedIn,
sign_in: Arc::new(|_, _| {}),
account_too_young: true,
user_plan: Some(Plan::ZedFree),
tab_index: Some(1),
}
.into_any_element(),
),
single_example(
"Pro Trial",
AiUpsellCard {
sign_in_status: SignInStatus::SignedIn,
sign_in: Arc::new(|_, _| {}),
account_too_young: false,
user_plan: Some(Plan::ZedProTrial),
tab_index: Some(1),
}
.into_any_element(),
),
single_example(
"Pro Plan",
AiUpsellCard {
sign_in_status: SignInStatus::SignedIn,
sign_in: Arc::new(|_, _| {}),
account_too_young: false,
user_plan: Some(Plan::ZedPro),
tab_index: Some(1),
}
.into_any_element(),
),
],
))
.into_any_element(),
)
}

View File

@@ -36,13 +36,12 @@ use crate::delete_path_tool::DeletePathTool;
use crate::diagnostics_tool::DiagnosticsTool;
use crate::edit_file_tool::EditFileTool;
use crate::fetch_tool::FetchTool;
use crate::find_path_tool::FindPathTool;
use crate::list_directory_tool::ListDirectoryTool;
use crate::now_tool::NowTool;
use crate::thinking_tool::ThinkingTool;
pub use edit_file_tool::{EditFileMode, EditFileToolInput};
pub use find_path_tool::FindPathToolInput;
pub use find_path_tool::*;
pub use grep_tool::{GrepTool, GrepToolInput};
pub use open_tool::OpenTool;
pub use project_notifications_tool::ProjectNotificationsTool;

View File

@@ -37,7 +37,7 @@ impl Tool for ThinkingTool {
}
fn icon(&self) -> IconName {
IconName::ToolBulb
IconName::ToolThink
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {

View File

@@ -134,10 +134,15 @@ impl Settings for AutoUpdateSetting {
type FileContent = Option<AutoUpdateSettingContent>;
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
let auto_update = [sources.server, sources.release_channel, sources.user]
.into_iter()
.find_map(|value| value.copied().flatten())
.unwrap_or(sources.default.ok_or_else(Self::missing_default)?);
let auto_update = [
sources.server,
sources.release_channel,
sources.operating_system,
sources.user,
]
.into_iter()
.find_map(|value| value.copied().flatten())
.unwrap_or(sources.default.ok_or_else(Self::missing_default)?);
Ok(Self(auto_update.0))
}

View File

@@ -226,35 +226,17 @@ impl UserStore {
match status {
Status::Authenticated | Status::Connected { .. } => {
if let Some(user_id) = client.user_id() {
let response = client
.cloud_client()
.get_authenticated_user()
.await
.log_err();
let current_user_and_response = if let Some(response) = response {
let user = Arc::new(User {
id: user_id,
github_login: response.user.github_login.clone().into(),
avatar_uri: response.user.avatar_url.clone().into(),
name: response.user.name.clone(),
});
Some((user, response))
} else {
None
};
current_user_tx
.send(
current_user_and_response
.as_ref()
.map(|(user, _)| user.clone()),
)
.await
.ok();
let response = client.cloud_client().get_authenticated_user().await;
let mut current_user = None;
cx.update(|cx| {
if let Some((user, response)) = current_user_and_response {
if let Some(response) = response.log_err() {
let user = Arc::new(User {
id: user_id,
github_login: response.user.github_login.clone().into(),
avatar_uri: response.user.avatar_url.clone().into(),
name: response.user.name.clone(),
});
current_user = Some(user.clone());
this.update(cx, |this, cx| {
this.by_github_login
.insert(user.github_login.clone(), user_id);
@@ -265,6 +247,7 @@ impl UserStore {
anyhow::Ok(())
}
})??;
current_user_tx.send(current_user).await.ok();
this.update(cx, |_, cx| cx.notify())?;
}

View File

@@ -2,5 +2,6 @@ ZED_ENVIRONMENT=production
RUST_LOG=info
INVITE_LINK_PREFIX=https://zed.dev/invites/
AUTO_JOIN_CHANNEL_ID=283
DATABASE_MAX_CONNECTIONS=250
# Set DATABASE_MAX_CONNECTIONS max connections in the `deploy_collab.yml`:
# https://github.com/zed-industries/zed/blob/main/.github/workflows/deploy_collab.yml
LLM_DATABASE_MAX_CONNECTIONS=25

View File

@@ -1,477 +1,10 @@
use anyhow::{Context as _, bail};
use chrono::{DateTime, Utc};
use sea_orm::ActiveValue;
use std::{sync::Arc, time::Duration};
use stripe::{CancellationDetailsReason, EventObject, EventType, ListEvents, SubscriptionStatus};
use util::ResultExt;
use std::sync::Arc;
use stripe::SubscriptionStatus;
use crate::AppState;
use crate::db::billing_subscription::{
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
};
use crate::db::{
CreateBillingCustomerParams, CreateBillingSubscriptionParams, CreateProcessedStripeEventParams,
UpdateBillingCustomerParams, UpdateBillingSubscriptionParams, billing_customer,
};
use crate::rpc::{ResultExt as _, Server};
use crate::stripe_client::{
StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
StripeSubscriptionId,
};
/// The amount of time we wait in between each poll of Stripe events.
///
/// This value should strike a balance between:
/// 1. Being short enough that we update quickly when something in Stripe changes
/// 2. Being long enough that we don't eat into our rate limits.
///
/// As a point of reference, the Sequin folks say they have this at **500ms**:
///
/// > We poll the Stripe /events endpoint every 500ms per account
/// >
/// > — https://blog.sequinstream.com/events-not-webhooks/
const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
/// The maximum number of events to return per page.
///
/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
///
/// > Limit can range between 1 and 100, and the default is 10.
const EVENTS_LIMIT_PER_PAGE: u64 = 100;
/// The number of pages consisting entirely of already-processed events that we
/// will see before we stop retrieving events.
///
/// This is used to prevent over-fetching the Stripe events API for events we've
/// already seen and processed.
const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
/// Polls the Stripe events API periodically to reconcile the records in our
/// database with the data in Stripe.
pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
let Some(real_stripe_client) = app.real_stripe_client.clone() else {
log::warn!("failed to retrieve Stripe client");
return;
};
let Some(stripe_client) = app.stripe_client.clone() else {
log::warn!("failed to retrieve Stripe client");
return;
};
let executor = app.executor.clone();
executor.spawn_detached({
let executor = executor.clone();
async move {
loop {
poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client)
.await
.log_err();
executor.sleep(POLL_EVENTS_INTERVAL).await;
}
}
});
}
async fn poll_stripe_events(
app: &Arc<AppState>,
rpc_server: &Arc<Server>,
stripe_client: &Arc<dyn StripeClient>,
real_stripe_client: &stripe::Client,
) -> anyhow::Result<()> {
let feature_flags = app.db.list_feature_flags().await?;
let sync_events_using_cloud = feature_flags
.iter()
.any(|flag| flag.flag == "cloud-stripe-events-polling" && flag.enabled_for_all);
if sync_events_using_cloud {
return Ok(());
}
fn event_type_to_string(event_type: EventType) -> String {
// Calling `to_string` on `stripe::EventType` members gives us a quoted string,
// so we need to unquote it.
event_type.to_string().trim_matches('"').to_string()
}
let event_types = [
EventType::CustomerCreated,
EventType::CustomerUpdated,
EventType::CustomerSubscriptionCreated,
EventType::CustomerSubscriptionUpdated,
EventType::CustomerSubscriptionPaused,
EventType::CustomerSubscriptionResumed,
EventType::CustomerSubscriptionDeleted,
]
.into_iter()
.map(event_type_to_string)
.collect::<Vec<_>>();
let mut pages_of_already_processed_events = 0;
let mut unprocessed_events = Vec::new();
log::info!(
"Stripe events: starting retrieval for {}",
event_types.join(", ")
);
let mut params = ListEvents::new();
params.types = Some(event_types.clone());
params.limit = Some(EVENTS_LIMIT_PER_PAGE);
let mut event_pages = stripe::Event::list(&real_stripe_client, &params)
.await?
.paginate(params);
loop {
let processed_event_ids = {
let event_ids = event_pages
.page
.data
.iter()
.map(|event| event.id.as_str())
.collect::<Vec<_>>();
app.db
.get_processed_stripe_events_by_event_ids(&event_ids)
.await?
.into_iter()
.map(|event| event.stripe_event_id)
.collect::<Vec<_>>()
};
let mut processed_events_in_page = 0;
let events_in_page = event_pages.page.data.len();
for event in &event_pages.page.data {
if processed_event_ids.contains(&event.id.to_string()) {
processed_events_in_page += 1;
log::debug!("Stripe events: already processed '{}', skipping", event.id);
} else {
unprocessed_events.push(event.clone());
}
}
if processed_events_in_page == events_in_page {
pages_of_already_processed_events += 1;
}
if event_pages.page.has_more {
if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
{
log::info!(
"Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
);
break;
} else {
log::info!("Stripe events: retrieving next page");
event_pages = event_pages.next(&real_stripe_client).await?;
}
} else {
break;
}
}
log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
// Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
for event in unprocessed_events {
let event_id = event.id.clone();
let processed_event_params = CreateProcessedStripeEventParams {
stripe_event_id: event.id.to_string(),
stripe_event_type: event_type_to_string(event.type_),
stripe_event_created_timestamp: event.created,
};
// If the event has happened too far in the past, we don't want to
// process it and risk overwriting other more-recent updates.
//
// 1 day was chosen arbitrarily. This could be made longer or shorter.
let one_day = Duration::from_secs(24 * 60 * 60);
let a_day_ago = Utc::now() - one_day;
if a_day_ago.timestamp() > event.created {
log::info!(
"Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
event_id
);
app.db
.create_processed_stripe_event(&processed_event_params)
.await?;
continue;
}
let process_result = match event.type_ {
EventType::CustomerCreated | EventType::CustomerUpdated => {
handle_customer_event(app, real_stripe_client, event).await
}
EventType::CustomerSubscriptionCreated
| EventType::CustomerSubscriptionUpdated
| EventType::CustomerSubscriptionPaused
| EventType::CustomerSubscriptionResumed
| EventType::CustomerSubscriptionDeleted => {
handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
}
_ => Ok(()),
};
if let Some(()) = process_result
.with_context(|| format!("failed to process event {event_id} successfully"))
.log_err()
{
app.db
.create_processed_stripe_event(&processed_event_params)
.await?;
}
}
Ok(())
}
async fn handle_customer_event(
app: &Arc<AppState>,
_stripe_client: &stripe::Client,
event: stripe::Event,
) -> anyhow::Result<()> {
let EventObject::Customer(customer) = event.data.object else {
bail!("unexpected event payload for {}", event.id);
};
log::info!("handling Stripe {} event: {}", event.type_, event.id);
let Some(email) = customer.email else {
log::info!("Stripe customer has no email: skipping");
return Ok(());
};
let Some(user) = app.db.get_user_by_email(&email).await? else {
log::info!("no user found for email: skipping");
return Ok(());
};
if let Some(existing_customer) = app
.db
.get_billing_customer_by_stripe_customer_id(&customer.id)
.await?
{
app.db
.update_billing_customer(
existing_customer.id,
&UpdateBillingCustomerParams {
// For now we just leave the information as-is, as it is not
// likely to change.
..Default::default()
},
)
.await?;
} else {
app.db
.create_billing_customer(&CreateBillingCustomerParams {
user_id: user.id,
stripe_customer_id: customer.id.to_string(),
})
.await?;
}
Ok(())
}
async fn sync_subscription(
app: &Arc<AppState>,
stripe_client: &Arc<dyn StripeClient>,
subscription: StripeSubscription,
) -> anyhow::Result<billing_customer::Model> {
let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
stripe_billing
.determine_subscription_kind(&subscription)
.await
} else {
None
};
let billing_customer =
find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer)
.await?
.context("billing customer not found")?;
if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
if subscription.status == SubscriptionStatus::Trialing {
let current_period_start =
DateTime::from_timestamp(subscription.current_period_start, 0)
.context("No trial subscription period start")?;
app.db
.update_billing_customer(
billing_customer.id,
&UpdateBillingCustomerParams {
trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
..Default::default()
},
)
.await?;
}
}
let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
&& subscription
.cancellation_details
.as_ref()
.and_then(|details| details.reason)
.map_or(false, |reason| {
reason == StripeCancellationDetailsReason::PaymentFailed
});
if was_canceled_due_to_payment_failure {
app.db
.update_billing_customer(
billing_customer.id,
&UpdateBillingCustomerParams {
has_overdue_invoices: ActiveValue::set(true),
..Default::default()
},
)
.await?;
}
if let Some(existing_subscription) = app
.db
.get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref())
.await?
{
app.db
.update_billing_subscription(
existing_subscription.id,
&UpdateBillingSubscriptionParams {
billing_customer_id: ActiveValue::set(billing_customer.id),
kind: ActiveValue::set(subscription_kind),
stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
stripe_subscription_status: ActiveValue::set(subscription.status.into()),
stripe_cancel_at: ActiveValue::set(
subscription
.cancel_at
.and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
.map(|time| time.naive_utc()),
),
stripe_cancellation_reason: ActiveValue::set(
subscription
.cancellation_details
.and_then(|details| details.reason)
.map(|reason| reason.into()),
),
stripe_current_period_start: ActiveValue::set(Some(
subscription.current_period_start,
)),
stripe_current_period_end: ActiveValue::set(Some(
subscription.current_period_end,
)),
},
)
.await?;
} else {
if let Some(existing_subscription) = app
.db
.get_active_billing_subscription(billing_customer.user_id)
.await?
{
if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
&& subscription_kind == Some(SubscriptionKind::ZedProTrial)
{
let stripe_subscription_id = StripeSubscriptionId(
existing_subscription.stripe_subscription_id.clone().into(),
);
stripe_client
.cancel_subscription(&stripe_subscription_id)
.await?;
} else {
// If the user already has an active billing subscription, ignore the
// event and return an `Ok` to signal that it was processed
// successfully.
//
// There is the possibility that this could cause us to not create a
// subscription in the following scenario:
//
// 1. User has an active subscription A
// 2. User cancels subscription A
// 3. User creates a new subscription B
// 4. We process the new subscription B before the cancellation of subscription A
// 5. User ends up with no subscriptions
//
// In theory this situation shouldn't arise as we try to process the events in the order they occur.
log::info!(
"user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
user_id = billing_customer.user_id,
subscription_id = subscription.id
);
return Ok(billing_customer);
}
}
app.db
.create_billing_subscription(&CreateBillingSubscriptionParams {
billing_customer_id: billing_customer.id,
kind: subscription_kind,
stripe_subscription_id: subscription.id.to_string(),
stripe_subscription_status: subscription.status.into(),
stripe_cancellation_reason: subscription
.cancellation_details
.and_then(|details| details.reason)
.map(|reason| reason.into()),
stripe_current_period_start: Some(subscription.current_period_start),
stripe_current_period_end: Some(subscription.current_period_end),
})
.await?;
}
if let Some(stripe_billing) = app.stripe_billing.as_ref() {
if subscription.status == SubscriptionStatus::Canceled
|| subscription.status == SubscriptionStatus::Paused
{
let already_has_active_billing_subscription = app
.db
.has_active_billing_subscription(billing_customer.user_id)
.await?;
if !already_has_active_billing_subscription {
let stripe_customer_id =
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
stripe_billing
.subscribe_to_zed_free(stripe_customer_id)
.await?;
}
}
}
Ok(billing_customer)
}
async fn handle_customer_subscription_event(
app: &Arc<AppState>,
rpc_server: &Arc<Server>,
stripe_client: &Arc<dyn StripeClient>,
event: stripe::Event,
) -> anyhow::Result<()> {
let EventObject::Subscription(subscription) = event.data.object else {
bail!("unexpected event payload for {}", event.id);
};
log::info!("handling Stripe {} event: {}", event.type_, event.id);
let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?;
// When the user's subscription changes, push down any changes to their plan.
rpc_server
.update_plan_for_user_legacy(billing_customer.user_id)
.await
.trace_err();
// When the user's subscription changes, we want to refresh their LLM tokens
// to either grant/revoke access.
rpc_server
.refresh_llm_tokens_for_user(billing_customer.user_id)
.await;
Ok(())
}
use crate::db::billing_subscription::StripeSubscriptionStatus;
use crate::db::{CreateBillingCustomerParams, billing_customer};
use crate::stripe_client::{StripeClient, StripeCustomerId};
impl From<SubscriptionStatus> for StripeSubscriptionStatus {
fn from(value: SubscriptionStatus) -> Self {
@@ -488,16 +21,6 @@ impl From<SubscriptionStatus> for StripeSubscriptionStatus {
}
}
impl From<CancellationDetailsReason> for StripeCancellationReason {
fn from(value: CancellationDetailsReason) -> Self {
match value {
CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
}
}
}
/// Finds or creates a billing customer using the provided customer.
pub async fn find_or_create_billing_customer(
app: &Arc<AppState>,

View File

@@ -7,6 +7,7 @@ use axum::{
routing::get,
};
use collab::ServiceMode;
use collab::api::CloudflareIpCountryHeader;
use collab::llm::db::LlmDatabase;
use collab::migrations::run_database_migrations;
@@ -15,7 +16,6 @@ use collab::{
AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env,
executor::Executor, rpc::ResultExt,
};
use collab::{ServiceMode, api::billing::poll_stripe_events_periodically};
use db::Database;
use std::{
env::args,
@@ -119,8 +119,6 @@ async fn main() -> Result<()> {
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
rpc_server.start().await?;
poll_stripe_events_periodically(state.clone(), rpc_server.clone());
app = app
.merge(collab::api::routes(rpc_server.clone()))
.merge(collab::rpc::routes(rpc_server.clone()));

View File

@@ -41,9 +41,11 @@ use chrono::Utc;
use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
use futures::TryFutureExt as _;
use reqwest_client::ReqwestClient;
use rpc::proto::{MultiLspQuery, split_repository_update};
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
use tracing::Span;
use futures::{
FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
@@ -94,8 +96,13 @@ const MAX_CONCURRENT_CONNECTIONS: usize = 512;
static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
const TOTAL_DURATION_MS: &str = "total_duration_ms";
const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
const QUEUE_DURATION_MS: &str = "queue_duration_ms";
const HOST_WAITING_MS: &str = "host_waiting_ms";
type MessageHandler =
Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
pub struct ConnectionGuard;
@@ -163,6 +170,42 @@ impl Principal {
}
}
#[derive(Clone)]
struct MessageContext {
session: Session,
span: tracing::Span,
}
impl Deref for MessageContext {
type Target = Session;
fn deref(&self) -> &Self::Target {
&self.session
}
}
impl MessageContext {
pub fn forward_request<T: RequestMessage>(
&self,
receiver_id: ConnectionId,
request: T,
) -> impl Future<Output = anyhow::Result<T::Response>> {
let request_start_time = Instant::now();
let span = self.span.clone();
tracing::info!("start forwarding request");
self.peer
.forward_request(self.connection_id, receiver_id, request)
.inspect(move |_| {
span.record(
HOST_WAITING_MS,
request_start_time.elapsed().as_micros() as f64 / 1000.0,
);
})
.inspect_err(|_| tracing::error!("error forwarding request"))
.inspect_ok(|_| tracing::info!("finished forwarding request"))
}
}
#[derive(Clone)]
struct Session {
principal: Principal,
@@ -646,40 +689,37 @@ impl Server {
fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
Fut: 'static + Send + Future<Output = Result<()>>,
M: EnvelopedMessage,
{
let prev_handler = self.handlers.insert(
TypeId::of::<M>(),
Box::new(move |envelope, session| {
Box::new(move |envelope, session, span| {
let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
let received_at = envelope.received_at;
tracing::info!("message received");
let start_time = Instant::now();
let future = (handler)(*envelope, session);
let future = (handler)(
*envelope,
MessageContext {
session,
span: span.clone(),
},
);
async move {
let result = future.await;
let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
let queue_duration_ms = total_duration_ms - processing_duration_ms;
span.record(TOTAL_DURATION_MS, total_duration_ms);
span.record(PROCESSING_DURATION_MS, processing_duration_ms);
span.record(QUEUE_DURATION_MS, queue_duration_ms);
match result {
Err(error) => {
tracing::error!(
?error,
total_duration_ms,
processing_duration_ms,
queue_duration_ms,
"error handling message"
)
tracing::error!(?error, "error handling message")
}
Ok(()) => tracing::info!(
total_duration_ms,
processing_duration_ms,
queue_duration_ms,
"finished handling message"
),
Ok(()) => tracing::info!("finished handling message"),
}
}
.boxed()
@@ -693,7 +733,7 @@ impl Server {
fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
F: 'static + Send + Sync + Fn(M, Session) -> Fut,
F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
Fut: 'static + Send + Future<Output = Result<()>>,
M: EnvelopedMessage,
{
@@ -703,7 +743,7 @@ impl Server {
fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
Fut: Send + Future<Output = Result<()>>,
M: RequestMessage,
{
@@ -746,6 +786,7 @@ impl Server {
address: String,
principal: Principal,
zed_version: ZedVersion,
release_channel: Option<String>,
user_agent: Option<String>,
geoip_country_code: Option<String>,
system_id: Option<String>,
@@ -760,12 +801,16 @@ impl Server {
login=field::Empty,
impersonator=field::Empty,
user_agent=field::Empty,
geoip_country_code=field::Empty
geoip_country_code=field::Empty,
release_channel=field::Empty,
);
principal.update_span(&span);
if let Some(user_agent) = user_agent {
span.record("user_agent", user_agent);
}
if let Some(release_channel) = release_channel {
span.record("release_channel", release_channel);
}
if let Some(country_code) = geoip_country_code.as_ref() {
span.record("geoip_country_code", country_code);
@@ -884,12 +929,16 @@ impl Server {
login=field::Empty,
impersonator=field::Empty,
multi_lsp_query_request=field::Empty,
{ TOTAL_DURATION_MS }=field::Empty,
{ PROCESSING_DURATION_MS }=field::Empty,
{ QUEUE_DURATION_MS }=field::Empty,
{ HOST_WAITING_MS }=field::Empty
);
principal.update_span(&span);
let span_enter = span.enter();
if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
let is_background = message.is_background();
let handle_message = (handler)(message, session.clone());
let handle_message = (handler)(message, session.clone(), span.clone());
drop(span_enter);
let handle_message = async move {
@@ -1181,6 +1230,35 @@ impl Header for AppVersionHeader {
}
}
#[derive(Debug)]
pub struct ReleaseChannelHeader(String);
impl Header for ReleaseChannelHeader {
fn name() -> &'static HeaderName {
static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
}
fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
where
Self: Sized,
I: Iterator<Item = &'i axum::http::HeaderValue>,
{
Ok(Self(
values
.next()
.ok_or_else(axum::headers::Error::invalid)?
.to_str()
.map_err(|_| axum::headers::Error::invalid())?
.to_owned(),
))
}
fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
values.extend([self.0.parse().unwrap()]);
}
}
pub fn routes(server: Arc<Server>) -> Router<(), Body> {
Router::new()
.route("/rpc", get(handle_websocket_request))
@@ -1196,6 +1274,7 @@ pub fn routes(server: Arc<Server>) -> Router<(), Body> {
pub async fn handle_websocket_request(
TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
app_version_header: Option<TypedHeader<AppVersionHeader>>,
release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
Extension(server): Extension<Arc<Server>>,
Extension(principal): Extension<Principal>,
@@ -1220,6 +1299,8 @@ pub async fn handle_websocket_request(
.into_response();
};
let release_channel = release_channel_header.map(|header| header.0.0);
if !version.can_collaborate() {
return (
StatusCode::UPGRADE_REQUIRED,
@@ -1255,6 +1336,7 @@ pub async fn handle_websocket_request(
socket_address,
principal,
version,
release_channel,
user_agent.map(|header| header.to_string()),
country_code_header.map(|header| header.to_string()),
system_id_header.map(|header| header.to_string()),
@@ -1348,7 +1430,11 @@ async fn connection_lost(
}
/// Acknowledges a ping from a client, used to keep the connection alive.
async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
async fn ping(
_: proto::Ping,
response: Response<proto::Ping>,
_session: MessageContext,
) -> Result<()> {
response.send(proto::Ack {})?;
Ok(())
}
@@ -1357,7 +1443,7 @@ async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session
async fn create_room(
_request: proto::CreateRoom,
response: Response<proto::CreateRoom>,
session: Session,
session: MessageContext,
) -> Result<()> {
let livekit_room = nanoid::nanoid!(30);
@@ -1397,7 +1483,7 @@ async fn create_room(
async fn join_room(
request: proto::JoinRoom,
response: Response<proto::JoinRoom>,
session: Session,
session: MessageContext,
) -> Result<()> {
let room_id = RoomId::from_proto(request.id);
@@ -1464,7 +1550,7 @@ async fn join_room(
async fn rejoin_room(
request: proto::RejoinRoom,
response: Response<proto::RejoinRoom>,
session: Session,
session: MessageContext,
) -> Result<()> {
let room;
let channel;
@@ -1592,15 +1678,15 @@ fn notify_rejoined_projects(
}
// Stream this worktree's diagnostics.
for summary in worktree.diagnostic_summaries {
session.peer.send(
session.connection_id,
proto::UpdateDiagnosticSummary {
project_id: project.id.to_proto(),
worktree_id: worktree.id,
summary: Some(summary),
},
)?;
let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
if let Some(summary) = worktree_diagnostics.next() {
let message = proto::UpdateDiagnosticSummary {
project_id: project.id.to_proto(),
worktree_id: worktree.id,
summary: Some(summary),
more_summaries: worktree_diagnostics.collect(),
};
session.peer.send(session.connection_id, message)?;
}
for settings_file in worktree.settings_files {
@@ -1641,7 +1727,7 @@ fn notify_rejoined_projects(
async fn leave_room(
_: proto::LeaveRoom,
response: Response<proto::LeaveRoom>,
session: Session,
session: MessageContext,
) -> Result<()> {
leave_room_for_session(&session, session.connection_id).await?;
response.send(proto::Ack {})?;
@@ -1652,7 +1738,7 @@ async fn leave_room(
async fn set_room_participant_role(
request: proto::SetRoomParticipantRole,
response: Response<proto::SetRoomParticipantRole>,
session: Session,
session: MessageContext,
) -> Result<()> {
let user_id = UserId::from_proto(request.user_id);
let role = ChannelRole::from(request.role());
@@ -1700,7 +1786,7 @@ async fn set_room_participant_role(
async fn call(
request: proto::Call,
response: Response<proto::Call>,
session: Session,
session: MessageContext,
) -> Result<()> {
let room_id = RoomId::from_proto(request.room_id);
let calling_user_id = session.user_id();
@@ -1769,7 +1855,7 @@ async fn call(
async fn cancel_call(
request: proto::CancelCall,
response: Response<proto::CancelCall>,
session: Session,
session: MessageContext,
) -> Result<()> {
let called_user_id = UserId::from_proto(request.called_user_id);
let room_id = RoomId::from_proto(request.room_id);
@@ -1804,7 +1890,7 @@ async fn cancel_call(
}
/// Decline an incoming call.
async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
let room_id = RoomId::from_proto(message.room_id);
{
let room = session
@@ -1839,7 +1925,7 @@ async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<(
async fn update_participant_location(
request: proto::UpdateParticipantLocation,
response: Response<proto::UpdateParticipantLocation>,
session: Session,
session: MessageContext,
) -> Result<()> {
let room_id = RoomId::from_proto(request.room_id);
let location = request.location.context("invalid location")?;
@@ -1858,7 +1944,7 @@ async fn update_participant_location(
async fn share_project(
request: proto::ShareProject,
response: Response<proto::ShareProject>,
session: Session,
session: MessageContext,
) -> Result<()> {
let (project_id, room) = &*session
.db()
@@ -1879,7 +1965,7 @@ async fn share_project(
}
/// Unshare a project from the room.
async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
let project_id = ProjectId::from_proto(message.project_id);
unshare_project_internal(project_id, session.connection_id, &session).await
}
@@ -1926,7 +2012,7 @@ async fn unshare_project_internal(
async fn join_project(
request: proto::JoinProject,
response: Response<proto::JoinProject>,
session: Session,
session: MessageContext,
) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id);
@@ -2022,15 +2108,15 @@ async fn join_project(
}
// Stream this worktree's diagnostics.
for summary in worktree.diagnostic_summaries {
session.peer.send(
session.connection_id,
proto::UpdateDiagnosticSummary {
project_id: project_id.to_proto(),
worktree_id: worktree.id,
summary: Some(summary),
},
)?;
let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
if let Some(summary) = worktree_diagnostics.next() {
let message = proto::UpdateDiagnosticSummary {
project_id: project.id.to_proto(),
worktree_id: worktree.id,
summary: Some(summary),
more_summaries: worktree_diagnostics.collect(),
};
session.peer.send(session.connection_id, message)?;
}
for settings_file in worktree.settings_files {
@@ -2073,7 +2159,7 @@ async fn join_project(
}
/// Leave someone elses shared project.
async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
let sender_id = session.connection_id;
let project_id = ProjectId::from_proto(request.project_id);
let db = session.db().await;
@@ -2096,7 +2182,7 @@ async fn leave_project(request: proto::LeaveProject, session: Session) -> Result
async fn update_project(
request: proto::UpdateProject,
response: Response<proto::UpdateProject>,
session: Session,
session: MessageContext,
) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id);
let (room, guest_connection_ids) = &*session
@@ -2125,7 +2211,7 @@ async fn update_project(
async fn update_worktree(
request: proto::UpdateWorktree,
response: Response<proto::UpdateWorktree>,
session: Session,
session: MessageContext,
) -> Result<()> {
let guest_connection_ids = session
.db()
@@ -2149,7 +2235,7 @@ async fn update_worktree(
async fn update_repository(
request: proto::UpdateRepository,
response: Response<proto::UpdateRepository>,
session: Session,
session: MessageContext,
) -> Result<()> {
let guest_connection_ids = session
.db()
@@ -2173,7 +2259,7 @@ async fn update_repository(
async fn remove_repository(
request: proto::RemoveRepository,
response: Response<proto::RemoveRepository>,
session: Session,
session: MessageContext,
) -> Result<()> {
let guest_connection_ids = session
.db()
@@ -2197,7 +2283,7 @@ async fn remove_repository(
/// Updates other participants with changes to the diagnostics
async fn update_diagnostic_summary(
message: proto::UpdateDiagnosticSummary,
session: Session,
session: MessageContext,
) -> Result<()> {
let guest_connection_ids = session
.db()
@@ -2221,7 +2307,7 @@ async fn update_diagnostic_summary(
/// Updates other participants with changes to the worktree settings
async fn update_worktree_settings(
message: proto::UpdateWorktreeSettings,
session: Session,
session: MessageContext,
) -> Result<()> {
let guest_connection_ids = session
.db()
@@ -2245,7 +2331,7 @@ async fn update_worktree_settings(
/// Notify other participants that a language server has started.
async fn start_language_server(
request: proto::StartLanguageServer,
session: Session,
session: MessageContext,
) -> Result<()> {
let guest_connection_ids = session
.db()
@@ -2268,7 +2354,7 @@ async fn start_language_server(
/// Notify other participants that a language server has changed.
async fn update_language_server(
request: proto::UpdateLanguageServer,
session: Session,
session: MessageContext,
) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id);
let db = session.db().await;
@@ -2301,7 +2387,7 @@ async fn update_language_server(
async fn forward_read_only_project_request<T>(
request: T,
response: Response<T>,
session: Session,
session: MessageContext,
) -> Result<()>
where
T: EntityMessage + RequestMessage,
@@ -2312,10 +2398,7 @@ where
.await
.host_for_read_only_project_request(project_id, session.connection_id)
.await?;
let payload = session
.peer
.forward_request(session.connection_id, host_connection_id, request)
.await?;
let payload = session.forward_request(host_connection_id, request).await?;
response.send(payload)?;
Ok(())
}
@@ -2325,7 +2408,7 @@ where
async fn forward_mutating_project_request<T>(
request: T,
response: Response<T>,
session: Session,
session: MessageContext,
) -> Result<()>
where
T: EntityMessage + RequestMessage,
@@ -2337,10 +2420,7 @@ where
.await
.host_for_mutating_project_request(project_id, session.connection_id)
.await?;
let payload = session
.peer
.forward_request(session.connection_id, host_connection_id, request)
.await?;
let payload = session.forward_request(host_connection_id, request).await?;
response.send(payload)?;
Ok(())
}
@@ -2348,7 +2428,7 @@ where
async fn multi_lsp_query(
request: MultiLspQuery,
response: Response<MultiLspQuery>,
session: Session,
session: MessageContext,
) -> Result<()> {
tracing::Span::current().record("multi_lsp_query_request", request.request_str());
tracing::info!("multi_lsp_query message received");
@@ -2358,7 +2438,7 @@ async fn multi_lsp_query(
/// Notify other participants that a new buffer has been created
async fn create_buffer_for_peer(
request: proto::CreateBufferForPeer,
session: Session,
session: MessageContext,
) -> Result<()> {
session
.db()
@@ -2380,7 +2460,7 @@ async fn create_buffer_for_peer(
async fn update_buffer(
request: proto::UpdateBuffer,
response: Response<proto::UpdateBuffer>,
session: Session,
session: MessageContext,
) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id);
let mut capability = Capability::ReadOnly;
@@ -2415,17 +2495,14 @@ async fn update_buffer(
};
if host != session.connection_id {
session
.peer
.forward_request(session.connection_id, host, request.clone())
.await?;
session.forward_request(host, request.clone()).await?;
}
response.send(proto::Ack {})?;
Ok(())
}
async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
let project_id = ProjectId::from_proto(message.project_id);
let operation = message.operation.as_ref().context("invalid operation")?;
@@ -2470,7 +2547,7 @@ async fn update_context(message: proto::UpdateContext, session: Session) -> Resu
/// Notify other participants that a project has been updated.
async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
request: T,
session: Session,
session: MessageContext,
) -> Result<()> {
let project_id = ProjectId::from_proto(request.remote_entity_id());
let project_connection_ids = session
@@ -2495,7 +2572,7 @@ async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProj
async fn follow(
request: proto::Follow,
response: Response<proto::Follow>,
session: Session,
session: MessageContext,
) -> Result<()> {
let room_id = RoomId::from_proto(request.room_id);
let project_id = request.project_id.map(ProjectId::from_proto);
@@ -2508,10 +2585,7 @@ async fn follow(
.check_room_participants(room_id, leader_id, session.connection_id)
.await?;
let response_payload = session
.peer
.forward_request(session.connection_id, leader_id, request)
.await?;
let response_payload = session.forward_request(leader_id, request).await?;
response.send(response_payload)?;
if let Some(project_id) = project_id {
@@ -2527,7 +2601,7 @@ async fn follow(
}
/// Stop following another user in a call.
async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
let room_id = RoomId::from_proto(request.room_id);
let project_id = request.project_id.map(ProjectId::from_proto);
let leader_id = request.leader_id.context("invalid leader id")?.into();
@@ -2556,7 +2630,7 @@ async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
}
/// Notify everyone following you of your current location.
async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
let room_id = RoomId::from_proto(request.room_id);
let database = session.db.lock().await;
@@ -2591,7 +2665,7 @@ async fn update_followers(request: proto::UpdateFollowers, session: Session) ->
async fn get_users(
request: proto::GetUsers,
response: Response<proto::GetUsers>,
session: Session,
session: MessageContext,
) -> Result<()> {
let user_ids = request
.user_ids
@@ -2619,7 +2693,7 @@ async fn get_users(
async fn fuzzy_search_users(
request: proto::FuzzySearchUsers,
response: Response<proto::FuzzySearchUsers>,
session: Session,
session: MessageContext,
) -> Result<()> {
let query = request.query;
let users = match query.len() {
@@ -2651,7 +2725,7 @@ async fn fuzzy_search_users(
async fn request_contact(
request: proto::RequestContact,
response: Response<proto::RequestContact>,
session: Session,
session: MessageContext,
) -> Result<()> {
let requester_id = session.user_id();
let responder_id = UserId::from_proto(request.responder_id);
@@ -2698,7 +2772,7 @@ async fn request_contact(
async fn respond_to_contact_request(
request: proto::RespondToContactRequest,
response: Response<proto::RespondToContactRequest>,
session: Session,
session: MessageContext,
) -> Result<()> {
let responder_id = session.user_id();
let requester_id = UserId::from_proto(request.requester_id);
@@ -2756,7 +2830,7 @@ async fn respond_to_contact_request(
async fn remove_contact(
request: proto::RemoveContact,
response: Response<proto::RemoveContact>,
session: Session,
session: MessageContext,
) -> Result<()> {
let requester_id = session.user_id();
let responder_id = UserId::from_proto(request.user_id);
@@ -3015,7 +3089,10 @@ async fn update_user_plan(session: &Session) -> Result<()> {
Ok(())
}
async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
async fn subscribe_to_channels(
_: proto::SubscribeToChannels,
session: MessageContext,
) -> Result<()> {
subscribe_user_to_channels(session.user_id(), &session).await?;
Ok(())
}
@@ -3041,7 +3118,7 @@ async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Resul
async fn create_channel(
request: proto::CreateChannel,
response: Response<proto::CreateChannel>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
@@ -3096,7 +3173,7 @@ async fn create_channel(
async fn delete_channel(
request: proto::DeleteChannel,
response: Response<proto::DeleteChannel>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
@@ -3124,7 +3201,7 @@ async fn delete_channel(
async fn invite_channel_member(
request: proto::InviteChannelMember,
response: Response<proto::InviteChannelMember>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@@ -3161,7 +3238,7 @@ async fn invite_channel_member(
async fn remove_channel_member(
request: proto::RemoveChannelMember,
response: Response<proto::RemoveChannelMember>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@@ -3205,7 +3282,7 @@ async fn remove_channel_member(
async fn set_channel_visibility(
request: proto::SetChannelVisibility,
response: Response<proto::SetChannelVisibility>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@@ -3250,7 +3327,7 @@ async fn set_channel_visibility(
async fn set_channel_member_role(
request: proto::SetChannelMemberRole,
response: Response<proto::SetChannelMemberRole>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@@ -3298,7 +3375,7 @@ async fn set_channel_member_role(
async fn rename_channel(
request: proto::RenameChannel,
response: Response<proto::RenameChannel>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@@ -3330,7 +3407,7 @@ async fn rename_channel(
async fn move_channel(
request: proto::MoveChannel,
response: Response<proto::MoveChannel>,
session: Session,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let to = ChannelId::from_proto(request.to);
@@ -3372,7 +3449,7 @@ async fn move_channel(
async fn reorder_channel(
request: proto::ReorderChannel,
response: Response<proto::ReorderChannel>,
session: Session,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let direction = request.direction();
@@ -3418,7 +3495,7 @@ async fn reorder_channel(
async fn get_channel_members(
request: proto::GetChannelMembers,
response: Response<proto::GetChannelMembers>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@@ -3438,7 +3515,7 @@ async fn get_channel_members(
async fn respond_to_channel_invite(
request: proto::RespondToChannelInvite,
response: Response<proto::RespondToChannelInvite>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@@ -3479,7 +3556,7 @@ async fn respond_to_channel_invite(
async fn join_channel(
request: proto::JoinChannel,
response: Response<proto::JoinChannel>,
session: Session,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
join_channel_internal(channel_id, Box::new(response), session).await
@@ -3502,7 +3579,7 @@ impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
async fn join_channel_internal(
channel_id: ChannelId,
response: Box<impl JoinChannelInternalResponse>,
session: Session,
session: MessageContext,
) -> Result<()> {
let joined_room = {
let mut db = session.db().await;
@@ -3597,7 +3674,7 @@ async fn join_channel_internal(
async fn join_channel_buffer(
request: proto::JoinChannelBuffer,
response: Response<proto::JoinChannelBuffer>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@@ -3628,7 +3705,7 @@ async fn join_channel_buffer(
/// Edit the channel notes
async fn update_channel_buffer(
request: proto::UpdateChannelBuffer,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@@ -3680,7 +3757,7 @@ async fn update_channel_buffer(
async fn rejoin_channel_buffers(
request: proto::RejoinChannelBuffers,
response: Response<proto::RejoinChannelBuffers>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let buffers = db
@@ -3715,7 +3792,7 @@ async fn rejoin_channel_buffers(
async fn leave_channel_buffer(
request: proto::LeaveChannelBuffer,
response: Response<proto::LeaveChannelBuffer>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@@ -3777,7 +3854,7 @@ fn send_notifications(
async fn send_channel_message(
request: proto::SendChannelMessage,
response: Response<proto::SendChannelMessage>,
session: Session,
session: MessageContext,
) -> Result<()> {
// Validate the message body.
let body = request.body.trim().to_string();
@@ -3870,7 +3947,7 @@ async fn send_channel_message(
async fn remove_channel_message(
request: proto::RemoveChannelMessage,
response: Response<proto::RemoveChannelMessage>,
session: Session,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let message_id = MessageId::from_proto(request.message_id);
@@ -3905,7 +3982,7 @@ async fn remove_channel_message(
async fn update_channel_message(
request: proto::UpdateChannelMessage,
response: Response<proto::UpdateChannelMessage>,
session: Session,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let message_id = MessageId::from_proto(request.message_id);
@@ -3989,7 +4066,7 @@ async fn update_channel_message(
/// Mark a channel message as read
async fn acknowledge_channel_message(
request: proto::AckChannelMessage,
session: Session,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let message_id = MessageId::from_proto(request.message_id);
@@ -4009,7 +4086,7 @@ async fn acknowledge_channel_message(
/// Mark a buffer version as synced
async fn acknowledge_buffer_version(
request: proto::AckBufferOperation,
session: Session,
session: MessageContext,
) -> Result<()> {
let buffer_id = BufferId::from_proto(request.buffer_id);
session
@@ -4029,7 +4106,7 @@ async fn acknowledge_buffer_version(
async fn get_supermaven_api_key(
_request: proto::GetSupermavenApiKey,
response: Response<proto::GetSupermavenApiKey>,
session: Session,
session: MessageContext,
) -> Result<()> {
let user_id: String = session.user_id().to_string();
if !session.is_staff() {
@@ -4058,7 +4135,7 @@ async fn get_supermaven_api_key(
async fn join_channel_chat(
request: proto::JoinChannelChat,
response: Response<proto::JoinChannelChat>,
session: Session,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
@@ -4076,7 +4153,10 @@ async fn join_channel_chat(
}
/// Stop receiving chat updates for a channel
async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
async fn leave_channel_chat(
request: proto::LeaveChannelChat,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
session
.db()
@@ -4090,7 +4170,7 @@ async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session)
async fn get_channel_messages(
request: proto::GetChannelMessages,
response: Response<proto::GetChannelMessages>,
session: Session,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let messages = session
@@ -4114,7 +4194,7 @@ async fn get_channel_messages(
async fn get_channel_messages_by_id(
request: proto::GetChannelMessagesById,
response: Response<proto::GetChannelMessagesById>,
session: Session,
session: MessageContext,
) -> Result<()> {
let message_ids = request
.message_ids
@@ -4137,7 +4217,7 @@ async fn get_channel_messages_by_id(
async fn get_notifications(
request: proto::GetNotifications,
response: Response<proto::GetNotifications>,
session: Session,
session: MessageContext,
) -> Result<()> {
let notifications = session
.db()
@@ -4159,7 +4239,7 @@ async fn get_notifications(
async fn mark_notification_as_read(
request: proto::MarkNotificationRead,
response: Response<proto::MarkNotificationRead>,
session: Session,
session: MessageContext,
) -> Result<()> {
let database = &session.db().await;
let notifications = database
@@ -4181,7 +4261,7 @@ async fn mark_notification_as_read(
async fn get_private_user_info(
_request: proto::GetPrivateUserInfo,
response: Response<proto::GetPrivateUserInfo>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
@@ -4205,7 +4285,7 @@ async fn get_private_user_info(
async fn accept_terms_of_service(
_request: proto::AcceptTermsOfService,
response: Response<proto::AcceptTermsOfService>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
@@ -4229,7 +4309,7 @@ async fn accept_terms_of_service(
async fn get_llm_api_token(
_request: proto::GetLlmToken,
response: Response<proto::GetLlmToken>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;

View File

@@ -1,21 +1,15 @@
use std::sync::Arc;
use anyhow::anyhow;
use chrono::Utc;
use collections::HashMap;
use stripe::SubscriptionStatus;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::Result;
use crate::db::billing_subscription::SubscriptionKind;
use crate::stripe_client::{
RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateMeterEventParams,
StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams,
StripeCustomerId, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId,
StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
UpdateSubscriptionParams,
RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateSubscriptionItems,
StripeCreateSubscriptionParams, StripeCustomerId, StripePrice, StripePriceId,
StripeSubscription,
};
pub struct StripeBilling {
@@ -94,30 +88,6 @@ impl StripeBilling {
.ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
}
pub async fn determine_subscription_kind(
&self,
subscription: &StripeSubscription,
) -> Option<SubscriptionKind> {
let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
let zed_free_price_id = self.zed_free_price_id().await.ok()?;
subscription.items.iter().find_map(|item| {
let price = item.price.as_ref()?;
if price.id == zed_pro_price_id {
Some(if subscription.status == SubscriptionStatus::Trialing {
SubscriptionKind::ZedProTrial
} else {
SubscriptionKind::ZedPro
})
} else if price.id == zed_free_price_id {
Some(SubscriptionKind::ZedFree)
} else {
None
}
})
}
/// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
/// not already exist.
///
@@ -150,65 +120,6 @@ impl StripeBilling {
Ok(customer_id)
}
pub async fn subscribe_to_price(
&self,
subscription_id: &StripeSubscriptionId,
price: &StripePrice,
) -> Result<()> {
let subscription = self.client.get_subscription(subscription_id).await?;
if subscription_contains_price(&subscription, &price.id) {
return Ok(());
}
const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
let price_per_unit = price.unit_amount.unwrap_or_default();
let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
self.client
.update_subscription(
subscription_id,
UpdateSubscriptionParams {
items: Some(vec![UpdateSubscriptionItems {
price: Some(price.id.clone()),
}]),
trial_settings: Some(StripeSubscriptionTrialSettings {
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
},
}),
},
)
.await?;
Ok(())
}
pub async fn bill_model_request_usage(
&self,
customer_id: &StripeCustomerId,
event_name: &str,
requests: i32,
) -> Result<()> {
let timestamp = Utc::now().timestamp();
let idempotency_key = Uuid::new_v4();
self.client
.create_meter_event(StripeCreateMeterEventParams {
identifier: &format!("model_requests/{}", idempotency_key),
event_name,
payload: StripeCreateMeterEventPayload {
value: requests as u64,
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
})
.await?;
Ok(())
}
pub async fn subscribe_to_zed_free(
&self,
customer_id: StripeCustomerId,
@@ -243,14 +154,3 @@ impl StripeBilling {
Ok(subscription)
}
}
fn subscription_contains_price(
subscription: &StripeSubscription,
price_id: &StripePriceId,
) -> bool {
subscription.items.iter().any(|item| {
item.price
.as_ref()
.map_or(false, |price| price.id == *price_id)
})
}

View File

@@ -3101,9 +3101,7 @@ async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestA
// Turn inline-blame-off by default so no state is transferred without us explicitly doing so
let inline_blame_off_settings = Some(InlineBlameSettings {
enabled: false,
delay_ms: None,
min_column: None,
show_commit_summary: false,
..Default::default()
});
cx_a.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {

View File

@@ -1,14 +1,9 @@
use std::sync::Arc;
use chrono::{Duration, Utc};
use pretty_assertions::assert_eq;
use crate::stripe_billing::StripeBilling;
use crate::stripe_client::{
FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId,
StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
StripeSubscriptionItemId, UpdateSubscriptionItems,
};
use crate::stripe_client::{FakeStripeClient, StripePrice, StripePriceId, StripePriceRecurring};
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
let stripe_client = Arc::new(FakeStripeClient::new());
@@ -21,24 +16,6 @@ fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
async fn test_initialize() {
let (stripe_billing, stripe_client) = make_stripe_billing();
// Add test meters
let meter1 = StripeMeter {
id: StripeMeterId("meter_1".into()),
event_name: "event_1".to_string(),
};
let meter2 = StripeMeter {
id: StripeMeterId("meter_2".into()),
event_name: "event_2".to_string(),
};
stripe_client
.meters
.lock()
.insert(meter1.id.clone(), meter1);
stripe_client
.meters
.lock()
.insert(meter2.id.clone(), meter2);
// Add test prices
let price1 = StripePrice {
id: StripePriceId("price_1".into()),
@@ -144,217 +121,3 @@ async fn test_find_or_create_customer_by_email() {
assert_eq!(customer.email.as_deref(), Some(email));
}
}
#[gpui::test]
async fn test_subscribe_to_price() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let price = StripePrice {
id: StripePriceId("price_test".into()),
unit_amount: Some(2000),
lookup_key: Some("test-price".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(price.id.clone(), price.clone());
let now = Utc::now();
let subscription = StripeSubscription {
id: StripeSubscriptionId("sub_test".into()),
customer: StripeCustomerId("cus_test".into()),
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![],
cancel_at: None,
cancellation_details: None,
};
stripe_client
.subscriptions
.lock()
.insert(subscription.id.clone(), subscription.clone());
stripe_billing
.subscribe_to_price(&subscription.id, &price)
.await
.unwrap();
let update_subscription_calls = stripe_client
.update_subscription_calls
.lock()
.iter()
.map(|(id, params)| (id.clone(), params.clone()))
.collect::<Vec<_>>();
assert_eq!(update_subscription_calls.len(), 1);
assert_eq!(update_subscription_calls[0].0, subscription.id);
assert_eq!(
update_subscription_calls[0].1.items,
Some(vec![UpdateSubscriptionItems {
price: Some(price.id.clone())
}])
);
// Subscribing to a price that is already on the subscription is a no-op.
{
let now = Utc::now();
let subscription = StripeSubscription {
id: StripeSubscriptionId("sub_test".into()),
customer: StripeCustomerId("cus_test".into()),
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()),
price: Some(price.clone()),
}],
cancel_at: None,
cancellation_details: None,
};
stripe_client
.subscriptions
.lock()
.insert(subscription.id.clone(), subscription.clone());
stripe_billing
.subscribe_to_price(&subscription.id, &price)
.await
.unwrap();
assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
}
}
#[gpui::test]
async fn test_subscribe_to_zed_free() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let zed_pro_price = StripePrice {
id: StripePriceId("price_1".into()),
unit_amount: Some(0),
lookup_key: Some("zed-pro".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(zed_pro_price.id.clone(), zed_pro_price.clone());
let zed_free_price = StripePrice {
id: StripePriceId("price_2".into()),
unit_amount: Some(0),
lookup_key: Some("zed-free".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(zed_free_price.id.clone(), zed_free_price.clone());
stripe_billing.initialize().await.unwrap();
// Customer is subscribed to Zed Free when not already subscribed to a plan.
{
let customer_id = StripeCustomerId("cus_no_plan".into());
let subscription = stripe_billing
.subscribe_to_zed_free(customer_id)
.await
.unwrap();
assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
}
// Customer is not subscribed to Zed Free when they already have an active subscription.
{
let customer_id = StripeCustomerId("cus_active_subscription".into());
let now = Utc::now();
let existing_subscription = StripeSubscription {
id: StripeSubscriptionId("sub_existing_active".into()),
customer: customer_id.clone(),
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()),
price: Some(zed_pro_price.clone()),
}],
cancel_at: None,
cancellation_details: None,
};
stripe_client.subscriptions.lock().insert(
existing_subscription.id.clone(),
existing_subscription.clone(),
);
let subscription = stripe_billing
.subscribe_to_zed_free(customer_id)
.await
.unwrap();
assert_eq!(subscription, existing_subscription);
}
// Customer is not subscribed to Zed Free when they already have a trial subscription.
{
let customer_id = StripeCustomerId("cus_trial_subscription".into());
let now = Utc::now();
let existing_subscription = StripeSubscription {
id: StripeSubscriptionId("sub_existing_trial".into()),
customer: customer_id.clone(),
status: stripe::SubscriptionStatus::Trialing,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(14)).timestamp(),
items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()),
price: Some(zed_pro_price.clone()),
}],
cancel_at: None,
cancellation_details: None,
};
stripe_client.subscriptions.lock().insert(
existing_subscription.id.clone(),
existing_subscription.clone(),
);
let subscription = stripe_billing
.subscribe_to_zed_free(customer_id)
.await
.unwrap();
assert_eq!(subscription, existing_subscription);
}
}
#[gpui::test]
async fn test_bill_model_request_usage() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let customer_id = StripeCustomerId("cus_test".into());
stripe_billing
.bill_model_request_usage(&customer_id, "some_model/requests", 73)
.await
.unwrap();
let create_meter_event_calls = stripe_client
.create_meter_event_calls
.lock()
.iter()
.cloned()
.collect::<Vec<_>>();
assert_eq!(create_meter_event_calls.len(), 1);
assert!(
create_meter_event_calls[0]
.identifier
.starts_with("model_requests/")
);
assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
assert_eq!(
create_meter_event_calls[0].event_name.as_ref(),
"some_model/requests"
);
assert_eq!(create_meter_event_calls[0].value, 73);
}

View File

@@ -297,6 +297,7 @@ impl TestServer {
client_name,
Principal::User(user),
ZedVersion(SemanticVersion::new(1, 0, 0)),
Some("test".to_string()),
None,
None,
None,

View File

@@ -136,7 +136,10 @@ impl Focusable for CommandPalette {
impl Render for CommandPalette {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
v_flex().w(rems(34.)).child(self.picker.clone())
v_flex()
.key_context("CommandPalette")
.w(rems(34.))
.child(self.picker.clone())
}
}

View File

@@ -21,7 +21,7 @@ use language::{
point_from_lsp, point_to_lsp,
};
use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName};
use node_runtime::{NodeRuntime, VersionCheck};
use node_runtime::NodeRuntime;
use parking_lot::Mutex;
use project::DisableAiSettings;
use request::StatusNotification;
@@ -1169,8 +1169,9 @@ async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::
const SERVER_PATH: &str =
"node_modules/@github/copilot-language-server/dist/language-server.js";
// pinning it: https://github.com/zed-industries/zed/issues/36093
const PINNED_VERSION: &str = "1.354";
let latest_version = node_runtime
.npm_package_latest_version(PACKAGE_NAME)
.await?;
let server_path = paths::copilot_dir().join(SERVER_PATH);
fs.create_dir(paths::copilot_dir()).await?;
@@ -1180,13 +1181,12 @@ async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::
PACKAGE_NAME,
&server_path,
paths::copilot_dir(),
&PINNED_VERSION,
VersionCheck::VersionMismatch,
&latest_version,
)
.await;
if should_install {
node_runtime
.npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &PINNED_VERSION)])
.npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &latest_version)])
.await?;
}

View File

@@ -58,11 +58,19 @@ impl EditPredictionProvider for CopilotCompletionProvider {
}
fn show_completions_in_menu() -> bool {
true
}
fn show_tab_accept_marker() -> bool {
true
}
fn supports_jump_to_edit() -> bool {
false
}
fn is_refreshing(&self) -> bool {
self.pending_refresh.is_some()
self.pending_refresh.is_some() && self.completions.is_empty()
}
fn is_enabled(
@@ -343,8 +351,8 @@ mod tests {
executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT);
cx.update_editor(|editor, window, cx| {
assert!(editor.context_menu_visible());
assert!(!editor.has_active_edit_prediction());
// Since we have both, the copilot suggestion is not shown inline
assert!(editor.has_active_edit_prediction());
// Since we have both, the copilot suggestion is existing but does not show up as ghost text
assert_eq!(editor.text(cx), "one.\ntwo\nthree\n");
assert_eq!(editor.display_text(cx), "one.\ntwo\nthree\n");
@@ -934,8 +942,9 @@ mod tests {
executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT);
cx.update_editor(|editor, _, cx| {
assert!(editor.context_menu_visible());
assert!(!editor.has_active_edit_prediction(),);
assert!(editor.has_active_edit_prediction());
assert_eq!(editor.text(cx), "one\ntwo.\nthree\n");
assert_eq!(editor.display_text(cx), "one\ntwo.\nthree\n");
});
}
@@ -1077,8 +1086,6 @@ mod tests {
vec![complete_from_marker.clone(), replace_range_marker.clone()],
);
let complete_from_position =
cx.to_lsp(marked_ranges.remove(&complete_from_marker).unwrap()[0].start);
let replace_range =
cx.to_lsp_range(marked_ranges.remove(&replace_range_marker).unwrap()[0].clone());
@@ -1087,10 +1094,6 @@ mod tests {
let completions = completions.clone();
async move {
assert_eq!(params.text_document_position.text_document.uri, url.clone());
assert_eq!(
params.text_document_position.position,
complete_from_position
);
Ok(Some(lsp::CompletionResponse::Array(
completions
.iter()

View File

@@ -74,6 +74,12 @@ impl Borrow<str> for DebugAdapterName {
}
}
impl Borrow<SharedString> for DebugAdapterName {
fn borrow(&self) -> &SharedString {
&self.0
}
}
impl std::fmt::Display for DebugAdapterName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.0, f)

View File

@@ -87,7 +87,7 @@ impl DapRegistry {
self.0.read().adapters.get(name).cloned()
}
pub fn enumerate_adapters(&self) -> Vec<DebugAdapterName> {
pub fn enumerate_adapters<B: FromIterator<DebugAdapterName>>(&self) -> B {
self.0.read().adapters.keys().cloned().collect()
}
}

View File

@@ -152,9 +152,6 @@ impl PythonDebugAdapter {
maybe!(async move {
let response = latest_release.filter(|response| response.status().is_success())?;
let download_dir = debug_adapters_dir().join(Self::ADAPTER_NAME);
std::fs::create_dir_all(&download_dir).ok()?;
let mut output = String::new();
response
.into_body()

View File

@@ -300,7 +300,7 @@ impl DebugPanel {
});
session.update(cx, |session, _| match &mut session.mode {
SessionState::Building(state_task) => {
SessionState::Booting(state_task) => {
*state_task = Some(boot_task);
}
SessionState::Running(_) => {

View File

@@ -299,59 +299,76 @@ pub fn init(cx: &mut App) {
else {
return;
};
let session = active_session
.read(cx)
.running_state
.read(cx)
.session()
.read(cx);
if session.is_terminated() {
return;
}
let editor = cx.entity().downgrade();
window.on_action(TypeId::of::<editor::actions::RunToCursor>(), {
let editor = editor.clone();
let active_session = active_session.clone();
move |_, phase, _, cx| {
if phase != DispatchPhase::Bubble {
return;
}
maybe!({
let (buffer, position, _) = editor
.update(cx, |editor, cx| {
let cursor_point: language::Point =
editor.selections.newest(cx).head();
editor
.buffer()
.read(cx)
.point_to_buffer_point(cursor_point, cx)
})
.ok()??;
window.on_action_when(
session.any_stopped_thread(),
TypeId::of::<editor::actions::RunToCursor>(),
{
let editor = editor.clone();
let active_session = active_session.clone();
move |_, phase, _, cx| {
if phase != DispatchPhase::Bubble {
return;
}
maybe!({
let (buffer, position, _) = editor
.update(cx, |editor, cx| {
let cursor_point: language::Point =
editor.selections.newest(cx).head();
let path =
editor
.buffer()
.read(cx)
.point_to_buffer_point(cursor_point, cx)
})
.ok()??;
let path =
debugger::breakpoint_store::BreakpointStore::abs_path_from_buffer(
&buffer, cx,
)?;
let source_breakpoint = SourceBreakpoint {
row: position.row,
path,
message: None,
condition: None,
hit_condition: None,
state: debugger::breakpoint_store::BreakpointState::Enabled,
};
let source_breakpoint = SourceBreakpoint {
row: position.row,
path,
message: None,
condition: None,
hit_condition: None,
state: debugger::breakpoint_store::BreakpointState::Enabled,
};
active_session.update(cx, |session, cx| {
session.running_state().update(cx, |state, cx| {
if let Some(thread_id) = state.selected_thread_id() {
state.session().update(cx, |session, cx| {
session.run_to_position(
source_breakpoint,
thread_id,
cx,
);
})
}
active_session.update(cx, |session, cx| {
session.running_state().update(cx, |state, cx| {
if let Some(thread_id) = state.selected_thread_id() {
state.session().update(cx, |session, cx| {
session.run_to_position(
source_breakpoint,
thread_id,
cx,
);
})
}
});
});
});
Some(())
});
}
});
Some(())
});
}
},
);
window.on_action(
TypeId::of::<editor::actions::EvaluateSelectedText>(),

View File

@@ -1,5 +1,5 @@
use anyhow::{Context as _, bail};
use collections::{FxHashMap, HashMap};
use collections::{FxHashMap, HashMap, HashSet};
use language::LanguageRegistry;
use std::{
borrow::Cow,
@@ -450,7 +450,7 @@ impl NewProcessModal {
.and_then(|buffer| buffer.read(cx).language())
.cloned();
let mut available_adapters = workspace
let mut available_adapters: Vec<_> = workspace
.update(cx, |_, cx| DapRegistry::global(cx).enumerate_adapters())
.unwrap_or_default();
if let Some(language) = active_buffer_language {
@@ -1054,6 +1054,9 @@ impl DebugDelegate {
})
})
});
let valid_adapters: HashSet<_> = cx.global::<DapRegistry>().enumerate_adapters();
cx.spawn(async move |this, cx| {
let (recent, scenarios) = if let Some(task) = task {
task.await
@@ -1094,6 +1097,7 @@ impl DebugDelegate {
} => !(hide_vscode && dir.ends_with(".vscode")),
_ => true,
})
.filter(|(_, scenario)| valid_adapters.contains(&scenario.adapter))
.map(|(kind, scenario)| {
let (language, scenario) =
Self::get_scenario_kind(&languages, &dap_registry, scenario);

View File

@@ -1651,7 +1651,7 @@ impl RunningState {
let is_building = self.session.update(cx, |session, cx| {
session.shutdown(cx).detach();
matches!(session.mode, session::SessionState::Building(_))
matches!(session.mode, session::SessionState::Booting(_))
});
if is_building {

View File

@@ -29,7 +29,6 @@ use ui::{
Scrollbar, ScrollbarState, SharedString, StatefulInteractiveElement, Styled, Toggleable,
Tooltip, Window, div, h_flex, px, v_flex,
};
use util::ResultExt;
use workspace::Workspace;
use zed_actions::{ToggleEnableBreakpoint, UnsetBreakpoint};
@@ -56,8 +55,6 @@ pub(crate) struct BreakpointList {
scrollbar_state: ScrollbarState,
breakpoints: Vec<BreakpointEntry>,
session: Option<Entity<Session>>,
hide_scrollbar_task: Option<Task<()>>,
show_scrollbar: bool,
focus_handle: FocusHandle,
scroll_handle: UniformListScrollHandle,
selected_ix: Option<usize>,
@@ -103,8 +100,6 @@ impl BreakpointList {
worktree_store,
scrollbar_state,
breakpoints: Default::default(),
hide_scrollbar_task: None,
show_scrollbar: false,
workspace,
session,
focus_handle,
@@ -565,21 +560,6 @@ impl BreakpointList {
Ok(())
}
fn hide_scrollbar(&mut self, window: &mut Window, cx: &mut Context<Self>) {
const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1);
self.hide_scrollbar_task = Some(cx.spawn_in(window, async move |panel, cx| {
cx.background_executor()
.timer(SCROLLBAR_SHOW_INTERVAL)
.await;
panel
.update(cx, |panel, cx| {
panel.show_scrollbar = false;
cx.notify();
})
.log_err();
}))
}
fn render_list(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
let selected_ix = self.selected_ix;
let focus_handle = self.focus_handle.clone();
@@ -614,43 +594,39 @@ impl BreakpointList {
.flex_grow()
}
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Option<Stateful<Div>> {
if !(self.show_scrollbar || self.scrollbar_state.is_dragging()) {
return None;
}
Some(
div()
.occlude()
.id("breakpoint-list-vertical-scrollbar")
.on_mouse_move(cx.listener(|_, _, _, cx| {
cx.notify();
cx.stop_propagation()
}))
.on_hover(|_, _, cx| {
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Stateful<Div> {
div()
.occlude()
.id("breakpoint-list-vertical-scrollbar")
.on_mouse_move(cx.listener(|_, _, _, cx| {
cx.notify();
cx.stop_propagation()
}))
.on_hover(|_, _, cx| {
cx.stop_propagation();
})
.on_any_mouse_down(|_, _, cx| {
cx.stop_propagation();
})
.on_mouse_up(
MouseButton::Left,
cx.listener(|_, _, _, cx| {
cx.stop_propagation();
})
.on_any_mouse_down(|_, _, cx| {
cx.stop_propagation();
})
.on_mouse_up(
MouseButton::Left,
cx.listener(|_, _, _, cx| {
cx.stop_propagation();
}),
)
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
cx.notify();
}))
.h_full()
.absolute()
.right_1()
.top_1()
.bottom_0()
.w(px(12.))
.cursor_default()
.children(Scrollbar::vertical(self.scrollbar_state.clone())),
)
}),
)
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
cx.notify();
}))
.h_full()
.absolute()
.right_1()
.top_1()
.bottom_0()
.w(px(12.))
.cursor_default()
.children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx)))
}
pub(crate) fn render_control_strip(&self) -> AnyElement {
let selection_kind = self.selection_kind();
let focus_handle = self.focus_handle.clone();
@@ -819,15 +795,6 @@ impl Render for BreakpointList {
.id("breakpoint-list")
.key_context("BreakpointList")
.track_focus(&self.focus_handle)
.on_hover(cx.listener(|this, hovered, window, cx| {
if *hovered {
this.show_scrollbar = true;
this.hide_scrollbar_task.take();
cx.notify();
} else if !this.focus_handle.contains_focused(window, cx) {
this.hide_scrollbar(window, cx);
}
}))
.on_action(cx.listener(Self::select_next))
.on_action(cx.listener(Self::select_previous))
.on_action(cx.listener(Self::select_first))
@@ -844,7 +811,7 @@ impl Render for BreakpointList {
v_flex()
.size_full()
.child(self.render_list(cx))
.children(self.render_vertical_scrollbar(cx)),
.child(self.render_vertical_scrollbar(cx)),
)
.when_some(self.strip_mode, |this, _| {
this.child(Divider::horizontal()).child(

View File

@@ -23,7 +23,6 @@ use ui::{
ParentElement, Pixels, PopoverMenuHandle, Render, Scrollbar, ScrollbarState, SharedString,
StatefulInteractiveElement, Styled, TextSize, Tooltip, Window, div, h_flex, px, v_flex,
};
use util::ResultExt;
use workspace::Workspace;
use crate::{ToggleDataBreakpoint, session::running::stack_frame_list::StackFrameList};
@@ -34,9 +33,7 @@ pub(crate) struct MemoryView {
workspace: WeakEntity<Workspace>,
scroll_handle: UniformListScrollHandle,
scroll_state: ScrollbarState,
show_scrollbar: bool,
stack_frame_list: WeakEntity<StackFrameList>,
hide_scrollbar_task: Option<Task<()>>,
focus_handle: FocusHandle,
view_state: ViewState,
query_editor: Entity<Editor>,
@@ -150,8 +147,6 @@ impl MemoryView {
scroll_state,
scroll_handle,
stack_frame_list,
show_scrollbar: false,
hide_scrollbar_task: None,
focus_handle: cx.focus_handle(),
view_state,
query_editor,
@@ -168,61 +163,42 @@ impl MemoryView {
.detach();
this
}
fn hide_scrollbar(&mut self, window: &mut Window, cx: &mut Context<Self>) {
const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1);
self.hide_scrollbar_task = Some(cx.spawn_in(window, async move |panel, cx| {
cx.background_executor()
.timer(SCROLLBAR_SHOW_INTERVAL)
.await;
panel
.update(cx, |panel, cx| {
panel.show_scrollbar = false;
cx.notify();
})
.log_err();
}))
}
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Option<Stateful<Div>> {
if !(self.show_scrollbar || self.scroll_state.is_dragging()) {
return None;
}
Some(
div()
.occlude()
.id("memory-view-vertical-scrollbar")
.on_drag_move(cx.listener(|this, evt, _, cx| {
let did_handle = this.handle_scroll_drag(evt);
cx.notify();
if did_handle {
cx.stop_propagation()
}
}))
.on_drag(ScrollbarDragging, |_, _, _, cx| cx.new(|_| Empty))
.on_hover(|_, _, cx| {
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Stateful<Div> {
div()
.occlude()
.id("memory-view-vertical-scrollbar")
.on_drag_move(cx.listener(|this, evt, _, cx| {
let did_handle = this.handle_scroll_drag(evt);
cx.notify();
if did_handle {
cx.stop_propagation()
}
}))
.on_drag(ScrollbarDragging, |_, _, _, cx| cx.new(|_| Empty))
.on_hover(|_, _, cx| {
cx.stop_propagation();
})
.on_any_mouse_down(|_, _, cx| {
cx.stop_propagation();
})
.on_mouse_up(
MouseButton::Left,
cx.listener(|_, _, _, cx| {
cx.stop_propagation();
})
.on_any_mouse_down(|_, _, cx| {
cx.stop_propagation();
})
.on_mouse_up(
MouseButton::Left,
cx.listener(|_, _, _, cx| {
cx.stop_propagation();
}),
)
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
cx.notify();
}))
.h_full()
.absolute()
.right_1()
.top_1()
.bottom_0()
.w(px(12.))
.cursor_default()
.children(Scrollbar::vertical(self.scroll_state.clone())),
)
}),
)
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
cx.notify();
}))
.h_full()
.absolute()
.right_1()
.top_1()
.bottom_0()
.w(px(12.))
.cursor_default()
.children(Scrollbar::vertical(self.scroll_state.clone()).map(|s| s.auto_hide(cx)))
}
fn render_memory(&self, cx: &mut Context<Self>) -> UniformList {
@@ -920,15 +896,6 @@ impl Render for MemoryView {
.on_action(cx.listener(Self::page_up))
.size_full()
.track_focus(&self.focus_handle)
.on_hover(cx.listener(|this, hovered, window, cx| {
if *hovered {
this.show_scrollbar = true;
this.hide_scrollbar_task.take();
cx.notify();
} else if !this.focus_handle.contains_focused(window, cx) {
this.hide_scrollbar(window, cx);
}
}))
.child(
h_flex()
.w_full()
@@ -978,7 +945,7 @@ impl Render for MemoryView {
)
.with_priority(1)
}))
.children(self.render_vertical_scrollbar(cx)),
.child(self.render_vertical_scrollbar(cx)),
)
}
}

View File

@@ -298,7 +298,7 @@ async fn test_dap_adapter_config_conversion_and_validation(cx: &mut TestAppConte
let adapter_names = cx.update(|cx| {
let registry = DapRegistry::global(cx);
registry.enumerate_adapters()
registry.enumerate_adapters::<Vec<_>>()
});
let zed_config = ZedDebugConfig {

View File

@@ -177,9 +177,9 @@ impl ProjectDiagnosticsEditor {
}
project::Event::DiagnosticsUpdated {
language_server_id,
path,
paths,
} => {
this.paths_to_update.insert(path.clone());
this.paths_to_update.extend(paths.clone());
let project = project.clone();
this.diagnostic_summary_update = cx.spawn(async move |this, cx| {
cx.background_executor()
@@ -193,9 +193,9 @@ impl ProjectDiagnosticsEditor {
cx.emit(EditorEvent::TitleChanged);
if this.editor.focus_handle(cx).contains_focused(window, cx) || this.focus_handle.contains_focused(window, cx) {
log::debug!("diagnostics updated for server {language_server_id}, path {path:?}. recording change");
log::debug!("diagnostics updated for server {language_server_id}, paths {paths:?}. recording change");
} else {
log::debug!("diagnostics updated for server {language_server_id}, path {path:?}. updating excerpts");
log::debug!("diagnostics updated for server {language_server_id}, paths {paths:?}. updating excerpts");
this.update_stale_excerpts(window, cx);
}
}

View File

@@ -876,7 +876,7 @@ async fn test_random_diagnostics_with_inlays(cx: &mut TestAppContext, mut rng: S
vec![Inlay::edit_prediction(
post_inc(&mut next_inlay_id),
snapshot.buffer_snapshot.anchor_before(position),
format!("Test inlay {next_inlay_id}"),
Rope::from_iter(["Test inlay ", "next_inlay_id"]),
)],
cx,
);

View File

@@ -61,6 +61,10 @@ pub trait EditPredictionProvider: 'static + Sized {
fn show_tab_accept_marker() -> bool {
false
}
fn supports_jump_to_edit() -> bool {
true
}
fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
DataCollectionState::Unsupported
}
@@ -116,6 +120,7 @@ pub trait EditPredictionProviderHandle {
) -> bool;
fn show_completions_in_menu(&self) -> bool;
fn show_tab_accept_marker(&self) -> bool;
fn supports_jump_to_edit(&self) -> bool;
fn data_collection_state(&self, cx: &App) -> DataCollectionState;
fn usage(&self, cx: &App) -> Option<EditPredictionUsage>;
fn toggle_data_collection(&self, cx: &mut App);
@@ -166,6 +171,10 @@ where
T::show_tab_accept_marker()
}
fn supports_jump_to_edit(&self) -> bool {
T::supports_jump_to_edit()
}
fn data_collection_state(&self, cx: &App) -> DataCollectionState {
self.read(cx).data_collection_state(cx)
}

View File

@@ -491,7 +491,12 @@ impl EditPredictionButton {
let subtle_mode = matches!(current_mode, EditPredictionsMode::Subtle);
let eager_mode = matches!(current_mode, EditPredictionsMode::Eager);
if matches!(provider, EditPredictionProvider::Zed) {
if matches!(
provider,
EditPredictionProvider::Zed
| EditPredictionProvider::Copilot
| EditPredictionProvider::Supermaven
) {
menu = menu
.separator()
.header("Display Modes")

View File

@@ -745,5 +745,6 @@ actions!(
UniqueLinesCaseInsensitive,
/// Removes duplicate lines (case-sensitive).
UniqueLinesCaseSensitive,
UnwrapSyntaxNode
]
);

View File

@@ -48,16 +48,16 @@ pub struct Inlay {
impl Inlay {
pub fn hint(id: usize, position: Anchor, hint: &project::InlayHint) -> Self {
let mut text = hint.text();
if hint.padding_right && !text.ends_with(' ') {
text.push(' ');
if hint.padding_right && text.chars_at(text.len().saturating_sub(1)).next() != Some(' ') {
text.push(" ");
}
if hint.padding_left && !text.starts_with(' ') {
text.insert(0, ' ');
if hint.padding_left && text.chars_at(0).next() != Some(' ') {
text.push_front(" ");
}
Self {
id: InlayId::Hint(id),
position,
text: text.into(),
text,
color: None,
}
}
@@ -737,13 +737,13 @@ impl InlayMap {
Inlay::mock_hint(
post_inc(next_inlay_id),
snapshot.buffer.anchor_at(position, bias),
text.clone(),
&text,
)
} else {
Inlay::edit_prediction(
post_inc(next_inlay_id),
snapshot.buffer.anchor_at(position, bias),
text.clone(),
&text,
)
};
let inlay_id = next_inlay.id;
@@ -1694,7 +1694,7 @@ mod tests {
(offset, inlay.clone())
})
.collect::<Vec<_>>();
let mut expected_text = Rope::from(buffer_snapshot.text());
let mut expected_text = Rope::from(&buffer_snapshot.text());
for (offset, inlay) in inlays.iter().rev() {
expected_text.replace(*offset..*offset, &inlay.text.to_string());
}

View File

@@ -228,6 +228,49 @@ async fn test_edit_prediction_invalidation_range(cx: &mut gpui::TestAppContext)
});
}
#[gpui::test]
async fn test_edit_prediction_jump_disabled_for_non_zed_providers(cx: &mut gpui::TestAppContext) {
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
let provider = cx.new(|_| FakeNonZedEditPredictionProvider::default());
assign_editor_completion_provider_non_zed(provider.clone(), &mut cx);
// Cursor is 2+ lines above the proposed edit
cx.set_state(indoc! {"
line 0
line ˇ1
line 2
line 3
line
"});
propose_edits_non_zed(
&provider,
vec![(Point::new(4, 3)..Point::new(4, 3), " 4")],
&mut cx,
);
cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
// For non-Zed providers, there should be no move completion (jump functionality disabled)
cx.editor(|editor, _, _| {
if let Some(completion_state) = &editor.active_edit_prediction {
// Should be an Edit prediction, not a Move prediction
match &completion_state.completion {
EditPrediction::Edit { .. } => {
// This is expected for non-Zed providers
}
EditPrediction::Move { .. } => {
panic!(
"Non-Zed providers should not show Move predictions (jump functionality)"
);
}
}
}
});
}
fn assert_editor_active_edit_completion(
cx: &mut EditorTestContext,
assert: impl FnOnce(MultiBufferSnapshot, &Vec<(Range<Anchor>, String)>),
@@ -301,6 +344,37 @@ fn assign_editor_completion_provider(
})
}
fn propose_edits_non_zed<T: ToOffset>(
provider: &Entity<FakeNonZedEditPredictionProvider>,
edits: Vec<(Range<T>, &str)>,
cx: &mut EditorTestContext,
) {
let snapshot = cx.buffer_snapshot();
let edits = edits.into_iter().map(|(range, text)| {
let range = snapshot.anchor_after(range.start)..snapshot.anchor_before(range.end);
(range, text.into())
});
cx.update(|_, cx| {
provider.update(cx, |provider, _| {
provider.set_edit_prediction(Some(edit_prediction::EditPrediction {
id: None,
edits: edits.collect(),
edit_preview: None,
}))
})
});
}
fn assign_editor_completion_provider_non_zed(
provider: Entity<FakeNonZedEditPredictionProvider>,
cx: &mut EditorTestContext,
) {
cx.update_editor(|editor, window, cx| {
editor.set_edit_prediction_provider(Some(provider), window, cx);
})
}
#[derive(Default, Clone)]
pub struct FakeEditPredictionProvider {
pub completion: Option<edit_prediction::EditPrediction>,
@@ -325,6 +399,84 @@ impl EditPredictionProvider for FakeEditPredictionProvider {
false
}
fn supports_jump_to_edit() -> bool {
true
}
fn is_enabled(
&self,
_buffer: &gpui::Entity<language::Buffer>,
_cursor_position: language::Anchor,
_cx: &gpui::App,
) -> bool {
true
}
fn is_refreshing(&self) -> bool {
false
}
fn refresh(
&mut self,
_project: Option<Entity<Project>>,
_buffer: gpui::Entity<language::Buffer>,
_cursor_position: language::Anchor,
_debounce: bool,
_cx: &mut gpui::Context<Self>,
) {
}
fn cycle(
&mut self,
_buffer: gpui::Entity<language::Buffer>,
_cursor_position: language::Anchor,
_direction: edit_prediction::Direction,
_cx: &mut gpui::Context<Self>,
) {
}
fn accept(&mut self, _cx: &mut gpui::Context<Self>) {}
fn discard(&mut self, _cx: &mut gpui::Context<Self>) {}
fn suggest<'a>(
&mut self,
_buffer: &gpui::Entity<language::Buffer>,
_cursor_position: language::Anchor,
_cx: &mut gpui::Context<Self>,
) -> Option<edit_prediction::EditPrediction> {
self.completion.clone()
}
}
#[derive(Default, Clone)]
pub struct FakeNonZedEditPredictionProvider {
pub completion: Option<edit_prediction::EditPrediction>,
}
impl FakeNonZedEditPredictionProvider {
pub fn set_edit_prediction(&mut self, completion: Option<edit_prediction::EditPrediction>) {
self.completion = completion;
}
}
impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
fn name() -> &'static str {
"fake-non-zed-provider"
}
fn display_name() -> &'static str {
"Fake Non-Zed Provider"
}
fn show_completions_in_menu() -> bool {
false
}
fn supports_jump_to_edit() -> bool {
false
}
fn is_enabled(
&self,
_buffer: &gpui::Entity<language::Buffer>,

View File

@@ -2705,6 +2705,11 @@ impl Editor {
self.completion_provider = provider;
}
#[cfg(any(test, feature = "test-support"))]
pub fn completion_provider(&self) -> Option<Rc<dyn CompletionProvider>> {
self.completion_provider.clone()
}
pub fn semantics_provider(&self) -> Option<Rc<dyn SemanticsProvider>> {
self.semantics_provider.clone()
}
@@ -7760,8 +7765,14 @@ impl Editor {
} else {
None
};
let is_move =
move_invalidation_row_range.is_some() || self.edit_predictions_hidden_for_vim_mode;
let supports_jump = self
.edit_prediction_provider
.as_ref()
.map(|provider| provider.provider.supports_jump_to_edit())
.unwrap_or(true);
let is_move = supports_jump
&& (move_invalidation_row_range.is_some() || self.edit_predictions_hidden_for_vim_mode);
let completion = if is_move {
invalidation_row_range =
move_invalidation_row_range.unwrap_or(edit_start_row..edit_end_row);
@@ -8799,8 +8810,12 @@ impl Editor {
return None;
}
let highlighted_edits =
crate::edit_prediction_edit_text(&snapshot, edits, edit_preview.as_ref()?, false, cx);
let highlighted_edits = if let Some(edit_preview) = edit_preview.as_ref() {
crate::edit_prediction_edit_text(&snapshot, edits, edit_preview, false, cx)
} else {
// Fallback for providers without edit_preview
crate::edit_prediction_fallback_text(edits, cx)
};
let styled_text = highlighted_edits.to_styled_text(&style.text);
let line_count = highlighted_edits.text.lines().count();
@@ -9068,6 +9083,18 @@ impl Editor {
let editor_bg_color = cx.theme().colors().editor_background;
editor_bg_color.blend(accent_color.opacity(0.6))
}
fn get_prediction_provider_icon_name(
provider: &Option<RegisteredEditPredictionProvider>,
) -> IconName {
match provider {
Some(provider) => match provider.provider.name() {
"copilot" => IconName::Copilot,
"supermaven" => IconName::Supermaven,
_ => IconName::ZedPredict,
},
None => IconName::ZedPredict,
}
}
fn render_edit_prediction_cursor_popover(
&self,
@@ -9080,6 +9107,7 @@ impl Editor {
cx: &mut Context<Editor>,
) -> Option<AnyElement> {
let provider = self.edit_prediction_provider.as_ref()?;
let provider_icon = Self::get_prediction_provider_icon_name(&self.edit_prediction_provider);
if provider.provider.needs_terms_acceptance(cx) {
return Some(
@@ -9106,7 +9134,7 @@ impl Editor {
h_flex()
.flex_1()
.gap_2()
.child(Icon::new(IconName::ZedPredict))
.child(Icon::new(provider_icon))
.child(Label::new("Accept Terms of Service"))
.child(div().w_full())
.child(
@@ -9122,12 +9150,8 @@ impl Editor {
let is_refreshing = provider.provider.is_refreshing(cx);
fn pending_completion_container() -> Div {
h_flex()
.h_full()
.flex_1()
.gap_2()
.child(Icon::new(IconName::ZedPredict))
fn pending_completion_container(icon: IconName) -> Div {
h_flex().h_full().flex_1().gap_2().child(Icon::new(icon))
}
let completion = match &self.active_edit_prediction {
@@ -9157,7 +9181,7 @@ impl Editor {
Icon::new(IconName::ZedPredictUp)
}
}
EditPrediction::Edit { .. } => Icon::new(IconName::ZedPredict),
EditPrediction::Edit { .. } => Icon::new(provider_icon),
}))
.child(
h_flex()
@@ -9224,15 +9248,15 @@ impl Editor {
cx,
)?,
None => {
pending_completion_container().child(Label::new("...").size(LabelSize::Small))
}
None => pending_completion_container(provider_icon)
.child(Label::new("...").size(LabelSize::Small)),
},
None => pending_completion_container().child(Label::new("No Prediction")),
None => pending_completion_container(provider_icon)
.child(Label::new("...").size(LabelSize::Small)),
};
let completion = if is_refreshing {
let completion = if is_refreshing || self.active_edit_prediction.is_none() {
completion
.with_animation(
"loading-completion",
@@ -9332,23 +9356,35 @@ impl Editor {
.child(Icon::new(arrow).color(Color::Muted).size(IconSize::Small))
}
let supports_jump = self
.edit_prediction_provider
.as_ref()
.map(|provider| provider.provider.supports_jump_to_edit())
.unwrap_or(true);
match &completion.completion {
EditPrediction::Move {
target, snapshot, ..
} => Some(
h_flex()
.px_2()
.gap_2()
.flex_1()
.child(
if target.text_anchor.to_point(&snapshot).row > cursor_point.row {
Icon::new(IconName::ZedPredictDown)
} else {
Icon::new(IconName::ZedPredictUp)
},
)
.child(Label::new("Jump to Edit")),
),
} => {
if !supports_jump {
return None;
}
Some(
h_flex()
.px_2()
.gap_2()
.flex_1()
.child(
if target.text_anchor.to_point(&snapshot).row > cursor_point.row {
Icon::new(IconName::ZedPredictDown)
} else {
Icon::new(IconName::ZedPredictUp)
},
)
.child(Label::new("Jump to Edit")),
)
}
EditPrediction::Edit {
edits,
@@ -9358,14 +9394,13 @@ impl Editor {
} => {
let first_edit_row = edits.first()?.0.start.text_anchor.to_point(&snapshot).row;
let (highlighted_edits, has_more_lines) = crate::edit_prediction_edit_text(
&snapshot,
&edits,
edit_preview.as_ref()?,
true,
cx,
)
.first_line_preview();
let (highlighted_edits, has_more_lines) =
if let Some(edit_preview) = edit_preview.as_ref() {
crate::edit_prediction_edit_text(&snapshot, &edits, edit_preview, true, cx)
.first_line_preview()
} else {
crate::edit_prediction_fallback_text(&edits, cx).first_line_preview()
};
let styled_text = gpui::StyledText::new(highlighted_edits.text)
.with_default_highlights(&style.text, highlighted_edits.highlights);
@@ -9376,11 +9411,13 @@ impl Editor {
.child(styled_text)
.when(has_more_lines, |parent| parent.child(""));
let left = if first_edit_row != cursor_point.row {
let left = if supports_jump && first_edit_row != cursor_point.row {
render_relative_row_jump("", cursor_point.row, first_edit_row)
.into_any_element()
} else {
Icon::new(IconName::ZedPredict).into_any_element()
let icon_name =
Editor::get_prediction_provider_icon_name(&self.edit_prediction_provider);
Icon::new(icon_name).into_any_element()
};
Some(
@@ -14711,6 +14748,81 @@ impl Editor {
}
}
pub fn unwrap_syntax_node(
&mut self,
_: &UnwrapSyntaxNode,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.hide_mouse_cursor(HideMouseCursorOrigin::MovementAction, cx);
let buffer = self.buffer.read(cx).snapshot(cx);
let old_selections: Box<[_]> = self.selections.all::<usize>(cx).into();
let edits = old_selections
.iter()
// only consider the first selection for now
.take(1)
.map(|selection| {
// Only requires two branches once if-let-chains stabilize (#53667)
let selection_range = if !selection.is_empty() {
selection.range()
} else if let Some((_, ancestor_range)) =
buffer.syntax_ancestor(selection.start..selection.end)
{
match ancestor_range {
MultiOrSingleBufferOffsetRange::Single(range) => range,
MultiOrSingleBufferOffsetRange::Multi(range) => range,
}
} else {
selection.range()
};
let mut new_range = selection_range.clone();
while let Some((_, ancestor_range)) = buffer.syntax_ancestor(new_range.clone()) {
new_range = match ancestor_range {
MultiOrSingleBufferOffsetRange::Single(range) => range,
MultiOrSingleBufferOffsetRange::Multi(range) => range,
};
if new_range.start < selection_range.start
|| new_range.end > selection_range.end
{
break;
}
}
(selection, selection_range, new_range)
})
.collect::<Vec<_>>();
self.transact(window, cx, |editor, window, cx| {
for (_, child, parent) in &edits {
let text = buffer.text_for_range(child.clone()).collect::<String>();
editor.replace_text_in_range(Some(parent.clone()), &text, window, cx);
}
editor.change_selections(
SelectionEffects::scroll(Autoscroll::fit()),
window,
cx,
|s| {
s.select(
edits
.iter()
.map(|(s, old, new)| Selection {
id: s.id,
start: new.start,
end: new.start + old.len(),
goal: SelectionGoal::None,
reversed: s.reversed,
})
.collect(),
);
},
);
});
}
fn refresh_runnables(&mut self, window: &mut Window, cx: &mut Context<Self>) -> Task<()> {
if !EditorSettings::get_global(cx).gutter.runnables {
self.clear_tasks();
@@ -23195,6 +23307,33 @@ fn edit_prediction_edit_text(
edit_preview.highlight_edits(current_snapshot, &edits, include_deletions, cx)
}
fn edit_prediction_fallback_text(edits: &[(Range<Anchor>, String)], cx: &App) -> HighlightedText {
// Fallback for providers that don't provide edit_preview (like Copilot/Supermaven)
// Just show the raw edit text with basic styling
let mut text = String::new();
let mut highlights = Vec::new();
let insertion_highlight_style = HighlightStyle {
color: Some(cx.theme().colors().text),
..Default::default()
};
for (_, edit_text) in edits {
let start_offset = text.len();
text.push_str(edit_text);
let end_offset = text.len();
if start_offset < end_offset {
highlights.push((start_offset..end_offset, insertion_highlight_style));
}
}
HighlightedText {
text: text.into(),
highlights,
}
}
pub fn diagnostic_style(severity: lsp::DiagnosticSeverity, colors: &StatusColors) -> Hsla {
match severity {
lsp::DiagnosticSeverity::ERROR => colors.error,

View File

@@ -7969,6 +7969,38 @@ async fn test_select_larger_smaller_syntax_node_for_string(cx: &mut TestAppConte
});
}
#[gpui::test]
async fn test_unwrap_syntax_node(cx: &mut gpui::TestAppContext) {
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
let language = Arc::new(Language::new(
LanguageConfig::default(),
Some(tree_sitter_rust::LANGUAGE.into()),
));
cx.update_buffer(|buffer, cx| {
buffer.set_language(Some(language), cx);
});
cx.set_state(
&r#"
use mod1::mod2::{«mod3ˇ», mod4};
"#
.unindent(),
);
cx.update_editor(|editor, window, cx| {
editor.unwrap_syntax_node(&UnwrapSyntaxNode, window, cx);
});
cx.assert_editor_state(
&r#"
use mod1::mod2::«mod3ˇ»;
"#
.unindent(),
);
}
#[gpui::test]
async fn test_fold_function_bodies(cx: &mut TestAppContext) {
init_test(cx, |_| {});

View File

@@ -86,8 +86,6 @@ use util::post_inc;
use util::{RangeExt, ResultExt, debug_panic};
use workspace::{CollaboratorId, Workspace, item::Item, notifications::NotifyTaskExt};
const INLINE_BLAME_PADDING_EM_WIDTHS: f32 = 7.;
/// Determines what kinds of highlights should be applied to a lines background.
#[derive(Clone, Copy, Default)]
struct LineHighlightSpec {
@@ -357,6 +355,7 @@ impl EditorElement {
register_action(editor, window, Editor::toggle_comments);
register_action(editor, window, Editor::select_larger_syntax_node);
register_action(editor, window, Editor::select_smaller_syntax_node);
register_action(editor, window, Editor::unwrap_syntax_node);
register_action(editor, window, Editor::select_enclosing_symbol);
register_action(editor, window, Editor::move_to_enclosing_bracket);
register_action(editor, window, Editor::undo_selection);
@@ -2427,10 +2426,13 @@ impl EditorElement {
let editor = self.editor.read(cx);
let blame = editor.blame.clone()?;
let padding = {
const INLINE_BLAME_PADDING_EM_WIDTHS: f32 = 6.;
const INLINE_ACCEPT_SUGGESTION_EM_WIDTHS: f32 = 14.;
let mut padding = INLINE_BLAME_PADDING_EM_WIDTHS;
let mut padding = ProjectSettings::get_global(cx)
.git
.inline_blame
.unwrap_or_default()
.padding as f32;
if let Some(edit_prediction) = editor.active_edit_prediction.as_ref() {
match &edit_prediction.completion {
@@ -2468,7 +2470,7 @@ impl EditorElement {
let min_column_in_pixels = ProjectSettings::get_global(cx)
.git
.inline_blame
.and_then(|settings| settings.min_column)
.map(|settings| settings.min_column)
.map(|col| self.column_pixels(col as usize, window))
.unwrap_or(px(0.));
let min_start = content_origin.x - scroll_pixel_position.x + min_column_in_pixels;
@@ -3681,6 +3683,7 @@ impl EditorElement {
.id("path header block")
.size_full()
.justify_between()
.overflow_hidden()
.child(
h_flex()
.gap_2()
@@ -8363,7 +8366,13 @@ impl Element for EditorElement {
})
.flatten()?;
let mut element = render_inline_blame_entry(blame_entry, &style, cx)?;
let inline_blame_padding = INLINE_BLAME_PADDING_EM_WIDTHS * em_advance;
let inline_blame_padding = ProjectSettings::get_global(cx)
.git
.inline_blame
.unwrap_or_default()
.padding
as f32
* em_advance;
Some(
element
.layout_as_root(AvailableSpace::min_size(), window, cx)

View File

@@ -3546,7 +3546,7 @@ pub mod tests {
let excerpt_hints = excerpt_hints.read();
for id in &excerpt_hints.ordered_hints {
let hint = &excerpt_hints.hints_by_id[id];
let mut label = hint.text();
let mut label = hint.text().to_string();
if hint.padding_left {
label.insert(0, ' ');
}

View File

@@ -1,8 +1,8 @@
use crate::{
Copy, CopyAndTrim, CopyPermalinkToLine, Cut, DisplayPoint, DisplaySnapshot, Editor,
EvaluateSelectedText, FindAllReferences, GoToDeclaration, GoToDefinition, GoToImplementation,
GoToTypeDefinition, Paste, Rename, RevealInFileManager, SelectMode, SelectionEffects,
SelectionExt, ToDisplayPoint, ToggleCodeActions,
GoToTypeDefinition, Paste, Rename, RevealInFileManager, RunToCursor, SelectMode,
SelectionEffects, SelectionExt, ToDisplayPoint, ToggleCodeActions,
actions::{Format, FormatSelections},
selections_collection::SelectionsCollection,
};
@@ -200,15 +200,21 @@ pub fn deploy_context_menu(
});
let evaluate_selection = window.is_action_available(&EvaluateSelectedText, cx);
let run_to_cursor = window.is_action_available(&RunToCursor, cx);
ui::ContextMenu::build(window, cx, |menu, _window, _cx| {
let builder = menu
.on_blur_subscription(Subscription::new(|| {}))
.when(evaluate_selection && has_selections, |builder| {
builder
.action("Evaluate Selection", Box::new(EvaluateSelectedText))
.separator()
.when(run_to_cursor, |builder| {
builder.action("Run to Cursor", Box::new(RunToCursor))
})
.when(evaluate_selection && has_selections, |builder| {
builder.action("Evaluate Selection", Box::new(EvaluateSelectedText))
})
.when(
run_to_cursor || (evaluate_selection && has_selections),
|builder| builder.separator(),
)
.action("Go to Definition", Box::new(GoToDefinition))
.action("Go to Declaration", Box::new(GoToDeclaration))
.action("Go to Type Definition", Box::new(GoToTypeDefinition))

View File

@@ -191,7 +191,7 @@ impl Editor {
if let Some(language) = language {
for signature in &mut signature_help.signatures {
let text = Rope::from(signature.label.to_string());
let text = Rope::from(signature.label.as_ref());
let highlights = language
.highlight_text(&text, 0..signature.label.len())
.into_iter()

View File

@@ -846,14 +846,12 @@ impl GitRepository for RealGitRepository {
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()?;
child
.stdin
.take()
.unwrap()
.write_all(content.as_bytes())
.await?;
let mut stdin = child.stdin.take().unwrap();
stdin.write_all(content.as_bytes()).await?;
stdin.flush().await?;
drop(stdin);
let output = child.output().await?.stdout;
let sha = String::from_utf8(output)?;
let sha = str::from_utf8(&output)?.trim();
log::debug!("indexing SHA: {sha}, path {path:?}");
@@ -871,6 +869,7 @@ impl GitRepository for RealGitRepository {
String::from_utf8_lossy(&output.stderr)
);
} else {
log::debug!("removing path {path:?} from the index");
let output = new_smol_command(&git_binary_path)
.current_dir(&working_directory)
.envs(env.iter())
@@ -921,6 +920,7 @@ impl GitRepository for RealGitRepository {
for rev in &revs {
write!(&mut stdin, "{rev}\n")?;
}
stdin.flush()?;
drop(stdin);
let output = process.wait_with_output()?;

View File

@@ -1,12 +1,22 @@
use std::str::FromStr;
use std::sync::LazyLock;
use regex::Regex;
use url::Url;
use git::{
BuildCommitPermalinkParams, BuildPermalinkParams, GitHostingProvider, ParsedGitRemote,
RemoteUrl,
PullRequest, RemoteUrl,
};
fn pull_request_regex() -> &'static Regex {
static PULL_REQUEST_REGEX: LazyLock<Regex> = LazyLock::new(|| {
// This matches Bitbucket PR reference pattern: (pull request #xxx)
Regex::new(r"\(pull request #(\d+)\)").unwrap()
});
&PULL_REQUEST_REGEX
}
pub struct Bitbucket {
name: String,
base_url: Url,
@@ -96,6 +106,22 @@ impl GitHostingProvider for Bitbucket {
);
permalink
}
fn extract_pull_request(&self, remote: &ParsedGitRemote, message: &str) -> Option<PullRequest> {
// Check first line of commit message for PR references
let first_line = message.lines().next()?;
// Try to match against our PR patterns
let capture = pull_request_regex().captures(first_line)?;
let number = capture.get(1)?.as_str().parse::<u32>().ok()?;
// Construct the PR URL in Bitbucket format
let mut url = self.base_url();
let path = format!("/{}/{}/pull-requests/{}", remote.owner, remote.repo, number);
url.set_path(&path);
Some(PullRequest { number, url })
}
}
#[cfg(test)]
@@ -203,4 +229,34 @@ mod tests {
"https://bitbucket.org/zed-industries/zed/src/f00b4r/main.rs#lines-24:48";
assert_eq!(permalink.to_string(), expected_url.to_string())
}
#[test]
fn test_bitbucket_pull_requests() {
use indoc::indoc;
let remote = ParsedGitRemote {
owner: "zed-industries".into(),
repo: "zed".into(),
};
let bitbucket = Bitbucket::public_instance();
// Test message without PR reference
let message = "This does not contain a pull request";
assert!(bitbucket.extract_pull_request(&remote, message).is_none());
// Pull request number at end of first line
let message = indoc! {r#"
Merged in feature-branch (pull request #123)
Some detailed description of the changes.
"#};
let pr = bitbucket.extract_pull_request(&remote, message).unwrap();
assert_eq!(pr.number, 123);
assert_eq!(
pr.url.as_str(),
"https://bitbucket.org/zed-industries/zed/pull-requests/123"
);
}
}

View File

@@ -180,6 +180,7 @@ impl Focusable for BranchList {
impl Render for BranchList {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
.key_context("GitBranchSelector")
.w(self.width)
.on_modifiers_changed(cx.listener(Self::handle_modifiers_changed))
.child(self.picker.clone())

View File

@@ -109,7 +109,10 @@ impl Focusable for RepositorySelector {
impl Render for RepositorySelector {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
div().w(self.width).child(self.picker.clone())
div()
.key_context("GitRepositorySelector")
.w(self.width)
.child(self.picker.clone())
}
}

View File

@@ -308,10 +308,14 @@ impl Settings for LineIndicatorFormat {
type FileContent = Option<LineIndicatorFormatContent>;
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> anyhow::Result<Self> {
let format = [sources.release_channel, sources.user]
.into_iter()
.find_map(|value| value.copied().flatten())
.unwrap_or(sources.default.ok_or_else(Self::missing_default)?);
let format = [
sources.release_channel,
sources.operating_system,
sources.user,
]
.into_iter()
.find_map(|value| value.copied().flatten())
.unwrap_or(sources.default.ok_or_else(Self::missing_default)?);
Ok(format.0)
}

View File

@@ -121,7 +121,7 @@ smallvec.workspace = true
smol.workspace = true
strum.workspace = true
sum_tree.workspace = true
taffy = "=0.8.3"
taffy = "=0.9.0"
thiserror.workspace = true
util.workspace = true
uuid.workspace = true

View File

@@ -79,6 +79,14 @@ impl<T> Task<T> {
Task(TaskState::Spawned(task)) => task.detach(),
}
}
/// Whether the task has run to completion.
pub fn is_finished(&self) -> bool {
match self {
Task(TaskState::Ready(_)) => true,
Task(TaskState::Spawned(task)) => task.is_finished(),
}
}
}
impl<E, T> Task<Result<T, E>>

View File

@@ -150,7 +150,7 @@ pub struct MouseClickEvent {
}
/// A click event that was generated by a keyboard button being pressed and released.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Default)]
pub struct KeyboardClickEvent {
/// The keyboard button that was pressed to trigger the click.
pub button: KeyboardButton,
@@ -168,6 +168,12 @@ pub enum ClickEvent {
Keyboard(KeyboardClickEvent),
}
impl Default for ClickEvent {
fn default() -> Self {
ClickEvent::Keyboard(KeyboardClickEvent::default())
}
}
impl ClickEvent {
/// Returns the modifiers that were held during the click event
///
@@ -256,9 +262,10 @@ impl ClickEvent {
}
/// An enum representing the keyboard button that was pressed for a click event.
#[derive(Hash, PartialEq, Eq, Copy, Clone, Debug)]
#[derive(Hash, PartialEq, Eq, Copy, Clone, Debug, Default)]
pub enum KeyboardButton {
/// Enter key was clicked
#[default]
Enter,
/// Space key was clicked
Space,

View File

@@ -4248,6 +4248,25 @@ impl Window {
.on_action(action_type, Rc::new(listener));
}
/// Register an action listener on the window for the next frame if the condition is true.
/// The type of action is determined by the first parameter of the given listener.
/// When the next frame is rendered the listener will be cleared.
///
/// This is a fairly low-level method, so prefer using action handlers on elements unless you have
/// a specific need to register a global listener.
pub fn on_action_when(
&mut self,
condition: bool,
action_type: TypeId,
listener: impl Fn(&dyn Any, DispatchPhase, &mut Window, &mut App) + 'static,
) {
if condition {
self.next_frame
.dispatch_tree
.on_action(action_type, Rc::new(listener));
}
}
/// Read information about the GPU backing this window.
/// Currently returns None on Mac and Windows.
pub fn gpu_specs(&self) -> Option<GpuSpecs> {

View File

@@ -261,7 +261,6 @@ pub enum IconName {
TodoComplete,
TodoPending,
TodoProgress,
ToolBulb,
ToolCopy,
ToolDeleteFile,
ToolDiagnostics,
@@ -273,6 +272,7 @@ pub enum IconName {
ToolRegex,
ToolSearch,
ToolTerminal,
ToolThink,
ToolWeb,
Trash,
Triangle,

View File

@@ -941,8 +941,6 @@ impl LanguageModel for CloudLanguageModel {
request,
model.id(),
model.supports_parallel_tool_calls(),
model.supports_prompt_cache_key(),
None,
None,
);
let llm_api_token = self.llm_api_token.clone();

View File

@@ -14,7 +14,7 @@ use language_model::{
RateLimiter, Role, StopReason, TokenUsage,
};
use menu;
use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent, stream_completion};
use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
@@ -45,7 +45,6 @@ pub struct AvailableModel {
pub max_tokens: u64,
pub max_output_tokens: Option<u64>,
pub max_completion_tokens: Option<u64>,
pub reasoning_effort: Option<ReasoningEffort>,
}
pub struct OpenAiLanguageModelProvider {
@@ -214,7 +213,6 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens,
max_completion_tokens: model.max_completion_tokens,
reasoning_effort: model.reasoning_effort.clone(),
},
);
}
@@ -303,25 +301,7 @@ impl LanguageModel for OpenAiLanguageModel {
}
fn supports_images(&self) -> bool {
use open_ai::Model;
match &self.model {
Model::FourOmni
| Model::FourOmniMini
| Model::FourPointOne
| Model::FourPointOneMini
| Model::FourPointOneNano
| Model::Five
| Model::FiveMini
| Model::FiveNano
| Model::O1
| Model::O3
| Model::O4Mini => true,
Model::ThreePointFiveTurbo
| Model::Four
| Model::FourTurbo
| Model::O3Mini
| Model::Custom { .. } => false,
}
false
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
@@ -370,9 +350,7 @@ impl LanguageModel for OpenAiLanguageModel {
request,
self.model.id(),
self.model.supports_parallel_tool_calls(),
self.model.supports_prompt_cache_key(),
self.max_output_tokens(),
self.model.reasoning_effort(),
);
let completions = self.stream_completion(request, cx);
async move {
@@ -387,9 +365,7 @@ pub fn into_open_ai(
request: LanguageModelRequest,
model_id: &str,
supports_parallel_tool_calls: bool,
supports_prompt_cache_key: bool,
max_output_tokens: Option<u64>,
reasoning_effort: Option<ReasoningEffort>,
) -> open_ai::Request {
let stream = !model_id.starts_with("o1-");
@@ -479,11 +455,6 @@ pub fn into_open_ai(
} else {
None
},
prompt_cache_key: if supports_prompt_cache_key {
request.thread_id
} else {
None
},
tools: request
.tools
.into_iter()
@@ -500,7 +471,6 @@ pub fn into_open_ai(
LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
LanguageModelToolChoice::None => open_ai::ToolChoice::None,
}),
reasoning_effort,
}
}

View File

@@ -355,16 +355,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel {
LanguageModelCompletionError,
>,
> {
let supports_parallel_tool_call = true;
let supports_prompt_cache_key = false;
let request = into_open_ai(
request,
&self.model.name,
supports_parallel_tool_call,
supports_prompt_cache_key,
self.max_output_tokens(),
None,
);
let request = into_open_ai(request, &self.model.name, true, self.max_output_tokens());
let completions = self.stream_completion(request, cx);
async move {
let mapper = OpenAiEventMapper::new();

View File

@@ -355,9 +355,7 @@ impl LanguageModel for VercelLanguageModel {
request,
self.model.id(),
self.model.supports_parallel_tool_calls(),
self.model.supports_prompt_cache_key(),
self.max_output_tokens(),
None,
);
let completions = self.stream_completion(request, cx);
async move {

View File

@@ -359,9 +359,7 @@ impl LanguageModel for XAiLanguageModel {
request,
self.model.id(),
self.model.supports_parallel_tool_calls(),
self.model.supports_prompt_cache_key(),
self.max_output_tokens(),
None,
);
let completions = self.stream_completion(request, cx);
async move {

View File

@@ -86,7 +86,10 @@ impl LanguageSelector {
impl Render for LanguageSelector {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
v_flex().w(rems(34.)).child(self.picker.clone())
v_flex()
.key_context("LanguageSelector")
.w(rems(34.))
.child(self.picker.clone())
}
}

View File

@@ -75,9 +75,6 @@ impl super::LspAdapter for CLspAdapter {
&*version.downcast::<GitHubLspBinaryVersion>().unwrap();
let version_dir = container_dir.join(format!("clangd_{name}"));
let binary_path = version_dir.join("bin/clangd");
let expected_digest = digest
.as_ref()
.and_then(|digest| digest.strip_prefix("sha256:"));
let binary = LanguageServerBinary {
path: binary_path.clone(),
@@ -102,9 +99,7 @@ impl super::LspAdapter for CLspAdapter {
log::warn!("Unable to run {binary_path:?} asset, redownloading: {err}",)
})
};
if let (Some(actual_digest), Some(expected_digest)) =
(&metadata.digest, expected_digest)
{
if let (Some(actual_digest), Some(expected_digest)) = (&metadata.digest, digest) {
if actual_digest == expected_digest {
if validity_check().await.is_ok() {
return Ok(binary);

Some files were not shown because too many files have changed in this diff Show More